17: exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights, precision=precision) I think the values are still [values, heads, features], whereas I think the exp_weights are [queries, heads, keys]. I think the values are multiplied by the per-key weights, producing a vector of weighted features for each query and each head. A dot product aligning the key and value dimensions. This must be "v_i * exp(s_i - m_i)" in "v = v * exp(m - m_i) + v_i * exp(s_i - m_i)" from section 3: numerical stability in the paper. 18: return (exp_values, exp_weights.sum(axis=-1), 19: max_score.reshape((query.shape[0], num_heads)) A tuple of three values is returned from summarize_chunk. - it looks like the first element, exp_values, is the attention vector combining the values with the keys and query, prior to the mean at the denominator of the softmax. "v*" in the numerical stability portion of the paper. - the second component, exp_weights.sum(axis=-1), appears to be the sum of the exponent along the key dimension, one sum for each query and each head, of shape [queries, heads]. this is likely for recombining for the denominator of the softmax. - the third component calls reshape on the max_score. I think max_score was the maximum taken across the keys, of the unexponented attention key weights i.e. the dot products of the query features and the key features. It looks like the reshape just removes the final dummy dimension that was for broadcasting subtraction. Now let's go back to where this function was called, and look at the use of these return values.