#!/usr/bin/env python3
"""Standalone parser for CAST `.CST` files ??? CVia (and optional pads).

This port implements just enough of the original `CAST_CST_FILE` binary
layout (from the provided diff) to extract `CVia` entries and CPads.

It detects file endianness, parses the separated-data blocks, reads the
netlist names, then prints a table sorted by net name. Coordinates are
kept as integers (mils) by default; use `--scale` to convert units.

Usage:
  python3 scripts/generate_cvia_list.py magi7_41w0904.CST
  python3 scripts/generate_cvia_list.py magi7_41w0904.CST --scale 1000 --pads
"""
from __future__ import annotations
import sys
import argparse
import struct
from pathlib import Path
from typing import Tuple, List, Dict, Optional


class BinaryReader:
    def __init__(self, data: bytes, endian: str = '<'):
        self._b = data
        self._i = 0
        self.endian = endian

    def tell(self) -> int:
        return self._i

    def read(self, n: int) -> bytes:
        if self._i + n > len(self._b):
            raise EOFError('Unexpected end of file')
        r = self._b[self._i:self._i + n]
        self._i += n
        return r

    def unpack(self, fmt: str):
        size = struct.calcsize(fmt)
        return struct.unpack(self.endian + fmt, self.read(size))

    # helpers matching the C++ ReadVals behaviour
    def read_uint16(self) -> int:
        return self.unpack('H')[0]

    def read_int16(self) -> int:
        return self.unpack('h')[0]

    def read_int8(self) -> int:
        return self.unpack('b')[0]

    def read_uint8(self) -> int:
        return self.unpack('B')[0]

    def read_int32(self) -> int:
        return self.unpack('i')[0]

    def read_str_with_size(self, size_type: str) -> str:
        # size_type: 'b' (int8) or 'h' (int16)
        size = self.unpack(size_type)[0]
        if size < 0:
            raise ValueError('Negative string size')
        if size == 0:
            return ''
        raw = self.read(size)
        try:
            return raw.decode('utf-8').strip()
        except Exception:
            return raw.decode('latin-1', 'replace').strip()


def detect_endian(data: bytes) -> str:
    # We expect the first separated-data header to start with:
    #  uint16_t count; int16_t sigs[2]; where sigs should be (-1,0)
    if len(data) < 6:
        raise ValueError('File too small')
    for endian in ('<', '>'):
        count, sig0, sig1 = struct.unpack(endian + 'Hhh', data[0:6])
        if sig0 == -1 and sig1 == 0:
            return endian
    raise ValueError('Unable to detect endianness')


def parse_separated(reader: BinaryReader, sep_values: List[int], item_reader, header_name_size_type: str = 'h'):
    # header: uint16 count; int16 sigs[2]; then ReadStr<int16_t>(name)
    count = reader.read_uint16()
    sig0 = reader.read_int16()
    sig1 = reader.read_int16()
    if sig0 != -1 or sig1 != 0:
        raise ValueError('Unexpected separated-data signature')
    name = reader.read_str_with_size(header_name_size_type)
    items = []
    if count > 0:
        items.append(item_reader(reader))
        sep_bytes = b''.join(struct.pack(reader.endian + 'h', v) for v in sep_values)
        for _ in range(1, count):
            got = reader.read(len(sep_bytes))
            if got != sep_bytes:
                raise ValueError('Separator mismatch')
            items.append(item_reader(reader))
    return name, items


# Item readers matching types from the diff
def read_cdev(reader: BinaryReader) -> Dict:
    # name: ReadStr<int8_t>
    name = reader.read_str_with_size('b')
    # pos.x, pos.y (int16), layer (uint8), shape (int32)
    px = reader.read_int16()
    py = reader.read_int16()
    layer = reader.read_uint8()
    shape = reader.read_int32()
    return {'name': name, 'pos': (px, py), 'layer': layer, 'shape': shape}


def read_ctp(reader: BinaryReader) -> Dict:
    name = reader.read_str_with_size('b')
    layer = reader.read_uint8()
    px = reader.read_int16()
    py = reader.read_int16()
    return {'name': name, 'pos': (px, py), 'layer': layer}


def read_cvia(reader: BinaryReader) -> Dict:
    net = reader.read_int32()
    width = reader.read_int32()
    x = reader.read_int16()
    y = reader.read_int16()
    shape = reader.read_int32()
    return {'net': net, 'width': width, 'x': x, 'y': y, 'shape': shape}


