AminRezaei commented on my PR and made some requests I only just reviewed. I'm guessing that's what's important is resolving top bug below. Expecting then it to be reasonably easy to meet the requests. The implementation choice I made, of a general callback() function provides for users to hack access to the raw attentions if the function isn't jitted. But maybe it is not the way to go, unsure. Hadn't discussed it on the pr before. github/xloem/memory-efficient-attention most recent at the top commit 57d63f6b78063142d978be547edab3531c5ae24f (HEAD -> callbacks, origin/callbacks) Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 23:08:27 2022 +0000 changed the implementation to use a single callback with optional pure data. jits now, but tests are failing, which is the next issue to address. commit df7accf0a18a5190e657371128d290a1c7562d37 Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 21:59:42 2022 +0000 fixed dimension errors in tests. however, jax.jit appears to refuse to compile callback arguments, so a different approach may be appropriate commit ee0a939d906fb5a9e1a4470b0e6de313345e999b Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 15:14:47 2022 +0000 working on adding mask/bias chunk callbacks. presently have dimension error that throws when tests are run. commit 1e45f724d55c938f991a483fc4ca9a4ac413b981 (origin/sparse-jax-masks, sparse-jax-masks) Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 10:16:24 2022 +0000 bugfix: masks and biases can now be passed with O(n) shapes