print('loading ... did not mmap model weights or shrink model or train for task, so it is slow to load ...')
import transformers

DEFAULT_PIPELINE = transformers.pipeline('text-generation', 'bigscience/bloomz-560m') # bloomz: 560m, 1.1b, 1.7b, 3b, 7.1b, 176b ; mt0: small (300m), base, large, xl, xxl
print('loaded')
#DEFAULT_PIPELINE = transformers.pipeline('text-generation', 'bigscience/bloomz-1b1') # bloomz: 560m, 1.1b, 1.7b, 3b, 7.1b, 176b ; mt0: small (300m), base, large, xl, xxl
#DEFAULT_PIPELINE = transformers.pipeline('text-generation', 'bigscience/bloomz-1b7') # bloomz: 560m, 1.1b, 1.7b, 3b, 7.1b, 176b ; mt0: small (300m), base, large, xl, xxl

class Prompter:
    def __init__(self, prefix, pipeline = None, model = None, tokenizer = None):
        if pipeline is None:
            pipeline = DEFAULT_PIPELINE
        if model is None:
            model = pipeline.model
        if tokenizer is None:
            tokenizer = pipeline.tokenizer
        self.model = model
        self.tokenizer = tokenizer
        self._prefix, self._extra_prefix = self._convert_extra_tail(prefix)
        assert not self._prefix.endswith(' ')
        self._prefix_token_ids = self.tokenizer(self._prefix, return_tensors='pt').input_ids # , attention_mask
        outputs = self.model(self._prefix_token_ids, use_cache=True)#, return_dict=True, output_hidden_states=True)
        self._cache_key_values = outputs.past_key_values
    def _convert_extra_tail(self, str):
        if len(str) > 2 and str[-1] == ' ' and str[-2] != ' ':
            return str[:-1], str[-1:]
        else:
            return str, ''
        
    def __call__(self, suffix, **kwparams):
        assert suffix
        token_ids = self.tokenizer(self._prefix + self._extra_prefix + suffix, return_tensors='pt').input_ids[:,self._prefix_token_ids.shape[1]:]
        self.model.eval()
        outputs = self.model(input_ids = token_ids, past_key_values = self._cache_key_values, use_cache = False, **kwparams)
        assert outputs.logits.shape[0] == 1
        logits = outputs.logits[0,-1]
        return logits, outputs.past_key_values

class Classifier(Prompter):
    def __init__(self, prefix, suffix, *default_answers, bad_answer=None, **kwparams):
        super().__init__(prefix, **kwparams)
        self.default_answers = default_answers
        self.bad_answer = bad_answer
        self.suffix = suffix
    def __call__(self, infix, *answers, bad_answer = None, confidence = 0.65, return_prob = False):
        suffix, extra_suffix = self._convert_extra_tail(infix + self.suffix)
        if not answers:
            answers = self.default_answers
        if not bad_answer:
            bad_answer = self.bad_answer
        answer_ids = [self.tokenizer(extra_suffix + answer).input_ids[0] for answer in answers]
        logits, key_values = super().__call__(suffix)
        probs = logits.softmax(dim=-1)
        answer_relprobs = probs[answer_ids]
        answer_probs = answer_relprobs / answer_relprobs.sum()
        best_answer = answer_probs.argmax()
        best_prob = answer_probs[best_answer]
        answer = answers[best_answer]
        if best_prob < confidence:
            if bad_answer:
                answer = bad_answer
                try:
                    idx = answers.index(bad_answer)
                    best_prob = answer_probs[idx]
                except:
                    pass
            else:
                raise ValueError(f"insufficient confidence {answer}={round(best_prob.item()*100)}%")
        if return_prob:
            return answer, best_prob.item()
        else:
            return answer

