#!/usr/bin/env python3
import argparse
import json
import os
import shutil
import socket
import subprocess
import sys
import tempfile
from pathlib import Path


ROOT = Path(__file__).resolve().parents[1]
MIGRATION_DIR_LABEL = "db/supabase/migrations"
MIGRATION_DIR = ROOT / MIGRATION_DIR_LABEL
EXPECTED_MIGRATION_FILENAMES = [
    "20260702_0001_cloud_contract.sql",
    "20260702_0002_postgrest_role_grants.sql",
]


BOOTSTRAP_SQL = """
create schema if not exists auth;
create role authenticated;
create role service_role;
create or replace function auth.uid()
returns uuid
language sql
stable
as $$
  select nullif(current_setting('request.jwt.claim.sub', true), '')::uuid
$$;
"""


GRANT_SQL = """
grant usage on schema public to authenticated;
grant usage on schema auth to authenticated;
grant select on all tables in schema public to authenticated;
grant execute on all functions in schema public to authenticated;
grant execute on all functions in schema auth to authenticated;
"""


SEED_SQL = """
insert into public.users (
  user_id, auth_uid, email, privacy_settings
) values
  ('usr_owner', '11111111-1111-1111-1111-111111111111', 'owner@example.test', '{}'::jsonb),
  ('usr_other', '22222222-2222-2222-2222-222222222222', 'other@example.test', '{}'::jsonb);

insert into public.pet_profiles (
  pet_id, user_id, pet_name, species, status
) values
  ('pet_owner', 'usr_owner', 'Owner pet', 'cat', 'generation_allowed'),
  ('pet_other', 'usr_other', 'Other pet', 'cat', 'generation_allowed');

insert into public.beta_invites (
  invite_id, code_hash, grant_plan_code, expires_at
) values (
  'inv_smoke', 'hash_smoke', 'sixxie_monthly_launch', now() + interval '30 days'
);

insert into public.webhook_events (
  webhook_event_id, provider, provider_event_id, event_type, status
) values (
  'wh_smoke', 'stripe', 'evt_webhook_smoke', 'checkout.session.completed', 'processed'
);

insert into public.audit_logs (
  audit_log_id, actor_type, actor_id, action, target_type, target_id
) values (
  'aud_smoke', 'system', 'worker', 'smoke_test', 'migration', '20260702_0001_cloud_contract'
);

insert into public.subscriptions (
  subscription_id,
  user_id,
  plan_code,
  provider,
  source,
  status,
  current_period_start,
  current_period_end,
  access_ends_at
) values (
  'sub_owner',
  'usr_owner',
  'sixxie_monthly_launch',
  'stripe',
  'checkout',
  'active',
  now(),
  now() + interval '30 days',
  now() + interval '30 days'
);

insert into public.orders (
  order_id,
  user_id,
  pet_id,
  provider,
  subscription_id,
  client_surface,
  provider_checkout_session_id,
  provider_event_id,
  status,
  currency,
  amount_total_cents,
  product_code,
  paid_at
) values (
  'ord_owner',
  'usr_owner',
  'pet_owner',
  'stripe',
  'sub_owner',
  'desktop_app',
  'cs_smoke',
  'evt_smoke',
  'paid',
  'usd',
  999,
  'sixxie_monthly_launch',
  now()
);

insert into public.entitlements (
  entitlement_id,
  user_id,
  subscription_id,
  order_id,
  pet_id,
  type,
  status,
  quantity_total,
  quantity_used,
  source_type,
  source_id,
  starts_at,
  expires_at
) values
  (
    'ent_owner_runtime',
    'usr_owner',
    'sub_owner',
    'ord_owner',
    'pet_owner',
    'package_runtime_access',
    'active',
    1,
    0,
    'stripe',
    'ord_owner',
    now(),
    now() + interval '30 days'
  ),
  (
    'ent_other_runtime',
    'usr_other',
    null,
    null,
    'pet_other',
    'package_runtime_access',
    'active',
    1,
    0,
    'manual',
    'manual_smoke',
    now(),
    now() + interval '30 days'
  );
"""


