[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 02:26:04 PST 2022


Here's the explanation of the first chunk of their example code.

In the out loop (line 56f), we split the queries in to chunks of
constant size, resulting in a linear number of iterations. In each
iteration of the outer loop, we call _query_chunk_attention, which
itself processes the keys and values in chunks (lines 23 to 33).

42:def attention(query, key, value, precision=jax.lax.Precision.HIGHEST,
43:        query_chunk_size=1024):
44:  """Memory-efficient multi-head dot product attention."""
45:  num_q, num_heads, q_features = query.shape
46:
47:  def chunk_scanner(chunk_idx, _):
48:    query_chunk = lax.dynamic_slice(
49:      query, (chunk_idx, 0, 0),
50:      slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features))
51:    return (chunk_idx + query_chunk_size,
52:      _query_chunk_attention(query_chunk, key, value, precision=precision))
53:
54:  _, res = jax.lax.scan(
55:    chunk_scanner, init=0, xs=None, length=math.ceil(num_q /
query_chunk_size))
56:  return res.reshape(num_q, num_heads, value.shape[-1])

I think they miswrote line the reference to 56, unsure.  jax.lax.scan
is like folding: it calls chunk_scanner repeatedly, passing its first
output back in, and stacking its second output into the result.  This
is from help(jax.lax.scan):
>  def scan(f, init, xs, length=None):
>        if xs is None:
>          xs = [None] * length
>        carry = init
>        ys = []
>        for x in xs:
>          carry, y = f(carry, x)
>          ys.append(y)
>        return carry, np.stack(ys)

So basically line 54 calculates num_q / query_chunk_size, calls
chunk_scanner that many times, and stacks the results into one tensor.
Stacking combines many tensors into a single one with with an
additional dimension, that contains each adjacent to the next.  A
tensor is an n-dimensional matrix.

Looking at chunk_scanner, by inspection I can see that chunk_idx is
the offset into the queries to start splitting data out.
jax.lax.dynamic_slice appears to be a simple slice operation: the
first indices are the lower-indexed corner of the slice, and the
second set of indices are the size of the slice.

I'm noting the dimensions of the query are num_q, num_heads, and
num_features.  So we can infer that the data has already been split
into "heads" prior to this call.  It also looks like the individual
query data is 1-dimensional here.

The slice operation can then be interpreted as simply chunking the
queries into query_chunk_size, whilst keeping the keys and values
whole unchunked.  Each query is then passed to _query_chunk_attention.

The use of jax.lax and dynamic_slice rather than more normative
operations may likely be so that each chunk is processed in parallel
with the others during jax's precompilation phase, although I don't
really know.


More information about the cypherpunks mailing list