#!/usr/bin/env python3
from __future__ import annotations

import argparse
import hashlib
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from services.ai.qwen_client import (  # noqa: E402
    QwenImageGenerationClient,
    _ensure_qwen_image_model,
    _to_data_url,
)
from services.media_pipeline.png_transparency import ensure_png_alpha  # noqa: E402
from services.pet_builder.action_atlas import (  # noqa: E402
    CELL_HEIGHT,
    CELL_WIDTH,
    FRAMES_PER_ACTION,
    _extract_components,
    _fit_component_to_cell,
    _read_rgba_png,
    _write_rgba_png,
    compose_action_frame_atlas,
)
from services.pet_builder.action_semantics import evaluate_action_semantics  # noqa: E402
from services.pet_builder.qa_media import write_animation_previews  # noqa: E402


WALK_PHASES = [
    "side-view walking step, one front paw forward",
    "side-view walking step, opposite front paw forward",
    "side-view walking step, rear legs alternate, body traveling",
    "side-view walking step that loops back to frame 1",
]

USER_SIDE_STEP_DESCRIPTIONS = [
    "前方一只前脚向右前方伸出，另一只前脚在身体下方支撑，后脚自然跟随。",
    "伸出的前脚落地支撑，另一只前脚开始向右前方抬起，身体略向右移动。",
    "另一只前脚向右前方伸出，后脚交替前进，保持自然四脚走路姿态。",
    "前脚回到接近第 1 帧的循环姿态，身体仍朝右，准备接回下一轮走路。",
]

NEGATIVE_PROMPT = (
    "sitting, loafing, front-facing pose, static pose, pixel art, text, logo, scenery, "
    "transparent background, alpha transparency, checkerboard, motion lines, speed lines"
)


