from __future__ import annotations

import hashlib
import shutil
import subprocess
import struct
import tempfile
import zlib
from pathlib import Path
from typing import Callable


HeicConverter = Callable[[Path, Path], Path]


def prepare_photo_inputs(
    photo_paths: list[str | Path],
    *,
    temp_root: str | Path | None = None,
    heic_converter: HeicConverter | None = None,
) -> list[dict]:
    return [
        _prepare_one_photo_input(Path(path), temp_root=Path(temp_root) if temp_root else None, heic_converter=heic_converter)
        for path in photo_paths
    ]


def _prepare_one_photo_input(
    path: Path,
    *,
    temp_root: Path | None,
    heic_converter: HeicConverter | None,
) -> dict:
    if not path.exists():
        raise ValueError("photo file does not exist")
    _validate_raw_photo_path(path)

    source_bytes = path.read_bytes()
    source_mime = _source_mime_for_path(path)
    source_sha256 = hashlib.sha256(source_bytes).hexdigest()
    converted = False
    prepared_bytes = source_bytes
    prepared_mime = source_mime

    if source_mime == "image/heic":
        converted = True
        root = temp_root or Path(tempfile.gettempdir()) / "ai-pet-prepared-inputs"
        root.mkdir(parents=True, exist_ok=True)
        out_path = root / f"{source_sha256}.png"
        converter = heic_converter or _convert_heic_with_quicklook
        prepared_path = converter(path, out_path)
        if not prepared_path.exists() or prepared_path.stat().st_size == 0:
            raise ValueError("HEIC conversion failed")
        prepared_bytes = prepared_path.read_bytes()
        prepared_mime = _source_mime_for_path(prepared_path)
        _validate_prepared_image_is_visible(prepared_bytes, prepared_mime)

    prepared_sha256 = hashlib.sha256(prepared_bytes).hexdigest()
    return {
        "name": path.name,
        "input_kind": "raw_pet_photo",
        "source_mime": source_mime,
        "mime": prepared_mime,
        "bytes": prepared_bytes,
        "sha256": source_sha256,
        "source_sha256": source_sha256,
        "prepared_sha256": prepared_sha256,
        "converted": converted,
    }


def _validate_raw_photo_path(path: Path) -> None:
    lower_name = path.name.lower()
    parts = {part.lower() for part in path.parts}
    if "public" in parts and any("合成样例" in part for part in path.parts):
        raise ValueError("public synthetic fixture cannot be used as raw pet photo input")
    if lower_name.startswith("spritesheet.") or lower_name.startswith("atlas."):
        raise ValueError("spritesheet or atlas files are generated assets, not raw pet photos")


def _source_mime_for_path(path: Path) -> str:
    suffix = path.suffix.lower()
    if suffix in {".heic", ".heif"}:
        return "image/heic"
    if suffix == ".png":
        return "image/png"
    if suffix == ".webp":
        return "image/webp"
    return "image/jpeg"


def _convert_heic_with_quicklook(source_path: Path, output_path: Path) -> Path:
    result = subprocess.run(
        ["qlmanage", "-t", "-s", "512", "-o", str(output_path.parent), str(source_path)],
        timeout=30,
        capture_output=True,
        text=True,
        check=False,
    )
    generated_path = output_path.parent / f"{source_path.name}.png"
    if result.returncode == 0 and generated_path.exists() and generated_path.stat().st_size > 0:
        if generated_path != output_path:
            if output_path.exists():
                output_path.unlink()
            shutil.move(str(generated_path), str(output_path))
        return output_path
    return _convert_heic_with_sips_png(source_path, output_path)


def _convert_heic_with_sips_png(source_path: Path, output_path: Path) -> Path:
    result = subprocess.run(
        ["sips", "-s", "format", "png", "-Z", "1536", str(source_path), "--out", str(output_path)],
        timeout=30,
        capture_output=True,
        text=True,
        check=False,
    )
    if result.returncode != 0 or not output_path.exists() or output_path.stat().st_size == 0:
        raise ValueError("HEIC conversion failed")
    return output_path


