[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 10:54:45 PST 2022


17:    exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights,
precision=precision)

I think the values are still [values, heads, features], whereas I think the
exp_weights are [queries, heads, keys]. I think the values are
multiplied by the per-key weights, producing a vector of weighted
features for each query and each head.  A dot product aligning the key
and value dimensions.

This must be "v_i * exp(s_i - m_i)" in "v = v * exp(m - m_i) + v_i *
exp(s_i - m_i)" from section 3: numerical stability in the paper.

18:    return (exp_values, exp_weights.sum(axis=-1),
19:      max_score.reshape((query.shape[0], num_heads))

A tuple of three values is returned from summarize_chunk.

- it looks like the first element, exp_values, is the attention vector
combining the values with the keys and query, prior to the mean at the
denominator of the softmax.  "v*" in the numerical stability portion
of the paper.

- the second component, exp_weights.sum(axis=-1), appears to be the
sum of the exponent along the key dimension, one sum for each query
and each head, of shape [queries, heads].  this is likely for
recombining for the denominator of the softmax.

- the third component calls reshape on the max_score.  I think
max_score was the maximum taken across the keys, of the unexponented
attention key weights i.e. the dot products of the query features and
the key features. It looks like the reshape just removes the final
dummy dimension that was for broadcasting subtraction.

Now let's go back to where this function was called, and look at the
use of these return values.


More information about the cypherpunks mailing list