class Generator(Prompter):
    def __init__(self, prefix, suffix = '', *params, **kwparams):
        super().__init__(prefix, **kwparams)
        self.suffix = suffix
    def iterate(self, infix, eos_token_ids = None, max_tokens = 64, **kwparams):
        assert infix
        suffix, extra_suffix = self._convert_extra_tail(infix + self.suffix)
        import torch
        if eos_token_ids is None:
            eos_token_ids = torch.tensor([self.tokenizer.eos_token_id])
        else:
            eos_token_ids = torch.tensor(eos_token_ids)
        result = torch.empty(eos_token_ids.shape, dtype=int)
        result_probs = torch.empty(eos_token_ids.shape)
        token_ids = self.tokenizer(self._prefix + self._extra_prefix + suffix, return_tensors='pt').input_ids[:,self._prefix_token_ids.shape[1]:]
        past_key_values = self._cache_key_values
        for ct in range(max_tokens):
            outputs = self.model(token_ids, past_key_values = past_key_values)
            logits = outputs.logits[0,-1]
            probs = logits.softmax(dim=-1)
            best_token_id = torch.argmax(logits, keepdim=True)[0]
            best_prob = probs[best_token_id]
            result = torch.cat((result[1:], best_token_id[None]))
            result_probs = torch.cat((result_probs[1:], best_prob[...,None]))
            if ct + 1 >= len(eos_token_ids):
                if (result == eos_token_ids).all():
                    return
                else:
                    text1 = self.tokenizer.decode(result)
                    text2 = self.tokenizer.decode(result[1:])
                    text = text1[:len(text1)-len(text2)]
                    if ct + 1 - len(result) == 0 and text.startswith(extra_suffix):
                        text = text[len(extra_suffix):]
                    yield text, result_probs[0].item()
            #if best_token_id == self.tokenizer.eos_token_id: # this happens spuriously for wrong data but doesn't reflect in confidence if it started well
            #    break
            token_ids = best_token_id[None,None]
            past_key_values = outputs.past_key_values

        text1 = self.tokenizer.decode(result)
        for idx, (token, prob) in enumerate(zip(result[-ct:], result_probs[-ct:])):
            if idx + 1 < len(result):
                text2 = self.tokenizer.decode(result[idx+1:])
            else:
                text2 = ''
            text = text1[:len(text1)-len(text2)]
            yield text, prob
            text1 = text2
    #def many(self, suffix, eos_token_ids = None, max_tokens = 64, confidence = 0, return_prob = False, **kwparams):
    #    # a candidate is a string sequence with state, and 
    def __call__(self, suffix, eos_token_ids = None, max_tokens = 64, confidence = 0, return_prob = False, **kwparams):
        total_prob = 1
        str = ''
        for word, prob in self.iterate(suffix, eos_token_ids, max_tokens = max_tokens, **kwparams):
            total_prob *= prob
            str += word
            if total_prob < confidence:
                raise ValueError(f"insufficient confidence {str.strip()}={round(float(total_prob)*100)}%")
        if return_prob:
            return str, total_prob
        else:
            return str
        # isn't there already a generate function, but if issue is significant don't worry about
            # it does not include probabilities in the api i have learned
            # they can likely be extracted using a callback function without much issue
            # current code unstable atm

class Many(Prompter):
    def __init__(self, prefix, suffix = '', *params, **kwparams):
        super().__init__(prefix, **kwparams)
        self.suffix = suffix
    '''this is the same as Generate with very minor changes.'''
    '''- it combines the tokens in iterate and yields the whole string
       - it does a number of strings by checking if it is processing the first token, sorting, and indexing the result. usually one would do the calculation only once.
           
        uhhhh it really should iterate subbranches
                well y'know what it's frustrating and confusing not to generalize
                it would make sense to generalize a _little_ bit, by doing the branching approach
                but you could also add the article to the thing
    '''
    def iterate(self, count, infix, eos_token_ids = None, max_tokens = 64, **kwparams):
        assert infix
        suffix, extra_suffix = self._convert_extra_tail(infix + self.suffix)
        import torch
        if eos_token_ids is None:
            eos_token_ids = torch.tensor([self.tokenizer.eos_token_id])
        else:
            eos_token_ids = torch.tensor(eos_token_ids)
        for count_idx in range(count):
            str = ''
            prob = 1
            result = torch.empty(eos_token_ids.shape, dtype=int)
            result_probs = torch.empty(eos_token_ids.shape)
            token_ids = self.tokenizer(self._prefix + self._extra_prefix + suffix, return_tensors='pt').input_ids[:,self._prefix_token_ids.shape[1]:]
            past_key_values = self._cache_key_values
            for ct in range(max_tokens):
                outputs = self.model(token_ids, past_key_values = past_key_values)
                logits = outputs.logits[0,-1]
                if ct > 0:
                    best_token_id = torch.argmax(logits, keepdim=True)[0]
                else:
                    best_token_id = logits.sort(descending=True).indices[count_idx]
                probs = logits.softmax(dim=-1)
                best_prob = probs[best_token_id]
                result = torch.cat((result[1:], best_token_id[None]))
                result_probs = torch.cat((result_probs[1:], best_prob[...,None]))
                if ct + 1 >= len(eos_token_ids):
                    if (result == eos_token_ids).all():
                        break
                    else:
                        text1 = self.tokenizer.decode(result)
                        text2 = self.tokenizer.decode(result[1:])
                        text = text1[:len(text1)-len(text2)]
                        if ct + 1 - len(result) == 0 and text.startswith(extra_suffix):
                            text = text[len(extra_suffix):]
                        str += text
                        prob *= result_probs[0].item()
                #if best_token_id == self.tokenizer.eos_token_id: # this happens spuriously for wrong data but doesn't reflect in confidence if it started well
                #    break
                token_ids = best_token_id[None,None]
                past_key_values = outputs.past_key_values

            if (result != eos_token_ids).any():
                text1 = self.tokenizer.decode(result)
                for idx, (token, tokenprob) in enumerate(zip(result[-ct:], result_probs[-ct:])):
                    if idx + 1 < len(result):
                        text2 = self.tokenizer.decode(result[idx+1:])
                    else:
                        text2 = ''
                    text = text1[:len(text1)-len(text2)]
                    str += text
                    prob *= tokenprob
                    text1 = text2
            yield str, prob
            # thinking on probabilities here. Like what if many interesting things start with 'a'.
            # really we would do a most-probable-sequence-first search of all the tokens, like in the other code.
    __call__ = iterate
        

