from __future__ import annotations

import binascii
import struct
import zlib
from collections import Counter, deque
from typing import Any


PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"


def ensure_png_alpha(
    image_bytes: bytes,
    *,
    chroma_key: tuple[int, int, int] | None = None,
    chroma_threshold: float = 96.0,
) -> tuple[bytes, dict[str, Any]]:
    """Convert common Qwen checkerboard RGB PNG backgrounds into transparent alpha."""
    try:
        png = _read_png(image_bytes)
    except ValueError as exc:
        return image_bytes, {"applied": False, "alpha_capable": False, "reason": str(exc)[:120]}

    if png["color_type"] == 6:
        pixels = _decode_rgba_pixels(png)
        if chroma_key is not None:
            return _remove_explicit_chroma_background(
                pixels=[(red, green, blue, alpha) for red, green, blue, alpha in pixels],
                width=int(png["width"]),
                height=int(png["height"]),
                source_has_alpha=True,
                chroma_key=chroma_key,
                chroma_threshold=chroma_threshold,
            )
        rgb_pixels = [(red, green, blue) for red, green, blue, _alpha in pixels]
        opaque_rgb_pixels = [(red, green, blue) for red, green, blue, alpha in pixels if alpha > 0]
        if (
            any(alpha == 0 for *_rgb, alpha in pixels)
            and not _opaque_edge_background_palette(pixels, int(png["width"]), int(png["height"]))
            and not _global_chroma_background_palette(opaque_rgb_pixels)
        ):
            return image_bytes, {
                "applied": False,
                "alpha_capable": True,
                "reason": "png_already_has_alpha",
                "width": png["width"],
                "height": png["height"],
            }
        return _remove_connected_edge_background(
            pixels=[(red, green, blue, alpha) for red, green, blue, alpha in pixels],
            width=int(png["width"]),
            height=int(png["height"]),
            source_has_alpha=True,
        )

    if png["color_type"] == 4 or png["has_trns"]:
        return image_bytes, {
            "applied": False,
            "alpha_capable": True,
            "reason": "png_already_has_alpha",
            "width": png["width"],
            "height": png["height"],
        }

    if png["color_type"] != 2 or png["bit_depth"] != 8 or png["interlace"] != 0:
        return image_bytes, {
            "applied": False,
            "alpha_capable": False,
            "reason": "unsupported_png_format",
            "width": png["width"],
            "height": png["height"],
            "color_type": png["color_type"],
            "bit_depth": png["bit_depth"],
        }

    pixels = _decode_rgb_pixels(png)
    if chroma_key is not None:
        return _remove_explicit_chroma_background(
            pixels=[(red, green, blue, 255) for red, green, blue in pixels],
            width=int(png["width"]),
            height=int(png["height"]),
            source_has_alpha=False,
            chroma_key=chroma_key,
            chroma_threshold=chroma_threshold,
        )
    return _remove_connected_edge_background(
        pixels=[(red, green, blue, 255) for red, green, blue in pixels],
        width=int(png["width"]),
        height=int(png["height"]),
        source_has_alpha=False,
    )


def _remove_explicit_chroma_background(
    *,
    pixels: list[tuple[int, int, int, int]],
    width: int,
    height: int,
    source_has_alpha: bool,
    chroma_key: tuple[int, int, int],
    chroma_threshold: float,
) -> tuple[bytes, dict[str, Any]]:
    rgba = bytearray()
    transparent_count = 0
    chroma_residue_pixels = 0
    for red, green, blue, original_alpha in pixels:
        if original_alpha == 0 or _rgb_distance((red, green, blue), chroma_key) <= chroma_threshold:
            transparent_count += 1
            rgba.extend((0, 0, 0, 0))
        else:
            if _is_chroma_residue_pixel((red, green, blue), chroma_key):
                chroma_residue_pixels += 1
            rgba.extend((red, green, blue, original_alpha))

    return _write_rgba_png(width, height, bytes(rgba)), {
        "applied": transparent_count > 0,
        "alpha_capable": True,
        "method": "explicit_chroma_key",
        "width": width,
        "height": height,
        "transparent_pixels": transparent_count,
        "source_has_alpha": source_has_alpha,
        "chroma_key": list(chroma_key),
        "chroma_threshold": chroma_threshold,
        "chroma_residue_pixels": chroma_residue_pixels,
    }


