#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from packages.pet_package_schema import ALLOWED_ACTIONS
from services.ai.apimart_image_client import APIMartImageGenerationClient
from services.ai.qwen_client import QwenPhotoAnalyzer, QwenReviewClient
from services.media_pipeline.photo_inputs import prepare_photo_inputs
from services.media_pipeline.png_transparency import ensure_png_alpha
from services.pet_builder.action_atlas import (
    CELL_HEIGHT,
    CELL_WIDTH,
    FRAMES_PER_ACTION,
    _extract_components,
    _fit_component_to_cell,
    _merge_frame_components,
    _write_rgba_png,
)
from services.pet_builder.action_semantics import score_candidate_frame
from services.pet_builder.photo_generation_worker import (
    ACTION_FRAME_CHROMA_KEY,
    ACTION_FRAME_CHROMA_THRESHOLD,
    _candidate_visual_detail_score,
    _frame_rejection_reason,
    _safe_frame_report,
    _safe_transparency_report,
    _slugify,
    _transparency_rejection_reason,
)
from services.pet_builder.qa_media import write_animation_previews


def main() -> int:
    parser = argparse.ArgumentParser(description="Probe one APIMart gpt-image-2 action row from authorized pet photos.")
    parser.add_argument("--pet-name", required=True)
    parser.add_argument("--notes", default="")
    parser.add_argument("--photo", action="append", required=True, help="Authorized local photo path. Repeatable.")
    parser.add_argument("--action", default="tail_wag", choices=ALLOWED_ACTIONS)
    parser.add_argument("--output", default=str(ROOT / "outputs" / "experiments" / "apimart-action-probes"))
    parser.add_argument("--candidates-per-frame", type=int, default=1)
    args = parser.parse_args()

    action = args.action
    candidates_per_frame = max(1, min(int(args.candidates_per_frame), 3))
    probe_dir = Path(args.output) / f"probe-apimart-{_slugify(args.pet_name)}-{action}-{_utc_stamp()}"
    qa_dir = probe_dir / "qa"
    raw_dir = qa_dir / "raw-candidates"
    candidate_dir = qa_dir / "candidates"
    frame_dir = qa_dir / "frames"
    preview_dir = qa_dir / "previews"
    for path in (raw_dir, candidate_dir, frame_dir, preview_dir):
        path.mkdir(parents=True, exist_ok=True)

    analyzer = QwenPhotoAnalyzer.from_env()
    image_generator = APIMartImageGenerationClient.from_env()
    action_reviewer = QwenReviewClient.from_env()

    images = prepare_photo_inputs([Path(photo) for photo in args.photo])
    analysis = analyzer.analyze_pet_photos(pet_name=args.pet_name.strip(), notes=args.notes.strip(), images=images)
    _write_json(qa_dir / "analysis.json", analysis)

    canonical_generation = image_generator.generate_canonical_pet(
        pet_name=args.pet_name.strip(),
        notes=args.notes.strip(),
        analysis=analysis,
        images=images,
        size="1024*1024",
    )
    canonical_reference_bytes = _require_image_bytes(canonical_generation, "canonical")
    (qa_dir / "canonical-reference.png").write_bytes(canonical_reference_bytes)
    canonical_alpha_bytes, canonical_transparency = ensure_png_alpha(canonical_reference_bytes)
    (qa_dir / "canonical.png").write_bytes(canonical_alpha_bytes)

    selected_frames: list[bytes] = []
    frame_reports: list[dict[str, Any]] = []
    candidate_failures: list[dict[str, Any]] = []
    selected_generations: list[dict[str, Any]] = []

    for frame_index in range(FRAMES_PER_ACTION):
        accepted_candidates: list[dict[str, Any]] = []
        for candidate_index in range(candidates_per_frame):
            generation = image_generator.generate_action_frame(
                action=action,
                frame_index=frame_index,
                candidate_index=candidate_index,
                pet_name=args.pet_name.strip(),
                notes=args.notes.strip(),
                analysis=analysis,
                images=[],
                canonical_image_bytes=canonical_reference_bytes,
                previous_frame_image_bytes=None,
                size="1024*1024",
            )
            raw_bytes = _require_image_bytes(generation, f"{action}-{frame_index}-{candidate_index}")
            (raw_dir / f"{action}-{frame_index}-{candidate_index}.png").write_bytes(raw_bytes)
            candidate = _process_candidate(
                action=action,
                frame_index=frame_index,
                candidate_index=candidate_index,
                raw_bytes=raw_bytes,
                candidate_path=candidate_dir / f"{action}-{frame_index}-{candidate_index}.png",
            )
            if candidate["ok"]:
                generation = dict(generation)
                generation.pop("image_bytes", None)
                accepted_candidates.append({**candidate, "generation": generation})
            else:
                candidate_failures.append(candidate["failure"])
                _write_json(qa_dir / "candidate-failures.json", {"failures": candidate_failures})
        if not accepted_candidates:
            raise ValueError(f"no acceptable APIMart candidate for {action} frame {frame_index}")
        selected = max(accepted_candidates, key=lambda item: item["score"])
        selected_frames.append(selected["image_bytes"])
        selected_generations.append(selected["generation"])
        frame_reports.append(
            {
                "frame": frame_index,
                "selected_candidate_index": selected["candidate_index"],
                "score": selected["score"],
                "frame_report": _safe_frame_report(selected["frame_report"]),
                "transparency": _safe_transparency_report(selected["transparency"]),
            }
        )
        (frame_dir / f"{action}-{frame_index}.png").write_bytes(selected["image_bytes"])
        _write_json(qa_dir / "frame-qa.json", {"action": action, "frames": frame_reports})

    action_contact = _compose_action_contact_sheet(selected_frames)
    (qa_dir / "action-contact.png").write_bytes(action_contact)
    write_animation_previews(action_frames={action: selected_frames}, output_dir=preview_dir)
    review = action_reviewer.review_action_frames(
        pet_name=args.pet_name.strip(),
        action=action,
        image_bytes=action_contact,
        image_mime="image/png",
    )
    _write_json(qa_dir / "review.json", review)
    summary = {
        "status": "ok" if review.get("review", {}).get("passed") else "review_failed",
        "provider": "apimart",
        "image_model": image_generator.model,
        "action": action,
        "probe_dir": str(probe_dir.resolve()),
        "canonical_request_id": canonical_generation.get("request_id"),
        "action_request_ids": [item.get("request_id") for item in selected_generations],
        "candidate_failures": len(candidate_failures),
        "canonical_transparency": _safe_transparency_report(canonical_transparency),
        "review": review.get("review"),
    }
    _write_json(probe_dir / "run-summary.json", summary)
    print(json.dumps(summary, ensure_ascii=False, indent=2))
    return 0 if summary["status"] == "ok" else 2


