i used HuggingFace's LlamaModel which is just the Llama architecture. i ignored the model's embedding map and passed my own embeddings which i generated with a trainable linear module from the input model weights and data. similarly, i used a trainable linear layer for the output to generate only 1 float per pass and used it in a causal manner. (you can train on entire sequences, and then infer 1 float at a time). I've trimmed the below code for conciseness so it may have an inconsistency if i made a trimming mistake. import os import torch, transformers class make_one_transformer(torch.nn.Module): def __init__(self, name, input_size, output_size=1, complexity=None, load=True): super().__init__() self.name = name self.input_size = input_size self.output_size = output_size if complexity is None: complexity = max(output_size, input_size*16) # ratios from the default layers = max(complexity // 1024, 1) hidden_size = complexity // 8 // (2*layers) * 2*layers intermediate_size = int(complexity // 2.9) self.config = transformers.LlamaConfig( num_attention_heads=layers, num_hidden_layers=layers, num_key_value_heads=layers, vocab_size=output_size, max_position_embeddings=input_size, hidden_size=hidden_size, intermediate_size=intermediate_size, ) self.model = transformers.LlamaModel(self.config) self.embeddings = torch.nn.Linear(in_features = 1, out_features = self.config.hidden_size) self.output_head = torch.nn.Linear(self.config.hidden_size, self.output_size, bias=False) if load and os.path.exists(f'{name}.pt'): state_dict = torch.load(f'{name}.pt') self.iteration = state_dict.pop('iteration') self.load_state_dict(state_dict) else: self.iteration = 0 def forward(self, input): # possible linear layer to map input to hidden size inputs_embeds = self.embeddings(input[...,None]) output = self.model(inputs_embeds=inputs_embeds).last_hidden_state return self.output_head(output) def generate(self, input, length): # not totally sure about what past key vals needs, but it looks like you could pass it straight from outputs and debug for idx in range(length): inputs_embeds = self.embeddings(input[...,None]) logits = self.model(inputs_embeds=inputs_embeds).last_hidden_state output = self.output_head(logits) # since we do have an output size, we'll want an lm_head input = torch.cat(input, output, dim=-1) return input[...,-length:] # this model no lm_head ! # the above joke retained for humor was made before output_head was added