def main() -> int:
    parser = argparse.ArgumentParser(description="Probe Qwen with a compact walk-only prompt.")
    parser.add_argument("--pet-name", required=True)
    parser.add_argument("--canonical", required=True, help="Canonical PNG from a prior approved/probe build.")
    parser.add_argument("--output", default=str(ROOT / "outputs" / "pet_builds" / "walk_prompt_probes"))
    parser.add_argument("--notes", default="")
    parser.add_argument("--size", default="1024*1024")
    parser.add_argument("--frames", type=int, default=FRAMES_PER_ACTION, help="Number of walk frames to generate, from 1 to 4.")
    parser.add_argument(
        "--prompt-mode",
        default="user-minimal",
        choices=("user-minimal", "user-side-steps", "compact"),
        help="Prompt style for the walk probe. user-minimal uses the exact short Chinese prompt approved by the user.",
    )
    parser.add_argument(
        "--custom-prompt",
        default="",
        help="Optional exact prompt template. Use {frame} to insert the 1-based frame number.",
    )
    parser.add_argument(
        "--skip-transparency",
        action="store_true",
        help="Keep Qwen raw white-background frames and skip local alpha cleanup, atlas QA, and semantic QA.",
    )
    parser.add_argument("--base-color", default="blue-grey and white")
    parser.add_argument("--eye-color", default="copper")
    args = parser.parse_args()

    canonical_path = Path(args.canonical)
    if not canonical_path.exists():
        raise FileNotFoundError(canonical_path)
    canonical_bytes = canonical_path.read_bytes()
    if not canonical_bytes:
        raise ValueError("canonical image is empty")

    prompt_mode = "custom" if args.custom_prompt.strip() else args.prompt_mode
    output_dir = Path(args.output) / f"walk-{prompt_mode}-{_slugify(args.pet_name)}-{_utc_stamp()}"
    frames_dir = output_dir / "frames"
    raw_dir = output_dir / "raw"
    prompts_dir = output_dir / "prompts"
    qa_dir = output_dir / "qa"
    for path in (frames_dir, raw_dir, prompts_dir, qa_dir):
        path.mkdir(parents=True, exist_ok=True)

    client = QwenImageGenerationClient.from_env()
    _ensure_qwen_image_model(client.model)

    frame_bytes: list[bytes] = []
    raw_frame_bytes: list[bytes] = []
    frame_results: list[dict[str, Any]] = []
    prompt_sha256s: list[str] = []
    transparency_reports: list[dict[str, Any]] = []
    frame_count = max(1, min(args.frames, FRAMES_PER_ACTION))
    for frame_index, phase in enumerate(WALK_PHASES[:frame_count]):
        prompt = _build_walk_prompt(
            mode=args.prompt_mode,
            custom_prompt=args.custom_prompt,
            pet_name=args.pet_name,
            frame_index=frame_index,
            phase=phase,
            notes=args.notes,
            base_color=args.base_color,
            eye_color=args.eye_color,
        )
        (prompts_dir / f"walk-{frame_index}.txt").write_text(prompt, encoding="utf-8")
        content: list[dict[str, Any]] = []
        content.extend(
            [
                {"image": _to_data_url(canonical_bytes, "image/png")},
                {"text": prompt},
            ]
        )
        result = client._generate_image(
            content=content,
            prompt=prompt,
            size=args.size,
            model=client.edit_model,
            negative_prompt=NEGATIVE_PROMPT if args.prompt_mode == "compact" else "",
        )
        raw_image_bytes = result["image_bytes"]
        (raw_dir / f"walk-{frame_index}.png").write_bytes(raw_image_bytes)
        raw_frame_bytes.append(raw_image_bytes)
        if args.skip_transparency:
            transparency = {"applied": False, "alpha_capable": None, "reason": "skipped_by_probe"}
        else:
            alpha_image_bytes, transparency = ensure_png_alpha(
                raw_image_bytes,
                chroma_key=(0, 0, 255),
                chroma_threshold=180,
            )
            (frames_dir / f"walk-{frame_index}.png").write_bytes(alpha_image_bytes)
            frame_bytes.append(alpha_image_bytes)
        frame_results.append(
            {
                "frame": frame_index,
                "request_id": result.get("request_id"),
                "model": result.get("model"),
                "usage": result.get("usage", {}),
                "prompt_sha256": _sha256(prompt),
            }
        )
        prompt_sha256s.append(_sha256(prompt))
        transparency_reports.append(transparency)

    raw_contact_strip = _compose_raw_contact_strip(raw_frame_bytes)
    (qa_dir / "walk-raw-contact-strip.png").write_bytes(raw_contact_strip)
    if args.skip_transparency:
        walk_frame_qa = {"ok": None, "skipped": True, "reason": "transparency_disabled_by_probe"}
        semantic_qa = {"ok": None, "skipped": True, "reason": "transparency_disabled_by_probe"}
        previews = {}
    else:
        strip_bytes = _compose_walk_strip(frame_bytes)
        (qa_dir / "walk-contact-strip.png").write_bytes(strip_bytes)
        atlas = compose_action_frame_atlas({"walk": frame_bytes})
        (qa_dir / "walk-atlas.png").write_bytes(atlas.image_bytes)
        walk_frame_qa = {
            **atlas.qa,
            "rows": [row for row in atlas.qa.get("rows", []) if row.get("action") == "walk"],
        }
        semantic_qa = evaluate_action_semantics(walk_frame_qa)
        previews = write_animation_previews(action_frames={"walk": frame_bytes}, output_dir=qa_dir / "previews")
    _write_json(qa_dir / "frame-qa.json", walk_frame_qa)
    _write_json(qa_dir / "action-semantic-qa.json", semantic_qa)
    _write_json(
        output_dir / "manifest.json",
        {
            "kind": "qwen_walk_compact_prompt_probe",
            "pet_name": args.pet_name,
            "prompt_mode": prompt_mode,
            "canonical_path": str(canonical_path),
            "model": client.edit_model,
            "size": args.size,
            "negative_prompt": NEGATIVE_PROMPT if args.prompt_mode == "compact" else "",
            "transparency_mode": "disabled" if args.skip_transparency else "edge_background_flood_fill",
            "prompt_sha256s": prompt_sha256s,
            "frames": frame_results,
            "transparency": transparency_reports,
            "frame_qa": walk_frame_qa,
            "action_semantic_qa": semantic_qa,
            "outputs": {
                "contact_strip": None if args.skip_transparency else "qa/walk-contact-strip.png",
                "raw_contact_strip": "qa/walk-raw-contact-strip.png",
                "atlas": None if args.skip_transparency else "qa/walk-atlas.png",
                "previews": previews,
                "frames": None if args.skip_transparency else "frames",
                "raw": "raw",
                "prompts": "prompts",
            },
        },
    )
    print(
        json.dumps(
            {
                "status": "ok",
                "output_dir": str(output_dir.resolve()),
                "contact_strip": None if args.skip_transparency else str((qa_dir / "walk-contact-strip.png").resolve()),
                "raw_contact_strip": str((qa_dir / "walk-raw-contact-strip.png").resolve()),
                "walk_atlas": None if args.skip_transparency else str((qa_dir / "walk-atlas.png").resolve()),
                "frame_qa": str((qa_dir / "frame-qa.json").resolve()),
                "action_semantic_qa": str((qa_dir / "action-semantic-qa.json").resolve()),
                "semantic_ok": semantic_qa.get("ok"),
                "request_ids": [result.get("request_id") for result in frame_results],
            },
            ensure_ascii=False,
            indent=2,
        )
    )
    return 0