def _remove_connected_edge_background(
    *,
    pixels: list[tuple[int, int, int, int]],
    width: int,
    height: int,
    source_has_alpha: bool,
) -> tuple[bytes, dict[str, Any]]:
    rgb_pixels = [(red, green, blue) for red, green, blue, _alpha in pixels]
    transparent = _find_connected_edge_background(rgb_pixels, width, height)
    detached_chroma_palette = _global_chroma_background_palette(rgb_pixels)
    if detached_chroma_palette:
        for index, pixel in enumerate(rgb_pixels):
            if transparent[index]:
                continue
            if any(_rgb_distance(pixel, color) <= 96 or _is_chroma_shadow_pixel(pixel, color) for color in detached_chroma_palette):
                transparent[index] = True
        _remove_connected_chroma_halo(rgb_pixels, transparent, width, height, detached_chroma_palette)
    rgba = bytearray()
    transparent_count = 0
    for index, (red, green, blue, original_alpha) in enumerate(pixels):
        alpha = 0 if transparent[index] else original_alpha
        if alpha == 0:
            transparent_count += 1
            rgba.extend((0, 0, 0, 0))
        else:
            rgba.extend((red, green, blue, alpha))

    applied = transparent_count > 0
    return _write_rgba_png(width, height, bytes(rgba)), {
        "applied": applied,
        "alpha_capable": True,
        "method": "edge_background_flood_fill",
        "width": width,
        "height": height,
        "transparent_pixels": transparent_count,
        "source_has_alpha": source_has_alpha,
    }


def _read_png(image_bytes: bytes) -> dict[str, Any]:
    if not image_bytes.startswith(PNG_SIGNATURE):
        raise ValueError("not_png")
    position = len(PNG_SIGNATURE)
    width = height = bit_depth = color_type = compression = filter_method = interlace = None
    idat_parts: list[bytes] = []
    has_trns = False
    while position + 12 <= len(image_bytes):
        length = struct.unpack(">I", image_bytes[position : position + 4])[0]
        chunk_type = image_bytes[position + 4 : position + 8]
        data_start = position + 8
        data_end = data_start + length
        chunk_data = image_bytes[data_start:data_end]
        position = data_end + 4
        if chunk_type == b"IHDR":
            width, height, bit_depth, color_type, compression, filter_method, interlace = struct.unpack(
                ">IIBBBBB", chunk_data
            )
        elif chunk_type == b"IDAT":
            idat_parts.append(chunk_data)
        elif chunk_type == b"tRNS":
            has_trns = True
        elif chunk_type == b"IEND":
            break
    if None in {width, height, bit_depth, color_type, compression, filter_method, interlace}:
        raise ValueError("missing_ihdr")
    if compression != 0 or filter_method != 0:
        raise ValueError("unsupported_png_compression")
    return {
        "width": width,
        "height": height,
        "bit_depth": bit_depth,
        "color_type": color_type,
        "interlace": interlace,
        "idat": b"".join(idat_parts),
        "has_trns": has_trns,
    }


def _decode_rgb_pixels(png: dict[str, Any]) -> list[tuple[int, int, int]]:
    width = int(png["width"])
    height = int(png["height"])
    bytes_per_pixel = 3
    stride = width * bytes_per_pixel
    raw = zlib.decompress(png["idat"])
    rows: list[bytearray] = []
    position = 0
    previous = bytearray(stride)
    for _ in range(height):
        filter_type = raw[position]
        position += 1
        row = bytearray(raw[position : position + stride])
        position += stride
        _unfilter_row(row, previous, filter_type, bytes_per_pixel)
        rows.append(row)
        previous = row

    pixels: list[tuple[int, int, int]] = []
    for row in rows:
        for x in range(0, len(row), 3):
            pixels.append((row[x], row[x + 1], row[x + 2]))
    return pixels


