[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 11:52:58 PST 2022


I tweaked chunked_attention.py so it would work for me. I propagated
the key chunking parameter that has no way for use in the original
code, and specified lax's namespace as jax where it was missing, I
think on line 14.

>>> import chunked_tweaks as chunked
>>> import jax.numpy as jnp
>>> queries, keys, values = jnp.mgrid[:64, :8, :16]
>>> chunked.attention(queries, keys, values, query_chunk_size = 4, key_chunk_size = 4)

outputs a [64,8,16] array of vectors each ascending from 0 through 15.

OK, now I'll try writing conventional attention as described in the
paper on my own, and see if the output matches.  Then maybe I'll see
if I can check with random data.


More information about the cypherpunks mailing list