RLS_ASSERT_SQL = """
set row_security = on;
set role authenticated;
select set_config('request.jwt.claim.sub', '11111111-1111-1111-1111-111111111111', false);

do $$
declare
  current_user_id text;
  own_pet_count integer;
  other_pet_count integer;
  beta_invites_visible integer;
  price_plans_visible integer;
  webhook_events_visible integer;
  audit_logs_visible integer;
  orders_visible integer;
  package_runtime_access_visible integer;
begin
  select public.current_app_user_id() into current_user_id;
  select count(*) into own_pet_count
    from public.pet_profiles
    where user_id = 'usr_owner';
  select count(*) into other_pet_count
    from public.pet_profiles
    where user_id = 'usr_other';
  select count(*) into beta_invites_visible
    from public.beta_invites;
  select count(*) into price_plans_visible
    from public.price_plans;
  select count(*) into webhook_events_visible
    from public.webhook_events;
  select count(*) into audit_logs_visible
    from public.audit_logs;
  select count(*) into orders_visible
    from public.orders;
  select count(*) into package_runtime_access_visible
    from public.entitlements
    where type = 'package_runtime_access';

  if current_user_id <> 'usr_owner' then
    raise exception 'current_app_user_id() returned %, expected usr_owner', current_user_id;
  end if;
  if own_pet_count <> 1 then
    raise exception 'own_pet_count expected 1, got %', own_pet_count;
  end if;
  if other_pet_count <> 0 then
    raise exception 'other_pet_count expected 0, got %', other_pet_count;
  end if;
  if beta_invites_visible <> 0 then
    raise exception 'beta_invites_visible expected 0, got %', beta_invites_visible;
  end if;
  if price_plans_visible <> 0 then
    raise exception 'price_plans_visible expected 0, got %', price_plans_visible;
  end if;
  if webhook_events_visible <> 0 then
    raise exception 'webhook_events_visible expected 0, got %', webhook_events_visible;
  end if;
  if audit_logs_visible <> 0 then
    raise exception 'audit_logs_visible expected 0, got %', audit_logs_visible;
  end if;
  if orders_visible <> 1 then
    raise exception 'orders_visible expected 1, got %', orders_visible;
  end if;
  if package_runtime_access_visible <> 1 then
    raise exception 'package_runtime_access_visible expected 1, got %', package_runtime_access_visible;
  end if;
end $$;
reset role;
"""


def _which(executable: str, pg_bin: Path | None) -> str:
    if pg_bin is not None:
        candidate = pg_bin / executable
        if candidate.exists():
            return str(candidate)
    resolved = shutil.which(executable)
    if resolved:
        return resolved
    raise FileNotFoundError(
        f"Could not find {executable}. Pass --pg-bin or set SIXXIE_PG_BIN."
    )


def _free_port() -> int:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.bind(("127.0.0.1", 0))
        return int(sock.getsockname()[1])


def _run(
    cmd: list[str],
    *,
    env: dict[str, str] | None = None,
    input_text: str | None = None,
) -> subprocess.CompletedProcess[str]:
    completed = subprocess.run(
        cmd,
        input=input_text,
        text=True,
        capture_output=True,
        env=env,
        check=False,
    )
    if completed.returncode != 0:
        raise RuntimeError(
            "Command failed: "
            + " ".join(cmd)
            + "\nSTDOUT:\n"
            + completed.stdout
            + "\nSTDERR:\n"
            + completed.stderr
        )
    return completed


def _psql(
    psql: str,
    env: dict[str, str],
    *,
    sql: str | None = None,
    file_path: Path | None = None,
) -> None:
    cmd = [psql, "-X", "-v", "ON_ERROR_STOP=1", "-q"]
    if file_path is not None:
        cmd.extend(["-f", str(file_path)])
        _run(cmd, env=env)
        return
    if sql is None:
        raise ValueError("sql or file_path is required")
    cmd.extend(["-c", sql])
    _run(cmd, env=env)


