i've forked memory-efficient-attention in an attempt to add a return_weights parameter. i think the torch implementation of this would be simplified by using a for loop rather than a scan function parameterised by a callback. https://github.com/xloem/memory-efficient-attention/commits/return_weights Author: xloem <0xloem@gmail.com> Date: Thu Jan 27 14:50:32 2022 +0000 wip: needs a change so return_weights output is migrated through scan() the reason for this is because transformers has a return_weights configuration, where the pre-softmax weights of attention passes are returned to the user from the library. supporting that means getting inside attention somehow. i experience pressure to cover less expanding work. ideas for reducing the steps for this part include: - simply disabling return_weights in transformers if efficient attention is engaged - writing a transformers-specific implementation of efficient attention but i'll probably open an issue in the repository and plan to move forward on a pull request