[ot][spam][crazy][data] transformer model 'attention' improvement
k
gmkarl at gmail.com
Tue Jan 25 03:23:40 PST 2022
22: key_chunk = jax.lax.dynamic_slice(
23: key, (chunk_idx, 0, 0),
24: slice_sizes=(key_chunk_size, num_heads, k_features))
Note also on line 31 that a step size is passed into jnp.arange(), so
chunk_idx is offsets separated by key_chunk_size. Lines 22-24 break
key into key_chunk_size'd chunks along the first dimension. The
second and third dimensions (heads and features) remain full.
25: value_chunk = jax.lax.dynamic_slice(
26: value, (chunk_idx, 0, 0),
27: slice_sizes=(key_chunk_size, num_heads, v_features))
Lines 25-27 do the same thing for value_chunk, value, and v_features,
breaking the value tensor into key_chunk_size'd chunks along the first
dimension.
28: return summarize_chunk(query, key_chunk, value_chunk)
summarize_chunk must do the actual attention calculation for the chunk.
I think this code would be clearer if query were called query_chunk,
which is what is passed in here.
More information about the cypherpunks
mailing list