from __future__ import annotations

from pathlib import Path
from typing import Any

from packages.pet_package_schema import ALLOWED_ACTIONS
from services.pet_builder.action_atlas import (
    CELL_HEIGHT,
    CELL_WIDTH,
    _extract_components,
    _fit_component_to_cell,
    _merge_frame_components,
)


def write_contact_sheet_png(*, spritesheet_bytes: bytes, output_path: str | Path) -> dict[str, Any]:
    path = Path(output_path)
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_bytes(spritesheet_bytes)
    return {
        "path": str(path.name),
        "kind": "atlas_contact_sheet",
        "format": "png",
    }


def write_animation_previews(
    *,
    action_frames: dict[str, list[bytes]],
    output_dir: str | Path,
    frame_delay_ms: int = 240,
) -> dict[str, Any]:
    preview_dir = Path(output_dir)
    preview_dir.mkdir(parents=True, exist_ok=True)
    previews: dict[str, str] = {}
    for action in ALLOWED_ACTIONS:
        frames = action_frames.get(action, [])[:4]
        if not frames:
            continue
        fitted_frames = [_fit_frame_to_cell(frame) for frame in frames]
        path = preview_dir / f"{action}.gif"
        path.write_bytes(_animated_gif(fitted_frames, delay_ms=frame_delay_ms))
        previews[action] = str(path.relative_to(preview_dir.parent.parent))
    return {
        "path": str(preview_dir.name),
        "format": "gif",
        "frame_delay_ms": frame_delay_ms,
        "previews": previews,
    }


def _fit_frame_to_cell(image_bytes: bytes) -> list[tuple[int, int, int, int]]:
    components = _extract_components(image_bytes)
    if not components:
        return [(0, 0, 0, 0)] * (CELL_WIDTH * CELL_HEIGHT)
    component, _report = _merge_frame_components(components)
    return _fit_component_to_cell(component["pixels"], component["width"], component["height"])


def _animated_gif(
    frames: list[list[tuple[int, int, int, int]]],
    *,
    delay_ms: int,
) -> bytes:
    palette, indexed_frames = _indexed_frames(frames)
    delay_cs = max(1, int(delay_ms / 10))
    data = bytearray()
    data.extend(b"GIF89a")
    data.extend(CELL_WIDTH.to_bytes(2, "little"))
    data.extend(CELL_HEIGHT.to_bytes(2, "little"))
    data.extend(bytes([0b11110111, 0, 0]))
    for red, green, blue in palette:
        data.extend(bytes([red, green, blue]))
    data.extend(b"\x21\xff\x0bNETSCAPE2.0\x03\x01\x00\x00\x00")

    for indexes in indexed_frames:
        data.extend(b"\x21\xf9\x04")
        data.extend(bytes([0b00001001]))
        data.extend(delay_cs.to_bytes(2, "little"))
        data.extend(bytes([0, 0]))
        data.extend(b"\x2c")
        data.extend((0).to_bytes(2, "little"))
        data.extend((0).to_bytes(2, "little"))
        data.extend(CELL_WIDTH.to_bytes(2, "little"))
        data.extend(CELL_HEIGHT.to_bytes(2, "little"))
        data.extend(bytes([0]))
        data.extend(bytes([8]))
        lzw = _literal_lzw_stream(indexes)
        for offset in range(0, len(lzw), 255):
            chunk = lzw[offset : offset + 255]
            data.extend(bytes([len(chunk)]))
            data.extend(chunk)
        data.extend(bytes([0]))
    data.extend(b";")
    return bytes(data)


def _indexed_frames(
    frames: list[list[tuple[int, int, int, int]]],
) -> tuple[list[tuple[int, int, int]], list[list[int]]]:
    palette: list[tuple[int, int, int]] = [(0, 0, 0)]
    palette_index = {(0, 0, 0): 0}
    indexed_frames: list[list[int]] = []
    for frame in frames:
        indexes: list[int] = []
        for red, green, blue, alpha in frame:
            if alpha == 0:
                indexes.append(0)
                continue
            color = _quantize_color(red, green, blue)
            index = palette_index.get(color)
            if index is None:
                index = len(palette)
                palette.append(color)
                palette_index[color] = index
            indexes.append(index)
        indexed_frames.append(indexes)

    palette = palette[:256]
    while len(palette) < 256:
        palette.append((0, 0, 0))
    return palette, [[min(index, 255) for index in indexes] for indexes in indexed_frames]


def _quantize_color(red: int, green: int, blue: int) -> tuple[int, int, int]:
    return (
        min(255, max(0, int(red)) // 51 * 51),
        min(255, max(0, int(green)) // 51 * 51),
        min(255, max(0, int(blue)) // 51 * 51),
    )


def _literal_lzw_stream(indexes: list[int]) -> bytes:
    clear_code = 256
    end_code = 257
    codes: list[int] = []
    for offset in range(0, len(indexes), 200):
        codes.append(clear_code)
        codes.extend(indexes[offset : offset + 200])
    codes.append(end_code)
    return _pack_codes_9bit(codes)


def _pack_codes_9bit(codes: list[int]) -> bytes:
    output = bytearray()
    accumulator = 0
    bit_count = 0
    for code in codes:
        accumulator |= int(code) << bit_count
        bit_count += 9
        while bit_count >= 8:
            output.append(accumulator & 0xFF)
            accumulator >>= 8
            bit_count -= 8
    if bit_count:
        output.append(accumulator & 0xFF)
    return bytes(output)
