30: chunk_values, chunk_weights, chunk_max = jax.lax.map( 31: chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) chunk_values is exp_values, which I think was the local values dotted with the exponentiated attention weights minus their local max. It looks like chunk_weights is those exponentiated weights. And it looks like chunk_max are those local maximum values. 32: 33: global_max = jnp.max(chunk_max, axis=0, keepdims=True) [struggling some to continue, it looks like these tensors get recombined into one vector by jax.lax.map. chunk_max has its global max taken along axis 0, which is likely the axis that jax.lax.map adds when recombining them. So maybe this would extend the existing maximum values, to find the maximums among the split keys and values.] 34: max_diffs = jnp.exp(chunk_max - global_max) This likely calculates the scale needed for each chunk, that would change it to be relative to the global max, rather than its local max. A scale because it's inside an exponent. 35: chunk_values *= jnp.expand_dims(max_diffs, axis=-1) jnp.expand_dims appears to simply wrap every value in max_diffs in a new one-sized dimension, maybe for the multiplication to broadcast across the feature dimension of the values. So line 35 multiplies all the output values, by the calculated scale. This looks like it turns exp(s_i - m_i) into exp(s_i - m_i + m_i - global_max) i.e. exp(s_i - global_max). 36: chunk_weights *= max_diffs This likely performs the same operation. The chunk_weights are not yet dotted with the value vectors. 37: 38: all_values = chunk_values.sum(axis=0) Here we finally form the summation of the combined values across the key and value chunks. They were already dotted across the values dimension, so this possibly combine the values of adjacent chunks as if there were one large dot-product taken. 39: all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) This appears to perform the same for the weights, which I think were dotted with the original value vectors to make chunk_values. Additionally a dimension is added to wrap each value. 40: return all_values / all_weights Here is where the softmax operation must finally complete. all_values and all_weights have been arithmetically shifted inside exp() operations, to be relative to their maxima. When the division is performed, the shift is analogous to a scaling value applied to both the numerator and denominator, and the result is the same, but with much higher precision due to less extreme values. I think! Whoohoo! Let's quickly glance at where that value comes out after _query_chunk_attention returns it. On line 52, it's returned through chunk_scanner, and then must be stacked with other query chunks back on line 55: 54: _, res = jax.lax.scan( 55: chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) 56: return res.reshape(num_q, num_heads, value.shape[-1]) The final call to reshape likely squashes that stack into a single [queries, heads, features] output tensor. Maybe rather than connecting more verbiage from the paper it would make sense to try to run this now, and compare an implementation with and without chunking, to verify it's correct.