[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 05:00:33 PST 2022


> 13:    attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision)

attn_weights is a [queries, heads, keys] tensor consisting of the dot
product between the query and key features.

> 14:    max_score = jnp.max(attn_weights, axis = -1, keepdims = True)
> 15:    max_score = jax.lax.stop_gradient(max_score)

> -1 is the last axis. This appears to make max_score be a tensor of
> shape [key, head, 1] containing the maximum pre-softmax attention
> weight for each query, over the entire chunk of keys.

I'm thinking the axis names here are wrong.  max_score is a tensor of
shape [queries, heads, 1], but it does contain the maximum pre-softmax
attention weight for each query, over that query's chunk of keys.

> This might be the calculation m_i = max(m, s_i), with m initialised to
> -inf?

> 16:    exp_weights = jnp.exp(attn_weights - max_score)

Here we can see s_i - m_i: exp_weights is the exponentiated attention
weights reduced by their per-query maximum.  max_score's one-sized
dimension is 'broadcast' to all the other keys in attn_weights.
exp_weights is then a tensor of the same shape as attn_weights:
[queries, heads, keys].


More information about the cypherpunks mailing list