def read_cpad(reader: BinaryReader) -> Dict:
    dev = reader.read_int16()
    pin = reader.read_int16()
    net = reader.read_int16()
    x = reader.read_int16()
    y = reader.read_int16()
    shape = reader.read_int32()
    return {'dev': dev, 'pin': pin, 'net': net, 'x': x, 'y': y, 'shape': shape}


def parse_cst(path: Path, include_pads: bool = False) -> Tuple[List[Dict], List[Dict], List[str], List[Dict]]:
    data = path.read_bytes()
    endian = detect_endian(data)
    reader = BinaryReader(data, endian)

    # parse m_devs (SEPARATED_DATA<CDEV>)
    dev_name, devs = parse_separated(reader, [-0x7fff], read_cdev)

    # read net count uint16, then net strings ReadStr<int8_t>
    netct = reader.read_uint16()
    nets: List[str] = []
    for _ in range(netct):
        nets.append(reader.read_str_with_size('b'))

    # m_tps
    tps_name, tps = parse_separated(reader, [-0x76b6], read_ctp)

    # m_vias
    vias_name, vias = parse_separated(reader, [-0x6e60], read_cvia)

    pads = []
    if include_pads:
        pads_name, pads = parse_separated(reader, [-0x50d9], read_cpad)

    return vias, pads, nets, devs


def format_output(vias: List[Dict], pads: List[Dict], nets: List[str], devs: List[Dict], scale: Optional[float]) -> str:
    # Map net index to name when possible
    def net_name(idx: int) -> str:
        if 0 <= idx < len(nets):
            return nets[idx]
        return str(idx)

    def dev_name(idx: int) -> str:
        try:
            if 0 <= idx < len(devs):
                name = devs[idx].get('name', '')
                return name if name else str(idx)
        except Exception:
            pass
        return str(idx)

    # Combine vias and pads into one list with a type field
    rows = []
    for v in vias:
        rows.append({'type': 'VIA', 'net_idx': v['net'], 'net': net_name(v['net']), 'x': v['x'], 'y': v['y'], 'width': v['width'], 'extra': v.get('shape')})
    for p in pads:
        rows.append({'type': 'PAD', 'net_idx': p['net'], 'net': net_name(p['net']), 'x': p['x'], 'y': p['y'], 'width': None, 'extra': (p.get('dev'), p.get('pin'))})

    # sort by net name (case-insensitive)
    rows.sort(key=lambda r: (r['net'].lower(), r['type']))

    # minimal-width columns: 1-char Type, small Net, compact X/Y, single-char info keys
    hdr = f"{'T':1} {'Net':10} {'X':>7} {'Y':>7} {'Info'}"

    def fmt_row(r):
        if scale:
            xs = f"{r['x']/scale:7.4f}"
            ys = f"{r['y']/scale:7.4f}"
        else:
            xs = f"{r['x']:7d}"
            ys = f"{r['y']:7d}"
        info_parts = []
        tchar = r['type'][0] if r['type'] else '?'
        if tchar == 'V':
            if r['width'] is not None:
                info_parts.append(f"w={r['width']}")
        else:
            dev_idx = r['extra'][0]
            pin = r['extra'][1]
            info_parts.append(f"d={dev_name(dev_idx)}")
            info_parts.append(f"p={pin}")
        info = ' '.join(info_parts)
        return f"{tchar:1} {r['net'][:10]:10} {xs} {ys} {info}"

    lines = [hdr, '-' * len(hdr)]
    for r in rows:
        lines.append(fmt_row(r))

    # two-column layout
    def two_columnify(lines_list, gap=4):
        items = lines_list
        if not items:
            return ''
        out_lines = []
        # pair consecutive items so nearby entries appear together (0&1, 2&3, ...)
        pair_width = 0
        pairs = []
        for i in range(0, len(items), 2):
            L = items[i]
            R = items[i+1] if i+1 < len(items) else ''
            pairs.append((L, R))
            if len(L) > pair_width:
                pair_width = len(L)
        for L, R in pairs:
            if R:
                out_lines.append(L.ljust(pair_width) + ' ' * gap + R)
            else:
                out_lines.append(L)
        return '\n'.join(out_lines)

    return two_columnify(lines)