def _validate_prepared_image_is_visible(image_bytes: bytes, image_mime: str) -> None:
    if image_mime != "image/png":
        return
    stats = _png_luminance_stats(image_bytes)
    if stats["max_luminance"] <= 5 and stats["mean_luminance"] <= 2:
        raise ValueError("HEIC conversion produced a blank image")


def _png_luminance_stats(image_bytes: bytes) -> dict[str, float]:
    if not image_bytes.startswith(b"\x89PNG\r\n\x1a\n"):
        raise ValueError("converted HEIC output is not a PNG")
    position = 8
    width = height = bit_depth = color_type = interlace = None
    idat_parts: list[bytes] = []
    while position + 12 <= len(image_bytes):
        length = struct.unpack(">I", image_bytes[position : position + 4])[0]
        kind = image_bytes[position + 4 : position + 8]
        data = image_bytes[position + 8 : position + 8 + length]
        position += 12 + length
        if kind == b"IHDR":
            width, height, bit_depth, color_type, _compression, _filter, interlace = struct.unpack(">IIBBBBB", data)
        elif kind == b"IDAT":
            idat_parts.append(data)
        elif kind == b"IEND":
            break
    if None in {width, height, bit_depth, color_type, interlace}:
        raise ValueError("converted HEIC PNG is missing header data")
    if bit_depth not in {8, 16} or color_type not in {0, 2, 4, 6} or interlace != 0:
        raise ValueError("converted HEIC PNG format is unsupported")

    channels_by_color_type = {0: 1, 2: 3, 4: 2, 6: 4}
    color_channels_by_color_type = {0: 1, 2: 3, 4: 1, 6: 3}
    bytes_per_sample = int(bit_depth) // 8
    channels = channels_by_color_type[int(color_type)]
    color_channels = color_channels_by_color_type[int(color_type)]
    bytes_per_pixel = channels * bytes_per_sample
    stride = int(width) * bytes_per_pixel
    raw = zlib.decompress(b"".join(idat_parts))
    previous = bytearray(stride)
    offset = 0
    max_luminance = 0.0
    luminance_sum = 0.0
    pixel_count = 0
    for _ in range(int(height)):
        filter_type = raw[offset]
        offset += 1
        row = bytearray(raw[offset : offset + stride])
        offset += stride
        _unfilter_png_row(row, previous, filter_type, bytes_per_pixel)
        previous = row
        for index in range(0, len(row), bytes_per_pixel):
            if color_channels == 1:
                red = green = blue = row[index]
            else:
                red = row[index]
                green = row[index + bytes_per_sample]
                blue = row[index + bytes_per_sample * 2]
            luminance = 0.2126 * red + 0.7152 * green + 0.0722 * blue
            max_luminance = max(max_luminance, luminance)
            luminance_sum += luminance
            pixel_count += 1
    return {
        "max_luminance": max_luminance,
        "mean_luminance": luminance_sum / max(pixel_count, 1),
    }


def _unfilter_png_row(row: bytearray, previous: bytearray, filter_type: int, bytes_per_pixel: int) -> None:
    if filter_type == 0:
        return
    for index in range(len(row)):
        left = row[index - bytes_per_pixel] if index >= bytes_per_pixel else 0
        up = previous[index]
        upper_left = previous[index - bytes_per_pixel] if index >= bytes_per_pixel else 0
        if filter_type == 1:
            row[index] = (row[index] + left) & 0xFF
        elif filter_type == 2:
            row[index] = (row[index] + up) & 0xFF
        elif filter_type == 3:
            row[index] = (row[index] + ((left + up) // 2)) & 0xFF
        elif filter_type == 4:
            row[index] = (row[index] + _png_paeth(left, up, upper_left)) & 0xFF
        else:
            raise ValueError(f"unsupported PNG filter in converted HEIC output: {filter_type}")


def _png_paeth(left: int, up: int, upper_left: int) -> int:
    estimate = left + up - upper_left
    left_distance = abs(estimate - left)
    up_distance = abs(estimate - up)
    upper_left_distance = abs(estimate - upper_left)
    if left_distance <= up_distance and left_distance <= upper_left_distance:
        return left
    if up_distance <= upper_left_distance:
        return up
    return upper_left
