
# block tree structure

# blocks written to are leaves
# depth is log2 leaves

# when a flush is made, all blocks are written, and also enough nodes such that every leaf can be accessed within depth lookups.


# consider we have an existing tree
# with say m flushes, containing n leaves (or m leaves). we'll likely call it n.


# each flush shows which leaves it has

# additionally, with the final flush, each leaf has an existing depth.

# when we reflush, we need to provide a new index for any leaves that become too deep.

# which leaves are too deep?

# we could basically walk them all to find out. this would be a consistent first approach.

class Simple:
    class Chunk:
        def __init__(self, offset, data):
            self.start = offset
            self.data = data
            self.end = self.start + len(self.data)
    class Flush:
        # flush has a list of new leaves, and a list of indexes to old leaves with ranges
        def __init__(self, *writes, prev_flush=None):
            self.prev_flush = prev_flush
            self.data = writes

            # find extents
            start = min((write.start for write in self.data))
            end = max((write.end for write in self.data))
            
            if prev_flush is None:
                self.start = start
                self.end = end
                self.index = []
                return

            self.start = min(start, prev_flush.start)
            self.end = max(end, prev_flush.end)
            self.index = [(prev_flush.start, prev_flush.end, prev_flush)]

            # find leaf count and leaf depths
            #offset = start
            #while offset < end:
            
        #def lookup(self, offset
        def leaves(self, start = None, end = None, depth = 0):
            if start is None:
                start = self.start
            if end is None:
                end = self.end
            offset = start
            data_iter = iter(self.data)
            index_iter = iter(self.index)
            next_write = next(data_iter, None)
            next_index = next(index_iter, None)
            while offset < end:
                if next_write is not None and offset >= next_write.start:
                    # offset >= next_write
                    # so we look in the write
                    substart = offset - next_write.start
                    subend = min(end, next_write.end)
                    yield (depth, offset, next_write.data[substart:subend])
                else:
                    # offset < next_write
                    # so we look in the index
                    assert next_index is not None
                    subend = next_write.start if next_write is not None else end
                    while offset >= next_index.end:
                        next_index = next(index_iter)
                    assert offset >= next_index.start and offset < next_index.end
                    subend = min(subend, next_index.end)
                    yield from next_index.leaves(offset, subend, depth + 1)
            assert offset == end
            if end == self.end:
                assert next(index_iter, None) is None
            
    def __init__(self, latest = None):
        self.tip = latest
        self.pending = []
    def write(self, offset, data):
        self.pending.append(self.Chunk(offset, data))
    def flush(self):
        self.tip = self.Flush(*self.pending, prev_flush=self.tip)
        self.pending = []
    def leaves(self, start = None, end = None):
        if self.tip is not None:
            return self.tip.leaves(start, end)

if __name__ == '__main__':
    import random
    SIZE=4096
    store = Simple()
    comparison = bytearray(SIZE)
    store.write(0, bytes(SIZE))
    for flushes in range(1024):
        for writes in range(1024):
            start = random.randint(0, SIZE)
            end = random.randint(0, SIZE)
            start, end = (start, end) if start <= end else (end, start)
            data = random.randbytes(end - start)
            comparison[start:end] = data
            store.write(start, data)
        store.flush()
        last_offset = 0
        max_depth = 0
        for depth, offset, data in store.leaves():
            assert comparison[last_offset:offset] == bytes(offset - last_offset)
            last_offset = offset + len(data)
            assert comparison[offset:last_offset] == data
            max_depth = max(depth, max_depth)
        assert comparison[last_offset:] == bytes(len(comparison) - last_offset)
        print(flush, max_depth, 'OK')
