26 Jan
2022
26 Jan
'22
6:29 p.m.
During making these small tests I discovered that the paper's specific jax implementation is very fast. Noticeably faster than my example that didn't use jax.lax, on my raspberry pi cpu. Huggingface's perceiver model implementation is not in jax; rather, it's in torch. There are implementations of this memory efficient attention paper in pytorch, at https://github.com/moskomule/memory_efficient_attention.pytorch and https://github.com/AminRezaei0x443/memory-efficient-attention . I'd like to try them both out on the same example data [might only try one out though].