The below is a draft, but I see I propagated some dimensions wrongly and plan to rewrite it, and maybe some previous posts, to help with clarity. 17: exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights, precision=precision) Another einsum, multiplying the values by the weights. [managed to continue after writing this :) -> It might take me a bit to engage this einsum. Hope to return quickly and just combine the available parts. For example, I could copy a definition of einsum in, or compare the line to the equation from the paper, or look at the paper's description of this chunk of code.] This must be the rest of v = v * exp(m - m_i) + v_i * exp(s_i - m_i) , specifically the "v_i *" part, since exp_weights is exp(s_i - m_i) already. Since m is initialised to -inf, the left term of the sum is zero and disregarded. in jnp.einsum('vhf,qhv->qhf', value, exp_weights), we can see how this multiplication is performed: exp_weights' dimensions are described as query, heads, value, but its dimensions will have propagated from max_score, in the last email: [queries, heads, 1]. It looks like I unfortunately miswrote 'query' as 'key' in the dimension in that email, as can be verified from the einsum that attn_weights came from. So exp_weights is [query, heads, 1], the exponentiation of the difference of a max among keys, and value is [value, heads, features]. I'm guessing the einsum says to broadcast each exp_weight to all the values, not certain. If I need to discern it more I can run test data or write the parts of the definition down and combine them textually. exp_values becomes a [query, head, feature] tensor containing the multiplication of the values with exponentiated attn weights. 18: return (exp_values, exp_weights.sum(axis=-1), 19: max_score.reshape((query.shape[0], num_heads))) The output is exp_values and exp_weights. exp_weights is the exponentiation of the scores minus their max score, for each query.