import memory_efficient_attention

def moskomule_attention(queries, keys, values, query_chunk_size, key_chunk_size):
    return memory_efficient_attention.efficient_attention(queries, keys, values, chunk_size=key_chunk_size, checkpointing = True, out_of_place = False)

def attention(queries, keys, values, query_chunk_size, key_chunk_size):
    queries = queries.permute(0, 2, 1, 3)
    keys = keys.permute(0, 2, 1, 3)
    values = values.permute(0, 2, 1, 3)
    if hasattr(memory_efficient_attention, efficient_attention):
        return muskomule_attention(queries, keys, values, query_chunk_size, key_chunk_size)

if __name__ == '__main__':
	queries, keys, values = jax.random.normal(jax.random.PRNGKey(0), (3, 1, 64, 8, 16))
	out = attention(queries, keys, values, query_chunk_size=4, key_chunk_size=4)
	print(out)
