import torch, math

def attention(queries, keys, values):
    queries = queries.permute(0, 2, 1, 3)
    keys = keys.permute(0, 2, 1, 3)
    values = values.permute(0, 2, 1, 3)

    # Take the dot product between the queries and keys to get the raw attention scores.
    attention_scores = torch.matmul(queries, keys.transpose(-1, -2))

    batch_size, num_heads, seq_len, q_head_dim = queries.shape
    _, _, _, v_head_dim = values.shape
    hiddens = num_heads * v_head_dim

    attention_scores = attention_scores / math.sqrt(q_head_dim)

    # Normalize the attention scores to probabilities
    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

    context_layer = torch.matmul(attention_probs, values)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (hiddens,)
    context_layer = context_layer.view(*new_context_layer_shape)

    return context_layer

if __name__ == '__main__':
    import jax
    queries, keys, values = torch.from_numpy(jax.random.normal(jax.random.PRNGKey(0), (3, 1, 64, 8, 16)).to_py())
    out = attention(queries, keys, values)
    print(out)
