from __future__ import annotations

import binascii
import struct
import zlib
from collections import deque
from dataclasses import dataclass
from typing import Any

from packages.pet_package_schema import ALLOWED_ACTIONS, DEFAULT_FRAMES_PER_ACTION


CELL_WIDTH = 192
CELL_HEIGHT = 208
FRAMES_PER_ACTION = DEFAULT_FRAMES_PER_ACTION
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
MAX_ROW_NORMALIZE_UPSCALE = 1.35


@dataclass
class ActionAtlasResult:
    image_bytes: bytes
    qa: dict[str, Any]


def compose_action_atlas(action_images: dict[str, bytes]) -> ActionAtlasResult:
    atlas_width = CELL_WIDTH * FRAMES_PER_ACTION
    atlas_height = CELL_HEIGHT * len(ALLOWED_ACTIONS)
    atlas = [(0, 0, 0, 0)] * (atlas_width * atlas_height)
    rows: list[dict[str, Any]] = []
    selected_frame_count = 0
    ok = True
    failure_reason = None

    for row_index, action in enumerate(ALLOWED_ACTIONS):
        image_bytes = action_images.get(action)
        components = _extract_components(image_bytes) if image_bytes else []
        if len(components) < FRAMES_PER_ACTION:
            ok = False
            failure_reason = failure_reason or "not_enough_frames"
        selected = components[:FRAMES_PER_ACTION]
        selected_frame_count += len(selected)
        frame_reports = []
        for frame_index, component in enumerate(selected):
            fitted = _fit_component_to_cell(component["pixels"], component["width"], component["height"])
            _paste_cell(atlas, atlas_width, row_index, frame_index, fitted)
            lower_metrics = _lower_foreground_metrics(component)
            frame_reports.append(
                {
                    "frame": frame_index,
                    "source_bbox": component["bbox"],
                    "foreground_pixels": component["foreground_pixels"],
                    **lower_metrics,
                }
            )
        rows.append(
            {
                "action": action,
                "component_count": len(components),
                "selected_frames": len(selected),
                "frames": frame_reports,
            }
        )

    qa = {
        "ok": ok,
        "failure_reason": failure_reason,
        "cell_width": CELL_WIDTH,
        "cell_height": CELL_HEIGHT,
        "frames_per_action": FRAMES_PER_ACTION,
        "action_count": len(ALLOWED_ACTIONS),
        "selected_frame_count": selected_frame_count,
        "rows": rows,
    }
    return ActionAtlasResult(image_bytes=_write_rgba_png(atlas_width, atlas_height, atlas), qa=qa)


