11: @functools.partial(jax.checkpoint, prevent_cse=False) I think checkpointing relates to limiting memory used by gradient backpropagation during training of a model. I think it means the gradients can be recalculated for this function when needed, by storing its arguments instead of each gradient.