from __future__ import annotations

import struct
import zlib
from pathlib import Path
from typing import Any

from packages.pet_package_schema import DEFAULT_FRAMES_PER_ACTION


CELL_WIDTH = 192
CELL_HEIGHT = 208
FRAMES = DEFAULT_FRAMES_PER_ACTION
P0_STATES = ("idle", "sleep", "walk", "look", "sit", "tail_wag")

PALETTES: dict[str, dict[str, tuple[int, int, int, int]]] = {
    "gray": {
        "body": (118, 126, 130, 255),
        "body2": (91, 101, 106, 255),
        "chest": (236, 235, 226, 255),
        "outline": (58, 55, 52, 255),
        "eye": (79, 108, 83, 255),
        "nose": (201, 132, 132, 255),
    },
    "golden": {
        "body": (190, 170, 130, 255),
        "body2": (142, 127, 101, 255),
        "chest": (245, 237, 218, 255),
        "outline": (71, 59, 47, 255),
        "eye": (69, 116, 78, 255),
        "nose": (207, 135, 132, 255),
    },
    "brown": {
        "body": (142, 106, 72, 255),
        "body2": (99, 73, 50, 255),
        "chest": (232, 219, 192, 255),
        "outline": (63, 49, 40, 255),
        "eye": (80, 101, 70, 255),
        "nose": (193, 126, 112, 255),
    },
    "black": {
        "body": (61, 63, 66, 255),
        "body2": (38, 40, 43, 255),
        "chest": (225, 225, 214, 255),
        "outline": (31, 30, 29, 255),
        "eye": (103, 143, 91, 255),
        "nose": (190, 119, 125, 255),
    },
    "white": {
        "body": (228, 226, 214, 255),
        "body2": (181, 178, 164, 255),
        "chest": (247, 242, 226, 255),
        "outline": (74, 67, 58, 255),
        "eye": (76, 119, 86, 255),
        "nose": (208, 134, 134, 255),
    },
}


def palette_from_analysis(analysis: dict[str, Any]) -> dict[str, tuple[int, int, int, int]]:
    raw = f"{analysis.get('base_color', '')} {analysis.get('accent_color', '')}".lower()
    if any(token in raw for token in ["gold", "yellow", "cream", "orange"]):
        return PALETTES["golden"]
    if any(token in raw for token in ["brown", "chocolate", "tan"]):
        return PALETTES["brown"]
    if "black" in raw:
        return PALETTES["black"]
    if "white" in raw and "gray" not in raw and "grey" not in raw:
        return PALETTES["white"]
    return PALETTES["gray"]


def write_pixel_spritesheet(path: str | Path, analysis: dict[str, Any]) -> None:
    palette = palette_from_analysis(analysis)
    width = CELL_WIDTH * FRAMES
    height = CELL_HEIGHT * len(P0_STATES)
    pixels = bytearray(width * height * 4)
    for row, state in enumerate(P0_STATES):
        for frame in range(FRAMES):
            draw_frame(pixels, row, frame, state, palette)
    write_png(Path(path), width, height, pixels)


def draw_frame(
    pixels: bytearray,
    row: int,
    frame: int,
    state: str,
    palette: dict[str, tuple[int, int, int, int]],
) -> None:
    x0 = frame * CELL_WIDTH
    y0 = row * CELL_HEIGHT
    bob = -3 if state in {"idle", "look"} and frame in {1, 2} else 0
    blink = state in {"idle", "look", "sit"} and frame == 2
    sleep = state == "sleep"
    walk_offset = (frame % 2) * 5 if state == "walk" else 0
    tail_lift = 12 if state == "tail_wag" and frame in {1, 3} else 4

    base_x = x0 + 62 + walk_offset
    base_y = y0 + 82 + bob
    outline = palette["outline"]
    body = palette["body"]
    body2 = palette["body2"]
    chest = palette["chest"]
    eye = palette["eye"]
    nose = palette["nose"]

    rect(pixels, base_x + 62, base_y + 36 - tail_lift, 34, 12, outline)
    rect(pixels, base_x + 66, base_y + 36 - tail_lift + 4, 28, 8, body2)
    rect(pixels, base_x + 10, base_y + 42, 66, 42, outline)
    rect(pixels, base_x + 14, base_y + 46, 58, 34, body)
    rect(pixels, base_x + 30, base_y + 52, 26, 28, chest)

    leg_shift = 5 if state == "walk" and frame in {1, 3} else 0
    rect(pixels, base_x + 18 + leg_shift, base_y + 78, 14, 16, outline)
    rect(pixels, base_x + 52 - leg_shift, base_y + 78, 14, 16, outline)
    rect(pixels, base_x + 22 + leg_shift, base_y + 80, 8, 12, body2)
    rect(pixels, base_x + 54 - leg_shift, base_y + 80, 8, 12, body2)

    head_y = base_y + 8 if state != "sleep" else base_y + 14
    rect(pixels, base_x + 16, head_y + 14, 54, 44, outline)
    rect(pixels, base_x + 20, head_y + 18, 46, 36, body)
    rect(pixels, base_x + 18, head_y + 4, 14, 18, outline)
    rect(pixels, base_x + 54, head_y + 4, 14, 18, outline)
    rect(pixels, base_x + 22, head_y + 10, 8, 10, body2)
    rect(pixels, base_x + 56, head_y + 10, 8, 10, body2)

    if sleep or blink:
        rect(pixels, base_x + 30, head_y + 34, 10, 3, outline)
        rect(pixels, base_x + 48, head_y + 34, 10, 3, outline)
    else:
        rect(pixels, base_x + 30, head_y + 30, 8, 8, eye)
        rect(pixels, base_x + 50, head_y + 30, 8, 8, eye)
    rect(pixels, base_x + 42, head_y + 41, 8, 6, nose)
    rect(pixels, base_x + 34, head_y + 44, 22, 8, chest)

    if state == "look":
        rect(pixels, base_x + 63, head_y + 26, 6, 6, eye)
    if state == "sit":
        rect(pixels, base_x + 16, base_y + 70, 56, 20, outline)
        rect(pixels, base_x + 20, base_y + 72, 48, 16, body)

def rect(pixels: bytearray, x: int, y: int, w: int, h: int, color: tuple[int, int, int, int]) -> None:
    width = CELL_WIDTH * FRAMES
    height = CELL_HEIGHT * len(P0_STATES)
    for py in range(max(0, y), min(height, y + h)):
        for px in range(max(0, x), min(width, x + w)):
            index = (py * width + px) * 4
            pixels[index : index + 4] = bytes(color)


def write_png(path: Path, width: int, height: int, pixels: bytearray) -> None:
    raw = bytearray()
    stride = width * 4
    for y in range(height):
        raw.append(0)
        raw.extend(pixels[y * stride : (y + 1) * stride])
    compressed = zlib.compress(bytes(raw), level=9)
    path.write_bytes(
        b"\x89PNG\r\n\x1a\n"
        + chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 6, 0, 0, 0))
        + chunk(b"IDAT", compressed)
        + chunk(b"IEND", b"")
    )


def chunk(kind: bytes, payload: bytes) -> bytes:
    checksum = zlib.crc32(kind + payload) & 0xFFFFFFFF
    return struct.pack(">I", len(payload)) + kind + payload + struct.pack(">I", checksum)