def compose_action_frame_atlas(
    action_frames: dict[str, list[bytes]],
    *,
    allow_stable_pet_parts: bool = False,
) -> ActionAtlasResult:
    atlas_width = CELL_WIDTH * FRAMES_PER_ACTION
    atlas_height = CELL_HEIGHT * len(ALLOWED_ACTIONS)
    atlas = [(0, 0, 0, 0)] * (atlas_width * atlas_height)
    rows: list[dict[str, Any]] = []
    selected_frame_count = 0
    ok = True
    failure_reason = None

    for row_index, action in enumerate(ALLOWED_ACTIONS):
        frame_images = list(action_frames.get(action, []))
        if len(frame_images) < FRAMES_PER_ACTION:
            ok = False
            failure_reason = failure_reason or "not_enough_frames"
        selected_images = frame_images[:FRAMES_PER_ACTION]
        frame_reports = []
        prepared_frames: list[dict[str, Any]] = []
        for frame_index, image_bytes in enumerate(selected_images):
            components = _extract_components(image_bytes)
            if not components:
                ok = False
                failure_reason = failure_reason or "missing_frame_foreground"
                prepared_frames.append(
                    {
                        "frame": frame_index,
                        "component_count": 0,
                        "source_bbox": None,
                        "foreground_pixels": 0,
                        "missing": True,
                    }
                )
                continue
            component, component_report = _merge_frame_components(
                components,
                action=action,
                allow_stable_pet_parts=allow_stable_pet_parts,
            )
            if component_report["fragmented_foreground"]:
                ok = False
                failure_reason = failure_reason or "fragmented_frame_foreground"
            bbox_area = max(1, component["width"] * component["height"])
            lower_metrics = _lower_foreground_metrics(component)
            prepared_frames.append(
                {
                    "frame": frame_index,
                    "component": component,
                    "component_count": len(components),
                    "accepted_component_count": component_report["accepted_component_count"],
                    "fragmented_foreground": component_report["fragmented_foreground"],
                    "stable_pet_parts_merged": component_report.get("stable_pet_parts_merged", False),
                    "source_bbox": component["bbox"],
                    "foreground_pixels": component["foreground_pixels"],
                    "bbox_fill_ratio": component["foreground_pixels"] / bbox_area,
                    **lower_metrics,
                }
            )
        valid_components = [item["component"] for item in prepared_frames if "component" in item]
        row_scales = _row_normalized_fit_scales(valid_components)
        valid_index = 0
        for item in prepared_frames:
            if "component" not in item:
                frame_reports.append(
                    {
                        "frame": item["frame"],
                        "component_count": item["component_count"],
                        "source_bbox": item["source_bbox"],
                        "foreground_pixels": item["foreground_pixels"],
                    }
                )
                continue
            component = item.pop("component")
            fit_scale = row_scales[valid_index] if valid_index < len(row_scales) else None
            valid_index += 1
            fitted, fit_report = _fit_component_to_cell_with_report(
                component["pixels"],
                component["width"],
                component["height"],
                scale=fit_scale,
                fit_mode="row_normalized_area" if len(valid_components) > 1 else "fit_to_cell",
            )
            _paste_cell(atlas, atlas_width, row_index, int(item["frame"]), fitted)
            selected_frame_count += 1
            frame_reports.append({**item, **fit_report})
        rows.append(
            {
                "action": action,
                "input_frames": len(frame_images),
                "selected_frames": len(selected_images),
                "frames": frame_reports,
            }
        )

    qa = {
        "ok": ok,
        "failure_reason": failure_reason,
        "cell_width": CELL_WIDTH,
        "cell_height": CELL_HEIGHT,
        "frames_per_action": FRAMES_PER_ACTION,
        "action_count": len(ALLOWED_ACTIONS),
        "selected_frame_count": selected_frame_count,
        "rows": rows,
    }
    return ActionAtlasResult(image_bytes=_write_rgba_png(atlas_width, atlas_height, atlas), qa=qa)


