And here's their explanation of the inner loops: In each iteration of the outer loop, we call _query_chunk_attention, which itself processes the keys and values in chunks (lines 23 to 33). The chunks are processed sequentially and each chunk is summarized independently (lines 14 to 21). Assuming a chunk size of sqrt(n) for the keys and values, we hence obtain sqrt(n) summaries, giving rise to the O(sqrt(n)) memory complexity. After the summaries are computed, they need to be rescaled (lines 35 to 38) along the lines of Section 3, before we return the values divided by the weights (line 42). The result of each iteration of the outer loop is directly written to the output tensor res (line 56), so that no additional memory is consumed across iterations. (A multi-stage summarization approach could achieve O(log n) but would complicate the implementation.) [... In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact (on TPUv2), while still providing significant memory savings.] 01:import functools, jax, math 02:from jax import numpy as jnp 03: 04:def _query_chunk_attention(query, key, value, precision, key_chunk_size=4096): 05: """Multi-head dot product attention with a limited number of queries.""" 06: num_kv, num_heads, k_features = key.shape 07: v_features = value.shape[-1] 08: key_chunk_size = min(key_chunk_size, num_kv) 09: query = query / jnp.sqrt(k_features) 10: 11: @functools.partial(jax.checkpoint, prevent_cse=False) 12: def summarize_chunk(query, key, value): 13: attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision) 14: max_score = jnp.max(attn_weights, axis=-1, keepdims=True) 15: max_score = jax.lax.stop_gradient(max_score) 16: exp_weights = jnp.exp(attn_weights - max_score) 17: exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights, precision=precision) 18: return (exp_values, exp_weights.sum(axis=-1), 19: max_score.reshape((query.shape[0], num_heads))) 20: 21: def chunk_scanner(chunk_idx): 22: key_chunk = jax.lax.dynamic_slice( 23: key, (chunk_idx, 0, 0), 24: slice_sizes=(key_chunk_size, num_heads, k_features)) 25: value_chunk = jax.lax.dynamic_slice( 26: value, (chunk_idx, 0, 0), 27: slice_sizes=(key_chunk_size, num_heads, v_features)) 28: return summarize_chunk(query, key_chunk, value_chunk) 29: 30: chunk_values, chunk_weights, chunk_max = jax.lax.map( 31: chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) 32: 33: global_max = jnp.max(chunk_max, axis=0, keepdims=True) 34: max_diffs = jnp.exp(chunk_max - global_max) 35: chunk_values *= jnp.expand_dims(max_diffs, axis=-1) 36: chunk_weights *= max_diffs 37: 38: all_values = chunk_values.sum(axis=0) 39: all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) 40: return all_values / all_weights 41: It looks like the line numbers got offset a little bit in printing.