import memory_efficient_attention

def moskomule_attention(queries, keys, values, query_chunk_size, key_chunk_size):


def attention(queries, keys, values, query_chunk_size, key_chunk_size):


if __name__ == '__main__':
	queries, keys, values = jax.random.normal(jax.random.PRNGKey(0), (3, 64, 8, 16))
	out = attention(queries, keys, values, query_chunk_size=4, key_chunk_size=4)
	print(out)
