[ot][spam][crazy][data] transformer model 'attention' improvement
k
gmkarl at gmail.com
Tue Jan 25 05:00:33 PST 2022
> 13: attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision)
attn_weights is a [queries, heads, keys] tensor consisting of the dot
product between the query and key features.
> 14: max_score = jnp.max(attn_weights, axis = -1, keepdims = True)
> 15: max_score = jax.lax.stop_gradient(max_score)
> -1 is the last axis. This appears to make max_score be a tensor of
> shape [key, head, 1] containing the maximum pre-softmax attention
> weight for each query, over the entire chunk of keys.
I'm thinking the axis names here are wrong. max_score is a tensor of
shape [queries, heads, 1], but it does contain the maximum pre-softmax
attention weight for each query, over that query's chunk of keys.
> This might be the calculation m_i = max(m, s_i), with m initialised to
> -inf?
> 16: exp_weights = jnp.exp(attn_weights - max_score)
Here we can see s_i - m_i: exp_weights is the exponentiated attention
weights reduced by their per-query maximum. max_score's one-sized
dimension is 'broadcast' to all the other keys in attn_weights.
exp_weights is then a tensor of the same shape as attn_weights:
[queries, heads, keys].
More information about the cypherpunks
mailing list