# it could store the logits and answers if is known, and then tune an adapter around all of them, likely not hard

# it could also condense them to remove all the prefixes using an adapter

class Converter(Generator):
    def __init__(self, prefix1, prefix2, kwparams={}, **examples):
        if prefix1:
            prefix1 = ' ' + prefix1
        if prefix2:
            prefix2 = ' ' + prefix2
        super().__init__(
            ''.join((f'{prefix1} {key}{prefix2} {val}' for key, val in examples.items()))
            + f'{prefix1} ',
            f'{prefix2} ',
            **kwparams
        )
        self._eos_token_ids = self.tokenizer.encode(f'{prefix1} ')[:-1] # the -1 removes the space, in case terminating tokens are encoded differently, may not be needed
    def __call__(self, form1, confidence = 0.65, return_prob = False):
        return super().__call__(form1, self._eos_token_ids, confidence = confidence, return_prob = return_prob)

class ExampleClassifier(Converter):
    def __init__(self, prefix1, prefix2, kwparams={}, classifier_kwparams={}, **examples):
        super().__init__(prefix1, prefix2, kwparams, **examples)
        self.examples = examples
        answers = set(examples.values())
        self.classifier = Classifier(self._prefix + self._extra_prefix, self.suffix, *answers, **kwparams, **classifier_kwparams)
    def __call__(self, form1, confidence = 0.65, return_prob = False):
        return self.classifier(form1, confidence = confidence, return_prob = return_prob)

#class Pluralizer(Generator):
#    def __init__(self, **kwparams):
#        super().__init__(' unknown: dirt plural: dirts unknown: apple plural: apples unknown: ', **kwparams) # space gets migrated by base class, to pass more common token
#        self._eos_token_ids = self.tokenizer.encode(' unknown: ')[:-1] # the -1 removes the space, in case terminating tokens are encoded differently, may not be needed
#        assert ':' in self.tokenizer.decode(self._eos_token_ids)
#    def __call__(self, word_with_unknown_multiplicity, confidence = 0.65, return_prob = False):
#        return super().__call__(word_with_unknown_multiplicity + ' plural:', self._eos_token_ids, confidence = confidence, return_prob = return_prob)
#        #plural = ''
#        #total_prob = 1
#        #for word, prob in super().__call__(word_with_unknown_multiplicity + ' plural:', self._eos_token_ids):
#        #    total_prob *= prob # not terminating early because an incomplete result can be incorrect but have high confidence if it starts something else correct
#        #    plural += word
#        #    if total_prob < confidence:
#        #        raise ValueError(f"insufficient confidence {plural.strip()}={round(float(total_prob)*100)}%")
#        #if return_prob:
#        #    return plural.strip(), total_prob
#        #else:
#        #    return plural.strip()

