import bisect

class Chunk:
    def __init__(self, start, end, data, height=0, leaf_count=1, age=0):
        self.start = start
        self.end = end
        self.data = data
        self.height = height
        self.leaf_count = leaf_count
        self.age = age
    def __len__(self):
        return self.end - self.start
    def is_leaf(self):
        return self.height == 0

class Flush(Chunk):
    class Entry(Chunk):
        def __init__(self, start, end, chunk, path = []):
            super().__init__(start, end, chunk, height=0, leaf_count=1, age=chunk.age)
            self.path = list(path)
            self.path.append(self.data)
            if not chunk.is_leaf():
                self.leaf_count = 0
                self.height = 1
                for entry in self.flush_entries():
                    self.leaf_count += entry.leaf_count
                    self.height = max(self.height, entry.height + 1)
        def flush_entries(self):
            assert not self.data.is_leaf()
            return (
                Flush.Entry(max(entry.start, self.start), min(entry.end, self.end), entry.data, self.path)
                for entry in self.data.data
                if entry.start < self.end and entry.end > self.start
            )
        def chunk_data(self):
            assert self.is_leaf()
            return self.data.data[self.start - self.data.start : self.end - self.data.start]
    def __init__(self, prev_flush = None):
        if prev_flush is not None:
            super().__init__(prev_flush.start, prev_flush.end, [], height=1, leaf_count=0, age=prev_flush.age+1)
            self.max_height = prev_flush.leaf_count.bit_length()
            prev_entry = Flush.Entry(self.start, self.end, prev_flush)
            self.add(prev_entry)
        else:
            super().__init__(None, None, [], height=1, leaf_count=0)
            self.max_height = 1
    def add(self, *adjacents):
        adjacents = list(adjacents)
        if self.start is None:
            self.start = adjacents[0].start
            self.end = adjacents[-1].end
        else:
            self.start = min(self.start, adjacents[0].start)
            self.end = max(self.end, adjacents[-1].end)

        # expand adjacents that are too deep
        idx = 0
        while idx < len(adjacents):
            entry = adjacents[idx]
            if entry.height + 1 > self.max_height:
                subadjacents = []
                shallow_start = entry.start
                shallow_end = shallow_start
                for subentry in entry.flush_entries():
                    if subentry.height + 2 > self.max_height:
                        if shallow_end != shallow_start:
                            subadjacents.append(Flush.Entry(shallow_start, shallow_end, entry.data))
                        subadjacents.append(subentry)
                        shallow_start = subentry.end
                    shallow_end = subentry.end
                if shallow_end != shallow_start:
                    subadjacents.append(Flush.Entry(shallow_start, shallow_end, entry.data))
                adjacents[idx:idx+1] = subadjacents
            else:
                idx += 1

        # first idx with end > start
        start_idx = bisect.bisect_right([entry.end for entry in self.data], adjacents[0].start)
        # first idx with start >= end
        end_idx = bisect.bisect_left([entry.start for entry in self.data], adjacents[-1].end, start_idx)
        replaced = self.data[start_idx:end_idx]
        if len(replaced):
            if replaced[0].start < adjacents[0].start:
                adjacents.insert(
                    0,
                    Flush.Entry(
                        replaced[0].start, adjacents[0].start, replaced[0].data
                    )
                )
                if start_idx > 0:
                    # the trimmed entry may have fewer leaves and itself merge with its neighbor
                    start_idx -= 1
                    replaced.insert(0, self.data[start_idx])
                    adjacents.insert(0, self.data[start_idx])
            if replaced[-1].end > adjacents[-1].end:
                adjacents.append(
                    Flush.Entry(
                        adjacents[-1].end, replaced[-1].end, replaced[-1].data
                    )
                )
                if end_idx < len(self.data):
                    # the trimmed entry may have fewer leaves and itself merge with its neighbor
                    replaced.append(self.data[end_idx])
                    adjacents.append(self.data[end_idx])
                    end_idx += 1

        for idx, entry in reversed(list(enumerate(adjacents))):
            if entry.leaf_count == 0:
                # no leaves left in this branch, remove
                adjacents.pop(idx)
                continue
            count = 0
            subentry = entry
            while count <= 1 and subentry is not None and not subentry.data.is_leaf():#type(subentry.data) is Flush:
                # make branches shallower by splicing out roots with only one child
                parent_entry = subentry
                count = 0
                subentry = None
                for subentry in parent_entry.flush_entries():
                    count += 1
                    if count > 1:
                        subentry = parent_entry
                        break
            if subentry is not entry:
                # some internodes were removed
                assert subentry is not None # can likely remove assignment to None above if this removed
                adjacents[idx] = subentry

        for idx, (left_adjacent, right_adjacent) in reversed(list(enumerate(zip(adjacents[:-1], adjacents[1:])))):

            # merge writes
            if (
                left_adjacent.age == self.age and
                right_adjacent.age == self.age and
                left_adjacent.end == right_adjacent.start
            ):
                left_adjacent.data = Chunk(
                    left_adjacent.start,
                    right_adjacent.end,
                    left_adjacent.chunk_data() + right_adjacent.chunk_data()
                )
                left_adjacent.end = right_adjacent.end
                adjacents.pop(idx+1)
                continue

            # merge branches with shared parents
            shared_path = [
                left_parent for left_parent, right_parent
                in zip(left_adjacent.path, right_adjacent.path)
                if left_parent is right_parent
            ]
            if len(shared_path) > 0 and left_adjacent.height + len(left_adjacent.path) - len(shared_path) < self.max_height and right_adjacent.height + len(right_adjacent.path) - len(shared_path) < self.max_height:
                if left_adjacent.end != right_adjacent.start:
                    between_entry = Flush.Entry(
                        left_adjacent.end,
                        right_adjacent.start,
                        shared_path[-1]
                    )
                    if between_entry.leaf_count > 0:
                        # the shared root contains leaves in between that have been removed
                        continue
                print(f'Merging {len(left_adjacent.path)}:{left_adjacent.height}, {len(right_adjacent.path)}:{right_adjacent.height} -> {len(shared_path)}:{left_adjacent.height + len(left_adjacent.path) - len(shared_path)}')
                left_adjacent.end = right_adjacent.end
                left_adjacent.leaf_count += right_adjacent.leaf_count
                left_adjacent.height += len(left_adjacent.path) - len(shared_path)
                left_adjacent.path = shared_path
                left_adjacent.data = shared_path[-1]
                adjacents.pop(idx+1)

        self.leaf_count += sum((adjacent.leaf_count for adjacent in adjacents))
        self.leaf_count -= sum((old.leaf_count for old in replaced))
        self.max_height = self.leaf_count.bit_length()
        self.data[start_idx:end_idx] = adjacents
        self.height = max((entry.height for entry in self.data)) + 1
        #self.check_leaf_count(self.start, self.end)
    def write(self, offset, data):
        chunk = Chunk(offset, offset + len(data), data, age=self.age)
        entry = Flush.Entry(offset, offset + len(data), chunk)
        return self.add(entry)
    def read(self, start, max_end = float('inf')):
        # first idx with end > start
        idx = bisect.bisect_right([entry.end for entry in self.data], start)
        if idx == len(self.data):
            return bytes(4096)
        entry = self.data[idx]
        if entry.start > start:
            end = min(max_end, entry.start)
            return bytes(end - start)
        end = min(max_end, entry.end)
        if entry.data.is_leaf():
            datastart = start - entry.data.start
            dataend = end - entry.data.start
            return entry.data.data[datastart : dataend]
        else:
            return entry.data.read(start, end)
    #def check_leaf_count(self, start, end):
    #    leaf_count = 0
    #    wrapper = Flush.Entry(start, end, self)
    #    for entry in wrapper.flush_entries():
    #        if type(entry.data) is Flush:
    #            entry_leaf_count = entry.data.check_leaf_count(entry.start, entry.end)
    #            assert entry_leaf_count == entry.leaf_count
    #            leaf_count += entry_leaf_count
    #        else:
    #            leaf_count += entry.data.leaf_count
    #    assert leaf_count == wrapper.leaf_count
    #    if start == self.start and end == self.end:
    #        assert leaf_count == self.leaf_count
    #    return leaf_count



def main():
    import random
    random.seed(0)
    SIZE=4096
    comparison = bytearray(SIZE)
    #import mmap
    #comparison = mmap.mmap(-1, SIZE)
    store = Flush()
    def compare(store, comparison):
        offset = 0
        while offset < len(comparison):
            data = store.read(offset)[:len(comparison) - offset]
            assert data == comparison[offset:offset+len(data)]
            offset += len(data)
        #store.check_leaf_count(store.start, store.end)
        return True
    for flushes in range(1024):
        for writes in range(random.randint(1,16)):
            start = random.randint(0, SIZE-1)
            size = min(SIZE-start, random.randint(1, 128))#1024))
            end = start + size
            data = random.getrandbits(size*8).to_bytes(size, 'little')
            store.write(start, data)
            comparison[start:end] = data
            #compare(store, comparison)
            #print('OK', flushes, writes)#, offset)
        compare(store, comparison)
        print('OK', len(store.data), 'x', store.height, '/', store.max_height, 'count =', store.leaf_count, 'flushes =', flushes)#, writes)#, offset)
        store = Flush(prev_flush = store)
        compare(store, comparison)

if __name__ == '__main__':
    main()
    #import cProfile
    #cProfile.run('main()')
