30 Jan
2022
30 Jan
'22
6:50 p.m.
in memory-efficient-attention/attention_torch.py, around line 15, attn_weights has dimension qhk. so it's holding the chunked O(n^2)->O(n) it's biased and masked and maxed, and then turned into exp_weights which has the same size but different values. exp_weights is multipled with values to make exp_values, dotting the value dimenion from values with the key dimension from exp_weights, with an output of dimension qhf, O(n). in the original implementaiton, exp_weights and attn_weights are then discarded, leaving the memory at O(n). my addition of saving all the exp_weights or attn_weights does indeed prevent the improvement. whoo! i need to figure out how to retain and continue to act on my skills and knowledge, pretty badly.