[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 03:26:38 PST 2022


11:  @functools.partial(jax.checkpoint, prevent_cse=False)

I think checkpointing relates to limiting memory used by gradient
backpropagation during training of a model.  I think it means the
gradients can be recalculated for this function when needed, by
storing its arguments instead of each gradient.


More information about the cypherpunks mailing list