def _decode_rgba_pixels(png: dict[str, Any]) -> list[tuple[int, int, int, int]]:
    width = int(png["width"])
    height = int(png["height"])
    bytes_per_pixel = 4
    stride = width * bytes_per_pixel
    raw = zlib.decompress(png["idat"])
    rows: list[bytearray] = []
    position = 0
    previous = bytearray(stride)
    for _ in range(height):
        filter_type = raw[position]
        position += 1
        row = bytearray(raw[position : position + stride])
        position += stride
        _unfilter_row(row, previous, filter_type, bytes_per_pixel)
        rows.append(row)
        previous = row

    pixels: list[tuple[int, int, int, int]] = []
    for row in rows:
        for x in range(0, len(row), 4):
            pixels.append((row[x], row[x + 1], row[x + 2], row[x + 3]))
    return pixels


def _unfilter_row(row: bytearray, previous: bytearray, filter_type: int, bytes_per_pixel: int) -> None:
    if filter_type == 0:
        return
    for index in range(len(row)):
        left = row[index - bytes_per_pixel] if index >= bytes_per_pixel else 0
        up = previous[index]
        upper_left = previous[index - bytes_per_pixel] if index >= bytes_per_pixel else 0
        if filter_type == 1:
            row[index] = (row[index] + left) & 0xFF
        elif filter_type == 2:
            row[index] = (row[index] + up) & 0xFF
        elif filter_type == 3:
            row[index] = (row[index] + ((left + up) // 2)) & 0xFF
        elif filter_type == 4:
            row[index] = (row[index] + _paeth(left, up, upper_left)) & 0xFF
        else:
            raise ValueError(f"unsupported_png_filter_{filter_type}")


def _paeth(left: int, up: int, upper_left: int) -> int:
    estimate = left + up - upper_left
    left_distance = abs(estimate - left)
    up_distance = abs(estimate - up)
    upper_left_distance = abs(estimate - upper_left)
    if left_distance <= up_distance and left_distance <= upper_left_distance:
        return left
    if up_distance <= upper_left_distance:
        return up
    return upper_left


def _find_connected_edge_background(
    pixels: list[tuple[int, int, int]],
    width: int,
    height: int,
) -> list[bool]:
    palette = _edge_background_palette(pixels, width, height)
    transparent = [False] * (width * height)
    if not palette:
        return transparent

    queue: deque[int] = deque()
    for x in range(width):
        for y in (0, height - 1):
            index = y * width + x
            if _is_background_pixel(pixels[index], palette):
                transparent[index] = True
                queue.append(index)
    for y in range(height):
        for x in (0, width - 1):
            index = y * width + x
            if not transparent[index] and _is_background_pixel(pixels[index], palette):
                transparent[index] = True
                queue.append(index)

    while queue:
        index = queue.popleft()
        x = index % width
        y = index // width
        for next_x, next_y in ((x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)):
            if next_x < 0 or next_x >= width or next_y < 0 or next_y >= height:
                continue
            next_index = next_y * width + next_x
            if transparent[next_index]:
                continue
            if _is_background_pixel(pixels[next_index], palette):
                transparent[next_index] = True
                queue.append(next_index)
    return transparent


def _edge_background_palette(pixels: list[tuple[int, int, int]], width: int, height: int) -> list[tuple[int, int, int]]:
    edge_pixels: list[tuple[int, int, int]] = []
    for x in range(width):
        edge_pixels.append(pixels[x])
        edge_pixels.append(pixels[(height - 1) * width + x])
    for y in range(height):
        edge_pixels.append(pixels[y * width])
        edge_pixels.append(pixels[y * width + width - 1])
    candidates = [
        _quantize(pixel)
        for pixel in edge_pixels
        if _is_supported_background_pixel(pixel)
    ]
    return [color for color, _ in Counter(candidates).most_common(6)]


def _opaque_edge_background_palette(
    pixels: list[tuple[int, int, int, int]],
    width: int,
    height: int,
) -> list[tuple[int, int, int]]:
    edge_pixels: list[tuple[int, int, int]] = []
    for x in range(width):
        for index in (x, (height - 1) * width + x):
            red, green, blue, alpha = pixels[index]
            if alpha > 0:
                edge_pixels.append((red, green, blue))
    for y in range(height):
        for index in (y * width, y * width + width - 1):
            red, green, blue, alpha = pixels[index]
            if alpha > 0:
                edge_pixels.append((red, green, blue))
    candidates = [_quantize(pixel) for pixel in edge_pixels if _is_supported_background_pixel(pixel)]
    return [color for color, _ in Counter(candidates).most_common(6)]


def _global_chroma_background_palette(pixels: list[tuple[int, int, int]]) -> list[tuple[int, int, int]]:
    candidates = [_quantize(pixel) for pixel in pixels if _is_chroma_key_like(pixel)]
    if not candidates:
        return []
    threshold = max(4, int(len(pixels) * 0.01))
    return [color for color, count in Counter(candidates).most_common(3) if count >= threshold]


def _remove_connected_chroma_halo(
    pixels: list[tuple[int, int, int]],
    transparent: list[bool],
    width: int,
    height: int,
    palette: list[tuple[int, int, int]],
) -> None:
    max_iterations = 1 if _palette_is_green_chroma(palette) else 3
    for _ in range(max_iterations):
        to_clear: list[int] = []
        for index, pixel in enumerate(pixels):
            if transparent[index] or not _is_chroma_halo_pixel(pixel, palette):
                continue
            x = index % width
            y = index // width
            for next_x, next_y in ((x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)):
                if next_x < 0 or next_x >= width or next_y < 0 or next_y >= height:
                    continue
                if transparent[next_y * width + next_x]:
                    to_clear.append(index)
                    break
        if not to_clear:
            return
        for index in to_clear:
            transparent[index] = True


def _palette_is_green_chroma(palette: list[tuple[int, int, int]]) -> bool:
    for red, green, blue in palette:
        if green > red and green > blue:
            return True
    return False


def _quantize(pixel: tuple[int, int, int]) -> tuple[int, int, int]:
    return tuple((value // 8) * 8 for value in pixel)


def _is_low_saturation_gray(pixel: tuple[int, int, int]) -> bool:
    return max(pixel) - min(pixel) <= 24


def _is_supported_background_pixel(pixel: tuple[int, int, int]) -> bool:
    if _is_low_saturation_gray(pixel):
        brightness = sum(pixel) // 3
        return 0 <= brightness <= 255
    return (
        _is_high_saturation_chroma(pixel)
        or _is_desaturated_green_chroma(pixel)
        or any(_rgb_distance(pixel, color) <= 72 for color in _chroma_key_colors())
    )


def _is_chroma_key_like(pixel: tuple[int, int, int]) -> bool:
    return (
        _is_high_saturation_chroma(pixel)
        or _is_desaturated_green_chroma(pixel)
        or any(_rgb_distance(pixel, color) <= 72 for color in _chroma_key_colors())
    )


def _is_chroma_halo_pixel(pixel: tuple[int, int, int], palette: list[tuple[int, int, int]]) -> bool:
    red, green, blue = pixel
    for color in palette:
        color_red, color_green, color_blue = color
        if color_green > color_red and color_green > color_blue:
            if green >= 135 and green - red >= 12 and green - blue >= 12:
                return True
        if color_red > color_green and color_blue > color_green:
            if red >= 135 and blue >= 100 and green + 12 < min(red, blue):
                return True
        if color_green > color_blue and color_red > color_blue:
            if green >= 135 and red >= 100 and blue + 12 < min(red, green):
                return True
    return False


def _is_chroma_residue_pixel(pixel: tuple[int, int, int], chroma: tuple[int, int, int]) -> bool:
    red, green, blue = pixel
    chroma_red, chroma_green, chroma_blue = chroma
    if chroma_red > 150 and chroma_blue > 150 and chroma_green < 120:
        return red >= 120 and blue >= 80 and green + 30 < red and green + 20 < blue
    if chroma_green > 150 and chroma_blue > 150 and chroma_red < 120:
        return green >= 120 and blue >= 80 and red + 30 < green and red + 20 < blue
    if chroma_green > 150 and chroma_red > 150 and chroma_blue < 120:
        return green >= 120 and red >= 120 and blue + 30 < green and blue + 20 < red
    return _is_chroma_halo_pixel(pixel, [chroma])


def _is_high_saturation_chroma(pixel: tuple[int, int, int]) -> bool:
    red, green, blue = pixel
    if max(pixel) - min(pixel) < 80:
        return False
    if sum(pixel) // 3 < 35:
        return False
    return (
        (red > 150 and blue > 100 and green < 120)
        or (green > 150 and blue > 100 and red < 120)
        or (green > 150 and red < 140 and blue < 140)
        or (green > 150 and red > 100 and blue < 120)
        or (blue > 150 and red < 120 and green < 160)
        or (red > 180 and green > 70 and blue < 80)
    )


def _is_desaturated_green_chroma(pixel: tuple[int, int, int]) -> bool:
    red, green, blue = pixel
    return green >= 120 and red <= 130 and blue <= 90 and green - red >= 40 and green - blue >= 60


def _is_background_pixel(pixel: tuple[int, int, int], palette: list[tuple[int, int, int]]) -> bool:
    return any(
        (_is_supported_background_pixel(pixel) and _rgb_distance(pixel, color) <= 72)
        or _is_chroma_shadow_pixel(pixel, color)
        for color in palette
    )


def _is_chroma_shadow_pixel(pixel: tuple[int, int, int], chroma: tuple[int, int, int]) -> bool:
    if not _is_high_saturation_chroma(chroma):
        return False
    red, green, blue = pixel
    chroma_red, chroma_green, chroma_blue = chroma
    high_channels = [index for index, value in enumerate(chroma) if value > 150]
    low_channels = [index for index, value in enumerate(chroma) if value <= 150]
    if not high_channels:
        return False
    values = (red, green, blue)
    if any(values[index] < 75 or values[index] > chroma[index] + 16 for index in high_channels):
        return False
    if any(values[index] > 120 for index in low_channels):
        return False
    high_min = min(values[index] for index in high_channels)
    low_max = max((values[index] for index in low_channels), default=0)
    if high_min - low_max < 45:
        return False
    if sum(values) // 3 < 35:
        return False
    if chroma_green > 150 and chroma_blue > 150 and chroma_red < 120:
        return abs(green - blue) <= 90
    if chroma_red > 150 and chroma_blue > 150 and chroma_green < 120:
        return abs(red - blue) <= 90
    if chroma_red > 150 and chroma_green > 150 and chroma_blue < 120:
        return abs(red - green) <= 90
    return True


def _rgb_distance(left: tuple[int, int, int], right: tuple[int, int, int]) -> float:
    return sum((left[index] - right[index]) ** 2 for index in range(3)) ** 0.5


def _chroma_key_colors() -> tuple[tuple[int, int, int], ...]:
    return (
        (255, 0, 255),
        (0, 255, 255),
        (0, 255, 0),
        (255, 255, 0),
        (0, 0, 255),
        (255, 127, 0),
    )


def _write_rgba_png(width: int, height: int, rgba: bytes) -> bytes:
    stride = width * 4
    raw = bytearray()
    for y in range(height):
        raw.append(0)
        raw.extend(rgba[y * stride : (y + 1) * stride])
    chunks = [
        _png_chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 6, 0, 0, 0)),
        _png_chunk(b"IDAT", zlib.compress(bytes(raw))),
        _png_chunk(b"IEND", b""),
    ]
    return PNG_SIGNATURE + b"".join(chunks)


def _png_chunk(chunk_type: bytes, data: bytes) -> bytes:
    checksum = binascii.crc32(chunk_type + data) & 0xFFFFFFFF
    return struct.pack(">I", len(data)) + chunk_type + data + struct.pack(">I", checksum)