def _process_candidate(
    *,
    action: str,
    frame_index: int,
    candidate_index: int,
    raw_bytes: bytes,
    candidate_path: Path,
) -> dict[str, Any]:
    image_bytes, transparency = ensure_png_alpha(
        raw_bytes,
        chroma_key=ACTION_FRAME_CHROMA_KEY,
        chroma_threshold=ACTION_FRAME_CHROMA_THRESHOLD,
    )
    transparency_rejection = _transparency_rejection_reason(transparency)
    if transparency_rejection == "chroma_key_background_missing":
        image_bytes, transparency = ensure_png_alpha(raw_bytes)
        transparency_rejection = _transparency_rejection_reason(transparency)
    candidate_path.write_bytes(image_bytes)
    if transparency_rejection:
        return {
            "ok": False,
            "failure": {
                "action": action,
                "frame_index": frame_index,
                "candidate_index": candidate_index,
                "reason": transparency_rejection,
                "transparency": _safe_transparency_report(transparency),
            },
        }
    frame_report = _frame_report_for_candidate(image_bytes)
    rejection_reason = _frame_rejection_reason(frame_report)
    if rejection_reason:
        return {
            "ok": False,
            "failure": {
                "action": action,
                "frame_index": frame_index,
                "candidate_index": candidate_index,
                "reason": rejection_reason,
                "frame_report": _safe_frame_report(frame_report),
            },
        }
    return {
        "ok": True,
        "candidate_index": candidate_index,
        "image_bytes": image_bytes,
        "transparency": transparency,
        "frame_report": frame_report,
        "score": score_candidate_frame(action=action, frame_report=frame_report, candidate_index=candidate_index)
        + _candidate_visual_detail_score(image_bytes),
    }


def _frame_report_for_candidate(image_bytes: bytes) -> dict[str, Any]:
    components = _extract_components(image_bytes)
    if not components:
        return {"frame": 0, "component_count": 0, "source_bbox": None, "foreground_pixels": 0}
    component, component_report = _merge_frame_components(components)
    bbox_area = max(1, component["width"] * component["height"])
    return {
        "frame": 0,
        "component_count": len(components),
        "accepted_component_count": component_report["accepted_component_count"],
        "fragmented_foreground": component_report["fragmented_foreground"],
        "source_bbox": component["bbox"],
        "foreground_pixels": component["foreground_pixels"],
        "bbox_fill_ratio": component["foreground_pixels"] / bbox_area,
    }


def _compose_action_contact_sheet(frames: list[bytes]) -> bytes:
    width = CELL_WIDTH * FRAMES_PER_ACTION
    height = CELL_HEIGHT
    pixels = [(0, 0, 0, 0)] * (width * height)
    for frame_index, image_bytes in enumerate(frames[:FRAMES_PER_ACTION]):
        components = _extract_components(image_bytes)
        if not components:
            continue
        component, _component_report = _merge_frame_components(components)
        fitted = _fit_component_to_cell(component["pixels"], component["width"], component["height"])
        for y in range(CELL_HEIGHT):
            for x in range(CELL_WIDTH):
                pixels[y * width + frame_index * CELL_WIDTH + x] = fitted[y * CELL_WIDTH + x]
    return _write_rgba_png(width, height, pixels)


def _require_image_bytes(generation: dict[str, Any], label: str) -> bytes:
    image_bytes = generation.get("image_bytes")
    if not isinstance(image_bytes, bytes) or not image_bytes:
        raise ValueError(f"{label} did not return image bytes")
    return image_bytes


def _write_json(path: Path, data: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")


def _utc_stamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")


if __name__ == "__main__":
    raise SystemExit(main())
