26 Jan
2022
26 Jan
'22
4:52 a.m.
I tweaked chunked_attention.py so it would work for me. I propagated the key chunking parameter that has no way for use in the original code, and specified lax's namespace as jax where it was missing, I think on line 14.
import chunked_tweaks as chunked import jax.numpy as jnp queries, keys, values = jnp.mgrid[:64, :8, :16] chunked.attention(queries, keys, values, query_chunk_size = 4, key_chunk_size = 4)
outputs a [64,8,16] array of vectors each ascending from 0 through 15. OK, now I'll try writing conventional attention as described in the paper on my own, and see if the output matches. Then maybe I'll see if I can check with random data.