#!/usr/bin/env python3
#
# Copyright (c) 2025 Rackslab
#
# This file is part of Slurm-web.
#
# SPDX-License-Identifier: MIT

"""Developer utility to compare slurmrestd OpenAPI descriptions between API versions.

This script loads OpenAPI descriptions for a single slurmrestd version and compares
all successive API versions, reporting differences in endpoints and schema structures.
"""

from __future__ import annotations

import argparse
import json
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any

# Path to assets directory
ASSETS_DIR = Path(__file__).parent.parent / "tests" / "assets" / "slurmrestd"


def load_openapi(slurm_version: str) -> dict[str, Any]:
    """Load OpenAPI JSON file for a given slurmrestd version.

    Args:
        slurm_version: Slurm version (e.g., "25.11")

    Returns:
        Parsed OpenAPI specification

    Raises:
        FileNotFoundError: If OpenAPI file doesn't exist
        json.JSONDecodeError: If file is not valid JSON
    """
    openapi_path = ASSETS_DIR / slurm_version / "openapi-v3.json"
    if not openapi_path.exists():
        raise FileNotFoundError(f"OpenAPI file not found: {openapi_path}")

    with open(openapi_path) as f:
        return json.load(f)


def extract_api_versions(openapi: dict[str, Any]) -> list[str]:
    """Extract all API versions from OpenAPI paths.

    Args:
        openapi: OpenAPI specification

    Returns:
        Sorted list of API versions found in paths
    """
    versions = set()
    paths = openapi.get("paths", {})

    # Pattern to match version in path: /slurm/v0.0.41/... or /slurmdb/v0.0.42/...
    version_pattern = re.compile(r"/v(\d+\.\d+\.\d+)/")

    for path in paths.keys():
        match = version_pattern.search(path)
        if match:
            versions.add(match.group(1))

    return sorted(versions, key=lambda v: tuple(map(int, v.split("."))))


def normalize_path(path: str) -> str:
    """Normalize path by removing API version segment.

    Args:
        path: Original path (e.g., "/slurm/v0.0.41/jobs")

    Returns:
        Normalized path (e.g., "/slurm/v{version}/jobs")
    """
    # Replace version pattern with placeholder
    normalized = re.sub(r"/v\d+\.\d+\.\d+/", "/v{version}/", path)
    return normalized


def resolve_ref(
    openapi: dict[str, Any], ref: str, visited: set[str] | None = None
) -> dict[str, Any]:
    """Resolve $ref reference recursively.

    Args:
        openapi: OpenAPI specification
        ref: Reference string (e.g., "#/components/schemas/v0.0.41_job_resp")
        visited: Set of already visited references to prevent circular references

    Returns:
        Resolved schema object

    Raises:
        KeyError: If reference cannot be resolved
    """
    if visited is None:
        visited = set()

    if not ref.startswith("#/"):
        raise ValueError(f"Invalid reference format: {ref}")

    # Prevent circular references
    if ref in visited:
        return {"$ref": ref, "_circular": True}

    visited.add(ref)

    # Navigate through the reference path
    parts = ref[2:].split("/")  # Remove "#/" prefix
    current = openapi

    for part in parts:
        if not isinstance(current, dict) or part not in current:
            raise KeyError(f"Cannot resolve reference: {ref} (missing {part})")
        current = current[part]

    # If the resolved object has a $ref, resolve it recursively
    if isinstance(current, dict) and "$ref" in current:
        resolved = resolve_ref(openapi, current["$ref"], visited)
        # Merge other properties from the current object
        result = {**current, **resolved}
        result.pop("$ref", None)
        return result

    visited.remove(ref)
    return current


def resolve_schema(openapi: dict[str, Any], schema: dict[str, Any]) -> dict[str, Any]:
    """Resolve all $ref references in a schema recursively.

    Args:
        openapi: OpenAPI specification
        schema: Schema object that may contain $ref

    Returns:
        Fully resolved schema
    """
    if not isinstance(schema, dict):
        return schema

    # Resolve $ref if present
    if "$ref" in schema:
        resolved = resolve_ref(openapi, schema["$ref"])
        # Merge other properties
        result = {**schema, **resolved}
        result.pop("$ref", None)
        schema = result

    # Recursively resolve nested schemas
    if "properties" in schema:
        schema = {**schema}
        schema["properties"] = {
            key: resolve_schema(openapi, value)
            for key, value in schema["properties"].items()
        }

    if "items" in schema:
        schema = {**schema}
        schema["items"] = resolve_schema(openapi, schema["items"])

    if "allOf" in schema:
        schema = {**schema}
        schema["allOf"] = [resolve_schema(openapi, item) for item in schema["allOf"]]

    if "anyOf" in schema:
        schema = {**schema}
        schema["anyOf"] = [resolve_schema(openapi, item) for item in schema["anyOf"]]

    if "oneOf" in schema:
        schema = {**schema}
        schema["oneOf"] = [resolve_schema(openapi, item) for item in schema["oneOf"]]

    return schema


