Drill down, dive in. Selection appears to have stopped working for me, so it's all typing the text over, happens sometimes. 04:def _query_chunk_attention(query, key, value, precision, key_chunk_size=4096): As earlier, the query here is a chunk of queries, a subset of all of them. 05: """Multi-head dot product attention with a limited number of queries.""" 06: num_kv, num_heads, k_features = key.shape 07: v_features = value.shape[-1] # i commented here on the meaning of the dimensions and then my body and browser spazzed and marked the thread as spam. when i recovered the thread the comments were gone, sending to preserve. 12:@functools.partial(jax.checkpoint, prevent_cse=False) This decorates the function to be passed through jax.checkpoint for further transformation after definition. It's pretty likely that this informs jax's backpropagation code to abandon its intermediate data, not caching it in memory, but rather marking it to be recalculated when needed. This is mentioned elsewhere in the paper.