try:
    import patch_pytorch
except:
    print('failed to find patch_pytorch, may crash on rasbpi. github/xloem/mempickle')
import math
thousands_names = ' thousand million billion'.split(' ')
numeral_names = 'zero one two three four five six seven eight nine'.split(' ')
tens_names = 'zero ten twenty thirty forty fifty sixty seventy eighty ninety'.split(' ')
teens_names = 'ten eleven twelve thirteen fourteen fifteen sixteen seventeen eighteen nineteen'.split(' ')

# can we convert between words and numbers
def number_to_word(num):
    num = int(num)
    if num == 0:
        return 'zero'
    result = ''
    prefix = ''
    suffix = ''
    if num < 0:
        prefix += 'negative '
        num = -num
    places = int(math.log10(num)) + 1
    for digit in range(0, places, 3):
        value = num % 1000
        num //= 1000
        if value == 0:
            continue
        hundred = value // 100
        ten = (value % 100) // 10
        one = value % 10
        part = ''
        if hundred > 0:
            part += numeral_names[hundred] + ' hundred'
        if ten == 1:
            if len(part):
                part += ' '
            part += teens_names[one]
        else:
            if ten > 0:
                if len(part):
                    part += ' '
                part += tens_names[ten]
            if one > 0:
                if len(part):
                    part += ' '
                part += numeral_names[one]
        if digit > 0 and len(part):
            part += ' ' + thousands_names[digit // 3]
        if len(suffix):
            part += ' '
        suffix = part + suffix
    return prefix + suffix


import transformers, torch

class Model(transformers.PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.input_preprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor(config)
        self.decoder = transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder(
            config,
            output_num_channels = config.d_latents,
            output_index_dims = config.max_position_embeddings,
            num_channels = config.d_model,
            qk_channels = config.qk_channels,
            v_channels = config.d_model,
            num_heads = config.num_decoder_heads,
            use_query_residual = False,
            final_project = False,
            trainable_position_encoding_kwargs = dict(
                num_channels = self.input_preprocessor.num_channels,
                index_dims = config.max_position_embeddings
            ),
        )
        self.perceiver = transformers.PerceiverModel(
            config,
            decoder = self.decoder,
            input_preprocessor = self.input_preprocessor,
        )
        self.output_postprocessor = transformers.models.perceiver.modeling_perceiver.PerceiverEmbeddingDecoder(config)

        self.post_init()
    def forward(self, inputs=None, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, labels=None):#, return_dict=None, input_ids=None):
        outputs = self.perceiver(
                inputs=inputs,
                attention_mask=attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=False,#return_dict,
        )

        logits = self.output_postprocessor(
                #outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
                outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
        )

        loss = None
        if labels is not None:
            loss = torch.nn.CrossEntropyLoss()(logits.view(-1, self.config.vocab_size), labels.view(-1))
        
        output = (logits,) + outputs[1:] # outputs[2:]
        if loss is None:
            return output
        else:
            return ((loss,) + output)

config = transformers.PerceiverConfig()
config.num_decoder_heads = config.num_cross_attention_heads
config.num_self_attends_per_block = 3#6#13#26#6
config.max_position_embeddings = 96
config.d_model = 96#384#768#128
config.d_latents = 160#640#1280#256
config.vocab_size = 256
config.qk_channels = 256#8 * 32
config.v_channels = config.d_latents
print('Constructing model ...', flush=True)
model = Model.from_pretrained('words2nums')
config = model.config
#config.chunk_size_query = 16
#config.chunk_size_key = 16
## maybe: per-process vmem for low-end systems; https://github.com/xloem/mempickle
#import pytorch_tensormap
#mmap_params = pytorch_tensormap.PyTorchMap()
#mmap_params.write(model.state_dict())
#model.load_state_dict(mmap_params.read(writeable = True))

import torch

model.eval()
while True:
    word = input('input word number: ')
    word_data = torch.frombuffer(bytearray(word, 'iso-8859-1').ljust(config.max_position_embeddings, b'\x9c'), dtype=torch.int8).to(torch.long).view((1, config.max_position_embeddings))
    word_mask = (word_data != -100).to(torch.float32)
    word_data[word_data == -100] = 32
    logits, output = model(inputs=word_data, attention_mask=word_mask)
    numbers = logits[0].detach().argmax(dim=1).to(torch.uint8).cpu().numpy().tobytes()
    numbers = numbers[:numbers.find(b'.')]
    print(numbers)