def _build_walk_prompt(
    *,
    mode: str,
    custom_prompt: str = "",
    pet_name: str,
    frame_index: int,
    phase: str,
    notes: str,
    base_color: str,
    eye_color: str,
) -> str:
    if custom_prompt.strip():
        return custom_prompt.strip().format(frame=frame_index + 1)

    if mode == "user-minimal":
        return f"基于这张图，生成同一角色的正在往右走路的动画第 {frame_index + 1} 帧。保留角色主要外貌特征。背景纯白色干净无杂物。"

    if mode == "user-side-steps":
        step_description = USER_SIDE_STEP_DESCRIPTIONS[frame_index % len(USER_SIDE_STEP_DESCRIPTIONS)]
        return (
            f"基于这张图，生成同一角色的右侧侧面正在往右走路的动画第 {frame_index + 1} 帧。"
            f"身体横向、面部朝右。保留主要外貌特征。纯白底图片干净无杂物。{step_description}"
        )

    clean_notes = " ".join(notes.strip().split())[:80]
    return (
        "Create one full-body frame for a quiet desktop pet walking loop. "
        f"Pet: {pet_name}. Frame {frame_index + 1}/4. "
        f"Requested pose: {phase}. "
        "Use the attached canonical pet image as the identity source; keep the same face, coat markings, body type, and calm expression. "
        f"Visual keywords: side-view walking cat, alternating paws, same pet identity, {base_color} coat, {eye_color} eyes, "
        "single pet, full body, high-fidelity 2D, flat pure blue #0000FF chroma-key background. "
        "Avoid: sitting, loafing, front-facing pose, pixel art, text, scenery, transparent pixels, motion lines. "
        f"Owner cue only: {clean_notes}."
    )


def _compose_walk_strip(frames: list[bytes]) -> bytes:
    strip_width = CELL_WIDTH * FRAMES_PER_ACTION
    strip_height = CELL_HEIGHT
    pixels = [(0, 0, 0, 0)] * (strip_width * strip_height)
    for frame_index, image_bytes in enumerate(frames[:FRAMES_PER_ACTION]):
        components = _extract_components(image_bytes)
        if not components:
            continue
        component = max(components, key=lambda item: item["foreground_pixels"])
        fitted = _fit_component_to_cell(component["pixels"], component["width"], component["height"])
        for y in range(CELL_HEIGHT):
            for x in range(CELL_WIDTH):
                pixel = fitted[y * CELL_WIDTH + x]
                if pixel[3] == 0:
                    continue
                target_x = frame_index * CELL_WIDTH + x
                pixels[y * strip_width + target_x] = pixel
    return _write_rgba_png(strip_width, strip_height, pixels)


def _compose_raw_contact_strip(frames: list[bytes]) -> bytes:
    slot_width = 240
    slot_height = 240
    frame_count = max(1, min(len(frames), FRAMES_PER_ACTION))
    strip_width = slot_width * frame_count
    pixels = [(255, 255, 255, 255)] * (strip_width * slot_height)
    for frame_index, image_bytes in enumerate(frames[:frame_count]):
        try:
            source_width, source_height, source_pixels = _read_rgba_png(image_bytes)
        except ValueError:
            continue
        scale = min(slot_width / source_width, slot_height / source_height)
        target_width = max(1, int(source_width * scale))
        target_height = max(1, int(source_height * scale))
        left = frame_index * slot_width + (slot_width - target_width) // 2
        top = (slot_height - target_height) // 2
        for y in range(target_height):
            source_y = min(source_height - 1, int(y / scale))
            for x in range(target_width):
                source_x = min(source_width - 1, int(x / scale))
                red, green, blue, alpha = source_pixels[source_y * source_width + source_x]
                if alpha < 255:
                    opacity = alpha / 255
                    red = int(red * opacity + 255 * (1 - opacity))
                    green = int(green * opacity + 255 * (1 - opacity))
                    blue = int(blue * opacity + 255 * (1 - opacity))
                pixels[(top + y) * strip_width + left + x] = (red, green, blue, 255)
    return _write_rgba_png(strip_width, slot_height, pixels)


def _slugify(value: str) -> str:
    safe = "".join(char if char.isalnum() or char in "-_" else "-" for char in value.strip().lower())
    return "-".join(part for part in safe.split("-") if part)[:80] or "pet"


def _utc_stamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")


def _sha256(value: str) -> str:
    return hashlib.sha256(value.encode("utf-8")).hexdigest()


def _write_json(path: Path, payload: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")


if __name__ == "__main__":
    raise SystemExit(main())
