from __future__ import annotations

from typing import Any


STRICT_ACTIONS = {"sleep", "walk", "tail_wag"}
LEG_MOTION_THRESHOLD_PX = 8.0
LOWER_FOREGROUND_MIN_PIXELS = 64
FOOT_FOREGROUND_MIN_PIXELS = 32


def evaluate_action_semantics(frame_qa: dict[str, Any]) -> dict[str, Any]:
    rows = []
    ok = True
    failure_reasons: set[str] = set()

    for row in frame_qa.get("rows", []):
        action = str(row.get("action", ""))
        loop_mode = str(row.get("loop_mode", ""))
        frames = [frame for frame in row.get("frames", []) if isinstance(frame, dict)]
        metrics = _row_metrics(frames)
        row_ok = True
        reasons: list[str] = []
        if (
            action in STRICT_ACTIONS
            and metrics["bbox_motion"] < 10
            and metrics["area_variation_ratio"] < 0.015
            and (action != "walk" or metrics["visible_leg_contact_motion"] < LEG_MOTION_THRESHOLD_PX)
        ):
            row_ok = False
            reasons.append("static_action_row")
            failure_reasons.add("static_action_rows")
        if (
            action == "walk"
            and metrics["width_variation"] < 8
            and metrics["center_motion"] < 8
            and metrics["visible_leg_contact_motion"] < LEG_MOTION_THRESHOLD_PX
        ):
            row_ok = False
            reasons.append("walk_pose_not_visible")
            failure_reasons.add("weak_action_semantics")
        if (
            action == "sleep"
            and metrics["height_variation"] < 8
            and metrics["width_variation"] < 8
            and metrics["area_variation_ratio"] < 0.015
        ):
            row_ok = False
            reasons.append("sleep_pose_not_visible")
            failure_reasons.add("weak_action_semantics")
        if action == "tail_wag" and metrics["width_variation"] < 8 and metrics["center_motion"] < 8:
            row_ok = False
            reasons.append("tail_motion_not_visible")
            failure_reasons.add("weak_action_semantics")
        if action == "sleep" and (
            metrics["median_aspect_ratio"] < 1.05 or metrics["low_wide_frame_count"] < 3
        ):
            row_ok = False
            reasons.append("sleep_pose_too_upright")
            failure_reasons.add("weak_action_semantics")
        if action == "walk" and (
            metrics["median_aspect_ratio"] < 0.9 or metrics["side_view_frame_count"] < 3
        ):
            row_ok = False
            reasons.append("walk_pose_too_upright")
            failure_reasons.add("weak_action_semantics")
        if action == "walk" and (
            metrics["visible_leg_contact_frame_count"] >= 3
            and metrics["visible_leg_contact_motion"] < LEG_MOTION_THRESHOLD_PX
        ):
            row_ok = False
            reasons.append("walk_leg_motion_too_weak")
            failure_reasons.add("weak_action_semantics")
        if action == "walk" and loop_mode != "two_frame_copy" and _walk_repeats_first_and_third_contact_phase(frames):
            row_ok = False
            reasons.append("walk_repeated_contact_phase")
            failure_reasons.add("weak_action_semantics")
        if action == "tail_wag" and metrics["tail_visible_frame_count"] < 3:
            row_ok = False
            reasons.append("tail_not_visible_in_most_frames")
            failure_reasons.add("weak_action_semantics")

        ok = ok and row_ok
        row_result = {"action": action, "ok": row_ok, "reasons": reasons, "metrics": metrics}
        if loop_mode:
            row_result["loop_mode"] = loop_mode
        rows.append(row_result)

    return {
        "ok": ok,
        "failure_reasons": sorted(failure_reasons),
        "rows": rows,
    }


def score_candidate_frame(*, action: str, frame_report: dict[str, Any], candidate_index: int) -> float:
    bbox = frame_report.get("source_bbox") or [0, 0, 0, 0]
    if len(bbox) != 4:
        return -1.0
    width = max(0, int(bbox[2]) - int(bbox[0]) + 1)
    height = max(0, int(bbox[3]) - int(bbox[1]) + 1)
    pixels = max(0, int(frame_report.get("foreground_pixels") or 0))
    component_count = max(1, int(frame_report.get("component_count") or 1))
    fill_ratio = float(frame_report.get("bbox_fill_ratio") or 1.0)
    score = 0.0
    aspect_ratio = width / max(height, 1)
    score += min(pixels / 10000.0, 50.0)
    score -= abs(aspect_ratio - _target_ratio(action)) * 10.0
    # Real fur, whiskers, and fine toes can split into a few foreground
    # components after chroma-key cleanup. Hard rejection handles truly broken
    # fragmentation; ranking should not strongly prefer flat vector art.
    score -= max(0, component_count - 1) * 4.0
    if fill_ratio < 0.35:
        score -= (0.35 - fill_ratio) * 50.0
    score -= candidate_index * 0.01
    return score


def _target_ratio(action: str) -> float:
    if action in {"sleep", "walk", "tail_wag"}:
        return 3.0
    if action == "sit":
        return 0.8
    return 1.0


