[ot][spam][crazy][data] transformer model 'attention' improvement
k
gmkarl at gmail.com
Sat Jan 22 00:58:13 PST 2022
i've glanced through this paper before but don't really remember it
well. i've got it open again, here are notes:
- 'attention' is a transformation between a query q, vectors of key
and value k_n and v_n, to a linear transformation of the value vector
- the linear transformation is roughly ('s' being an intermediate
value, i could have made an error writing this):
s_i = dot(q, k_i)
s_i = e^s_i / sum(s)
output = sum(v_i * s_i)
if that's right, i think it could be written this way:
s_i = e^dot(q, k_i)
output = sum(dot(v, s) / sum(s))
i think the revelation of the paper is that the inner sum can be moved
outside the outer sum. not sure. it adds more than just that,
though.
- this is their first example demonstration:
s_i = dot(q, k_i)
s_i = e^s_i
output = dot(v_i, s_i) / sum(s)
- it looks like their solution is to compute the output sequentially.
since nothing depends on calculating the sum when it is moved outside,
k_i and v_i can each be discarded as they are processed.
- the paper states that 'plain attention' has a constant query,
whereas 'self attention' has a query for each element of v and k.
- Warning! the e^s_i / sum(e^s_i) operation is usually optimized via
subtraction of exponents, to not propagate errors
- the paper handles this by performing running normalisation on e^s_i
I'm just going to type that explanation in:
We initialize the vector v and scalar s with 0, and m with -inf. As
before, given key value pair k_i, v_i, we compute s_i = dot(q, k_i),
but then the algorithm differs slightly from Section 2. We first
compute m_i = max(m, s_i) and update v = v * e^(m - m_i) + v_i *
e^(s_i - m_i) and s = s * e^(m - m_i) + e^(s_i - m_i) and m = m_i.
After processing all keys and queries, we divide v / s to get final
results.
All the funny math notation is kind of strange to me, but i infer they
basically found a way to refrain from dividing after exponentiation by
shuffling appropriate subtractions prior to them. There's a code
implementation too on github that is likely clearer, and some code on
the next page or two of the paper.
- Since the approach involves sequential accumulation, it doesn't
undergo gpu optimization. The paper proposes chunking for gpu/tpu
devices (page 2). They describe their algorithm in terms of their
example code (page 3)
More information about the cypherpunks
mailing list