so, the jax/flax hugging face t5 output doesn't include loss the way the huggingface t5 documentation implies. the pytorch output does. here's the loss from the huggingface pytorch t5 code. for me this is line 1643 of my old checkout of github.com/huggingface/transformers src/transformers/models/modeling_t5.py: if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb... CrossEntropyLoss is a very common function in transformer models that takes a vector of logs of odds of options and which option is correct and returns how close they are to selecting the correct one. if you look it up it does something like take the log of them all, the different of one, and divide by the sum, or something not too complex and relatively intuitive.