[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 04:07:30 PST 2022


14:    max_score = jnp.max(attn_weights, axis = -1, keepdims = True)

-1 is the last axis. This appears to make max_score be a tensor of
shape [key, head, 1] containing the maximum pre-softmax attention
weight for each query, over the entire chunk of keys.

This might be the calculation m_i = max(m, s_i), with m initialised to -inf?
I think it says more about this in the description of the lines posted earlier.

15:    max_score = jax.lax.stop_gradient(max_score)

I think this informs jax that calculations on on the returned
max_score do not impact backpropagation gradients and don't need their
gradients calculated and stored.

16:    exp_weights = jnp.exp(attn_weights - max_score)

This must be exp(s_i - m_i) in the initialisation of v = v * exp(m -
m_i) + v_i * exp(s_i - m_i).

I typed most of line 17 and then my finger hit 'delete' by accident,
so I'm sending to preserve.


More information about the cypherpunks mailing list