def get_schema_type(schema: dict[str, Any]) -> str:
    """Extract type information from schema.

    Args:
        schema: Schema object

    Returns:
        Type string representation
    """
    if "type" in schema:
        base_type = schema["type"]
        if base_type == "array" and "items" in schema:
            items_type = get_schema_type(schema["items"])
            return f"array[{items_type}]"
        return base_type
    if "$ref" in schema:
        return f"ref({schema['$ref']})"
    if "allOf" in schema:
        return "allOf"
    if "anyOf" in schema:
        return "anyOf"
    if "oneOf" in schema:
        return "oneOf"
    return "object"


def compare_schemas(
    schema1: dict[str, Any],
    schema2: dict[str, Any],
    path: str = "",
    openapi1: dict[str, Any] | None = None,
    openapi2: dict[str, Any] | None = None,
) -> list[tuple[str, str]]:
    """Compare two schemas recursively and return differences.

    Args:
        schema1: First schema
        schema2: Second schema
        path: Current path in schema (for reporting)
        openapi1: First OpenAPI spec (for resolving refs)
        openapi2: Second OpenAPI spec (for resolving refs)

    Returns:
        List of (prefix, description) tuples for differences
    """
    differences = []

    # Resolve schemas if OpenAPI specs are provided
    if openapi1:
        schema1 = resolve_schema(openapi1, schema1)
    if openapi2:
        schema2 = resolve_schema(openapi2, schema2)

    # Compare types
    type1 = get_schema_type(schema1)
    type2 = get_schema_type(schema2)
    if type1 != type2:
        differences.append(("~", f"{path}.type ({type1} → {type2})"))

    # Compare properties for objects
    props1 = schema1.get("properties", {})
    props2 = schema2.get("properties", {})

    # Properties only in schema1 (removed)
    for prop in set(props1.keys()) - set(props2.keys()):
        prop_path = f"{path}.properties.{prop}" if path else f"properties.{prop}"
        differences.append(("-", prop_path))

    # Properties only in schema2 (added)
    for prop in set(props2.keys()) - set(props1.keys()):
        prop_path = f"{path}.properties.{prop}" if path else f"properties.{prop}"
        differences.append(("+", prop_path))

    # Compare common properties recursively
    for prop in set(props1.keys()) & set(props2.keys()):
        prop_path = f"{path}.properties.{prop}" if path else f"properties.{prop}"
        prop_diffs = compare_schemas(
            props1[prop], props2[prop], prop_path, openapi1, openapi2
        )
        differences.extend(prop_diffs)

    # Compare required fields
    required1 = set(schema1.get("required", []))
    required2 = set(schema2.get("required", []))
    if required1 != required2:
        removed_required = required1 - required2
        added_required = required2 - required1
        for field in removed_required:
            differences.append(
                ("-", f"{path}.required.{field}" if path else f"required.{field}")
            )
        for field in added_required:
            differences.append(
                ("+", f"{path}.required.{field}" if path else f"required.{field}")
            )

    # Compare items for arrays
    if "items" in schema1 and "items" in schema2:
        items_path = f"{path}.items" if path else "items"
        items_diffs = compare_schemas(
            schema1["items"], schema2["items"], items_path, openapi1, openapi2
        )
        differences.extend(items_diffs)
    elif "items" in schema1:
        differences.append(("-", f"{path}.items" if path else "items"))
    elif "items" in schema2:
        differences.append(("+", f"{path}.items" if path else "items"))

    return differences


def extract_endpoint_schema(
    endpoint: dict[str, Any], openapi: dict[str, Any], schema_type: str
) -> dict[str, Any] | None:
    """Extract and resolve schema from endpoint (requestBody or response).

    Args:
        endpoint: Endpoint operation object
        openapi: OpenAPI specification
        schema_type: "requestBody" or response code (e.g., "200")

    Returns:
        Resolved schema or None if not found
    """
    if schema_type == "requestBody":
        request_body = endpoint.get("requestBody")
        if not request_body:
            return None
        content = request_body.get("content", {})
        json_content = content.get("application/json")
        if not json_content:
            return None
        schema = json_content.get("schema")
        if not schema:
            return None
        return resolve_schema(openapi, schema)
    else:
        # Response
        responses = endpoint.get("responses", {})
        response = responses.get(schema_type)
        if not response:
            return None
        content = response.get("content", {})
        json_content = content.get("application/json")
        if not json_content:
            return None
        schema = json_content.get("schema")
        if not schema:
            return None
        return resolve_schema(openapi, schema)


