22: key_chunk = jax.lax.dynamic_slice( 23: key, (chunk_idx, 0, 0), 24: slice_sizes=(key_chunk_size, num_heads, k_features)) Note also on line 31 that a step size is passed into jnp.arange(), so chunk_idx is offsets separated by key_chunk_size. Lines 22-24 break key into key_chunk_size'd chunks along the first dimension. The second and third dimensions (heads and features) remain full. 25: value_chunk = jax.lax.dynamic_slice( 26: value, (chunk_idx, 0, 0), 27: slice_sizes=(key_chunk_size, num_heads, v_features)) Lines 25-27 do the same thing for value_chunk, value, and v_features, breaking the value tensor into key_chunk_size'd chunks along the first dimension. 28: return summarize_chunk(query, key_chunk, value_chunk) summarize_chunk must do the actual attention calculation for the chunk. I think this code would be clearer if query were called query_chunk, which is what is passed in here.