def _row_metrics(frames: list[dict[str, Any]]) -> dict[str, Any]:
    boxes = [frame.get("source_bbox") for frame in frames if isinstance(frame.get("source_bbox"), list)]
    if not boxes:
        return {
            "bbox_motion": 0,
            "center_motion": 0,
            "width_variation": 0,
            "height_variation": 0,
            "area_variation_ratio": 0.0,
            "median_aspect_ratio": 0.0,
            "min_aspect_ratio": 0.0,
            "max_aspect_ratio": 0.0,
            "low_wide_frame_count": 0,
            "side_view_frame_count": 0,
            "tail_visible_frame_count": 0,
            "lower_foreground_center_frame_count": 0,
            "lower_foreground_center_motion": 0,
            "foot_foreground_center_frame_count": 0,
            "foot_foreground_center_motion": 0,
            "visible_leg_contact_frame_count": 0,
            "visible_leg_contact_motion": 0,
        }
    widths = [max(0, int(box[2]) - int(box[0]) + 1) for box in boxes]
    heights = [max(0, int(box[3]) - int(box[1]) + 1) for box in boxes]
    centers_x = [(int(box[0]) + int(box[2])) / 2 for box in boxes]
    centers_y = [(int(box[1]) + int(box[3])) / 2 for box in boxes]
    areas = [width * height for width, height in zip(widths, heights)]
    aspect_ratios = [width / max(height, 1) for width, height in zip(widths, heights)]
    lower_centers = [
        float(frame["lower_foreground_center_x"])
        for frame in frames
        if frame.get("lower_foreground_center_x") is not None
        and int(frame.get("lower_foreground_pixel_count") or 0) >= LOWER_FOREGROUND_MIN_PIXELS
    ]
    foot_centers = [
        float(frame["foot_foreground_center_x"])
        for frame in frames
        if frame.get("foot_foreground_center_x") is not None
        and int(frame.get("foot_foreground_pixel_count") or 0) >= FOOT_FOREGROUND_MIN_PIXELS
    ]
    lower_motion = max(lower_centers) - min(lower_centers) if lower_centers else 0
    foot_motion = max(foot_centers) - min(foot_centers) if foot_centers else 0
    max_area = max(areas) if areas else 1
    min_area = min(areas) if areas else 0
    return {
        "bbox_motion": max(max(widths) - min(widths), max(heights) - min(heights)),
        "center_motion": max(max(centers_x) - min(centers_x), max(centers_y) - min(centers_y)),
        "width_variation": max(widths) - min(widths),
        "height_variation": max(heights) - min(heights),
        "area_variation_ratio": (max_area - min_area) / max(max_area, 1),
        "median_aspect_ratio": _median(aspect_ratios),
        "min_aspect_ratio": min(aspect_ratios),
        "max_aspect_ratio": max(aspect_ratios),
        "low_wide_frame_count": sum(1 for ratio in aspect_ratios if ratio >= 1.05),
        "side_view_frame_count": sum(1 for ratio in aspect_ratios if ratio >= 0.9),
        "tail_visible_frame_count": sum(1 for ratio in aspect_ratios if ratio >= 0.75),
        "lower_foreground_center_frame_count": len(lower_centers),
        "lower_foreground_center_motion": lower_motion,
        "foot_foreground_center_frame_count": len(foot_centers),
        "foot_foreground_center_motion": foot_motion,
        "visible_leg_contact_frame_count": max(len(lower_centers), len(foot_centers)),
        "visible_leg_contact_motion": max(lower_motion, foot_motion),
    }


def _walk_repeats_first_and_third_contact_phase(frames: list[dict[str, Any]]) -> bool:
    if len(frames) < 3:
        return False
    first = frames[0]
    third = frames[2]
    first_box = first.get("source_bbox")
    third_box = third.get("source_bbox")
    if not isinstance(first_box, list) or not isinstance(third_box, list):
        return False
    if len(first_box) != 4 or len(third_box) != 4:
        return False

    first_contact = _walk_contact_center_x(first)
    third_contact = _walk_contact_center_x(third)
    contact_delta = None
    if first_contact is not None and third_contact is not None:
        contact_delta = abs(float(first_contact) - float(third_contact))

    box_delta = max(abs(int(a) - int(b)) for a, b in zip(first_box, third_box))
    if box_delta <= 4:
        return contact_delta is None or contact_delta <= 3.0
    if contact_delta is None:
        return False

    first_width = int(first_box[2]) - int(first_box[0]) + 1
    third_width = int(third_box[2]) - int(third_box[0]) + 1
    first_height = int(first_box[3]) - int(first_box[1]) + 1
    third_height = int(third_box[3]) - int(third_box[1]) + 1
    first_center_x = (int(first_box[0]) + int(first_box[2])) / 2
    third_center_x = (int(third_box[0]) + int(third_box[2])) / 2
    return (
        contact_delta <= 3.0
        and abs(first_center_x - third_center_x) <= 8.0
        and abs(first_width - third_width) <= 12
        and abs(first_height - third_height) <= 12
    )


def _walk_contact_center_x(frame: dict[str, Any]) -> float | None:
    foot = frame.get("foot_foreground_center_x")
    if foot is not None and int(frame.get("foot_foreground_pixel_count") or 0) >= FOOT_FOREGROUND_MIN_PIXELS:
        return float(foot)
    lower = frame.get("lower_foreground_center_x")
    if lower is not None and int(frame.get("lower_foreground_pixel_count") or 0) >= LOWER_FOREGROUND_MIN_PIXELS:
        return float(lower)
    return None


def _median(values: list[float]) -> float:
    if not values:
        return 0.0
    sorted_values = sorted(values)
    middle = len(sorted_values) // 2
    if len(sorted_values) % 2:
        return sorted_values[middle]
    return (sorted_values[middle - 1] + sorted_values[middle]) / 2
