from __future__ import annotations

import hashlib
import json
from pathlib import Path
from statistics import median
from typing import Any


def provider_call_metrics_from_build_dir(build_dir: str | Path) -> list[dict[str, Any]]:
    manifest_path = Path(build_dir) / "manifest.json"
    if not manifest_path.exists():
        return []
    try:
        manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
    except json.JSONDecodeError:
        return []
    if not isinstance(manifest, dict):
        return []
    return provider_call_metrics_from_manifest(manifest)


def provider_call_metrics_from_manifest(manifest: dict[str, Any]) -> list[dict[str, Any]]:
    source = manifest.get("source", {}) if isinstance(manifest.get("source"), dict) else {}
    provider = str(source.get("image_generation_provider") or "unknown")
    model = str(source.get("image_model") or "")
    usage = source.get("generation_usage", {}) if isinstance(source.get("generation_usage"), dict) else {}
    metrics: list[dict[str, Any]] = []
    seen: set[str] = set()

    def add_metric(*, action: str, request_id: Any, usage_value: Any) -> None:
        if not isinstance(usage_value, dict):
            return
        base_request_id = _base_request_id(str(request_id or ""))
        key = base_request_id or f"{action}:{len(metrics)}"
        if key in seen:
            return
        seen.add(key)
        metric = {
            "provider": provider,
            "model": model,
            "operation": "image_generation",
            "action": action,
            "status": "ok",
            "request_id": base_request_id,
        }
        provider_seconds = _float_or_none(usage_value.get("actual_time"))
        observed_seconds = _float_or_none(usage_value.get("observed_duration_seconds"))
        if observed_seconds is not None:
            metric["observed_duration_ms"] = round(observed_seconds * 1000)
            metric["duration_ms"] = metric["observed_duration_ms"]
            metric["duration_source"] = "observed"
        elif provider_seconds is not None:
            metric["provider_reported_duration_ms"] = round(provider_seconds * 1000)
            metric["duration_ms"] = metric["provider_reported_duration_ms"]
            metric["duration_source"] = "provider_reported"
        if provider_seconds is not None:
            metric["provider_reported_duration_ms"] = round(provider_seconds * 1000)
        metrics.append(metric)

    add_metric(
        action="canonical",
        request_id=source.get("canonical_request_id") or source.get("image_request_id"),
        usage_value=usage.get("canonical"),
    )
    action_usage = usage.get("actions", {}) if isinstance(usage.get("actions"), dict) else {}
    action_request_ids = source.get("action_request_ids", {}) if isinstance(source.get("action_request_ids"), dict) else {}
    for action, usage_items in action_usage.items():
        request_ids = action_request_ids.get(action, []) if isinstance(action_request_ids.get(action), list) else []
        if not isinstance(usage_items, list):
            continue
        for index, usage_value in enumerate(usage_items):
            request_id = request_ids[index] if index < len(request_ids) else ""
            add_metric(action=str(action), request_id=request_id, usage_value=usage_value)
    return metrics


def sanitize_provider_call_metrics(
    metrics: Any,
    *,
    user_id: str,
    pet_id: str,
    build_id: str,
    created_at: str,
) -> list[dict[str, Any]]:
    if not isinstance(metrics, list):
        return []
    sanitized: list[dict[str, Any]] = []
    for item in metrics:
        if not isinstance(item, dict):
            continue
        duration_ms = _duration_ms(item)
        provider = _short_text(item.get("provider") or "unknown", 40)
        metric = {
            "user_id": user_id,
            "pet_id": pet_id,
            "build_id": build_id,
            "provider": provider,
            "model": _short_text(item.get("model") or "", 80),
            "operation": _short_text(item.get("operation") or "image_generation", 80),
            "action": _short_text(item.get("action") or "", 40),
            "status": "ok" if str(item.get("status") or "ok").lower() == "ok" else "failed",
            "created_at": created_at,
        }
        if duration_ms is not None:
            metric["duration_ms"] = duration_ms
        provider_reported_duration_ms = _duration_ms(
            {
                "duration_ms": item.get("provider_reported_duration_ms"),
                "duration_seconds": item.get("provider_reported_duration_seconds"),
            }
        )
        observed_duration_ms = _duration_ms(
            {
                "duration_ms": item.get("observed_duration_ms"),
                "duration_seconds": item.get("observed_duration_seconds"),
            }
        )
        if provider_reported_duration_ms is not None:
            metric["provider_reported_duration_ms"] = provider_reported_duration_ms
        if observed_duration_ms is not None:
            metric["observed_duration_ms"] = observed_duration_ms
        if item.get("duration_source"):
            metric["duration_source"] = _short_text(item.get("duration_source"), 40)
        request_id = str(item.get("request_id") or "")
        if request_id:
            metric["request_id_hash"] = hashlib.sha256(request_id.encode("utf-8")).hexdigest()[:16]
        sanitized.append(metric)
    return sanitized