def compare_endpoints(
    endpoints1: dict[str, dict[str, Any]],
    endpoints2: dict[str, dict[str, Any]],
    openapi1: dict[str, Any],
    openapi2: dict[str, Any],
) -> dict[str, list[tuple[str, str]]]:
    """Compare two sets of endpoints.

    Args:
        endpoints1: First set of endpoints (normalized_path -> {method -> operation})
        endpoints2: Second set of endpoints
        openapi1: First OpenAPI specification
        openapi2: Second OpenAPI specification

    Returns:
        Dictionary mapping "VERB endpoint" to list of differences
    """
    differences: dict[str, list[tuple[str, str]]] = defaultdict(list)

    all_endpoints = set(endpoints1.keys()) | set(endpoints2.keys())

    for normalized_path in sorted(all_endpoints):
        methods1 = endpoints1.get(normalized_path, {})
        methods2 = endpoints2.get(normalized_path, {})

        all_methods = set(methods1.keys()) | set(methods2.keys())

        for method in sorted(all_methods):
            endpoint_key = f"{method.upper()} {normalized_path}"

            if method not in methods1:
                differences[endpoint_key].append(("+", "endpoint"))
                continue
            if method not in methods2:
                differences[endpoint_key].append(("-", "endpoint"))
                continue

            op1 = methods1[method]
            op2 = methods2[method]

            # Compare requestBody
            req_schema1 = extract_endpoint_schema(op1, openapi1, "requestBody")
            req_schema2 = extract_endpoint_schema(op2, openapi2, "requestBody")

            if req_schema1 is None and req_schema2 is not None:
                differences[endpoint_key].append(("+", "requestBody"))
            elif req_schema1 is not None and req_schema2 is None:
                differences[endpoint_key].append(("-", "requestBody"))
            elif req_schema1 is not None and req_schema2 is not None:
                req_diffs = compare_schemas(
                    req_schema1, req_schema2, "requestBody", openapi1, openapi2
                )
                differences[endpoint_key].extend(req_diffs)

            # Compare responses (ignore "default" responses)
            responses1 = op1.get("responses", {})
            responses2 = op2.get("responses", {})

            all_response_codes = set(responses1.keys()) | set(responses2.keys())
            # Filter out "default" responses
            all_response_codes = {
                code for code in all_response_codes if code != "default"
            }

            for code in sorted(all_response_codes):
                if code not in responses1:
                    differences[endpoint_key].append(("+", f"responses.{code}"))
                    continue
                if code not in responses2:
                    differences[endpoint_key].append(("-", f"responses.{code}"))
                    continue

                resp_schema1 = extract_endpoint_schema(op1, openapi1, code)
                resp_schema2 = extract_endpoint_schema(op2, openapi2, code)

                if resp_schema1 is None and resp_schema2 is not None:
                    differences[endpoint_key].append(("+", f"responses.{code}.schema"))
                elif resp_schema1 is not None and resp_schema2 is None:
                    differences[endpoint_key].append(("-", f"responses.{code}.schema"))
                elif resp_schema1 is not None and resp_schema2 is not None:
                    resp_diffs = compare_schemas(
                        resp_schema1,
                        resp_schema2,
                        f"responses.{code}",
                        openapi1,
                        openapi2,
                    )
                    differences[endpoint_key].extend(resp_diffs)

            # Compare parameters
            params1 = {p.get("name"): p for p in op1.get("parameters", [])}
            params2 = {p.get("name"): p for p in op2.get("parameters", [])}

            for param_name in set(params1.keys()) - set(params2.keys()):
                differences[endpoint_key].append(("-", f"parameters.{param_name}"))

            for param_name in set(params2.keys()) - set(params1.keys()):
                differences[endpoint_key].append(("+", f"parameters.{param_name}"))

            # Compare common parameters
            for param_name in set(params1.keys()) & set(params2.keys()):
                param1 = params1[param_name]
                param2 = params2[param_name]

                # Compare parameter schema
                schema1 = param1.get("schema", {})
                schema2 = param2.get("schema", {})
                if schema1 or schema2:
                    param_diffs = compare_schemas(
                        schema1, schema2, f"parameters.{param_name}", openapi1, openapi2
                    )
                    differences[endpoint_key].extend(param_diffs)

                # Compare required status
                if param1.get("required") != param2.get("required"):
                    status1 = "required" if param1.get("required") else "optional"
                    status2 = "required" if param2.get("required") else "optional"
                    differences[endpoint_key].append(
                        (
                            "~",
                            f"parameters.{param_name}.required ({status1} → {status2})",
                        )
                    )

    return differences


