Rebased for PR. https://github.com/AminRezaei0x443/memory-efficient-attention/pull/4 contains excess comments demonstrating craziness. commit ab6170cedec07a6d7554916c859d36329f1a4125 (HEAD -> sparse-jax-masks, origin/sparse-jax-masks) Author: xloem <0xloem@gmail.com> Date: Wed Feb 2 10:16:24 2022 +0000 feature: calc_fn and sparse broadcasting for bias, mask, weights rebased from 66706d510f78dfff682aa041a5614165de4d5c06 These are the missing commits: commit 66706d510f78dfff682aa041a5614165de4d5c06 Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 14:04:15 2022 +0000 wrapped newly long function signatures commit 20f2ccdd5e0122d2bceb063719047952d565c705 (origin/callbacks, callbacks) Author: xloem <0xloem@gmail.com> Date: Thu Feb 3 13:56:09 2022 +0000 check for unexpected mask and bias shapes