import transformers
import torch

DEFAULT_PIPELINE = transformers.pipeline('text-generation', 'bigscience/bloomz-560m', dtype=torch.bfloat16) # bloomz: 560m, 1.1b, 1.7b, 3b, 7.1b, 176b ; mt0: small (300m), base, large, xl, xxl

# simplified would go here
# it associates text + keyvals together, and can continue with text

class Generated:
    def __init__(self, text, token_ids = None, outputs = None, pipeline = None, model = None, tokenizer = None, **kwparams):
        if pipeline is None:
            pipeline = DEFAULT_PIPELINE
        if model is None:
            model = pipeline.model
        if tokenizer is None:
            tokenizer = pipeline.tokenizer
        self.model = model
        self.tokenizer = tokenizer
        self.kwparams = kwparams
        self._text = text
        if token_ids is None:
            token_ids = self.tokenizer(text, return_tensors='pt').input_ids # , attention_mask
            assert token_ids.shape[0] == 1
            token_ids = token_ids[0]
        else:
            assert len(token_ids.shape) == 1
        self._token_ids = token_ids
        if outputs is None:
            self.model.eval()
            outputs = self.model(self._token_ids[None], use_cache=True, **self.kwparams)#, return_dict=True, output_hidden_states=True)
        assert len(outputs.logits.shape) == 3
        assert outputs.logits.shape[0] == 1
        #self._outputs = outputs
        self._logits = outputs.logits[0,-1].detach().clone()
        #self._past_key_values = outputs.past_key_values
    def __str__(self):
        return self._text
    def next_id_probs(self):
        #logits = self._outputs.logits[0,-1]
        probs, ids = self._logits.softmax(dim=-1).detach().to(torch.bfloat16).sort(descending=True)
        for idx in range(len(probs)):
            yield ids[idx], probs[idx]
    def next_str_probs(self):
        offset = len(str(self))
        for id, prob in self.next_id_probs():
            yield self.tokenizer.decode(torch.cat((self._token_ids, id[...,None]), dim=-1))[offset:], prob
    def next_obj_probs(self, **kwparams):
        for id, prob in self.next_id_probs():
            suffix_ids = id[...,None]
            token_ids = torch.cat((self._token_ids, suffix_ids), dim=-1)
            suffix = self.tokenizer.decode(token_ids)[len(str(self)):]
            obj = self(suffix, suffix_ids, **kwparams)
            yield obj, prob

    def __call__(self, suffix, suffix_ids = None, **kwparams):
        kwparams = {**self.kwparams, **kwparams}
        assert bool(suffix) or bool(suffix_ids) # for now, could return self or pass str(self) with other kwparams if both are falsey
        if suffix_ids:
            token_ids = torch.cat((self._token_ids, suffix_ids))
        if suffix:
            text = str(self) + suffix
            if suffix_ids:
                #assert self.tokenizer.encode(text) == token_ids # disabled: the same text can have multiple encodings
                assert self.tokenizer.decode(token_ids) == text
            else:
                token_ids = self.tokenizer(text, return_tensors='pt').input_ids
                assert token_ids.shape[0] == 1
                token_ids = token_ids[0]
                suffix_ids = token_ids[len(self._token_ids):]
        self.model.eval()
        #outputs = self.model(suffix_ids[None], past_key_values=self._past_key_values, use_cache=True, **kwparams)
        outputs = self.model(token_ids[None], use_cache=False, **kwparams)
        obj = Generated(text = text, token_ids = token_ids, outputs = outputs, pipeline = None, model = self.model, tokenizer = self.tokenizer, **kwparams)
        return obj

class Multiple(Generated):
    def __init__(self, text, eos_text, *extra_eos_texts, params = [], **kwparams):
        super().__init__(text, *params, **kwparams)
        self.eos_texts = [eos_text, *extra_eos_texts]
    def __iter__(self):
        queue = [(1, 1, None, self, iter(self.next_obj_probs()))]
        while len(queue):
            queue.sort(key = lambda tup: tup[:2])
            max_prob, base_prob, prev_base, base, it = queue.pop()
            base_text = str(base)
            eos_found = False
            for eos_text in self.eos_texts:
                if base_text.endswith(eos_text):
                    yield base_text[len(str(self)):-len(eos_text)], base_prob
                    eos_found = True
            if eos_found:
                continue
            if max_prob == base_prob:
                print(round(float(base_prob*100),4), str(base).strip(), end='\r', flush=True)
            next_base, next_prob = next(it)
            next_prob = next_prob * base_prob
            queue.append((next_prob, next_prob, base, next_base, iter(next_base.next_obj_probs())))
            queue.append((next_prob, base_prob, prev_base, base, it)) # try this one again when prob drops

if __name__ == '__main__':
    print("This example is to the point where it is outpacing damaged parts of Karl's cognition that he needs to heal.")
    print("Is there any possibility to add training wheels to it, so the user learns a little relevance?")
        # thinking user could learn to provide 2 missing concepts when prompted with 1 example
        # ideally to keep growing so it is user-created content rather than ai-created
    try:
        with torch.no_grad():
            #multiple = Multiple('Places you might find in a city:', '</s>', ',', '.')
            #multiple = Multiple('Things you might find in a magic closet:', '</s>', ',', '.')
            multiple = Multiple('Some striking environments for a fantasy, sci-hi, or historical scene:', '</s>', ',')#, '.')
            remaining = 1
            for completion, prob in iter(multiple):
                remaining -= float(prob)
                print(f'{round(float(remaining)*100,4)} + {round(float(prob)*100,4)}: {completion}                                                  ')
    finally:
        print("This example is to the point where it is outpacing damaged parts of Karl's cognition that he needs to heal.")
        print("Is there any possibility to add training wheels to it, so the user learns a little relevance, and must learn to be able to do what the script does?")