def build_endpoint_map(
    openapi: dict[str, Any], api_version: str
) -> dict[str, dict[str, Any]]:
    """Build a map of normalized endpoints for a specific API version.

    Args:
        openapi: OpenAPI specification
        api_version: API version to extract (e.g., "0.0.41")

    Returns:
        Dictionary mapping normalized_path -> {method -> operation}
    """
    endpoints: dict[str, dict[str, Any]] = defaultdict(dict)
    paths = openapi.get("paths", {})

    version_pattern = re.compile(rf"/v{re.escape(api_version)}/")

    for path, path_item in paths.items():
        # Only process paths for this API version
        if not version_pattern.search(path):
            continue

        normalized = normalize_path(path)

        # Extract operations from path item
        for method in [
            "get",
            "post",
            "put",
            "patch",
            "delete",
            "head",
            "options",
            "trace",
        ]:
            if method in path_item:
                endpoints[normalized][method] = path_item[method]

    return endpoints


def format_diff(differences: dict[str, list[tuple[str, str]]]) -> str:
    """Format differences with unified diff style.

    Args:
        differences: Dictionary mapping endpoint to list of (prefix, description) tuples

    Returns:
        Formatted string
    """
    if not differences:
        return "  No differences found.\n"

    lines = []
    for endpoint in sorted(differences.keys()):
        endpoint_diffs = differences[endpoint]
        if not endpoint_diffs:
            continue

        lines.append(f"\n{endpoint}")
        for prefix, desc in endpoint_diffs:
            lines.append(f"  {prefix} {desc}")

    return "\n".join(lines)


def get_latest_slurm_version() -> str:
    """Get the latest slurmrestd version from available assets.

    Returns:
        Latest version string (e.g., "25.11")
    """
    if not ASSETS_DIR.exists():
        raise FileNotFoundError(f"Assets directory not found: {ASSETS_DIR}")

    versions = []
    for path in ASSETS_DIR.iterdir():
        if path.is_dir() and (path / "openapi-v3.json").exists():
            versions.append(path.name)

    if not versions:
        raise ValueError("No slurmrestd versions found")

    # Sort versions (assuming format like "24.05", "24.11", "25.05", "25.11")
    versions.sort(key=lambda v: tuple(map(int, v.split("."))), reverse=True)
    return versions[0]


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Compare slurmrestd OpenAPI descriptions between API versions",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Compare all successive API versions for latest slurmrestd version
  %(prog)s

  # Compare API versions for specific slurmrestd version
  %(prog)s --slurm-version 24.11

  # Verbose output
  %(prog)s --verbose
        """,
    )

    parser.add_argument(
        "--slurm-version",
        help="Slurm version (default: latest available)",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Show detailed output",
    )

    args = parser.parse_args()

    # Determine slurmrestd version
    if args.slurm_version:
        slurm_version = args.slurm_version
    else:
        try:
            slurm_version = get_latest_slurm_version()
            if args.verbose:
                print(f"Using latest slurmrestd version: {slurm_version}")
        except (FileNotFoundError, ValueError) as e:
            print(f"Error: {e}", file=sys.stderr)
            sys.exit(1)

    # Load OpenAPI specification
    try:
        openapi = load_openapi(slurm_version)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print(f"Error loading OpenAPI: {e}", file=sys.stderr)
        sys.exit(1)

    # Extract API versions
    api_versions = extract_api_versions(openapi)
    if len(api_versions) < 2:
        print(
            f"Found only {len(api_versions)} API version(s), "
            "need at least 2 for comparison"
        )
        print(f"Available versions: {', '.join(api_versions)}")
        sys.exit(0)

    if args.verbose:
        print(f"Found API versions: {', '.join(api_versions)}")

    # Compare all successive versions
    for i in range(len(api_versions) - 1):
        v1 = api_versions[i]
        v2 = api_versions[i + 1]

        print(f"\n{'=' * 80}")
        print(f"Comparing API versions: {v1} → {v2}")
        print(f"{'=' * 80}")

        # Build endpoint maps
        endpoints1 = build_endpoint_map(openapi, v1)
        endpoints2 = build_endpoint_map(openapi, v2)

        # Compare endpoints
        differences = compare_endpoints(endpoints1, endpoints2, openapi, openapi)

        # Format and print differences
        output = format_diff(differences)
        if output.strip():
            print(output)
        else:
            print("  No differences found.")


if __name__ == "__main__":
    main()
