import jax, jax.numpy as jnp

def attention(query, key, value, precision=jax.lax.Precision.HIGHEST):
    """Memory-efficient multi-head dot product attention."""
    num_q, num_heads, q_features = query.shape

    # query is queries, heads, features
    # key is keyvalues, heads, features
    # value is keyvalues, heads, features

    # 1. weights = dot(query, key)
    attn_weights = jnp.einsum('qhf,khf->qhk', query / jnp.sqrt(key.shape[-1]), key, precision=precision)
    # weights shape is now [queries, keys, heads]

    # 2. softmax of the weights across the features
    # softmax can be calculated as exp(a) / sum(exp(a))
    #   where a has its maximum values subtracted.
    #   the example code uses the max across keys, maybe since
    #   each query and head are separate.
    max_weights = jnp.max(attn_weights, axis=-1, keepdims=True)
    exp_weights = jnp.exp(attn_weights - max_weights)
    exp_weights /= jnp.sum(exp_weights, axis=-1, keepdims=True)


    # 3. dot of the weights with the values, losing the keyvalue dim,
    #    a separate result for each query
    attn_out = jnp.einsum('qhk,khf->qhf', exp_weights, value, precision=precision)

    return attn_out

if __name__ == '__main__':
    queries, keys, values = jax.random.normal(jax.random.PRNGKey(0), (3, 64, 8, 16))
    out = attention(queries, keys, values)
    print(out)
