13: attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision)
attn_weights is a [queries, heads, keys] tensor consisting of the dot product between the query and key features.
14: max_score = jnp.max(attn_weights, axis = -1, keepdims = True) 15: max_score = jax.lax.stop_gradient(max_score)
-1 is the last axis. This appears to make max_score be a tensor of shape [key, head, 1] containing the maximum pre-softmax attention weight for each query, over the entire chunk of keys.
I'm thinking the axis names here are wrong. max_score is a tensor of shape [queries, heads, 1], but it does contain the maximum pre-softmax attention weight for each query, over that query's chunk of keys.
This might be the calculation m_i = max(m, s_i), with m initialised to -inf?
16: exp_weights = jnp.exp(attn_weights - max_score)
Here we can see s_i - m_i: exp_weights is the exponentiated attention weights reduced by their per-query maximum. max_score's one-sized dimension is 'broadcast' to all the other keys in attn_weights. exp_weights is then a tensor of the same shape as attn_weights: [queries, heads, keys].