#!/usr/bin/env python3
from __future__ import annotations

import argparse
import html
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from services.ai.qwen_client import (  # noqa: E402
    QwenImageGenerationClient,
    QwenRequestError,
    _ensure_qwen_image_model,
    _to_data_url,
)


PROMPT_TEMPLATE = "基于这张图，生成同一角色的正在往右走路的动画第 {frame} 帧。保留角色主要外貌特征。背景纯白色干净无杂物。"

QWEN_IMAGE_MODELS = [
    "qwen-image-2.0-pro-2026-04-22",
    "qwen-image-2.0-pro-2026-03-03",
    "qwen-image-2.0-2026-03-03",
    "qwen-image-2.0",
    "qwen-image-max-2025-12-30",
    "qwen-image-max",
    "qwen-image-plus",
    "qwen-image-plus-2026-01-09",
    "qwen-image",
]

QWEN_IMAGE_EDIT_MODELS = [
    "qwen-image-edit",
    "qwen-image-edit-max-2026-01-16",
    "qwen-image-edit-plus",
    "qwen-image-edit-plus-2025-10-30",
    "qwen-image-edit-plus-2025-12-15",
]


def main() -> int:
    parser = argparse.ArgumentParser(description="Run a Qwen image/edit model matrix for one pet action prompt.")
    parser.add_argument("--pet-name", required=True)
    parser.add_argument("--canonical", required=True)
    parser.add_argument("--output", default=str(ROOT / "outputs" / "pet_builds" / "qwen_model_matrix"))
    parser.add_argument("--size", default="1024*1024")
    parser.add_argument("--frames", type=int, default=1, help="Frames per model. Use 1 for a low-cost sweep, 4 for action-loop comparison.")
    parser.add_argument(
        "--family",
        default="edit",
        choices=("edit", "generation", "all"),
        help="Which model family to test.",
    )
    parser.add_argument("--include-generation", action="store_true", help="Also test qwen-image generation models, not only edit models.")
    parser.add_argument("--continue-on-account-error", action="store_true", help="Keep trying models even after account-level billing errors.")
    args = parser.parse_args()

    frames = max(1, min(args.frames, 4))
    canonical_path = Path(args.canonical)
    canonical_bytes = canonical_path.read_bytes()
    if not canonical_bytes:
        raise ValueError("canonical image is empty")

    family = "all" if args.include_generation else args.family
    models = candidate_models(family=family)
    run_dir = Path(args.output) / f"matrix-{_slugify(args.pet_name)}-{_utc_stamp()}"
    run_dir.mkdir(parents=True, exist_ok=True)
    (run_dir / "canonical.png").write_bytes(canonical_bytes)

    client = QwenImageGenerationClient.from_env()
    results: list[dict[str, Any]] = []
    account_blocked = False
    for model in models:
        model_dir = run_dir / _slugify(model)
        model_dir.mkdir(parents=True, exist_ok=True)
        record: dict[str, Any] = {
            "model": model,
            "status": "pending",
            "frames": [],
            "prompt_template": PROMPT_TEMPLATE,
        }
        try:
            _ensure_qwen_image_model(model)
        except Exception as exc:
            record.update({"status": "skipped", "error": str(exc)})
            results.append(record)
            continue

        for frame_index in range(frames):
            prompt = build_prompt(frame_index)
            (model_dir / f"prompt-{frame_index}.txt").write_text(prompt, encoding="utf-8")
            content = [
                {"image": _to_data_url(canonical_bytes, "image/png")},
                {"text": prompt},
            ]
            try:
                generated = client._generate_image(
                    content=content,
                    prompt=prompt,
                    size=args.size,
                    model=model,
                    negative_prompt="",
                )
            except QwenRequestError as exc:
                error_text = str(exc)
                record.update(
                    {
                        "status": "failed",
                        "error": error_text,
                        "error_kind": classify_error(error_text),
                    }
                )
                if record["error_kind"] == "account_error" and not args.continue_on_account_error:
                    account_blocked = True
                break

            frame_path = model_dir / f"frame-{frame_index}.png"
            frame_path.write_bytes(generated["image_bytes"])
            record["frames"].append(
                {
                    "frame": frame_index,
                    "request_id": generated.get("request_id"),
                    "model": generated.get("model"),
                    "usage": generated.get("usage", {}),
                    "path": str(frame_path.relative_to(run_dir)),
                    "prompt": prompt,
                }
            )
        if record["status"] == "pending":
            record["status"] = "ok" if len(record["frames"]) == frames else "partial"
        results.append(record)
        if account_blocked:
            break

    manifest = {
        "kind": "qwen_model_matrix",
        "pet_name": args.pet_name,
        "canonical_path": str(canonical_path),
        "canonical_copy": "canonical.png",
        "size": args.size,
        "frames_per_model": frames,
        "family": family,
        "include_generation": family == "all",
        "account_blocked": account_blocked,
        "results": results,
    }
    (run_dir / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8")
    (run_dir / "matrix.html").write_text(render_html(manifest), encoding="utf-8")
    print(json.dumps({"run_dir": str(run_dir), "manifest": str(run_dir / "manifest.json"), "matrix": str(run_dir / "matrix.html"), "account_blocked": account_blocked}, ensure_ascii=False, indent=2))
    return 0


def candidate_models(*, family: str = "edit", include_generation: bool | None = None) -> list[str]:
    if include_generation is not None:
        family = "all" if include_generation else "edit"
    if family == "edit":
        models = list(QWEN_IMAGE_EDIT_MODELS)
    elif family == "generation":
        models = list(QWEN_IMAGE_MODELS)
    elif family == "all":
        models = [*QWEN_IMAGE_EDIT_MODELS, *QWEN_IMAGE_MODELS]
    else:
        raise ValueError(f"unknown model family: {family}")
    seen: set[str] = set()
    unique = []
    for model in models:
        if model in seen:
            continue
        seen.add(model)
        unique.append(model)
    return unique


def build_prompt(frame_index: int) -> str:
    return PROMPT_TEMPLATE.format(frame=frame_index + 1)


def classify_error(error_text: str) -> str:
    lowered = error_text.lower()
    if "arrearage" in lowered or "account is in good standing" in lowered or "overdue-payment" in lowered:
        return "account_error"
    if "invalid" in lowered or "not support" in lowered or "unsupported" in lowered:
        return "unsupported_model_or_request"
    return "request_error"


def render_html(manifest: dict[str, Any]) -> str:
    rows = []
    for result in manifest["results"]:
        frame_cells = []
        for frame in result.get("frames", []):
            src = html.escape(frame["path"])
            frame_cells.append(f'<figure><img src="{src}" alt="{html.escape(result["model"])} frame {frame["frame"] + 1}"><figcaption>frame {frame["frame"] + 1}</figcaption></figure>')
        if not frame_cells:
            frame_cells.append(f'<p class="error">{html.escape(result.get("error", result.get("status", "")))}</p>')
        rows.append(
            "<section>"
            f"<h2>{html.escape(result['model'])}</h2>"
            f"<p>Status: {html.escape(str(result.get('status')))}"
            f"{' / ' + html.escape(str(result.get('error_kind'))) if result.get('error_kind') else ''}</p>"
            f"<div class=\"frames\">{''.join(frame_cells)}</div>"
            "</section>"
        )
    return """<!doctype html>
<meta charset="utf-8">
<title>Qwen model matrix</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; margin: 24px; background: #f7f7f4; color: #202020; }
section { border: 1px solid #d6d6d0; background: white; margin: 16px 0; padding: 16px; border-radius: 8px; }
h1, h2 { margin: 0 0 8px; }
.frames { display: grid; grid-template-columns: repeat(4, minmax(160px, 1fr)); gap: 12px; align-items: start; }
figure { margin: 0; }
img { width: 100%; height: auto; background: white; border: 1px solid #e3e3df; }
figcaption, p { color: #555; }
.error { white-space: pre-wrap; color: #9b1c1c; }
</style>
<h1>Qwen model matrix</h1>
<p>Prompt template: """ + html.escape(str(manifest["results"][0]["prompt_template"] if manifest["results"] else PROMPT_TEMPLATE)) + """</p>
""" + "\n".join(rows)


def _slugify(value: str) -> str:
    safe = "".join(char if char.isalnum() or char in "-_" else "-" for char in value.strip().lower())
    return "-".join(part for part in safe.split("-") if part)[:90] or "model"


def _utc_stamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")


if __name__ == "__main__":
    raise SystemExit(main())
