i wrote down some of the weight names to help me think. the haiku weights are in a nested structure and are only named based on their neural network model type. so matching them will mean reviewing more than their names, maybe their order of construction and use in google's source compared to the huggingface source def haiku2torch(haiku_params): haiku_params = {**haiku_params} state_dict = {} state_dict['perceiver.input_preprocessor.embeddings.weight'] = haiku_params.pop('embed') state_dict['perceiver.input_preprocessor.position_embeddings.weight'] = haiku_params.pop('trainable_position_encoding') haiku_params['perceiver_encoder/~/cross_attention/attention/linear']['w'] ? state_dict['perceiver.encoder.cross_attention.attention.self.layernorm1.weight'] state_dict['perceiver.encoder.cross_attention.attention.self.layernorm1.bias'] state_dict['perceiver.encoder.cross_attention.attention.self.layernorm2.weight'] state_dict['perceiver.encoder.cross_attention.attention.self.layernorm2.bias'] state_dict['perceiver.encoder.cross_attention.attention.self.query.weight'] state_dict['perceiver.encoder.cross_attention.attention.self.query.bias'] state_dict['perceiver.encoder.cross_attention.attention.self.key.weight'] state_dict['perceiver.encoder.cross_attention.attention.self.key.bias'] state_dict['perceiver.encoder.cross_attention.attention.self.value.weight'] state_dict['perceiver.encoder.cross_attention.attention.self.value.bias'] state_dict['perceiver.encoder.cross_attention.attention.output.dense.weight'] state_dict['perceiver.encoder.cross_attention.attention.output.dense.bias'] state_dict['perceiver.encoder.cross_attention.attention.layernorm.weight'] state_dict['perceiver.encoder.cross_attention.attention.layernorm.bias'] state_dict['perceiver.encoder.cross_attention.attention.mlp.dense1.weight'] state_dict['perceiver.encoder.cross_attention.attention.mlp.dense1.bias'] state_dict['perceiver.encoder.cross_attention.attention.mlp.dense2.weight'] state_dict['perceiver.encoder.cross_attention.attention.mlp.dense2.bias'] state_dict['perceiver.embeddings.latents'] ?