2 Feb
2022
2 Feb
'22
9:58 a.m.
so, torch tensors are views, but jax tensors are copies. - my current work was torch only so it is << O(n^2) if and only if the passed matrices are not full and dense - the jax code in memorty-efficient-attention has a bug, it can't be <<O(n^2) if a mask or bias is passed I already drafted a fix for memory-efficient-attention before questioning if it was needed, so I'll see if I can test and contribute it.