def extract_action_strip_frames(
    action: str,
    image_bytes: bytes,
    *,
    allow_stable_pet_parts: bool = False,
) -> dict[str, Any]:
    image_width, image_height, pixels = _read_rgba_png(image_bytes)
    slot_width = max(1, image_width // FRAMES_PER_ACTION)
    global_component_count = len(_extract_components_from_pixels(pixels, image_width, image_height))
    allow_near_full_slot_width = global_component_count >= FRAMES_PER_ACTION
    selected: list[dict[str, Any]] = []
    slot_reports: list[dict[str, Any]] = []
    for frame_index in range(FRAMES_PER_ACTION):
        left = frame_index * slot_width
        right = image_width if frame_index == FRAMES_PER_ACTION - 1 else (frame_index + 1) * slot_width
        slot_pixels = _crop_pixels(pixels, image_width, left, 0, right - left, image_height)
        components = _extract_components_from_pixels(slot_pixels, right - left, image_height)
        if not components:
            slot_reports.append({"frame": frame_index, "component_count": 0, "foreground_pixels": 0})
            continue
        component, component_report = _merge_frame_components(
            components,
            action=action,
            allow_stable_pet_parts=allow_stable_pet_parts,
        )
        component = dict(component)
        component["bbox"] = [
            int(component["bbox"][0]) + left,
            int(component["bbox"][1]),
            int(component["bbox"][2]) + left,
            int(component["bbox"][3]),
        ]
        slot_reports.append(
            {
                "frame": frame_index,
                "component_count": len(components),
                "accepted_component_count": component_report["accepted_component_count"],
                "fragmented_foreground": component_report["fragmented_foreground"],
                "stable_pet_parts_merged": component_report.get("stable_pet_parts_merged", False),
                "foreground_pixels": component["foreground_pixels"],
                "source_bbox": component["bbox"],
                **_lower_foreground_metrics(component),
            }
        )
        if not component_report["fragmented_foreground"] and _slot_component_is_usable(
            component,
            right - left,
            image_height,
            allow_near_full_width=allow_near_full_slot_width,
        ):
            selected.append(component)
    frames: list[bytes] = []
    reports: list[dict[str, Any]] = []
    ok = len(selected) == FRAMES_PER_ACTION
    for frame_index, component in enumerate(selected):
        fitted = _fit_component_to_cell(component["pixels"], component["width"], component["height"])
        frames.append(_write_rgba_png(CELL_WIDTH, CELL_HEIGHT, fitted))
        reports.append(
            {
                "frame": frame_index,
                "component_count": int(slot_reports[frame_index].get("component_count", 0)),
                "accepted_component_count": int(slot_reports[frame_index].get("accepted_component_count", 0)),
                "fragmented_foreground": bool(slot_reports[frame_index].get("fragmented_foreground", False)),
                "stable_pet_parts_merged": bool(slot_reports[frame_index].get("stable_pet_parts_merged", False)),
                "source_bbox": component["bbox"],
                "foreground_pixels": component["foreground_pixels"],
                **_lower_foreground_metrics(component),
            }
        )
    return {
        "ok": ok,
        "action": action,
        "component_count": sum(int(report.get("component_count", 0)) for report in slot_reports),
        "selected_frames": len(selected),
        "frames": frames,
        "reports": reports,
        "slot_reports": slot_reports,
        "failure_reason": None if ok else "slot_foreground_invalid",
    }


def _extract_components(image_bytes: bytes | None) -> list[dict[str, Any]]:
    if not image_bytes:
        return []
    width, height, pixels = _read_rgba_png(image_bytes)
    return _extract_components_from_pixels(pixels, width, height)


def _merge_frame_components(
    components: list[dict[str, Any]],
    *,
    action: str = "",
    allow_stable_pet_parts: bool = False,
) -> tuple[dict[str, Any], dict[str, Any]]:
    total_pixels = sum(max(0, int(component.get("foreground_pixels") or 0)) for component in components)
    if not components:
        raise ValueError("components are required")
    significant_threshold = max(24, int(total_pixels * 0.01))
    significant = [
        component
        for component in components
        if int(component.get("foreground_pixels") or 0) >= significant_threshold
    ]
    if not significant:
        significant = [max(components, key=lambda item: int(item.get("foreground_pixels") or 0))]
    main = max(significant, key=lambda item: int(item.get("foreground_pixels") or 0))
    edge_fragment = any(
        _component_is_cutoff_edge_fragment(main, component, total_pixels)
        for component in significant
        if component is not main
    )
    side_sliver_fragment = allow_stable_pet_parts and any(
        _component_is_side_sliver_fragment(main, component, total_pixels)
        for component in significant
        if component is not main
    )
    if edge_fragment or side_sliver_fragment:
        return main, {
            "accepted_component_count": len(significant),
            "fragmented_foreground": True,
            "stable_pet_parts_merged": False,
        }
    detached = any(_component_is_detached_from_main(main, component, total_pixels) for component in significant if component is not main)
    if detached:
        if allow_stable_pet_parts and _components_look_like_stable_pet_parts(action, significant, main, total_pixels):
            return _merge_components(significant), {
                "accepted_component_count": len(significant),
                "fragmented_foreground": False,
                "stable_pet_parts_merged": True,
            }
        return main, {
            "accepted_component_count": len(significant),
            "fragmented_foreground": True,
            "stable_pet_parts_merged": False,
        }
    main_ratio = int(main.get("foreground_pixels") or 0) / max(1, sum(int(item.get("foreground_pixels") or 0) for item in significant))
    fragmented = len(significant) > 1 and main_ratio < 0.68
    if fragmented:
        if allow_stable_pet_parts and _components_look_like_stable_pet_parts(action, significant, main, total_pixels):
            return _merge_components(significant), {
                "accepted_component_count": len(significant),
                "fragmented_foreground": False,
                "stable_pet_parts_merged": True,
            }
        return main, {
            "accepted_component_count": len(significant),
            "fragmented_foreground": True,
            "stable_pet_parts_merged": False,
        }
    return _merge_components(significant), {
        "accepted_component_count": len(significant),
        "fragmented_foreground": False,
        "stable_pet_parts_merged": False,
    }


def _component_is_detached_from_main(main: dict[str, Any], component: dict[str, Any], total_pixels: int) -> bool:
    component_pixels = int(component.get("foreground_pixels") or 0)
    if component_pixels < max(100, int(total_pixels * 0.01)):
        return False
    main_box = [int(value) for value in main.get("bbox", [0, 0, 0, 0])]
    component_box = [int(value) for value in component.get("bbox", [0, 0, 0, 0])]
    gap_x = max(0, max(main_box[0], component_box[0]) - min(main_box[2], component_box[2]) - 1)
    gap_y = max(0, max(main_box[1], component_box[1]) - min(main_box[3], component_box[3]) - 1)
    return gap_x > 16 or gap_y > 16


def _component_is_cutoff_edge_fragment(main: dict[str, Any], component: dict[str, Any], total_pixels: int) -> bool:
    component_pixels = int(component.get("foreground_pixels") or 0)
    if component_pixels < max(80, int(total_pixels * 0.01)):
        return False
    if component_pixels > max(1, int(total_pixels * 0.08)):
        return False
    main_box = [int(value) for value in main.get("bbox", [0, 0, 0, 0])]
    component_box = [int(value) for value in component.get("bbox", [0, 0, 0, 0])]
    source_width = max(1, int(component.get("source_width") or 0))
    source_height = max(1, int(component.get("source_height") or 0))
    edge_margin = 16
    near_canvas_edge = (
        component_box[0] <= edge_margin
        or component_box[2] >= source_width - edge_margin - 1
        or component_box[1] <= edge_margin
        or component_box[3] >= source_height - edge_margin - 1
    )
    if not near_canvas_edge:
        return False
    gap_x = max(0, max(main_box[0], component_box[0]) - min(main_box[2], component_box[2]) - 1)
    gap_y = max(0, max(main_box[1], component_box[1]) - min(main_box[3], component_box[3]) - 1)
    return gap_x > 6 or gap_y > 6


def _component_is_side_sliver_fragment(main: dict[str, Any], component: dict[str, Any], total_pixels: int) -> bool:
    component_pixels = int(component.get("foreground_pixels") or 0)
    if component_pixels < max(80, int(total_pixels * 0.01)):
        return False
    if component_pixels > max(1, int(total_pixels * 0.08)):
        return False
    main_box = [int(value) for value in main.get("bbox", [0, 0, 0, 0])]
    component_box = [int(value) for value in component.get("bbox", [0, 0, 0, 0])]
    main_width = max(1, main_box[2] - main_box[0] + 1)
    main_height = max(1, main_box[3] - main_box[1] + 1)
    component_width = max(1, component_box[2] - component_box[0] + 1)
    component_height = max(1, component_box[3] - component_box[1] + 1)
    if component_width > main_width * 0.24:
        return False
    if component_height > main_height * 0.42:
        return False
    overlap_x = max(0, min(main_box[2], component_box[2]) - max(main_box[0], component_box[0]) + 1)
    gap_x = max(0, max(main_box[0], component_box[0]) - min(main_box[2], component_box[2]) - 1)
    component_center_x = (component_box[0] + component_box[2]) / 2
    component_center_y = (component_box[1] + component_box[3]) / 2
    side_band = (
        component_center_x < main_box[0] + main_width * 0.18
        or component_center_x > main_box[2] - main_width * 0.18
        or component_box[2] <= main_box[0] + 6
        or component_box[0] >= main_box[2] - 6
    )
    lower_band = component_center_y > main_box[1] + main_height * 0.55
    tenuous_horizontal_contact = gap_x > 6 or overlap_x <= max(6, int(component_width * 0.25))
    return side_band and lower_band and tenuous_horizontal_contact


def _components_look_like_stable_pet_parts(
    action: str,
    components: list[dict[str, Any]],
    main: dict[str, Any],
    total_pixels: int,
) -> bool:
    if action not in ALLOWED_ACTIONS:
        return False
    if len(components) < 2 or len(components) > 4:
        return False
    significant_total = max(1, sum(int(item.get("foreground_pixels") or 0) for item in components))
    main_ratio = int(main.get("foreground_pixels") or 0) / significant_total
    if main_ratio < 0.72:
        return False
    main_box = [int(value) for value in main.get("bbox", [0, 0, 0, 0])]
    main_width = max(1, main_box[2] - main_box[0] + 1)
    main_height = max(1, main_box[3] - main_box[1] + 1)
    for component in components:
        if component is main:
            continue
        component_pixels = int(component.get("foreground_pixels") or 0)
        component_ratio = component_pixels / max(1, total_pixels)
        if component_ratio > 0.18:
            return False
        if _component_fill_ratio(component) > 0.92:
            return False
        component_box = [int(value) for value in component.get("bbox", [0, 0, 0, 0])]
        gap_x = max(0, max(main_box[0], component_box[0]) - min(main_box[2], component_box[2]) - 1)
        gap_y = max(0, max(main_box[1], component_box[1]) - min(main_box[3], component_box[3]) - 1)
        if gap_x > max(72, int(main_width * 0.28)):
            return False
        if gap_y > max(72, int(main_height * 0.28)):
            return False
    return True


def _component_fill_ratio(component: dict[str, Any]) -> float:
    width = max(1, int(component.get("width") or 1))
    height = max(1, int(component.get("height") or 1))
    return int(component.get("foreground_pixels") or 0) / max(1, width * height)


def _merge_components(components: list[dict[str, Any]]) -> dict[str, Any]:
    min_x = min(int(component["bbox"][0]) for component in components)
    min_y = min(int(component["bbox"][1]) for component in components)
    max_x = max(int(component["bbox"][2]) for component in components)
    max_y = max(int(component["bbox"][3]) for component in components)
    width = max_x - min_x + 1
    height = max_y - min_y + 1
    pixels = [(0, 0, 0, 0)] * (width * height)
    foreground_pixels = 0
    for component in components:
        component_width = int(component["width"])
        component_height = int(component["height"])
        offset_x = int(component["bbox"][0]) - min_x
        offset_y = int(component["bbox"][1]) - min_y
        component_pixels = component["pixels"]
        for y in range(component_height):
            for x in range(component_width):
                pixel = component_pixels[y * component_width + x]
                if pixel[3] == 0:
                    continue
                target_index = (offset_y + y) * width + offset_x + x
                if pixels[target_index][3] == 0:
                    foreground_pixels += 1
                pixels[target_index] = pixel
    return {
        "bbox": [min_x, min_y, max_x, max_y],
        "width": width,
        "height": height,
        "pixels": pixels,
        "foreground_pixels": foreground_pixels,
    }


def _lower_foreground_metrics(component: dict[str, Any]) -> dict[str, Any]:
    width = max(0, int(component.get("width") or 0))
    height = max(0, int(component.get("height") or 0))
    pixels = component.get("pixels")
    if width <= 0 or height <= 0 or not isinstance(pixels, list):
        return {
            "lower_foreground_center_x": None,
            "lower_foreground_pixel_count": 0,
            "foot_foreground_center_x": None,
            "foot_foreground_pixel_count": 0,
        }

    lower_start_y = int(height * 0.58)
    foot_start_y = int(height * 0.78)
    lower_center, lower_count = _foreground_center_x_from_y(pixels, width, height, lower_start_y)
    foot_center, foot_count = _foreground_center_x_from_y(pixels, width, height, foot_start_y)
    return {
        "lower_foreground_center_x": lower_center,
        "lower_foreground_pixel_count": lower_count,
        "foot_foreground_center_x": foot_center,
        "foot_foreground_pixel_count": foot_count,
    }


def _foreground_center_x_from_y(
    pixels: list[tuple[int, int, int, int]],
    width: int,
    height: int,
    start_y: int,
) -> tuple[float | None, int]:
    x_total = 0.0
    count = 0
    for y in range(max(0, start_y), height):
        row_start = y * width
        for x in range(width):
            pixel = pixels[row_start + x]
            if not isinstance(pixel, tuple) or len(pixel) < 4 or int(pixel[3]) == 0:
                continue
            x_total += x
            count += 1
    return (x_total / count) if count else None, count


def _extract_components_from_pixels(
    pixels: list[tuple[int, int, int, int]],
    width: int,
    height: int,
) -> list[dict[str, Any]]:
    foreground = [pixel[3] > 0 for pixel in pixels]
    visited = [False] * (width * height)
    components: list[dict[str, Any]] = []

    for index, is_foreground in enumerate(foreground):
        if not is_foreground or visited[index]:
            continue
        queue: deque[int] = deque([index])
        visited[index] = True
        component_indices: list[int] = []
        min_x = width
        min_y = height
        max_x = 0
        max_y = 0
        while queue:
            current = queue.popleft()
            component_indices.append(current)
            x = current % width
            y = current // width
            min_x = min(min_x, x)
            min_y = min(min_y, y)
            max_x = max(max_x, x)
            max_y = max(max_y, y)
            for next_x, next_y in ((x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)):
                if next_x < 0 or next_x >= width or next_y < 0 or next_y >= height:
                    continue
                next_index = next_y * width + next_x
                if visited[next_index] or not foreground[next_index]:
                    continue
                visited[next_index] = True
                queue.append(next_index)
        if len(component_indices) < 16:
            continue
        component_width = max_x - min_x + 1
        component_height = max_y - min_y + 1
        crop = [(0, 0, 0, 0)] * (component_width * component_height)
        for source_index in component_indices:
            source_x = source_index % width
            source_y = source_index // width
            crop_x = source_x - min_x
            crop_y = source_y - min_y
            crop[crop_y * component_width + crop_x] = pixels[source_index]
        components.append(
            {
                "bbox": [min_x, min_y, max_x, max_y],
                "width": component_width,
                "height": component_height,
                "source_width": width,
                "source_height": height,
                "pixels": crop,
                "foreground_pixels": len(component_indices),
            }
        )

    components.sort(key=lambda item: (item["bbox"][1], item["bbox"][0]))
    return components


def _crop_pixels(
    pixels: list[tuple[int, int, int, int]],
    source_width: int,
    x0: int,
    y0: int,
    width: int,
    height: int,
) -> list[tuple[int, int, int, int]]:
    cropped: list[tuple[int, int, int, int]] = []
    for y in range(y0, y0 + height):
        row_start = y * source_width + x0
        cropped.extend(pixels[row_start : row_start + width])
    return cropped


def _slot_component_is_usable(
    component: dict[str, Any],
    slot_width: int,
    slot_height: int,
    *,
    allow_near_full_width: bool = False,
) -> bool:
    width = max(0, int(component.get("width", 0)))
    height = max(0, int(component.get("height", 0)))
    pixels = max(0, int(component.get("foreground_pixels", 0)))
    if pixels < 900:
        return False
    if width < max(12, slot_width * 0.12) or height < max(12, slot_height * 0.12):
        return False
    if width > slot_width * 0.98 and not allow_near_full_width:
        return False
    if width > slot_width or height > slot_height * 0.96:
        return False
    return True


def _fit_component_to_cell(pixels: list[tuple[int, int, int, int]], width: int, height: int) -> list[tuple[int, int, int, int]]:
    cell, _report = _fit_component_to_cell_with_report(pixels, width, height)
    return cell


def _fit_component_to_cell_with_report(
    pixels: list[tuple[int, int, int, int]],
    width: int,
    height: int,
    *,
    scale: float | None = None,
    fit_mode: str = "fit_to_cell",
) -> tuple[list[tuple[int, int, int, int]], dict[str, Any]]:
    max_width = int(CELL_WIDTH * 0.84)
    max_height = int(CELL_HEIGHT * 0.84)
    if scale is None:
        scale = min(max_width / max(width, 1), max_height / max(height, 1), 1.0)
    else:
        scale = min(float(scale), max_width / max(width, 1), max_height / max(height, 1))
    scale = max(scale, 0.01)
    target_width = max(1, int(width * scale))
    target_height = max(1, int(height * scale))
    resized = _resize_nearest(pixels, width, height, target_width, target_height)
    cell = [(0, 0, 0, 0)] * (CELL_WIDTH * CELL_HEIGHT)
    offset_x = (CELL_WIDTH - target_width) // 2
    offset_y = (CELL_HEIGHT - target_height) // 2
    for y in range(target_height):
        for x in range(target_width):
            cell[(offset_y + y) * CELL_WIDTH + offset_x + x] = resized[y * target_width + x]
    return cell, {
        "fitted_bbox": [offset_x, offset_y, offset_x + target_width - 1, offset_y + target_height - 1],
        "fitted_width": target_width,
        "fitted_height": target_height,
        "fit_scale": round(float(scale), 6),
        "fit_mode": fit_mode,
    }


def _row_normalized_fit_scales(components: list[dict[str, Any]]) -> list[float]:
    if not components:
        return []
    independently_fitted_areas: list[float] = []
    for component in components:
        width = max(1, int(component.get("width") or 1))
        height = max(1, int(component.get("height") or 1))
        scale = min(_max_cell_fit_scale(width=width, height=height), 1.0)
        independently_fitted_areas.append(width * height * scale * scale)
    target_area = _median(independently_fitted_areas)
    scales = []
    for component in components:
        width = max(1, int(component.get("width") or 1))
        height = max(1, int(component.get("height") or 1))
        source_area = max(1.0, float(width * height))
        normalized_scale = (target_area / source_area) ** 0.5
        scales.append(min(normalized_scale, _max_cell_fit_scale(width=width, height=height), MAX_ROW_NORMALIZE_UPSCALE))
    return scales


def _max_cell_fit_scale(*, width: int, height: int) -> float:
    max_width = int(CELL_WIDTH * 0.84)
    max_height = int(CELL_HEIGHT * 0.84)
    return min(max_width / max(width, 1), max_height / max(height, 1))


def _median(values: list[float]) -> float:
    if not values:
        return 0.0
    sorted_values = sorted(values)
    midpoint = len(sorted_values) // 2
    if len(sorted_values) % 2:
        return sorted_values[midpoint]
    return (sorted_values[midpoint - 1] + sorted_values[midpoint]) / 2


def _resize_nearest(
    pixels: list[tuple[int, int, int, int]],
    width: int,
    height: int,
    target_width: int,
    target_height: int,
) -> list[tuple[int, int, int, int]]:
    if width == target_width and height == target_height:
        return list(pixels)
    output: list[tuple[int, int, int, int]] = []
    for y in range(target_height):
        source_y = min(height - 1, int(y * height / target_height))
        for x in range(target_width):
            source_x = min(width - 1, int(x * width / target_width))
            output.append(pixels[source_y * width + source_x])
    return output


def _paste_cell(
    atlas: list[tuple[int, int, int, int]],
    atlas_width: int,
    row_index: int,
    frame_index: int,
    cell: list[tuple[int, int, int, int]],
) -> None:
    start_x = frame_index * CELL_WIDTH
    start_y = row_index * CELL_HEIGHT
    for y in range(CELL_HEIGHT):
        for x in range(CELL_WIDTH):
            atlas[(start_y + y) * atlas_width + start_x + x] = cell[y * CELL_WIDTH + x]


def _read_rgba_png(image_bytes: bytes) -> tuple[int, int, list[tuple[int, int, int, int]]]:
    if not image_bytes.startswith(PNG_SIGNATURE):
        raise ValueError("not_png")
    position = len(PNG_SIGNATURE)
    width = height = bit_depth = color_type = interlace = None
    idat_parts: list[bytes] = []
    while position + 12 <= len(image_bytes):
        length = struct.unpack(">I", image_bytes[position : position + 4])[0]
        kind = image_bytes[position + 4 : position + 8]
        data_start = position + 8
        data_end = data_start + length
        data = image_bytes[data_start:data_end]
        position = data_end + 4
        if kind == b"IHDR":
            width, height, bit_depth, color_type, _compression, _filter, interlace = struct.unpack(">IIBBBBB", data)
        elif kind == b"IDAT":
            idat_parts.append(data)
        elif kind == b"IEND":
            break
    if None in {width, height, bit_depth, color_type, interlace}:
        raise ValueError("missing_png_header")
    if bit_depth != 8 or color_type not in {2, 6} or interlace != 0:
        raise ValueError("unsupported_png_format")

    channels = 4 if color_type == 6 else 3
    stride = int(width) * channels
    raw = zlib.decompress(b"".join(idat_parts))
    rows: list[bytearray] = []
    previous = bytearray(stride)
    offset = 0
    for _ in range(int(height)):
        filter_type = raw[offset]
        offset += 1
        row = bytearray(raw[offset : offset + stride])
        offset += stride
        _unfilter_row(row, previous, filter_type, channels)
        rows.append(row)
        previous = row

    pixels: list[tuple[int, int, int, int]] = []
    for row in rows:
        for x in range(0, len(row), channels):
            if channels == 4:
                pixels.append((row[x], row[x + 1], row[x + 2], row[x + 3]))
            else:
                pixels.append((row[x], row[x + 1], row[x + 2], 255))
    return int(width), int(height), pixels


def _unfilter_row(row: bytearray, previous: bytearray, filter_type: int, bytes_per_pixel: int) -> None:
    if filter_type == 0:
        return
    for index in range(len(row)):
        left = row[index - bytes_per_pixel] if index >= bytes_per_pixel else 0
        up = previous[index]
        upper_left = previous[index - bytes_per_pixel] if index >= bytes_per_pixel else 0
        if filter_type == 1:
            row[index] = (row[index] + left) & 0xFF
        elif filter_type == 2:
            row[index] = (row[index] + up) & 0xFF
        elif filter_type == 3:
            row[index] = (row[index] + ((left + up) // 2)) & 0xFF
        elif filter_type == 4:
            row[index] = (row[index] + _paeth(left, up, upper_left)) & 0xFF
        else:
            raise ValueError(f"unsupported_png_filter_{filter_type}")


def _paeth(left: int, up: int, upper_left: int) -> int:
    estimate = left + up - upper_left
    left_distance = abs(estimate - left)
    up_distance = abs(estimate - up)
    upper_left_distance = abs(estimate - upper_left)
    if left_distance <= up_distance and left_distance <= upper_left_distance:
        return left
    if up_distance <= upper_left_distance:
        return up
    return upper_left


def _write_rgba_png(width: int, height: int, pixels: list[tuple[int, int, int, int]]) -> bytes:
    raw = bytearray()
    for y in range(height):
        raw.append(0)
        for red, green, blue, alpha in pixels[y * width : (y + 1) * width]:
            raw.extend((red, green, blue, alpha))
    return (
        PNG_SIGNATURE
        + _png_chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 6, 0, 0, 0))
        + _png_chunk(b"IDAT", zlib.compress(bytes(raw)))
        + _png_chunk(b"IEND", b"")
    )


def _png_chunk(kind: bytes, data: bytes) -> bytes:
    return struct.pack(">I", len(data)) + kind + data + struct.pack(">I", binascii.crc32(kind + data) & 0xFFFFFFFF)
