Figuring out Step 3b: Decoding decoder_outputs = self.decoder( query=decoder_query, # looks like just trainable parameters z=sequence_output, # this are the encoded hidden states query_mask=extended_attention_mask, # huh the mask is maybe applied to the query ) # PerceiverBasicDecoder.forward() layer_outputs = decoding_cross_attention( query, attention_mask=query_mask, inputs = z ) logits = final_layer(layer_outputs[0]) return logits # __init__ decoding_cross_attention = PerceiverLayer( is_cross_attention = True, kv_dim = d_latents, **kwparams # the rest of the dimensionality configuration is taken from the call constructing the BasicDecoder ) final_layer = nn.Linear(num_channels, output_num_channels) So, basically the decoder is a cross attention layer just like the first layer in the encoder. The "query" is used for the "hidden states" parameter, and the "inputs" are ferried along to the "inputs" parameter, as if it were an encoder. Just like the encoder, trainable parameters are used for the auxiliary data, and the "inputs" are passed along as the "inputs" data. It would be helpful for me at some point to put time into learning the kqv terminology inside the attention layers. That would make these layers less confusing. Step 3: Encoded hidden state -> PerceiverBasicDecoder cross attention with embedding parameters -> PerceiverBasicDecoder linear redimensioning -> Decoder outputs