[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 03:23:40 PST 2022


22:    key_chunk = jax.lax.dynamic_slice(
23:      key, (chunk_idx, 0, 0),
24:      slice_sizes=(key_chunk_size, num_heads, k_features))

Note also on line 31 that a step size is passed into jnp.arange(), so
chunk_idx is offsets separated by key_chunk_size.  Lines 22-24 break
key into key_chunk_size'd chunks along the first dimension.  The
second and third dimensions (heads and features) remain full.

25:    value_chunk = jax.lax.dynamic_slice(
26:      value, (chunk_idx, 0, 0),
27:      slice_sizes=(key_chunk_size, num_heads, v_features))

Lines 25-27 do the same thing for value_chunk, value, and v_features,
breaking the value tensor into key_chunk_size'd chunks along the first
dimension.

28:    return summarize_chunk(query, key_chunk, value_chunk)

summarize_chunk must do the actual attention calculation for the chunk.

I think this code would be clearer if query were called query_chunk,
which is what is passed in here.


More information about the cypherpunks mailing list