[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 04:52:13 PST 2022


The below is a draft, but I see I propagated some dimensions wrongly
and plan to rewrite it, and maybe some previous posts, to help with
clarity.

17:    exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights,
precision=precision)

Another einsum, multiplying the values by the weights.

[managed to continue after writing this :) -> It might take me a bit
to engage this einsum.  Hope to return quickly and just combine the
available parts.  For example, I could copy a definition of einsum in,
or compare the line to the equation from the paper, or look at the
paper's description of this chunk of code.]

This must be the rest of v = v * exp(m - m_i) + v_i * exp(s_i - m_i) ,
specifically the "v_i *" part, since exp_weights is exp(s_i - m_i)
already.  Since m is initialised to -inf, the left term of the sum is
zero and disregarded.

in jnp.einsum('vhf,qhv->qhf', value, exp_weights), we can see how this
multiplication is performed: exp_weights' dimensions are described as
query, heads, value, but its dimensions will have propagated from
max_score, in the last email: [queries, heads, 1].  It looks like I
unfortunately miswrote 'query' as 'key' in the dimension in that
email, as can be verified from the einsum that attn_weights came from.

So exp_weights is [query, heads, 1], the exponentiation of the
difference of a max among keys, and value is [value, heads, features].
I'm guessing the einsum says to broadcast each exp_weight to all the
values, not certain.  If I need to discern it more I can run test data
or write the parts of the definition down and combine them textually.

exp_values becomes a [query, head, feature] tensor containing the
multiplication of the values with exponentiated attn weights.

18:    return (exp_values, exp_weights.sum(axis=-1),
19:      max_score.reshape((query.shape[0], num_heads)))

The output is exp_values and exp_weights.  exp_weights is the
exponentiation of the scores minus their max score, for each query.


More information about the cypherpunks mailing list