import timeit
import jax, torch, math

if __name__ == '__main__':
    jax_qkvs = jax.random.normal(jax.random.PRNGKey(0), (3, 1, 64, 8, 16))
    torch_qkvs = [torch.from_numpy(jax.to_py()) for jax in jax_qkvs]
    #out = attention(queries, keys, values, query_chunk_size=4, key_chunk_size=4)
    #print(out)