if __name__ == '__main__':
    generator = Generator('<s>')
    import torch
    print('Once upon a time,' + generator('Once upon a time,'))
    #plural = Pluralizer()
    plural = Converter('(singular) a single', 'vs. (plural) a few, many, or multiple', dress='dresses', apple='apples')
    infinitive = Converter('unknown:', 'infinitive: to', thinkable='think', brought='bring', running='run')

        # thinking of having objects have types
        # and verbs require properties
        # then if type doesn't have property, asks user, and adds it only to database? unsure
    object_types = dict(
        thought = 'tangible'
    )
    

    #able = Convert('infinitive:
    try:
        print('house', plural('house', return_prob=True))
    except ValueError as e:
        print('not sure what plural of house is', e)
    try:
        print('mess', plural('mess', return_prob=True))
    except ValueError as e:
        print('not sure what plural of mess is', e)
    try:
        print('butterfly', plural('butterfly', return_prob=True))
    except ValueError as e:
        print('not sure what plural of butterfly is', e)
        

            # property classifiers helpful
            # a converter that differentiates between synonyms could also be helpful
            # like "we were sitting near the rock and it was so loud -> rock-music"
            # could maybe apply it to all of them to make it clear, dunno
    many = Many('Places you might find in ', ': ')
    print(*many(8, 'a city'))
    tangibility = ExampleClassifier('', '=', classifier_kwparams=dict(bad_answer='low confidence'), stone='tangible', thought='intangible', bird='tangible', speed='intangible', car='tangible')
    breadbox = ExampleClassifier('', '=', pebbles='small', mailbox='small', dog='big', tree='big', car='big', suitcase='small', bottle='small', table='big', peanuts='small', bench='big', watch='small', brain='small', toilet='big')
    print('memory tangible?', tangibility('memory', return_prob=True))
    print('rock tangible?', tangibility('stone', return_prob=True))

    #smallness = ExampleCl
                # it could make a lot of sense to do prefix tuning, although i suppose that doesn't move toward shrinking the model
                    # [a prefix is needed for each group of queries, as opposed to an adapter which can be shared; prefixes are much tinier and somewhat less smart]
    #actionobj = Classifier('Respond if the phrase is easy or hard to do. ')
    #actionobj = Classifier('EASY actions: open door , pick up watch , look at cloud , eat cake , talk to storeperson . HARD actions: eat rainbow , build planet , open cloud , pick up house , talk to color . Question: Would ', ' be EASY or HARD?', 'EASY', 'HARD') 
    #actionobj = Classifier(' EASY actions:\n open door\n pick up watch\n look at cloud\n eat cake\n talk to storeperson\n ride horse\n HARD actions:\n eat rainbow\n build planet\n open cloud\n pick up house\n talk to color\n turn into frog\n is this action EASY or HARD?\n ', '\n', 'EASY', 'HARD')
            # everything has properties, and everything is a context
            # an easy way to add this is to add a WHY and a context prefix; this is tried on internet with some success
    #actionobj = Classifier(' These actions are EASY:\n open door\n pick up watch\n look at cloud\n eat cake\n talk to storeperson\n ride horse\n These actions are HARD:\n eat rainbow\n build planet\n open cloud\n pick up house\n talk to color\n turn into frog\n Is this action EASY or HARD? ', '\n', 'EASY', 'HARD')
    actionobj = Classifier(' respond with whether or not a human could do it. can a human ', '? ', 'yes', 'no')

    # actionobj doesn't really work, but classifier context is quite useful, and it would of course work with tuning
    # it might make more sense to condition actions on the ability to generate a result from them.

    # we don't have more time, so having a success here might be helpful. object properties

    #objprop = Classifier(' frog color = green\n car size = 

    #for str in Generator.__call__(actionobj, 'eat cloud. Answer:', eos_token_ids=actionobj.tokenizer.encode('\n')):
    #    print(str)
    #actionobj.model.generate(past=actionobj._cache_key_values, use_cache=True)
    #print('eat cloud', actionobj('eat cloud. Answer:', ' yes', ' no', return_prob=True, confidence=0.5))
    #print('lift house', actionobj('pick up house. Answer:', ' yes', ' no', return_prob=True, confidence=0.5))
    #print('enter house', actionobj('enter house. Answer:', ' yes', ' no', return_prob=True, confidence=0.5))
    #print('eat rainbow', actionobj('eat rainbow', return_prob=True, confidence=0.5))
    #print('pick up house', actionobj('pick up house', return_prob=True, confidence=0.5))
    #print('look at rainbow', actionobj('look at rainbow', return_prob=True, confidence=0.5))
    #print('open box', actionobj('open box', return_prob=True, confidence=0.5))
    import pdb; pdb.set_trace()
    ''

