from collections import namedtuple
import bisect

class Chunk:
    def __init__(self, start, end, data, height=0, leaf_count=1, age=0):
        self.start = start
        self.end = end
        self.data = data
        self.height = height
        self.leaf_count = leaf_count
        self.age = age
    def __len__(self):
        return self.end - self.start
#
#class ChunkSlice:
#    def __init__(self, chunk, start, end):
#        super().__init__(start, end)
#        self.chunk = chunk

class Flush(Chunk):
    class Entry(Chunk):
        def __init__(self, start, end, chunk):
            super().__init__(start, end, chunk, height=0, leaf_count=1, age=chunk.age)
            if type(chunk) is Flush:
                self.leaf_count = 0
                self.height = 0
                for entry in self.flush_entries():
                    self.leaf_count += entry.leaf_count
                    self.height = max(self.height, entry.height + 1)
        def flush_entries(self):
            assert type(self.data) is Flush
            return (
                Flush.Entry(max(entry.start, self.start), min(entry.end, self.end), entry.data)
                for entry in self.data.data
                if entry.start < self.end and entry.end > self.start
            )
        def chunk_data(self):
            assert type(self.data) is Chunk
            return self.data.data[self.start - self.data.start : self.end - self.data.start]
    def __init__(self, prev_flush = None):
        if prev_flush is not None:
            leaf_count = prev_flush.leaf_count
            max_height = prev_flush.leaf_count.bit_length()
            super().__init__(prev_flush.start, prev_flush.end, [], leaf_count=0, age=prev_flush.age+1)
            self.add(prev_flush)
            for entry in prev_flush.data:
                if entry.height >= max_height:
                    #while type(entry.data) is Flush:
                    #    # reduce depth when only one subentry is used
                    #    count = 0
                    #    for subentry in entry:
                    #        count += 1
                    #        if count > 1:
                    #            break
                    #    if count <= 1:
                    #        entry = subentry
                    #    else:
                    #        break
                    self.add(entry)
        else:
            super().__init__(None, None, [], height=1, leaf_count=0)
    def add(self, *adjacents):
        adjacents = [
            adjacent if type(adjacent) is Flush.Entry
            else Flush.Entry(adjacent.start, adjacent.end, adjacent)
            for adjacent in adjacents
        ]
        if self.start is None:
            self.start = adjacents[0].start
            self.end = adjacents[-1].end
        else:
            self.start = min(self.start, adjacents[0].start)
            self.end = max(self.end, adjacents[-1].end)

        # first idx with end >= start
        start_idx = bisect.bisect_left([entry.end for entry in self.data], adjacents[0].start)
        # first idx with start > end
        end_idx = bisect.bisect_right([entry.start for entry in self.data], adjacents[-1].end, start_idx)
        replaced = self.data[start_idx:end_idx]
        if len(replaced):
            if replaced[0].start < adjacents[0].start:
                adjacents = [
                    Flush.Entry(
                        replaced[0].start, adjacents[0].start, replaced[0].data
                    ),
                    *adjacents
                ]
            if replaced[-1].end > adjacents[-1].end:
                adjacents = [
                    *adjacents,
                    Flush.Entry(
                        adjacents[-1].end, replaced[-1].end, replaced[-1].data
                    )
                ]
        for idx, (left_adjacent, right_adjacent) in reversed([*enumerate(zip(adjacents[:-1], adjacents[1:]))]):
            # merge writes
            if (
                type(left_adjacent.data) is Chunk and
                type(right_adjacent.data) is Chunk and
                left_adjacent.age == self.age and
                right_adjacent.age == self.age and
                left_adjacent.end == right_adjacent.start
            ):
                left_adjacent.data = Chunk(
                    left_adjacent.start,
                    right_adjacent.end,
                    left_adjacent.chunk_data() + right_adjacent.chunk_data()
                )
                left_adjacent.end = right_adjacent.end
                adjacents.pop(idx+1)
        self.leaf_count += sum((adjacent.leaf_count for adjacent in adjacents))
        self.leaf_count -= sum((old.leaf_count for old in replaced))
        self.data[start_idx:end_idx] = adjacents
        self.height = max((entry.height for entry in self.data)) + 1
        #self.check_leaf_count(self.start, self.end)
    def write(self, offset, data):
        chunk = Chunk(offset, offset + len(data), data, age=self.age)
        return self.add(chunk)
    def read(self, start, max_end = float('inf')):
        # first idx with end > start
        idx = bisect.bisect_right([entry.end for entry in self.data], start)
        if idx == len(self.data):
            return bytes(4096)
        entry = self.data[idx]
        if entry.start > start:
            end = min(max_end, entry.start)
            return bytes(end - start)
        end = min(max_end, entry.end)
        if type(entry.data) is Flush:
            return entry.data.read(start, end)
        elif type(entry.data) is Chunk:
            datastart = start - entry.data.start
            dataend = end - entry.data.start
            return entry.data.data[datastart : dataend]
    #def check_leaf_count(self, start, end):
    #    leaf_count = 0
    #    wrapper = Flush.Entry(start, end, self)
    #    for entry in wrapper.flush_entries():
    #        if type(entry.data) is Flush:
    #            entry_leaf_count = entry.data.check_leaf_count(entry.start, entry.end)
    #            assert entry_leaf_count == entry.leaf_count
    #            leaf_count += entry_leaf_count
    #        else:
    #            leaf_count += entry.data.leaf_count
    #    assert leaf_count == wrapper.leaf_count
    #    if start == self.start and end == self.end:
    #        assert leaf_count == self.leaf_count
    #    return leaf_count



if __name__ == '__main__':
    import random
    random.seed(2)
    SIZE=4096
    comparison = bytearray(SIZE)
    store = Flush()
    def compare(store, comparison):
        offset = 0
        while offset < len(comparison):
            data = store.read(offset)[:len(comparison) - offset]
            assert data == comparison[offset:offset+len(data)]
            offset += len(data)
        #store.check_leaf_count(store.start, store.end)
        return True
    for flushes in range(1024):
        for writes in range(random.randint(1,16)):
            start = random.randint(0, SIZE)
            end = random.randint(0, SIZE)
            if end < start:
                start, end = end, start 
            end = (end + start) // 2
            size = end - start
            data = random.randint(0, (1<<(size*8))-1).to_bytes(size, 'little')
            store.write(start, data)
            comparison[start:end] = data
            #compare(store, comparison)
            #print('OK', flushes, writes)#, offset)
        compare(store, comparison)
        print('OK', len(store.data), 'x', store.height, 'count =', store.leaf_count, 'flushes =', flushes)#, writes)#, offset)
        store = Flush(prev_flush = store)
        compare(store, comparison)

