from collections import namedtuple
import bisect

class Chunk:
    def __init__(self, start, end, data, height=0, leaf_count=1, age=0):
        self.start = start
        self.end = end
        self.data = data
        self.height = height
        self.leaf_count = leaf_count
        self.age = age
    def __len__(self):
        return self.end - self.start

class Flush(Chunk):
    class Entry(Chunk):
        def __init__(self, start, end, chunk, path = []):
            super().__init__(start, end, chunk, height=0, leaf_count=1, age=chunk.age)
            self.path = [*path, self.data]
            if type(chunk) is Flush:
                self.leaf_count = 0
                self.height = 0
                for entry in self.flush_entries():
                    self.leaf_count += entry.leaf_count
                    self.height = max(self.height, entry.height + 1)
        def flush_entries(self):
            assert type(self.data) is Flush
            return (
                Flush.Entry(max(entry.start, self.start), min(entry.end, self.end), entry.data, self.path)
                for entry in self.data.data
                if entry.start < self.end and entry.end > self.start
            )
        def chunk_data(self):
            assert type(self.data) is Chunk
            return self.data.data[self.start - self.data.start : self.end - self.data.start]
    def __init__(self, prev_flush = None):
        if prev_flush is not None:
            leaf_count = prev_flush.leaf_count
            self.max_height = leaf_count.bit_length()
            super().__init__(prev_flush.start, prev_flush.end, [], leaf_count=0, age=prev_flush.age+1)
            self.add(prev_flush)
            for entry in prev_flush.data:
                if entry.height >= self.max_height:
                    self.add(entry)
        else:
            super().__init__(None, None, [], height=1, leaf_count=0)
    def add(self, *adjacents):
        adjacents = [
            adjacent if type(adjacent) is Flush.Entry
            else Flush.Entry(adjacent.start, adjacent.end, adjacent)
            for adjacent in adjacents
        ]
        if self.start is None:
            self.start = adjacents[0].start
            self.end = adjacents[-1].end
        else:
            self.start = min(self.start, adjacents[0].start)
            self.end = max(self.end, adjacents[-1].end)

        # first idx with end > start
        start_idx = bisect.bisect_right([entry.end for entry in self.data], adjacents[0].start)
        # first idx with start >= end
        end_idx = bisect.bisect_left([entry.start for entry in self.data], adjacents[-1].end, start_idx)
        replaced = self.data[start_idx:end_idx]
        if len(replaced):
            if replaced[0].start < adjacents[0].start:
                adjacents.insert(
                    0,
                    Flush.Entry(
                        replaced[0].start, adjacents[0].start, replaced[0].data
                    )
                )
                if start_idx > 0:
                    # the trimmed entry may have fewer leaves and itself merge with its neighbor
                    start_idx -= 1
                    replaced.insert(0, self.data[start_idx])
                    adjacents.insert(0, self.data[start_idx])
            if replaced[-1].end > adjacents[-1].end:
                adjacents.append(
                    Flush.Entry(
                        adjacents[-1].end, replaced[-1].end, replaced[-1].data
                    )
                )
                if end_idx < len(self.data):
                    # the trimmed entry may have fewer leaves and itself merge with its neighbor
                    replaced.append(self.data[end_idx])
                    adjacents.append(self.data[end_idx])
                    end_idx += 1

        for idx, entry in reversed([*enumerate(adjacents)]):
            if entry.leaf_count == 0:
                # no leaves left in this branch, remove
                adjacents.pop(idx)
                continue
            count = 0
            subentry = entry
            while count <= 1 and subentry is not None and type(subentry.data) is Flush:
                # make branches shallower by splicing out roots with only one child
                parent_entry = subentry
                count = 0
                subentry = None
                for subentry in parent_entry.flush_entries():
                    count += 1
                    if count > 1:
                        subentry = parent_entry
                        break
            if subentry is not entry:
                # some internodes were removed
                assert subentry is not None # can likely remove assignment to None above if this removed
                adjacents[idx] = subentry

        for idx, (left_adjacent, right_adjacent) in reversed([*enumerate(zip(adjacents[:-1], adjacents[1:]))]):

            # merge writes
            if (
                left_adjacent.age == self.age and
                right_adjacent.age == self.age and
                #type(left_adjacent.data) is Chunk and
                #type(right_adjacent.data) is Chunk and
                left_adjacent.end == right_adjacent.start
            ):
                left_adjacent.data = Chunk(
                    left_adjacent.start,
                    right_adjacent.end,
                    left_adjacent.chunk_data() + right_adjacent.chunk_data()
                )
                left_adjacent.end = right_adjacent.end
                adjacents.pop(idx+1)
                continue

            # merge branches with shared parents
            shared_parents = [
                left_parent for left_parent, right_parent
                in zip(left_adjacent.path, right_adjacent.path)
                if left_parent is right_parent
            ]
            if len(shared_parents) > 0:
                import pdb; pdb.set_trace()
                '''this code path has not been hit before; does it work?'''
                print(shared_parents)
            if len(shared_parents) > 0 and left_adjacent.height + len(left_adjacent.parents) - len(shared_parents) < self.max_height and right_adjacent.height + len(right_adjacent.parents) - len(shared_parents) < self.max_height:
                if left_adjacent.end != right_adjacent.start:
                    between_entry = Flush.Entry(
                        left_adjacent.end,
                        right_adjacent.start
                    )
                    if between_entry.leaf_count > 0:
                        # the shared root contains leaves in between that have been removed
                        continue
                print(f'Merging {len(left_adjacent.path)}:{left_adjacent.height}, {len(right_adjacent.path)}:{right_adjacent.height} -> {len(shared_path)}:{left_adjacent.height + len(leaf_adjacent.path) - len(shared_parents)}')
                left_adjacent.end = right_adjacent.end
                left_adjacent.leaf_count += right_adjacent.leaf_count
                left_adjacent.height += len(left_adjacent.path) - len(shared_parents)
                left_adjacent.path = shared_parents
                left_adjacent.data = shared_parents[-1]
                adjacents.pop(idx+1)

        self.leaf_count += sum((adjacent.leaf_count for adjacent in adjacents))
        self.leaf_count -= sum((old.leaf_count for old in replaced))
        self.max_height = self.leaf_count.bit_length()
        self.data[start_idx:end_idx] = adjacents
        self.height = max((entry.height for entry in self.data)) + 1
        #self.check_leaf_count(self.start, self.end)
    def write(self, offset, data):
        chunk = Chunk(offset, offset + len(data), data, age=self.age)
        return self.add(chunk)
    def read(self, start, max_end = float('inf')):
        # first idx with end > start
        idx = bisect.bisect_right([entry.end for entry in self.data], start)
        if idx == len(self.data):
            return bytes(4096)
        entry = self.data[idx]
        if entry.start > start:
            end = min(max_end, entry.start)
            return bytes(end - start)
        end = min(max_end, entry.end)
        if type(entry.data) is Flush:
            return entry.data.read(start, end)
        elif type(entry.data) is Chunk:
            datastart = start - entry.data.start
            dataend = end - entry.data.start
            return entry.data.data[datastart : dataend]
    #def check_leaf_count(self, start, end):
    #    leaf_count = 0
    #    wrapper = Flush.Entry(start, end, self)
    #    for entry in wrapper.flush_entries():
    #        if type(entry.data) is Flush:
    #            entry_leaf_count = entry.data.check_leaf_count(entry.start, entry.end)
    #            assert entry_leaf_count == entry.leaf_count
    #            leaf_count += entry_leaf_count
    #        else:
    #            leaf_count += entry.data.leaf_count
    #    assert leaf_count == wrapper.leaf_count
    #    if start == self.start and end == self.end:
    #        assert leaf_count == self.leaf_count
    #    return leaf_count



if __name__ == '__main__':
    import random
    random.seed(0)
    SIZE=4096
    comparison = bytearray(SIZE)
    store = Flush()
    def compare(store, comparison):
        offset = 0
        while offset < len(comparison):
            data = store.read(offset)[:len(comparison) - offset]
            assert data == comparison[offset:offset+len(data)]
            offset += len(data)
        #store.check_leaf_count(store.start, store.end)
        return True
    for flushes in range(1024):
        for writes in range(random.randint(1,16)):
            start = random.randint(0, SIZE)
            end = random.randint(0, SIZE)
            if end < start:
                start, end = end, start 
            end = (end + start) // 2
            size = end - start
            data = random.randint(0, (1<<(size*8))-1).to_bytes(size, 'little')
            store.write(start, data)
            comparison[start:end] = data
            #compare(store, comparison)
            #print('OK', flushes, writes)#, offset)
        compare(store, comparison)
        print('OK', len(store.data), 'x', store.height, 'count =', store.leaf_count, 'flushes =', flushes)#, writes)#, offset)
        store = Flush(prev_flush = store)
        compare(store, comparison)

