Updated https://github.com/AminRezaei0x443/memory-efficient-attention/issues/1 for new commit I'm looking into hacking some of the models in the transformers library to use this library for attention, and I don't see a way to support `output_attentions` yet. This is a flag passed in transformers, where the attention weights are preserved and returned to the user, if it is set. I looked a little at implementing this in the torch backend, and I note the scan() function provides for only a single tensor return value. It seems to me that scan() function would be most clearly replaced by a for loop, but it could also be modified to handle tuples, or return_weights could be handled via accessing nonlocal data in some way instead of returning them through the chunk scanner. I'm also not sure how the output would best be passed to the user. Edit: Draft implementation 01/28 at https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main..... . I ended up extending the scan function for parity between implementations. Edit 2: Turns out it's the postsoftmax attention weights, not the presoftmax attention weights. I've updated this post and the draft implementation for this output: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main.....