so maybe

pip3 install https://github.com/xloem/GPTb

from GPTB import GPTBLMHeadModel
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
config = GPT2Config() # pass settings or you can pull the config from some pretrained model and tweak it
config.rebias = True # additional parameter for GPTB
model = GPTBLMHeadModel(config)

model.train()
model.zero_grad()
optimizer.zero_grad()
past_hidden_states = None
past_logits = None
for batch_of_tokens in data: # shape of batch_of_tokens is (batchsize, 1)
    if past_logits is not None:
        loss = torch.nn.functional.cross_entropy(past_logits.view(-1, vocab_size), batch_of_tokens.view(-1))
        loss.backward()
        optimizer.step()
        model.zero_grad()
        optimizer.zero_grad()
    past_logits, past_hidden_states, extra = model(batch_of_tokens, past_hidden_states=past_hidden_states)