def _migration_paths() -> list[Path]:
    migration_paths = sorted(MIGRATION_DIR.glob("*.sql"))
    if not migration_paths:
        raise FileNotFoundError(f"Missing migration files in: {MIGRATION_DIR}")
    filenames = {path.name for path in migration_paths}
    missing = [
        filename for filename in EXPECTED_MIGRATION_FILENAMES if filename not in filenames
    ]
    if missing:
        raise FileNotFoundError(
            "Missing expected migration files: " + ", ".join(missing)
        )
    return migration_paths


def validate(pg_bin: Path | None, keep_tmp: bool) -> dict[str, object]:
    migration_paths = _migration_paths()

    initdb = _which("initdb", pg_bin)
    pg_ctl = _which("pg_ctl", pg_bin)
    psql = _which("psql", pg_bin)
    port = _free_port()

    tmp = tempfile.TemporaryDirectory(prefix="sixxie-supabase-smoke-")
    tmp_path = Path(tmp.name)
    data_dir = tmp_path / "pgdata"
    socket_dir = tmp_path / "socket"
    socket_dir.mkdir()
    log_path = tmp_path / "postgres.log"
    started = False

    env = os.environ.copy()
    env.update(
        {
            "PGHOST": str(socket_dir),
            "PGPORT": str(port),
            "PGUSER": "postgres",
            "PGDATABASE": "postgres",
        }
    )

    try:
        _run([initdb, "-D", str(data_dir), "-A", "trust", "-U", "postgres"])
        _run(
            [
                pg_ctl,
                "-D",
                str(data_dir),
                "-l",
                str(log_path),
                "-o",
                f"-k {socket_dir} -p {port} -c listen_addresses=''",
                "start",
                "-w",
            ]
        )
        started = True

        _psql(psql, env, sql=BOOTSTRAP_SQL)
        for migration_path in migration_paths:
            _psql(psql, env, file_path=migration_path)
        _psql(psql, env, sql=GRANT_SQL)
        _psql(psql, env, sql=SEED_SQL)
        _psql(psql, env, sql=RLS_ASSERT_SQL)

        return {
            "status": "ok",
            "migrations": [str(path.relative_to(ROOT)) for path in migration_paths],
            "postgres_bin": str(Path(psql).parent),
            "checks": [
                "migration_applied",
                "auth_uid_mapping",
                "authenticated_rls_owner_isolation",
                "sensitive_service_tables_hidden",
                "runtime_entitlement_visible_to_owner",
            ],
            "tmp_dir": str(tmp_path) if keep_tmp else None,
        }
    finally:
        if started:
            _run([pg_ctl, "-D", str(data_dir), "stop", "-m", "fast", "-w"], env=env)
        if not keep_tmp:
            tmp.cleanup()


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Run a local Postgres smoke test for the Supabase migration."
    )
    parser.add_argument(
        "--pg-bin",
        default=os.environ.get("SIXXIE_PG_BIN"),
        help="Directory containing initdb, pg_ctl, and psql.",
    )
    parser.add_argument(
        "--keep-tmp",
        action="store_true",
        help="Keep the temporary Postgres data directory for debugging.",
    )
    args = parser.parse_args()

    pg_bin = Path(args.pg_bin).expanduser().resolve() if args.pg_bin else None
    result = validate(pg_bin, args.keep_tmp)
    print(json.dumps(result, ensure_ascii=False, indent=2, sort_keys=True))
    return 0


if __name__ == "__main__":
    try:
        raise SystemExit(main())
    except Exception as exc:
        print(json.dumps({"status": "failed", "error": str(exc)}, ensure_ascii=False), file=sys.stderr)
        raise SystemExit(1)
