12: def summarize_chunk(query, key, value): 13: attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision) An einsum is a way of doing matrix multiplication for n-dimensional matrices by specifying which axes of the tensors are dot'd with which other axes during the calculation. Any letters can be used in the string passed the developer wants, so long as they are consistent. Here they have picked 'q', 'k', 'h', and 'd' to represent 'query', 'key', 'head', and maybe 'data'. The multiplication ssys to treat the data dimension as the inner part, producing an output that is dimensioned by the query, head, and key axes. It's confusing to me.
From the paper, this would be the first step of attention: s_i = dot(q, k_i). Since there are multiple queries here, this is done for every query, rather than just one. The D dimension, the feature count, must be the dimension along which the dot product is taken, implying repeatedly that queries and keys have the same size in their final dimension. The other dimensions, heads, queries, and keys, must simply then be batching: doing this dot product many times.
Ok, let's see. Each query is dotted with each key, for each head, into a [queries, heads, keys] tensor of scalar dot products: pre-softmax attention weights. This is done with einsum('qhd,khd->qhk').