14: max_score = jnp.max(attn_weights, axis = -1, keepdims = True) -1 is the last axis. This appears to make max_score be a tensor of shape [key, head, 1] containing the maximum pre-softmax attention weight for each query, over the entire chunk of keys. This might be the calculation m_i = max(m, s_i), with m initialised to -inf? I think it says more about this in the description of the lines posted earlier. 15: max_score = jax.lax.stop_gradient(max_score) I think this informs jax that calculations on on the returned max_score do not impact backpropagation gradients and don't need their gradients calculated and stored. 16: exp_weights = jnp.exp(attn_weights - max_score) This must be exp(s_i - m_i) in the initialisation of v = v * exp(m - m_i) + v_i * exp(s_i - m_i). I typed most of line 17 and then my finger hit 'delete' by accident, so I'm sending to preserve.