[ot][spam][crazy][data] transformer model 'attention' improvement
i made two spamjournal threads regarding 'data science''y things: automated reverse engineering and perceiver model notes in the automated reverse engineering one i linked a paper: https://arxiv.org/abs/2112.05682 . this paper clearly describes an optimization to these models that should be obvious: an algebraic transformation of the 'attention' mechanism that requires understanding of the implementation of some of the operators to use, and drops the memory requirements by orders of magnitude. implementing this might help people with fewer resources than governments and large corporations, train models like the automated reverse engineering one. i _think_ that basically it means you can process a bigger batchsize or a bigger model or longer input/output text, on a smaller gpu or tpu (or cpu ram). i'd like to try to implement it in both perceiver and t5.
i've glanced through this paper before but don't really remember it well. i've got it open again, here are notes: - 'attention' is a transformation between a query q, vectors of key and value k_n and v_n, to a linear transformation of the value vector - the linear transformation is roughly ('s' being an intermediate value, i could have made an error writing this): s_i = dot(q, k_i) s_i = e^s_i / sum(s) output = sum(v_i * s_i) if that's right, i think it could be written this way: s_i = e^dot(q, k_i) output = sum(dot(v, s) / sum(s)) i think the revelation of the paper is that the inner sum can be moved outside the outer sum. not sure. it adds more than just that, though. - this is their first example demonstration: s_i = dot(q, k_i) s_i = e^s_i output = dot(v_i, s_i) / sum(s) - it looks like their solution is to compute the output sequentially. since nothing depends on calculating the sum when it is moved outside, k_i and v_i can each be discarded as they are processed. - the paper states that 'plain attention' has a constant query, whereas 'self attention' has a query for each element of v and k. - Warning! the e^s_i / sum(e^s_i) operation is usually optimized via subtraction of exponents, to not propagate errors - the paper handles this by performing running normalisation on e^s_i I'm just going to type that explanation in: We initialize the vector v and scalar s with 0, and m with -inf. As before, given key value pair k_i, v_i, we compute s_i = dot(q, k_i), but then the algorithm differs slightly from Section 2. We first compute m_i = max(m, s_i) and update v = v * e^(m - m_i) + v_i * e^(s_i - m_i) and s = s * e^(m - m_i) + e^(s_i - m_i) and m = m_i. After processing all keys and queries, we divide v / s to get final results. All the funny math notation is kind of strange to me, but i infer they basically found a way to refrain from dividing after exponentiation by shuffling appropriate subtractions prior to them. There's a code implementation too on github that is likely clearer, and some code on the next page or two of the paper. - Since the approach involves sequential accumulation, it doesn't undergo gpu optimization. The paper proposes chunking for gpu/tpu devices (page 2). They describe their algorithm in terms of their example code (page 3)
- their example gpu code is based around an attention() function on line 42 that takes the query, key, and value as function parameters, as well as a chunk size. - this engages the concept of 'heads'. i _ think_ a 'head' is basically a chunk of the input data, already, not sure. - their attention() function breaks the query into chunks of the passed size, each chunk associated with all values and all keys, and passes each one to _query_chunk_attention() ...
here's the pdf of this paper: https://ipfs.io/ipfs/bafkreifaqccxlq5hzka6677tpjyq2ngfybjn7fwzgddzogjc3t465k... and here it is run through pdf2txt, as a txt file: https://ipfs.io/ipfs/bafkreicu42wf6r53vjdh4ujk2r6fhdlmnsk7xbt6yr2vyxppg2wysw...
I'll review the definition of self-attention from the paper. s_i = dot(q, k_i) # take the dot product of the query with every key s'_i = exp(s_i) / sum(exp(s_j), j) # take the softmax of the result attn = sum(v_i * s'_i, i) # attention is the dot of that softmax with the value vectors This is done in parallel. Why does it take O(n^2) memory? - the paper says the calculation of s_i is O(n) for a single query, but for _self attention_, as used in transformer models, there is a different query for each sequence position: hence n parallel calculations of attention. So, the bulk of the simple paper is to convert from O(n) for a single query, to O(1) for a single query, by simply unparallelising the calculation. Okay, I was missing that, might review this email content again to try to comprehend what's up.
Here's the explanation of the first chunk of their example code. In the out loop (line 56f), we split the queries in to chunks of constant size, resulting in a linear number of iterations. In each iteration of the outer loop, we call _query_chunk_attention, which itself processes the keys and values in chunks (lines 23 to 33). 42:def attention(query, key, value, precision=jax.lax.Precision.HIGHEST, 43: query_chunk_size=1024): 44: """Memory-efficient multi-head dot product attention.""" 45: num_q, num_heads, q_features = query.shape 46: 47: def chunk_scanner(chunk_idx, _): 48: query_chunk = lax.dynamic_slice( 49: query, (chunk_idx, 0, 0), 50: slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features)) 51: return (chunk_idx + query_chunk_size, 52: _query_chunk_attention(query_chunk, key, value, precision=precision)) 53: 54: _, res = jax.lax.scan( 55: chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) 56: return res.reshape(num_q, num_heads, value.shape[-1]) I think they miswrote line the reference to 56, unsure. jax.lax.scan is like folding: it calls chunk_scanner repeatedly, passing its first output back in, and stacking its second output into the result. This is from help(jax.lax.scan):
def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
So basically line 54 calculates num_q / query_chunk_size, calls chunk_scanner that many times, and stacks the results into one tensor. Stacking combines many tensors into a single one with with an additional dimension, that contains each adjacent to the next. A tensor is an n-dimensional matrix. Looking at chunk_scanner, by inspection I can see that chunk_idx is the offset into the queries to start splitting data out. jax.lax.dynamic_slice appears to be a simple slice operation: the first indices are the lower-indexed corner of the slice, and the second set of indices are the size of the slice. I'm noting the dimensions of the query are num_q, num_heads, and num_features. So we can infer that the data has already been split into "heads" prior to this call. It also looks like the individual query data is 1-dimensional here. The slice operation can then be interpreted as simply chunking the queries into query_chunk_size, whilst keeping the keys and values whole unchunked. Each query is then passed to _query_chunk_attention. The use of jax.lax and dynamic_slice rather than more normative operations may likely be so that each chunk is processed in parallel with the others during jax's precompilation phase, although I don't really know.
And here's their explanation of the inner loops: In each iteration of the outer loop, we call _query_chunk_attention, which itself processes the keys and values in chunks (lines 23 to 33). The chunks are processed sequentially and each chunk is summarized independently (lines 14 to 21). Assuming a chunk size of sqrt(n) for the keys and values, we hence obtain sqrt(n) summaries, giving rise to the O(sqrt(n)) memory complexity. After the summaries are computed, they need to be rescaled (lines 35 to 38) along the lines of Section 3, before we return the values divided by the weights (line 42). The result of each iteration of the outer loop is directly written to the output tensor res (line 56), so that no additional memory is consumed across iterations. (A multi-stage summarization approach could achieve O(log n) but would complicate the implementation.) [... In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact (on TPUv2), while still providing significant memory savings.] 01:import functools, jax, math 02:from jax import numpy as jnp 03: 04:def _query_chunk_attention(query, key, value, precision, key_chunk_size=4096): 05: """Multi-head dot product attention with a limited number of queries.""" 06: num_kv, num_heads, k_features = key.shape 07: v_features = value.shape[-1] 08: key_chunk_size = min(key_chunk_size, num_kv) 09: query = query / jnp.sqrt(k_features) 10: 11: @functools.partial(jax.checkpoint, prevent_cse=False) 12: def summarize_chunk(query, key, value): 13: attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision) 14: max_score = jnp.max(attn_weights, axis=-1, keepdims=True) 15: max_score = jax.lax.stop_gradient(max_score) 16: exp_weights = jnp.exp(attn_weights - max_score) 17: exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights, precision=precision) 18: return (exp_values, exp_weights.sum(axis=-1), 19: max_score.reshape((query.shape[0], num_heads))) 20: 21: def chunk_scanner(chunk_idx): 22: key_chunk = jax.lax.dynamic_slice( 23: key, (chunk_idx, 0, 0), 24: slice_sizes=(key_chunk_size, num_heads, k_features)) 25: value_chunk = jax.lax.dynamic_slice( 26: value, (chunk_idx, 0, 0), 27: slice_sizes=(key_chunk_size, num_heads, v_features)) 28: return summarize_chunk(query, key_chunk, value_chunk) 29: 30: chunk_values, chunk_weights, chunk_max = jax.lax.map( 31: chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) 32: 33: global_max = jnp.max(chunk_max, axis=0, keepdims=True) 34: max_diffs = jnp.exp(chunk_max - global_max) 35: chunk_values *= jnp.expand_dims(max_diffs, axis=-1) 36: chunk_weights *= max_diffs 37: 38: all_values = chunk_values.sum(axis=0) 39: all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) 40: return all_values / all_weights 41: It looks like the line numbers got offset a little bit in printing.
Drill down, dive in. Selection appears to have stopped working for me, so it's all typing the text over, happens sometimes. 04:def _query_chunk_attention(query, key, value, precision, key_chunk_size=4096): As earlier, the query here is a chunk of queries, a subset of all of them. 05: """Multi-head dot product attention with a limited number of queries.""" 06: num_kv, num_heads, k_features = key.shape 07: v_features = value.shape[-1] # i commented here on the meaning of the dimensions and then my body and browser spazzed and marked the thread as spam. when i recovered the thread the comments were gone, sending to preserve. 12:@functools.partial(jax.checkpoint, prevent_cse=False) This decorates the function to be passed through jax.checkpoint for further transformation after definition. It's pretty likely that this informs jax's backpropagation code to abandon its intermediate data, not caching it in memory, but rather marking it to be recalculated when needed. This is mentioned elsewhere in the paper.
About my life: I have nothing going on right now. Most of my equipment has broken in some way. I live a life full of spasmodic muscle contractions and sudden dissociated cognitive changes that leaves me with a lot of suffering. I love projects that I can continue on without much suffering. Even if they are simple, mundane. Historically, I like working on very complex stuff, and always looked for something new that had never been done before, to add my cool-looking part to. [more information exists]
08: key_chunk_size = min(key_chunk_size, num_kv) It's the first dimension of the keys and values that will be split. 09: query = query / jnp.sqrt(k_features) # i typed a lot of comments on lines but they disappeared again. i plan to return to line 09 above because i'm not sure why it is. i skipped the inner functions to start with, and am working on copying over lines 30 and 31. sending to preserve.
30: chunk_values, chunk_weights, chunk_max = jax.lax.map( 31: chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) from help(jax.lax.map): def map(f, xs): return np.stack([f(x) for x in xs]) so it just passes each element of xs into chunk_scanner, and again stacks the results.
21: def chunk_scanner(idx): idx comes from elements of jnp.arange() on line 31, which generates a vector of ascending integers starting at 0.
22: key_chunk = jax.lax.dynamic_slice( 23: key, (chunk_idx, 0, 0), 24: slice_sizes=(key_chunk_size, num_heads, k_features)) Note also on line 31 that a step size is passed into jnp.arange(), so chunk_idx is offsets separated by key_chunk_size. Lines 22-24 break key into key_chunk_size'd chunks along the first dimension. The second and third dimensions (heads and features) remain full. 25: value_chunk = jax.lax.dynamic_slice( 26: value, (chunk_idx, 0, 0), 27: slice_sizes=(key_chunk_size, num_heads, v_features)) Lines 25-27 do the same thing for value_chunk, value, and v_features, breaking the value tensor into key_chunk_size'd chunks along the first dimension. 28: return summarize_chunk(query, key_chunk, value_chunk) summarize_chunk must do the actual attention calculation for the chunk. I think this code would be clearer if query were called query_chunk, which is what is passed in here.
11: @functools.partial(jax.checkpoint, prevent_cse=False) I think checkpointing relates to limiting memory used by gradient backpropagation during training of a model. I think it means the gradients can be recalculated for this function when needed, by storing its arguments instead of each gradient.
12: def summarize_chunk(query, key, value): 13: attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision) An einsum is a way of doing matrix multiplication for n-dimensional matrices by specifying which axes of the tensors are dot'd with which other axes during the calculation. Any letters can be used in the string passed the developer wants, so long as they are consistent. Here they have picked 'q', 'k', 'h', and 'd' to represent 'query', 'key', 'head', and maybe 'data'. The multiplication ssys to treat the data dimension as the inner part, producing an output that is dimensioned by the query, head, and key axes. It's confusing to me.
From the paper, this would be the first step of attention: s_i = dot(q, k_i). Since there are multiple queries here, this is done for every query, rather than just one. The D dimension, the feature count, must be the dimension along which the dot product is taken, implying repeatedly that queries and keys have the same size in their final dimension. The other dimensions, heads, queries, and keys, must simply then be batching: doing this dot product many times.
Ok, let's see. Each query is dotted with each key, for each head, into a [queries, heads, keys] tensor of scalar dot products: pre-softmax attention weights. This is done with einsum('qhd,khd->qhk').
These next bits starting line 14 (and i'm trying to remember there's an /sqrt(count) line i mean to return to) must be part of the strategy to iteratively calculate the precise softmax (expi(i) / sum[exp(i)]) by doing subtraction in the exponent rather than division outside it. Here;s text from Section 3 of the paper: In practice, the softmax is implemented by subtracting the maximum score from all scores. This does not change the result of the softmax, but avoids this numerical problem. Our incremental computation of the sum of exponentiated scores (and the values times the scores) does not immediately allow for the same trick, as the maximum may depend on the last score in the sequence. But the subtraction cannot be delayed either, since the scores must be exponentiated before they can be added to the cumulative sum. To resolve this problem, we introduce an additional scalar, which keeps track of the maximum score that the incremental algorithm has seen so far, and we renormalize the sums of exponentiated values as needed: We initialize the vector v and scalar s with 0, and m with -inf. As before, given a key value pair k_i, v_i, we compute s_i = dot(q, k_i), but then the algorithm differs slightly from Section 2. We first compute m_i = max(m, s_i), and update v = v * exp(m - m_i) + v_i * exp(s_i - m_i) and s = s * exp(m - m_i) + exp(s_i - m_i) and m = m_i. After processing all keys and queries, we divide v / s to get the final result. - I want to drill down into their equations more, but it would make sense to use the variable names from the code example starting line 14.
14: max_score = jnp.max(attn_weights, axis = -1, keepdims = True) -1 is the last axis. This appears to make max_score be a tensor of shape [key, head, 1] containing the maximum pre-softmax attention weight for each query, over the entire chunk of keys. This might be the calculation m_i = max(m, s_i), with m initialised to -inf? I think it says more about this in the description of the lines posted earlier. 15: max_score = jax.lax.stop_gradient(max_score) I think this informs jax that calculations on on the returned max_score do not impact backpropagation gradients and don't need their gradients calculated and stored. 16: exp_weights = jnp.exp(attn_weights - max_score) This must be exp(s_i - m_i) in the initialisation of v = v * exp(m - m_i) + v_i * exp(s_i - m_i). I typed most of line 17 and then my finger hit 'delete' by accident, so I'm sending to preserve.
The below is a draft, but I see I propagated some dimensions wrongly and plan to rewrite it, and maybe some previous posts, to help with clarity. 17: exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights, precision=precision) Another einsum, multiplying the values by the weights. [managed to continue after writing this :) -> It might take me a bit to engage this einsum. Hope to return quickly and just combine the available parts. For example, I could copy a definition of einsum in, or compare the line to the equation from the paper, or look at the paper's description of this chunk of code.] This must be the rest of v = v * exp(m - m_i) + v_i * exp(s_i - m_i) , specifically the "v_i *" part, since exp_weights is exp(s_i - m_i) already. Since m is initialised to -inf, the left term of the sum is zero and disregarded. in jnp.einsum('vhf,qhv->qhf', value, exp_weights), we can see how this multiplication is performed: exp_weights' dimensions are described as query, heads, value, but its dimensions will have propagated from max_score, in the last email: [queries, heads, 1]. It looks like I unfortunately miswrote 'query' as 'key' in the dimension in that email, as can be verified from the einsum that attn_weights came from. So exp_weights is [query, heads, 1], the exponentiation of the difference of a max among keys, and value is [value, heads, features]. I'm guessing the einsum says to broadcast each exp_weight to all the values, not certain. If I need to discern it more I can run test data or write the parts of the definition down and combine them textually. exp_values becomes a [query, head, feature] tensor containing the multiplication of the values with exponentiated attn weights. 18: return (exp_values, exp_weights.sum(axis=-1), 19: max_score.reshape((query.shape[0], num_heads))) The output is exp_values and exp_weights. exp_weights is the exponentiation of the scores minus their max score, for each query.
13: attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision)
attn_weights is a [queries, heads, keys] tensor consisting of the dot product between the query and key features.
14: max_score = jnp.max(attn_weights, axis = -1, keepdims = True) 15: max_score = jax.lax.stop_gradient(max_score)
-1 is the last axis. This appears to make max_score be a tensor of shape [key, head, 1] containing the maximum pre-softmax attention weight for each query, over the entire chunk of keys.
I'm thinking the axis names here are wrong. max_score is a tensor of shape [queries, heads, 1], but it does contain the maximum pre-softmax attention weight for each query, over that query's chunk of keys.
This might be the calculation m_i = max(m, s_i), with m initialised to -inf?
16: exp_weights = jnp.exp(attn_weights - max_score)
Here we can see s_i - m_i: exp_weights is the exponentiated attention weights reduced by their per-query maximum. max_score's one-sized dimension is 'broadcast' to all the other keys in attn_weights. exp_weights is then a tensor of the same shape as attn_weights: [queries, heads, keys].
17: exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights, precision=precision) I think the values are still [values, heads, features], whereas I think the exp_weights are [queries, heads, keys]. I think the values are multiplied by the per-key weights, producing a vector of weighted features for each query and each head. A dot product aligning the key and value dimensions. This must be "v_i * exp(s_i - m_i)" in "v = v * exp(m - m_i) + v_i * exp(s_i - m_i)" from section 3: numerical stability in the paper. 18: return (exp_values, exp_weights.sum(axis=-1), 19: max_score.reshape((query.shape[0], num_heads)) A tuple of three values is returned from summarize_chunk. - it looks like the first element, exp_values, is the attention vector combining the values with the keys and query, prior to the mean at the denominator of the softmax. "v*" in the numerical stability portion of the paper. - the second component, exp_weights.sum(axis=-1), appears to be the sum of the exponent along the key dimension, one sum for each query and each head, of shape [queries, heads]. this is likely for recombining for the denominator of the softmax. - the third component calls reshape on the max_score. I think max_score was the maximum taken across the keys, of the unexponented attention key weights i.e. the dot products of the query features and the key features. It looks like the reshape just removes the final dummy dimension that was for broadcasting subtraction. Now let's go back to where this function was called, and look at the use of these return values.
30: chunk_values, chunk_weights, chunk_max = jax.lax.map( 31: chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) chunk_values is exp_values, which I think was the local values dotted with the exponentiated attention weights minus their local max. It looks like chunk_weights is those exponentiated weights. And it looks like chunk_max are those local maximum values. 32: 33: global_max = jnp.max(chunk_max, axis=0, keepdims=True) [struggling some to continue, it looks like these tensors get recombined into one vector by jax.lax.map. chunk_max has its global max taken along axis 0, which is likely the axis that jax.lax.map adds when recombining them. So maybe this would extend the existing maximum values, to find the maximums among the split keys and values.] 34: max_diffs = jnp.exp(chunk_max - global_max) This likely calculates the scale needed for each chunk, that would change it to be relative to the global max, rather than its local max. A scale because it's inside an exponent. 35: chunk_values *= jnp.expand_dims(max_diffs, axis=-1) jnp.expand_dims appears to simply wrap every value in max_diffs in a new one-sized dimension, maybe for the multiplication to broadcast across the feature dimension of the values. So line 35 multiplies all the output values, by the calculated scale. This looks like it turns exp(s_i - m_i) into exp(s_i - m_i + m_i - global_max) i.e. exp(s_i - global_max). 36: chunk_weights *= max_diffs This likely performs the same operation. The chunk_weights are not yet dotted with the value vectors. 37: 38: all_values = chunk_values.sum(axis=0) Here we finally form the summation of the combined values across the key and value chunks. They were already dotted across the values dimension, so this possibly combine the values of adjacent chunks as if there were one large dot-product taken. 39: all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) This appears to perform the same for the weights, which I think were dotted with the original value vectors to make chunk_values. Additionally a dimension is added to wrap each value. 40: return all_values / all_weights Here is where the softmax operation must finally complete. all_values and all_weights have been arithmetically shifted inside exp() operations, to be relative to their maxima. When the division is performed, the shift is analogous to a scaling value applied to both the numerator and denominator, and the result is the same, but with much higher precision due to less extreme values. I think! Whoohoo! Let's quickly glance at where that value comes out after _query_chunk_attention returns it. On line 52, it's returned through chunk_scanner, and then must be stacked with other query chunks back on line 55: 54: _, res = jax.lax.scan( 55: chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) 56: return res.reshape(num_q, num_heads, value.shape[-1]) The final call to reshape likely squashes that stack into a single [queries, heads, features] output tensor. Maybe rather than connecting more verbiage from the paper it would make sense to try to run this now, and compare an implementation with and without chunking, to verify it's correct.
This is the transcription of the code I'm first trying with. I haven't tested it yet. I need to generate some data with appropriate dimensions, and I'm somewhat new to jax.numpy .
I tweaked chunked_attention.py so it would work for me. I propagated the key chunking parameter that has no way for use in the original code, and specified lax's namespace as jax where it was missing, I think on line 14.
import chunked_tweaks as chunked import jax.numpy as jnp queries, keys, values = jnp.mgrid[:64, :8, :16] chunked.attention(queries, keys, values, query_chunk_size = 4, key_chunk_size = 4)
outputs a [64,8,16] array of vectors each ascending from 0 through 15. OK, now I'll try writing conventional attention as described in the paper on my own, and see if the output matches. Then maybe I'll see if I can check with random data.
I've got my local code working with the mgrid data. I made at least two bugs: an incorrect einsum, and dotting with the weights rather than the exponent. The mgrid data doesn't test the softmax since each vector has the same maximum. Time to figure out how to make random tensors in jax.
karl: if you come here looking for the code to continue debugging it, maybe try writing it again :) since it's so hard to remember without repetition
The reason my barebones attention got a different answer than the paper's chunked attention was that I hadn't included the division by the square root of the feature count, that I had intended to return to but had not done. When included, the outputs are the same, and the script is attached, unsure why. Next I'm comparing the output of huggingface's PerceiverSelfAttention class with my script and the chunked attention. The output is different, maybe due to an additional post processing step? It also includes the square root denominator.
The first issue I have working with PerceiverSelfAttention is sorting out the huggingface permutations of the query, key, value matrices. The dot products aren't making the same weights, indicating I'm not providing the data in the right shape. They reorganise the matrices to handle multiple channels, and split into heads a certain way. I have trouble intuiting the relation between torch.matmul and einsum, regarding a matrix of dot products of feature vectors.
So basically a matmul is an einsum that drops the last coordinate of the first operand and the second to last coordinate of the second operand. Matrix multiplications really are sequences of dot products! Linear algebra is slowly and painfully coming back to me. Attached is a transcription of huggingface's perceiver attention that works with the same example data. The 'keys/queries/values' axis ends up being the sequence axis. They permute the matrices to exclude the heads dimension so the dot products can be done with a normal matmul rather than einsum and its string parsing.
During making these small tests I discovered that the paper's specific jax implementation is very fast. Noticeably faster than my example that didn't use jax.lax, on my raspberry pi cpu. Huggingface's perceiver model implementation is not in jax; rather, it's in torch. There are implementations of this memory efficient attention paper in pytorch, at https://github.com/moskomule/memory_efficient_attention.pytorch and https://github.com/AminRezaei0x443/memory-efficient-attention . I'd like to try them both out on the same example data [might only try one out though].
If the realtime free open source cellphone app goal is met outside us, of a general app to comprehend the meaning of information, I'd want to add a goal of rational contextual logic and inference here. So that systems can build that understand things and act on them with good ideas that include everyone, rather than just copying big attributes and stuff.
I'm making slow progress on running the pytorch implementations of the paper through the same test. Attached is the start of a wrapper. The two linked implementations have the same package name for now, so the plan was to put both wrappers in the same file, and just check the package to see if it has the implementation function to see which one is installed when both wrappers are in.
I drafted a wrapper for muskomule's implementation, looks like I still have something wrong.
The moskomule attention implementation is matching for me if I perform the square root myself outside it. The project does not appear to install properly, the test and the library file run adjacent in the same folder.
the AminRezaei0x443 implementation also produces the same data, attached again. the aminrezaei implementation does the square root, provides for optional mask and bias tensors, is on pypi, and has both a jax and torch implementation, so it seems the way to go. next i'll be timing it compared to the paper's implementation that i noted as speedy. just on my raspberry pi, though. i'm guessing it's roughly the same on good hardware with large models, where the core batches dominate everything. sometimes i mostly engage stuff i bump into. maybe it would be good just to quickly run through the source and verify that aminrezai does checkpointing and lax mapping like in the paper.
On 1/26/22, k <gmkarl@gmail.com> wrote:
the AminRezaei0x443 implementation also produces the same data, attached again.
the aminrezaei implementation does the square root, provides for optional mask and bias tensors, is on pypi, and has both a jax and torch implementation, so it seems the way to go.
next i'll be timing it compared to the paper's implementation that i noted as speedy. just on my raspberry pi, though. i'm guessing it's roughly the same on good hardware with large models, where the core batches dominate everything. sometimes i mostly engage stuff i bump into.
maybe it would be good just to quickly run through the source and verify that aminrezai does checkpointing and lax mapping like in the paper.
chunked_lib actually fails on aminrezaei, i still had the moskomule script adjacent to it when running it. should be resolvable.
well it works now the aminrezaei implementation prints out each chunk index as it scans it
timing script aminrezaei comes out a little faster than the paper here for me, not a serious test. i went into aminrezaei's source and removed the per-chunk print statement, may have influenced speed.
hum second run the other was faster, by about the same amount, so maybe they're comparable. anyway next step is to add aminrezaei to huggingface's perceiver implementation. maybe in a base class.
i found the 'reformer' at https://huggingface.co/transformers/v2.9.1/model_doc/reformer.html . this architecture appears to be about 2 years old. it says it is comparable to a transformer, but much more memory efficient, and uses a different kind of 'attention'.
the current mainstream model for very long sequences appears to be bigbird, and there is a pretrained model for long document summarization: https://github.com/google-research/bigbird https://huggingface.co/docs/transformers/model_doc/bigbird on to perceiver. i'm thinking of actually adding a configuration directive to huggingface for using efficient attention, and opening a pull request if there isn't one already, to see what they say.
- https://github.com/xloem/transformers/commit/3f2b78f787dd1d204d842f26cdc026c... Added configuration parameters and import for memory efficient attention. - There is an existing setup for feedforward chunking. I'll likely try to copy any reasonable patterns it establishes.
- https://github.com/xloem/transformers/commit/7575b8286dd5c2b328d3c34d9b66dab... A draft of calling memory_efficient_attention from the perceiver model, when configuration parameters are set. - Untested. Maybe I can copy google's example again, like before, somehow, and run the same test with the configuration settings set, and walk through it to make sure it uses the new code.
commit 437a050dea6595361b563d9c68d62617ed7dc59d (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sat Jan 29 01:10:20 2022 +0000 set gptj up to compare attentions; resolved some small bugs; started a test script
github/xloem/transformers repo, using return_weights branch of github/xloem/memory-efficient/attention repo commit e72d4a3536d1799fc40a85d83a7999e8a39563fc (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sat Jan 29 07:58:00 2022 +0000 more bugfixes. output differs, makes sense to step through in parallel again and compare.
github/xloem/transformers repo, using return_weights branch of github/xloem/memory-efficient/attention repo commit e72d4a3536d1799fc40a85d83a7999e8a39563fc (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sat Jan 29 07:58:00 2022 +0000 more bugfixes. output differs, makes sense to step through in parallel again and compare.
apparently i missent this log the first time the two main outputs are the same now, but it looks like i implemented the new 'output_attentions' feature wrong. there's likelihood (hard for me to tell so far) that it should be the _post_ softmax weights, not the _pre_ softmax weights as i said in the public issue i opened to move things forward responsibly. thinking about the public error (and engaging issues interacting with my system like the loss of the first send of this email) can stimulate my spasms, which prolongs the public presentation of the possibly-false information :rolls_eyes:. commit 788efe5c9a99cc4b432cc215d0dbb1175632d73a (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sun Jan 30 09:42:20 2022 +0000 typo fix resolves exception; also missing line in temporary debugging code. looks like return_attentions is returning the wrong thing.
Updated https://github.com/AminRezaei0x443/memory-efficient-attention/issues/1 for new commit I'm looking into hacking some of the models in the transformers library to use this library for attention, and I don't see a way to support `output_attentions` yet. This is a flag passed in transformers, where the attention weights are preserved and returned to the user, if it is set. I looked a little at implementing this in the torch backend, and I note the scan() function provides for only a single tensor return value. It seems to me that scan() function would be most clearly replaced by a for loop, but it could also be modified to handle tuples, or return_weights could be handled via accessing nonlocal data in some way instead of returning them through the chunk scanner. I'm also not sure how the output would best be passed to the user. Edit: Draft implementation 01/28 at https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main..... . I ended up extending the scan function for parity between implementations. Edit 2: Turns out it's the postsoftmax attention weights, not the presoftmax attention weights. I've updated this post and the draft implementation for this output: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main.....
Here's the commit hash. In my test, all the correct attention weights are 1.0 . So maybe I'll run the huge pretrained model through my mempickle project that lets it run on low-end systems, using this new code i'm writing, and verify that all the weights are output correctly before opening a pull request. commit 84724e1de4721ea0333d6bdbb91e8bce74fbeac2 (HEAD -> return_weights, origin/return_weights) Author: xloem <0xloem@gmail.com> Date: Sun Jan 30 10:14:37 2022 +0000 return the post-softmax weights rather than the pre-softmax weights
Rather than waiting for GPT-J to download, I just provided enough input to see more attentions, which worked fine. A further crash needs further debugging.
11:18 UTC Looks like I got a mask shape wrong. [reinstalling pkg with change, reinstalling is what's working for me atm for changing things.] 11:19 UTC running test without debugging to quickly scan for crashes [this step is unneccessary] [we saw a movie called 'jolt' and karl liked it a lot. he watched/liked it because of the idea of jolting one's body to stop [spasms]. he considers making a jolt device sometimes. there were also unliked things. the movie also overlaps government mind control, although karl says it could have been done better. there's a time in the movie where [psychiatrist-like-character] says [i need to tell you something, wait]; it is apparently quite important to get such information across, on both sides. you need to wait and quickly hear, and you need to run after them to tell them. lots of other stuff.]
11:22 no crashes, so pdb'ing in to test the outputs since i never had the code compare them. [due to multitasking inhibition, likelypossibly a branch of efficiency inhibition].
11:24 now the attentions match :) now probably will open a pull request with memory-efficient-attention. plan to descfribe it as a draft looking for comments commit 0e48ea180a50d201f107c8c69f9493db78cf813b (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sun Jan 30 11:26:31 2022 +0000 fix for incorrect mask shape
https://github.com/AminRezaei0x443/memory-efficient-attention/pull/2 Provide a flag for the user to receive attention weights This is my draft code for #1. I saw this feature in the transformers library and wanted to implement it here. I'm curious what you think about this feature and implementation. The code is simply slightly instrumented so that the final attention weights can be returned to the user. Tests are augmented to test this use. In utils, the `scan` function is expanded to handle tuples. A change to `dynamic_slice` crept in from dev, to use slices rather than index_slice. I've retained this change because it looks like it would execute faster to me, but it can be removed. Rebased and squashed from 84724e1de4721ea0333d6bdbb91e8bce74fbeac .
Meanwhile, in the transformers library, I've compartmentalised calls to Amin Rezaei's code into the modeling_utils source file, so I can do most of the work whether or not the pull request is accepted. If it's not replied to in a timely manner, I can copy the MIT-licenseed implementation (which is a direct copy from the paper) into huggingface's repository later.
I've realised the return_attentions addition I attempted to make to memory-efficient-transformers may have actually completely countered the memory savings of the research paper, by allocating a matrix sized by queries x keys for the entirety of the execution. If true, then my pull request could be confusing and harmful to the developer. I should re-review the paper to understand how much memory is saved, and whether or not my feature is appropriate in the algorithm. If not, it would simply be disabled in the transformers library if chunking is used.
participants (3)
-
k
-
Undiscussed Horrific Abuse, One Victim & Survivor of
-
Undiscussed Horrific Abuse, Victim & Survivor of