def sanitize_cloud_provider_call_metrics(
    metrics: Any,
    *,
    user_id: str,
    pet_id: str,
    build_id: str,
    created_at: str,
) -> list[dict[str, Any]]:
    local_metrics = sanitize_provider_call_metrics(
        metrics,
        user_id=user_id,
        pet_id=pet_id,
        build_id=build_id,
        created_at=created_at,
    )
    cloud_metrics: list[dict[str, Any]] = []
    for index, item in enumerate(local_metrics):
        cloud_metrics.append(
            {
                "provider_call_id": f"pcm_{hashlib.sha256(f'{build_id}:{index}:{created_at}'.encode('utf-8')).hexdigest()[:24]}",
                "user_id": user_id,
                "pet_id": pet_id,
                "build_id": build_id,
                "provider": _cloud_provider(item.get("provider")),
                "model": _short_text(item.get("model") or "", 80),
                "provider_request_id": item.get("request_id_hash"),
                "purpose": _cloud_metric_purpose(str(item.get("action") or "")),
                "status": "succeeded" if item.get("status") == "ok" else "failed",
                "latency_ms": item.get("duration_ms"),
                "estimated_cost_cents": 0,
                "actual_cost_cents": 0,
                "sanitized_failure_code": None if item.get("status") == "ok" else "provider_error",
                "safe_metadata": {
                    "action": _short_text(item.get("action") or "", 40),
                    "duration_source": _short_text(item.get("duration_source") or "", 40),
                },
                "created_at": created_at,
                "completed_at": created_at,
            }
        )
    return cloud_metrics


def summarize_provider_call_metrics(
    metrics: list[dict[str, Any]],
    *,
    provider: str | None = None,
    build_id: str | None = None,
) -> dict[str, Any]:
    durations = [int(item["duration_ms"]) for item in metrics if isinstance(item.get("duration_ms"), int)]
    return {
        "provider": provider or "all",
        "build_id": build_id,
        "call_count": len(metrics),
        "failure_count": sum(1 for item in metrics if item.get("status") != "ok"),
        "duration_ms": _duration_summary(durations),
        "by_provider": {
            provider_key: _duration_summary(
                [
                    int(item["duration_ms"])
                    for item in metrics
                    if item.get("provider") == provider_key and isinstance(item.get("duration_ms"), int)
                ]
            )
            for provider_key in sorted({str(item.get("provider") or "unknown") for item in metrics})
        },
    }


def _duration_summary(values: list[int]) -> dict[str, int | None]:
    if not values:
        return {"avg": None, "p50": None, "p95": None, "min": None, "max": None}
    ordered = sorted(values)
    return {
        "avg": round(sum(ordered) / len(ordered)),
        "p50": round(median(ordered)),
        "p95": _nearest_rank(ordered, 0.95),
        "min": ordered[0],
        "max": ordered[-1],
    }


def _nearest_rank(ordered_values: list[int], percentile: float) -> int:
    index = max(0, min(len(ordered_values) - 1, int(round(percentile * len(ordered_values) + 0.499999)) - 1))
    return ordered_values[index]


def _duration_ms(item: dict[str, Any]) -> int | None:
    explicit_ms = _float_or_none(item.get("duration_ms"))
    if explicit_ms is not None:
        return round(explicit_ms)
    observed_ms = _float_or_none(item.get("observed_duration_ms"))
    if observed_ms is not None:
        return round(observed_ms)
    provider_reported_ms = _float_or_none(item.get("provider_reported_duration_ms"))
    if provider_reported_ms is not None:
        return round(provider_reported_ms)
    seconds = _float_or_none(item.get("duration_seconds"))
    if seconds is None:
        seconds = _float_or_none(item.get("observed_duration_seconds"))
    if seconds is not None:
        return round(seconds * 1000)
    return None


def _base_request_id(request_id: str) -> str:
    return request_id.split(":frame-", 1)[0].split(":copy-fallback", 1)[0]


def _float_or_none(value: Any) -> float | None:
    try:
        return float(value)
    except (TypeError, ValueError):
        return None


def _short_text(value: Any, limit: int) -> str:
    return str(value or "").strip()[:limit]


def _cloud_provider(value: Any) -> str:
    provider = str(value or "openai").lower().strip()
    return provider if provider in {"openai", "qwen", "apimart"} else "openai"


def _cloud_metric_purpose(action: str) -> str:
    if action == "canonical":
        return "canonical_image"
    if action in {"visual_review", "style_review", "semantic_review"}:
        return "visual_review"
    if action == "revision":
        return "revision"
    return "action_frame"
