[ot][spam][crazy]
Undescribed Horrific Abuse, One Victim & Survivor of Many
gmkarl at gmail.com
Tue Oct 3 16:53:16 PDT 2023
i used HuggingFace's LlamaModel which is just the Llama architecture.
i ignored the model's embedding map and passed my own embeddings which
i generated with a trainable linear module from the input model
weights and data.
similarly, i used a trainable linear layer for the output to generate
only 1 float per pass and used it in a causal manner. (you can train
on entire sequences, and then infer 1 float at a time).
I've trimmed the below code for conciseness so it may have an
inconsistency if i made a trimming mistake.
import os
import torch, transformers
class make_one_transformer(torch.nn.Module):
def __init__(self, name, input_size, output_size=1,
complexity=None, load=True):
super().__init__()
self.name = name
self.input_size = input_size
self.output_size = output_size
if complexity is None:
complexity = max(output_size, input_size*16)
# ratios from the default
layers = max(complexity // 1024, 1)
hidden_size = complexity // 8 // (2*layers) * 2*layers
intermediate_size = int(complexity // 2.9)
self.config = transformers.LlamaConfig(
num_attention_heads=layers,
num_hidden_layers=layers,
num_key_value_heads=layers,
vocab_size=output_size,
max_position_embeddings=input_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)
self.model = transformers.LlamaModel(self.config)
self.embeddings = torch.nn.Linear(in_features = 1,
out_features = self.config.hidden_size)
self.output_head = torch.nn.Linear(self.config.hidden_size,
self.output_size, bias=False)
if load and os.path.exists(f'{name}.pt'):
state_dict = torch.load(f'{name}.pt')
self.iteration = state_dict.pop('iteration')
self.load_state_dict(state_dict)
else:
self.iteration = 0
def forward(self, input):
# possible linear layer to map input to hidden size
inputs_embeds = self.embeddings(input[...,None])
output = self.model(inputs_embeds=inputs_embeds).last_hidden_state
return self.output_head(output)
def generate(self, input, length):
# not totally sure about what past key vals needs, but it
looks like you could pass it straight from outputs and debug
for idx in range(length):
inputs_embeds = self.embeddings(input[...,None])
logits = self.model(inputs_embeds=inputs_embeds).last_hidden_state
output = self.output_head(logits) # since we do have an
output size, we'll want an lm_head
input = torch.cat(input, output, dim=-1)
return input[...,-length:]
# this model no lm_head !
# the above joke retained for humor was made before output_head was added
More information about the cypherpunks
mailing list