[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 02:53:00 PST 2022


Drill down, dive in.  Selection appears to have stopped working for
me, so it's all typing the text over, happens sometimes.

04:def _query_chunk_attention(query, key, value, precision,
key_chunk_size=4096):

As earlier, the query here is a chunk of queries, a subset of all of them.

05:  """Multi-head dot product attention with a limited number of queries."""
06: num_kv, num_heads, k_features = key.shape
07: v_features = value.shape[-1]

# i commented here on the meaning of the dimensions and then my body
and browser spazzed and marked the thread as spam.  when i recovered
the thread the comments were gone, sending to preserve.

12:@functools.partial(jax.checkpoint, prevent_cse=False)

This decorates the function to be passed through jax.checkpoint for
further transformation after definition. It's pretty likely that this
informs jax's backpropagation code to abandon its intermediate data,
not caching it in memory, but rather marking it to be recalculated
when needed. This is mentioned elsewhere in the paper.


More information about the cypherpunks mailing list