Here's the explanation of the first chunk of their example code. In the out loop (line 56f), we split the queries in to chunks of constant size, resulting in a linear number of iterations. In each iteration of the outer loop, we call _query_chunk_attention, which itself processes the keys and values in chunks (lines 23 to 33). 42:def attention(query, key, value, precision=jax.lax.Precision.HIGHEST, 43: query_chunk_size=1024): 44: """Memory-efficient multi-head dot product attention.""" 45: num_q, num_heads, q_features = query.shape 46: 47: def chunk_scanner(chunk_idx, _): 48: query_chunk = lax.dynamic_slice( 49: query, (chunk_idx, 0, 0), 50: slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features)) 51: return (chunk_idx + query_chunk_size, 52: _query_chunk_attention(query_chunk, key, value, precision=precision)) 53: 54: _, res = jax.lax.scan( 55: chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) 56: return res.reshape(num_q, num_heads, value.shape[-1]) I think they miswrote line the reference to 56, unsure. jax.lax.scan is like folding: it calls chunk_scanner repeatedly, passing its first output back in, and stacking its second output into the result. This is from help(jax.lax.scan):
def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
So basically line 54 calculates num_q / query_chunk_size, calls chunk_scanner that many times, and stacks the results into one tensor. Stacking combines many tensors into a single one with with an additional dimension, that contains each adjacent to the next. A tensor is an n-dimensional matrix. Looking at chunk_scanner, by inspection I can see that chunk_idx is the offset into the queries to start splitting data out. jax.lax.dynamic_slice appears to be a simple slice operation: the first indices are the lower-indexed corner of the slice, and the second set of indices are the size of the slice. I'm noting the dimensions of the query are num_q, num_heads, and num_features. So we can infer that the data has already been split into "heads" prior to this call. It also looks like the individual query data is 1-dimensional here. The slice operation can then be interpreted as simply chunking the queries into query_chunk_size, whilst keeping the keys and values whole unchunked. Each query is then passed to _query_chunk_attention. The use of jax.lax and dynamic_slice rather than more normative operations may likely be so that each chunk is processed in parallel with the others during jax's precompilation phase, although I don't really know.