[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 03:36:51 PST 2022


12:  def summarize_chunk(query, key, value):
13:    attn_weights = jnp.einsum('qhd,khd->qhk', query, key,
precision=precision)

An einsum is a way of doing matrix multiplication for n-dimensional
matrices by specifying which axes of the tensors are dot'd with which
other axes during the calculation. Any letters can be used in the
string passed the developer wants, so long as they are consistent.
Here they have picked 'q', 'k', 'h', and 'd' to represent 'query',
'key', 'head', and maybe 'data'.  The multiplication ssys to treat the
data dimension as the inner part, producing an output that is
dimensioned by the query, head, and key axes.

It's confusing to me.

>From the paper, this would be the first step of attention: s_i =
dot(q, k_i).  Since there are multiple queries here, this is done for
every query, rather than just one.  The D dimension, the feature
count, must be the dimension along which the dot product is taken,
implying repeatedly that queries and keys have the same size in their
final dimension.  The other dimensions, heads, queries, and keys, must
simply then be batching: doing this dot product many times.

Ok, let's see.  Each query is dotted with each key, for each head,
into a [queries, heads, keys] tensor of scalar dot products:
pre-softmax attention weights.  This is done with
einsum('qhd,khd->qhk').


More information about the cypherpunks mailing list