25 Jan
2022
25 Jan
'22
11:26 a.m.
11: @functools.partial(jax.checkpoint, prevent_cse=False) I think checkpointing relates to limiting memory used by gradient backpropagation during training of a model. I think it means the gradients can be recalculated for this function when needed, by storing its arguments instead of each gradient.