[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Sat Jan 22 00:58:13 PST 2022


i've glanced through this paper before but don't really remember it
well.  i've got it open again, here are notes:

- 'attention' is a transformation between a query q, vectors of key
and value k_n and v_n, to a linear transformation of the value vector
- the linear transformation is roughly ('s' being an intermediate
value, i could have made an error writing this):

s_i = dot(q, k_i)
s_i = e^s_i / sum(s)
output = sum(v_i * s_i)

if that's right, i think it could be written this way:

s_i = e^dot(q, k_i)
output = sum(dot(v, s) / sum(s))

i think the revelation of the paper is that the inner sum can be moved
outside the outer sum.  not sure.  it adds more than just that,
though.

- this is their first example demonstration:

s_i = dot(q, k_i)
s_i = e^s_i
output = dot(v_i, s_i) / sum(s)

- it looks like their solution is to compute the output sequentially.
since nothing depends on calculating the sum when it is moved outside,
k_i and v_i can each be discarded as they are processed.
- the paper states that 'plain attention' has a constant query,
whereas 'self attention' has a query for each element of v and k.

- Warning! the e^s_i / sum(e^s_i) operation is usually optimized via
subtraction of exponents, to not propagate errors
- the paper handles this by performing running normalisation on e^s_i

I'm just going to type that explanation in:

We initialize the vector v and scalar s with 0, and m with -inf. As
before, given key value pair k_i, v_i, we compute s_i = dot(q, k_i),
but then the algorithm differs slightly from Section 2. We first
compute m_i = max(m, s_i) and update v = v * e^(m - m_i) + v_i *
e^(s_i - m_i) and s =  s * e^(m - m_i) + e^(s_i - m_i) and m = m_i.
After processing all keys and queries, we divide v / s to get final
results.

All the funny math notation is kind of strange to me, but i infer they
basically found a way to refrain from dividing after exponentiation by
shuffling appropriate subtractions prior to them.  There's a code
implementation too on github that is likely clearer, and some code on
the next page or two of the paper.

- Since the approach involves sequential accumulation, it doesn't
undergo gpu optimization.  The paper proposes chunking for gpu/tpu
devices (page 2).  They describe their algorithm in terms of their
example code (page 3)


More information about the cypherpunks mailing list