Re: [ot][spam][crazy][data] transformer model 'attention' improvement
"The straight-forward implementation of the attention operation above requires us to first compute and remember s_i for all i, leading to a O(n) time and memory complexity for each query. Transformers use _self-attention_, which issues a separate query for each position in the sequence, so the overall time and space complexity is O(n^2)." s_i is the dot product between the keys and a single query. So "s" in the paper would be "attn_weights" in the source. I think. It's looking likely that my contribution was in error, but it's seeming very hard for me to cognitively verify this.
[my touchscreen (no mouse atm) also just stopped responding, and i'm still using x for email]
[I rebooted x but it didn't boot back up and i'm sending this from a phone.] The next step for me here is to form logical verification that the implementation of the feature completely counters the memory savings. It seems likely I can do this given the rest of these threads, but if I don't it makes sense to basically assume the verification is likely true, and move forward by closing the issues and pull request (they could be replaced by one noting the reason for the "feature") and changing transformers to disable attention output if chunking is engaged. Commented on the PR:
I've converted this to a draft because I think this "feature" may be ridiculous, making the memory usage O(n^2) again by retaining weights across chunks
I'm pretty sure the "n" in time * memory = O(n^2) relates to the key/query count, which are the same in self-attention. The batch dimension is used by the user to set memory bounds.
[and the phrase "self-attention" when applied to an exhaustive "cross-product" of information flow parts is reminiscent of the forgotten automation norm of having the system attend to its own effectiveness. missing in transformer models.] Lemme try that paragraph again: "The straight-forward implementation of the attention operation above requires us to first compute and remember s_i for all i, leading to a O(n) time and memory complexity for each query. Transformers use _self-attention_, which issues a separate query for each position in the sequence, so the overall time and space complexity is O(n^2)." s_i = dot(q_i,k) s_i' = softmax(s_i) attention = sum_i(v_i * s_i) so originally three ndvectors come in: q, k, v. They q and k have a dimension of length n. To find the dot of each q_i with k, an ndmatrix s is formed that has a dimension of both q and k. so it's the q,k dimension of s (attn_weights in the source) that makes it O(n^2). I think I can move that information over to the source to show myself that my addition does indeed prevent the memory savings.
in memory-efficient-attention/attention_torch.py, around line 15, attn_weights has dimension qhk. so it's holding the chunked O(n^2)->O(n) it's biased and masked and maxed, and then turned into exp_weights which has the same size but different values. exp_weights is multipled with values to make exp_values, dotting the value dimenion from values with the key dimension from exp_weights, with an output of dimension qhf, O(n). in the original implementaiton, exp_weights and attn_weights are then discarded, leaving the memory at O(n). my addition of saving all the exp_weights or attn_weights does indeed prevent the improvement. whoo! i need to figure out how to retain and continue to act on my skills and knowledge, pretty badly.
closed my issue and pr with I see for this to have any meaning it would need to be implemented as either a callback or only retaining a user-provided range, and I don't imagine that use cases provided for that yet. Apologies for the noise and thank you for so much great work.
next: transformers again
commit 1aa9022c6cd96a93e3c83801ce4fddb5b650ba45 (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sun Jan 30 21:41:34 2022 +0000 draft of chunking without returning attentions. currently fails due to passing wrong dimensions or data.
commit 1556018635522a442745afdc2094b4e58df24670 (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sun Jan 30 22:09:54 2022 +0000 unexpected crashes resolved for gptj torch, output unreviewed
doing some work on getting the current state of transformers code i have working again with my test perceiver model that converts numbers. these are my notes on the correct dimension shapes, for the code that ostensibly worked. i plan to compare these with the broken commit below them. correct context layer shape is 1,256,8,20 attention_probs.shape = 1,8,256,96 values.shape = 1,8,96,20 queries.shape = 1,8,256,32 keys.shape = 1,8,96,32 commit ca60cd579c82191d4e6696534af32e96b850015e (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Sun Jan 30 23:17:41 2022 +0000 commented out old perceiver code and drafted a call of the new attentions function that does both chunked and nonchunked. currently crashes due to dimension error.
I am suffering psychologically too much associated with this, and need time to heal. I have opened a pull request at https://github.com/huggingface/transformers/pull/15425 with my current work. It should be possible to finetune the cutting-edge GPTJ-6B model more effectively on lower end systems with this change. The chunking configuration parameters need to be set. I have not verified that the GPTJ output is correct. The model is large for my raspberry pi, even when used from disk.
Amin Rezaei commented on my work on their github, and pointed out that the paper advises the technique is only useful on models with incredibly large input data sizes. Not any of the ones I added it to. Briefly thinking about that, it could be because of the size of the O(n^2) data. For example, GPT-J has a total model size of 22GB or so, and is trained to predict tokens well from text of up to 2k tokens long. 2k^2 is about 4M floats, which is much smaller than the total model size. However, when training a model, a larger algebra graph can be allocated for each float, in order to calculate and use the gradients for backpropagation. Running a test to see what the memory requirements really are makes sense here, since a usable implementation is readily available now. Or at least finding and comprehending the text in the paper where it says the expected sizes of usefulness. It shows how off in wide field I am. But it was also a great opportunity to work and make something near these powerful things.
i'm working on below extant issue atm also huggingface replied to the PR i made when i was losin' it, and mentioned two other efficient attention implementations; they looked approximation-based. also they said their repo is specifically anti-DRY. which is not something anybody expects to hear. there's at least one fork of it though. commit 172ae5d668bec9180516e2238f195b56d11a9799 (HEAD -> memory-efficient-attention, xloem/memory-efficient-attention) Author: xloem <0xloem@gmail.com> Date: Tue Feb 1 20:47:43 2022 +0000 removed cruft and added memory-efficient-attention dependency. a remaining issue exists where masks and biases still allocate O(n^2) memory.
re the masks and biases, basically the chunking code assumes they are dense matrices, but by changing the chunking code you can pass only the data needed. i'm presently doing that. it may end up that the optimization is not reasonable on models that store a dense mask or bias as an on-disk weight.
oops! it may be that when masks and biases are expanded to be dense no additional memory is actually allocated. uhhh !
arright! this is hard! maybe we can expand some test tensor to be really big and see if memory allocation changes
- torch tensors do not allocate new memory when expanded, and are documented as views that have that property next: jax tensors, and model logic
so far the ways to make expanded or repeated jax tensors have all made copies for me. a test froze up this system for hours ;p be well human race
Good morning, spamthread. I commented on the PR. I believe the DRY concern relates to researchers being able to quickly review implementations without having to switch files, unsure. Here's the PR log: 2 days ago, xloem: # What does this PR do? This begins the implementation of a central `attention()` function in modeling_utils.py that calls out to https://github.com/AminRezaei0x443/memory-efficient-attention if configuration parameters are set, to allocate configurably down to O(n) memory rather than O(n^2) memory at the expense of parallel execution. I'm afraid the new memory-efficient-attention still needs to be added as a dependency and some development cruft removed from the source. I believe it is important to reuse existing projects, so that people's work can be more effective and valued, but I also believe memory-efficient-attention code is MIT licensed if copying it is preferred. The GPTJ and Perceiver models are altered to call out to the new attention function. Working on this has been very hard for me, so I am contributing what I have now. If others have better work on this, feel free to accept them before mine. - [ ] I have commented out the rest of the PR form here, to return to as I find capacity. -- 2 days ago, LysandreJik:
Hey @xloem, thanks a lot for your hard work on this! It is cool to support the attention mechanism as visible in https://github.com/AminRezaei0x443/memory-efficient-attention. However, the `transformers` library does not really work with central components to be shared among many models, so we do not design layers in the `modeling_utils.py` file.
This comes from the following two pieces of the "Why should I/shouldn't I use Transformers" from the README:
4. Easily customize a model or an example to your needs:
* We provide examples for each architecture to reproduce the results published by its original authors. * **Model internals are exposed as consistently as possible.** * **Model files can be used independently of the library for quick experiments.**
and
This library is not a modular toolbox of building blocks for neural nets. The code in the model files is not refactored with additional abstractions on purpose so that researchers can quickly iterate on each of the models without diving into additional abstractions/files.
You'll see that we have other implementations of efficient attention mechanisms spread across the codebase, and each of them are linked to a single model. Recent examples of this are YOSO and Nyströmformer, which were released in the v4.16 released last week.
cc @sgugger @patrickvonplaten for knowledge
-- 20 hours ago, patrickvonplaten: (this is a name i've seen before and recognise as a community member, didn't know they worked for huggingface)
@xloem, would you be interested in adding this new attention mechanism simply as a new model - linked to its official paper: https://arxiv.org/abs/2112.05682v2 (maybe called something like `O1Transformer` ?)
-- 13 hours ago, xloem (me): Thanks so much for your replies and advice. I hadn't noticed the readme line. Do you know of any similar libraries to transformers, that do indeed collect building blocks? The transformers library is far smaller than the torch and flax libraries that models already depend on, but I imagine you know the experience of research work better than I do to make that call. I've been finding a little more energy to work on this. The owner of the dependency repo found there's still an outstanding issue that masks and biases are built O(n^2), so more work is needed. -- 7 minutes ago, xloem (me): Apologies for submitting this PR in such a poor state. Unlike YOSO and Nystromformer, this implementation is exact. The output is theoretically idempotent with that generated by use of the softmax function. Although the research could be implemented as a standalone module, I'm really hoping to actually help users on low-memory systems use the code to fine tune current pretrained models. @patrickvonplaten how does that hope settle with you? After rereading the concerns, my plan is to move the code into the model files to provide for the statements in the readme. @LysandreJik, does this sound reasonable? I also try to stick to just one sourcefile when developing and learning. My usual approach is to open a python interpreter and do `print(inspect.getsource(modulefunc))` to quickly see abstraction implementations, which is indeed quite unideal. Abstraction can _really_ empower a domain: research and development accelerates when abstraction is made a norm.
so, torch tensors are views, but jax tensors are copies. - my current work was torch only so it is << O(n^2) if and only if the passed matrices are not full and dense - the jax code in memorty-efficient-attention has a bug, it can't be <<O(n^2) if a mask or bias is passed I already drafted a fix for memory-efficient-attention before questioning if it was needed, so I'll see if I can test and contribute it.
https://github.com/AminRezaei0x443/memory-efficient-attention/pull/4 commit 1e45f724d55c938f991a483fc4ca9a4ac413b981
now let's review perceiver and/or gpt-j and see if the masks and biases are O(n)able it's 2022-02-02 10:31 UTC . i wrote something like this in a state of mind: a guess is that existing workers also value grassrootsness in their hearts, and could be organising the research such that people who demonstrate having more time than money, are the people who can use it.
- in perceiver, the user-provided attention vector is expanded with 1-length dimensions and passed on. so perceiver has an O(n) attention mask. i didn't note a model-associated bias. my code generates a bias to accommodate feature matching between the two codebases, which will need an improvement if kept. it's 10:46 UTC. gptj next.
- gptj uses a pregenerated constant causal mask that is O(n^2). since it is simply a constant function of sequence index it could be made via a callback or inside a loop.
AminRezaei commented on my PR and made some requests I only just reviewed. I'm guessing that's what's important is resolving top bug below. Expecting then it to be reasonably easy to meet the requests. The implementation choice I made, of a general callback() function provides for users to hack access to the raw attentions if the function isn't jitted. But maybe it is not the way to go, unsure. Hadn't discussed it on the pr before. github/xloem/memory-efficient-attention most recent at the top commit 57d63f6b78063142d978be547edab3531c5ae24f (HEAD -> callbacks, origin/callbacks) Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 23:08:27 2022 +0000 changed the implementation to use a single callback with optional pure data. jits now, but tests are failing, which is the next issue to address. commit df7accf0a18a5190e657371128d290a1c7562d37 Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 21:59:42 2022 +0000 fixed dimension errors in tests. however, jax.jit appears to refuse to compile callback arguments, so a different approach may be appropriate commit ee0a939d906fb5a9e1a4470b0e6de313345e999b Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 15:14:47 2022 +0000 working on adding mask/bias chunk callbacks. presently have dimension error that throws when tests are run. commit 1e45f724d55c938f991a483fc4ca9a4ac413b981 (origin/sparse-jax-masks, sparse-jax-masks) Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 10:16:24 2022 +0000 bugfix: masks and biases can now be passed with O(n) shapes
Appts today. Below includes implementation of the mask and bias features as the owner requested. Latest at top. I still have two changes to make: - modify if conditions i added to handle all cases, including unworkable ones - fix bug when chunk size is not a factor of total size. existing code is written to handlee this but has an error.lacking test coverage. commit a3c18ed22088272bbf09bc6f5308e4fd3c1e1add (HEAD -> callbacks, origin/callbacks) Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 11:52:19 2022 +0000 add mask_calc_fn, bias_calc_fn, added and moved things to support this commit c20b1a3ea22528acd4620445400718c399fed51c Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 10:35:09 2022 +0000 rename chunk_callback to weights_calc_fn for similarity to pr comment commit c4c72b592ab6d2aa0d04456705b5abd1cd83b7a0 Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 10:15:11 2022 +0000 consolidate datasets() into data() as mentioned in PR comment. also reduced count. commit acc0a8afb5208eec14f090c3008790b21971bd85 Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 01:56:52 2022 +0000 bugfix: callback was not being tested. it now passes. commit e36dd43bb1efd1f5e650f50c9573e39b7d0027a7 Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 00:22:38 2022 +0000 tests pass when random data is generated with the same distribution
Rebased for PR. https://github.com/AminRezaei0x443/memory-efficient-attention/pull/4 contains excess comments demonstrating craziness. commit ab6170cedec07a6d7554916c859d36329f1a4125 (HEAD -> sparse-jax-masks, origin/sparse-jax-masks) Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 10:16:24 2022 +0000 feature: calc_fn and sparse broadcasting for bias, mask, weights rebased from 66706d510f78dfff682aa041a5614165de4d5c06 These are the missing commits: commit 66706d510f78dfff682aa041a5614165de4d5c06 Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 14:04:15 2022 +0000 wrapped newly long function signatures commit 20f2ccdd5e0122d2bceb063719047952d565c705 (origin/callbacks, callbacks) Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 13:56:09 2022 +0000 check for unexpected mask and bias shapes
So the generalisation improvements for this got into that repository, but I never moved forward on my pull request to huggingface. I wasn't sure how to make the work reusable in the face of the preference against abstraction or generalisation of model components. Still, copy-pasting is a real and useful way to reuse things, and people need things to copy-paste. Some time ago I stumbled on _another_ approach to making attention memory efficient without changing its underlying structure, top-k attention: https://github.com/ag1988/top_k_attention . Basically it only performs k of the most impactful multiplies. It sounds like this approach could be easily mutated to use a dynamic 'k' that preserves precision. That might be near the ideas of pruning or distillation, too. I also stumbled on just one library that collects together efficient attention improvements: https://github.com/idiap/fast-transformers . Last updated mid 2021, tests currently not passing. pytorch-only. Already has a top-k attention implementation: https://fast-transformers.github.io/api_docs/fast_transformers/attention/exa... . I wonder if there are other such libraries somewhere. That fast-transformers repo might be a better place for unifying attention improvements. It doesn't load pretrained models, but could be improved to. Still might seem most relevant to clean up the perceiver and gpt-j attention improvements and either make a fork or submit a PR that uses them. I kind of wanted to try running gpt-j on long text, and see how the actual memory usage went. This might mean shelling into a remote server or using colab to simplify the large memory allocation involved, and waiting some time for the large model to download.
participants (3)
-
Undiscussed Horrific Abuse, One Victim & Survivor of Many
-
Undiscussed Horrific Abuse, One Victim of Many
-
Undiscussed Horrific Abuse, Victim & Survivor of