def format_grouped_output(vias: List[Dict], pads: List[Dict], nets: List[str], grid_size_mils: int, scale: Optional[float], devs: List[Dict]=[]) -> str:
    # group entries into grid cells of size `grid_size_mils` (coords in mils)
    def net_name(idx: int) -> str:
        if 0 <= idx < len(nets):
            return nets[idx]
        return str(idx)

    rows = []
    for v in vias:
        rows.append({'type': 'VIA', 'net_idx': v['net'], 'net': net_name(v['net']), 'x': v['x'], 'y': v['y'], 'width': v['width'], 'extra': None})
    for p in pads:
        rows.append({'type': 'PAD', 'net_idx': p['net'], 'net': net_name(p['net']), 'x': p['x'], 'y': p['y'], 'width': None, 'extra': (p.get('dev'), p.get('pin'))})

    # compute grid key (col, row)
    grid = {}
    for r in rows:
        col = r['x'] // grid_size_mils
        row = r['y'] // grid_size_mils
        key = (col, row)
        grid.setdefault(key, []).append(r)

    # sort keys by row then col (so output reads top-to-bottom-ish)
    keys = sorted(grid.keys(), key=lambda k: (k[1], k[0]))
    out_lines = [f"Grid size: {grid_size_mils} mils ({grid_size_mils/1000.0:.3f} in)"]
    for k in keys:
        col, row = k
        out_lines.append(f"\nCell (col={col}, row={row})  X range [{col*grid_size_mils}..{(col+1)*grid_size_mils}) Y range [{row*grid_size_mils}..{(row+1)*grid_size_mils})")
        cell = grid[k]
        # sort cell entries by coordinate (Y then X) so nearby items print together
        cell.sort(key=lambda r: (r['y'], r['x']))
        # prepare cell lines then render two-column
        cell_lines = []
        for i, r in enumerate(cell, start=1):
            if scale:
                xs = f"{r['x']/scale:7.4f}"
                ys = f"{r['y']/scale:7.4f}"
            else:
                xs = f"{r['x']:7d}"
                ys = f"{r['y']:7d}"
            tchar = r['type'][0] if r['type'] else '?'
            if tchar == 'V':
                info = f"w={r['width']}" if r['width'] is not None else ''
            else:
                dev_idx = r['extra'][0]
                pin = r['extra'][1]
                dname = str(dev_idx)
                try:
                    if 0 <= dev_idx < len(devs):
                        dname = devs[dev_idx].get('name', dname)
                except Exception:
                    pass
                info = f"d={dname} p={pin}"
            cell_lines.append(f"{tchar:1} {r['net'][:10]:10} {xs} {ys}  {info}")
        # use two-column formatter
        def two_columnify_list(items, gap=4):
            n = len(items)
            half = (n + 1) // 2
            left = items[:half]
            right = items[half:]
            left_width = max((len(s) for s in left), default=0)
            out = []
            for i in range(half):
                L = left[i]
                R = right[i] if i < len(right) else ''
                if R:
                    out.append('  ' + L.ljust(left_width) + ' ' * gap + R)
                else:
                    out.append('  ' + L)
            return out
        out_lines.extend(two_columnify_list(cell_lines))

    return '\n'.join(out_lines)


def main(argv=None) -> int:
    p = argparse.ArgumentParser(description='Parse CAST .CST and list CVia entries')
    p.add_argument('input', help='Input .CST file')
    p.add_argument('--scale', '-s', type=float, default=None, help='Divide coordinates by this scale (e.g. 1000 to convert mil->inch)')
    p.add_argument('--pads', action='store_true', help='Include CPads in output')
    p.add_argument('--group', action='store_true', help='Group entries into grid cells')
    p.add_argument('--grid-size', type=int, default=1000, help='Grid cell size in mils (default 1000 = 1 inch)')
    args = p.parse_args(argv)

    path = Path(args.input)
    if not path.exists():
        print('Input file not found', file=sys.stderr)
        return 2

    try:
        vias, pads, nets, devs = parse_cst(path, include_pads=args.pads)
    except Exception as exc:
        print('Error parsing CST:', exc, file=sys.stderr)
        return 3
    if args.group:
        print(format_grouped_output(vias, pads if args.pads else [], nets, args.grid_size, args.scale, devs))
    else:
        print(format_output(vias, pads if args.pads else [], nets, devs, args.scale))
    return 0


if __name__ == '__main__':
    raise SystemExit(main())
