import math
thousands_names = ' thousand million billion'.split(' ')
numeral_names = 'zero one two three four five six seven eight nine'.split(' ')
tens_names = 'zero ten twenty thirty forty fifty sixty seventy eighty ninety'.split(' ')
teens_names = 'ten eleven twelve thirteen fourteen fifteen sixteen seventeen eighteen nineteen'.split(' ')

# can we convert between words and numbers
def number_to_word(num):
    if num == 0:
        return 'zero'
    result = ''
    prefix = ''
    suffix = ''
    if num < 0:
        prefix += 'negative '
        num = -num
    places = int(math.log10(num)) + 1
    for digit in range(0, places, 3):
        value = num % 1000
        num //= 1000
        if value == 0:
            continue
        hundred = value // 100
        ten = (value % 100) // 10
        one = value % 10
        part = ''
        if hundred > 0:
            part += numeral_names[hundred] + ' hundred'
        if ten == 1:
            if len(part):
                part += ' '
            part += teens_names[one]
        else:
            if ten > 0:
                if len(part):
                    part += ' '
                part += tens_names[ten]
            if one > 0:
                if len(part):
                    part += ' '
                part += numeral_names[one]
        if digit > 0 and len(part):
            part += ' ' + thousands_names[digit // 3]
        if len(suffix):
            part += ' '
        suffix = part + suffix
    return prefix + suffix


import transformers, torch

class Model(transformers.PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.input_preprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor(config)
        self.decoder = transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder(
            config,
            output_num_channels = config.d_latents,
            output_index_dims = config.max_position_embeddings,
            num_channels = config.d_model,
            qk_channels = config.qk_channels,
            v_channels = config.d_model,
            num_heads = config.num_decoder_heads,
            use_query_residual = False,
            final_project = False,
            trainable_position_encoding_kwargs = dict(
                num_channels = self.input_preprocessor.num_channels,
                index_dims = config.max_position_embeddings
            ),
        )
        self.perceiver = transformers.PerceiverModel(
            config,
            decoder = self.decoder,
            input_preprocessor = self.input_preprocessor,
        )
        self.output_postprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverEmbeddingDecoder(config)

        self.post_init()
    def forward(self, inputs=None, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, labels=None):#, return_dict=None, input_ids=None):
        outputs = self.perceiver(
                inputs=inputs,
                attention_mask=attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=False,#return_dict,
        )

        logits = self.output_postprocessor(
                #outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
                outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
        )

        loss = None
        if labels is not None:
            loss = torch.nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1))
        
        output = (logits,) + outputs[1:] # outputs[2:]
        if loss is None:
            return output
        else:
            return ((loss,) + output)

word2num_config = transformers.PerceiverConfig()
word2num_config.num_decoder_heads = word2num_config.num_cross_attention_heads
print('Constructing model ...')
model = Model(word2num_config)

import torch

def numbers_to_numword_tensors(numbers, batch_size):
    words = [number_to_word(number) for number in numbers]
    maxwordlen = max((len(word) for word in words))
    words = torch.stack([
        torch.stack([
            torch.frombuffer(word.ljust(maxwordlen).encode('iso-8859-1'), dtype=torch.uint8),
            torch.cat([torch.ones(len(word)), torch.zeros(maxwordlen - len(word))])
        ])
        for word in words
    ])
    numbers = [str(number) for number in numbers]
    maxnumlen = max((len(number) for number in numbers))
    numbers = torch.stack([
        torch.stack([
            torch.frombuffer(number.ljust(maxnumlen).encode('iso-8859-1'), dtype=torch.uint8),
            torch.cat([torch.ones(len(number)), torch.zeros(maxnumlen - len(number))])
        ])
        for number in numbers
    ])
    return numbers.view(len(numbers) // batch_size, 2, -1), words.view(len(numbers) // batch_size, 2, -1)

total = 2000
batch_size = total // 16
tt_split = batch_size #len(data) // 16
total = total - (total % batch_size)
print('Generating data ...')
all_numbers, all_words = numbers_to_numword_tensors(torch.randperm(total), batch_size)
train_numbers = all_numbers[:-tt_split]
test_numbers = all_numbers[-tt_split:]
train_words = all_words[:-tt_split]
test_words = all_words[-tt_split:]

# so on one end of the model, we take or output the number
# on the other end, we output or take the word

print('Starting training ...')
model.train()
optim = torch.optim.SGD(model.parameters(), lr=0.0001)
for idx, (number_batch, word_batch) in enumerate(zip(train_numbers, train_words)):
    optim.zero_grad()

    number_data = number_batch[:,0]
    number_mask = number_batch[:,1]
    word_data = word_batch[:,0]
    word_mask = word_batch[:,1]

    labels = number_data.clone()
    labels[number_mask == 0] = -100
    inputs = word_data
    attention_mask = word_mask

    loss, output = model(inputs=inputs, attention_mask=attention_mask, labels=labels)
    print(f'{idx} {loss} {repr(output.numpy().tobytes())} ', end='\r', flush=True)
    loss.backward()
    optim.step()
