#!/usr/bin/env python3
#
# Copyright (c) 2024 Rackslab
#
# This file is part of Slurm-web.
#
# SPDX-License-Identifier: MIT

from pathlib import Path
import argparse
import shutil
import sys
import getpass
import logging

sys.path.insert(0, str(Path(__file__).resolve().parent / "lib"))

from rfl.log import setup_logger

from crawler.lib import (
    load_settings,
    DevelopmentHostClient,
    DevelopmentHostConnectionError,
    DevelopmentHostCluster,
    SUPPORTED_SLURMRESTD_API_VERSIONS,
)
from crawler.slurmrestd import SlurmrestdCrawler
from crawler.agent import AgentCrawler
from crawler.gateway import slurmweb_token, GatewayCrawler
from crawler.prometheus import PrometheusCrawler

from racksdb import RacksDB
from slurmweb.slurmrestd.auth import SlurmrestdAuthentifier

logger = logging.getLogger("crawl-tests-assets")

DEBUG_FLAGS = ["slurmweb", "rfl", "werkzeug", "urllib3", "crawler"]
DEV_HOST = "firehpc.dev.rackslab.io"
USER = getpass.getuser()
GATEWAY_PREFERRED_INFRASTRUCTURE = "nova"
METRICS_PREFERRED_INFRASTRUCTURE = "titan"
# Map between infrastructure names and cluster names that are visible in Slurm-web.
MAP_CLUSTER_NAMES = {}


class SimpleProgressBar:
    """Simple progress bar for TTY output, spanning full terminal width."""

    def __init__(self, total: int):
        self.total = total
        self.current = 0
        self._update_display()

    def _update_display(self):
        """Update the progress bar display."""
        # Get current terminal width (may change if terminal is resized)
        terminal_width = shutil.get_terminal_size().columns

        if self.total == 0:
            percent = 100
        else:
            percent = int((self.current / self.total) * 100)

        # Calculate available width for the progress bar
        # Reserve space for: "Progress: [", "]", " XXX% (XXX/XXX)"
        prefix = "Progress: ["
        suffix = f"] {percent}% ({self.current}/{self.total})"
        reserved_width = len(prefix) + len(suffix)
        bar_length = max(1, terminal_width - reserved_width)

        filled = int((self.current / self.total) * bar_length) if self.total > 0 else 0
        bar = "=" * filled + "-" * (bar_length - filled)
        sys.stdout.write(f"\r{prefix}{bar}{suffix}")
        sys.stdout.flush()

    def update(self, n: int = 1):
        """Update progress by n items."""
        self.current += n
        if self.current > self.total:
            self.current = self.total
        self._update_display()

    def close(self):
        """Close the progress bar."""
        sys.stdout.write("\n")
        sys.stdout.flush()


def slurmweb_cluster_name(infrastructure: str):
    return MAP_CLUSTER_NAMES.get(infrastructure, infrastructure)


def _create_cluster(
    infrastructure: str,
    dev_host: DevelopmentHostClient,
    dev_tmp_dir: Path,
) -> DevelopmentHostCluster:
    """Create and return a DevelopmentHostCluster instance for the given infrastructure.

    Args:
        infrastructure: The infrastructure/cluster name
        dev_host: Development host client
        dev_tmp_dir: Temporary directory for development environment

    Returns:
        DevelopmentHostCluster instance.
    """
    # Load agent configuration
    settings = load_settings(
        "conf/vendor/agent.yml", dev_tmp_dir, f"agent-{infrastructure}.ini"
    )
    auth = SlurmrestdAuthentifier(
        settings.slurmrestd.auth,
        settings.slurmrestd.jwt_mode,
        settings.slurmrestd.jwt_user,
        settings.slurmrestd.jwt_key,
        settings.slurmrestd.jwt_lifespan,
        settings.slurmrestd.jwt_token,
    )
    cluster = DevelopmentHostCluster(dev_host, infrastructure, settings, auth)
    return cluster


def _get_clusters_with_latest_api(
    clusters_to_process: list[str],
    dev_host: DevelopmentHostClient,
    dev_tmp_dir: Path,
    clusters_cache: dict[str, DevelopmentHostCluster],
) -> set[str]:
    """Discover which clusters support the latest slurmrestd API version.

    Args:
        clusters_to_process: List of infrastructure names to check.
        dev_host: Development host client.
        dev_tmp_dir: Temporary directory for development environment.
        clusters_cache: Dictionary mapping infrastructure names to clusters.

    Returns:
        Set of infrastructure names that support the latest API version.
    """
    latest_api_version = SUPPORTED_SLURMRESTD_API_VERSIONS[0]
    clusters_with_latest_api = set()

    for infrastructure in clusters_to_process:
        if infrastructure not in clusters_cache:
            clusters_cache[infrastructure] = _create_cluster(
                infrastructure, dev_host, dev_tmp_dir
            )
        cluster = clusters_cache[infrastructure]

        # Discover API versions for this cluster
        discovered = cluster.discover_api_versions()
        discovered_api_versions = {api_version for _, api_version in discovered}

        if latest_api_version in discovered_api_versions:
            clusters_with_latest_api.add(infrastructure)
            logger.info(
                "Cluster %s supports latest API version %s",
                infrastructure,
                latest_api_version,
            )
        else:
            logger.info(
                "Cluster %s does not support latest API version %s (supports: %s)",
                infrastructure,
                latest_api_version,
                sorted(discovered_api_versions),
            )

    return clusters_with_latest_api


def _count_total_assets(
    clusters_to_process: list[str],
    components_to_process: list[str],
    clusters_cache: dict[str, DevelopmentHostCluster],
    dev_tmp_dir: Path,
    token: str,
    clusters_with_latest_api: set[str],
    asset_filter: list[str] | None = None,
) -> int:
    """Return the number of assets that need to be crawled."""
    total_assets = 0
    for infrastructure in clusters_to_process:
        cluster = clusters_cache[infrastructure]

        # Count assets for each component
        if "slurmrestd" in components_to_process:
            discovered = cluster.discover_api_versions()
            for slurm_version, api_version in discovered:
                auth = cluster.auth
                crawler = SlurmrestdCrawler(cluster, auth, slurm_version, api_version)
                total_assets += crawler.count_assets_to_crawl(asset_filter)

        if "gateway" in components_to_process:
            if infrastructure in clusters_with_latest_api:
                gateway_crawler = GatewayCrawler(
                    token,
                    cluster,
                    infrastructure,
                    dev_tmp_dir,
                )
                total_assets += gateway_crawler.count_assets_to_crawl(asset_filter)

        # Count agent assets
        if "agent" in components_to_process:
            if infrastructure in clusters_with_latest_api:
                crawl_metrics = infrastructure == METRICS_PREFERRED_INFRASTRUCTURE
                agent_crawler = AgentCrawler(
                    cluster.settings.service.port,
                    token,
                    crawl_metrics,
                    cluster,
                )
                total_assets += agent_crawler.count_assets_to_crawl(asset_filter)

        # Count prometheus assets
        if "prometheus" in components_to_process:
            if infrastructure == METRICS_PREFERRED_INFRASTRUCTURE:
                prometheus_crawler = PrometheusCrawler(
                    cluster.settings.metrics.host.geturl(),
                    cluster.settings.metrics.job,
                    cluster,
                )
                total_assets += prometheus_crawler.count_assets_to_crawl(asset_filter)
    return total_assets


def _crawl_assets(
    clusters_to_process: list[str],
    components_to_process: list[str],
    clusters_cache: dict[str, DevelopmentHostCluster],
    dev_tmp_dir: Path,
    token: str,
    progress_bar: SimpleProgressBar | None,
    clusters_with_latest_api: set[str],
    asset_filter: list[str] | None = None,
) -> None:
    """Perform crawling for all selected infrastructures and components."""
    for infrastructure in clusters_to_process:
        cluster = clusters_cache[infrastructure]

        # Initialize crawlers and crawl
        if "slurmrestd" in components_to_process:
            discovered = cluster.discover_api_versions()
            for slurm_version, api_version in discovered:
                logger.info(
                    "Crawling assets for Slurm %s API version %s",
                    slurm_version,
                    api_version,
                )
                auth = cluster.auth
                crawler = SlurmrestdCrawler(cluster, auth, slurm_version, api_version)
                crawler.crawl_all_assets(progress_bar, asset_filter)

        if "gateway" in components_to_process:
            if infrastructure in clusters_with_latest_api:
                gateway = GatewayCrawler(
                    token,
                    cluster,
                    infrastructure,
                    dev_tmp_dir,
                )
                gateway.crawl_all_assets(progress_bar, asset_filter)
            else:
                logger.info(
                    "Skipping gateway for cluster %s "
                    "(does not support latest API version)",
                    infrastructure,
                )

        # Agent and Prometheus
        crawl_metrics = infrastructure == METRICS_PREFERRED_INFRASTRUCTURE

        if "agent" in components_to_process:
            if infrastructure in clusters_with_latest_api:
                agent_crawler = AgentCrawler(
                    cluster.settings.service.port,
                    token,
                    crawl_metrics,
                    cluster,
                )
                agent_crawler.crawl_all_assets(progress_bar, asset_filter)
            else:
                logger.info(
                    "Skipping agent for cluster %s "
                    "(does not support latest API version)",
                    infrastructure,
                )

        if "prometheus" in components_to_process:
            if crawl_metrics:
                prometheus_crawler = PrometheusCrawler(
                    cluster.settings.metrics.host.geturl(),
                    cluster.settings.metrics.job,
                    cluster,
                )
                prometheus_crawler.crawl_all_assets(progress_bar, asset_filter)
            else:
                logger.info(
                    "Skipping prometheus for cluster %s (not metrics infrastructure)",
                    infrastructure,
                )


def _run_crawl(
    clusters_to_process: list[str],
    components_to_process: list[str],
    dev_host: DevelopmentHostClient,
    dev_tmp_dir: Path,
    token: str,
    asset_filter: list[str] | None = None,
) -> None:
    """Count, crawl and dump assets with progress bar."""

    print("Discovering development clusters")
    # Initialize clusters once and cache them
    clusters_cache: dict[str, DevelopmentHostCluster] = {}
    for infrastructure in clusters_to_process:
        clusters_cache[infrastructure] = _create_cluster(
            infrastructure, dev_host, dev_tmp_dir
        )

    # Discover which clusters support the latest API version
    clusters_with_latest_api = _get_clusters_with_latest_api(
        clusters_to_process, dev_host, dev_tmp_dir, clusters_cache
    )

    if "gateway" in components_to_process or "agent" in components_to_process:
        if not clusters_with_latest_api:
            logger.warning(
                "No clusters found supporting latest API version %s. "
                "Gateway and agent assets will not be crawled.",
                SUPPORTED_SLURMRESTD_API_VERSIONS[0],
            )
        else:
            logger.info(
                "Clusters supporting latest API version %s: %s",
                SUPPORTED_SLURMRESTD_API_VERSIONS[0],
                sorted(clusters_with_latest_api),
            )

    print("Counting missing assets")
    # Count total assets to crawl
    total_assets = _count_total_assets(
        clusters_to_process,
        components_to_process,
        clusters_cache,
        dev_tmp_dir,
        token,
        clusters_with_latest_api,
        asset_filter,
    )
    # Create progress bar if stdout is a TTY
    use_progress = sys.stdout.isatty() and total_assets > 0
    progress_bar = SimpleProgressBar(total_assets) if use_progress else None

    print("Crawling assets")
    try:
        _crawl_assets(
            clusters_to_process,
            components_to_process,
            clusters_cache,
            dev_tmp_dir,
            token,
            progress_bar,
            clusters_with_latest_api,
            asset_filter,
        )
    finally:
        if progress_bar:
            progress_bar.close()


def main() -> None:
    """Crawl and save test assets from Slurm-web gateway, agent and slurmrestd."""
    parser = argparse.ArgumentParser(
        description="Crawl and save test assets from Slurm-web components"
    )
    parser.add_argument(
        "--clusters",
        nargs="+",
        help="Restrict crawling to specific cluster names (default: all clusters)",
    )
    parser.add_argument(
        "--components",
        nargs="+",
        choices=["slurmrestd", "gateway", "agent", "prometheus"],
        help="Restrict crawling to specific components (default: all components)",
    )
    parser.add_argument(
        "--assets",
        nargs="+",
        help="Restrict crawling to specific asset names (default: all assets)",
    )
    args = parser.parse_args()

    # Setup logger
    setup_logger(
        debug=True,
        log_flags=["ALL"],
        debug_flags=DEBUG_FLAGS,
    )

    # Search for slurm-web development environment temporary directory
    dev_tmp_dirs = list(Path("/tmp").glob("slurm-web-*"))
    try:
        assert len(dev_tmp_dirs) == 1
    except AssertionError:
        logger.error(
            "Unexpectedly found %d Slurm-web development temporary directories",
            len(dev_tmp_dirs),
        )
        sys.exit(1)
    dev_tmp_dir = dev_tmp_dirs[0]
    logger.info(
        "Slurm-web development environment temporary directory: %s", dev_tmp_dir
    )

    # Load cluster list from RacksDB database
    db = RacksDB.load(db="dev/firehpc/db", schema="../RacksDB/schemas/racksdb.yml")
    all_clusters = list(db.infrastructures.keys())
    logger.info("Available clusters: %s", all_clusters)

    # Filter clusters based on argument
    if args.clusters:
        selected_clusters = [
            cluster for cluster in args.clusters if cluster in all_clusters
        ]
        invalid_clusters = [
            cluster for cluster in args.clusters if cluster not in all_clusters
        ]
        if invalid_clusters:
            logger.error("Invalid cluster names: %s", invalid_clusters)
            logger.error("Available clusters: %s", all_clusters)
            sys.exit(1)
        if not selected_clusters:
            logger.error("No valid clusters selected")
            sys.exit(1)
        clusters_to_process = selected_clusters
        logger.info("Restricting to clusters: %s", clusters_to_process)
    else:
        clusters_to_process = all_clusters
        logger.info("Processing all clusters: %s", clusters_to_process)

    # Determine which components to process
    all_components = ["slurmrestd", "gateway", "agent", "prometheus"]
    if args.components:
        components_to_process = args.components
        logger.info("Restricting to components: %s", components_to_process)
    else:
        components_to_process = all_components
        logger.info("Processing all components: %s", components_to_process)

    dev_host = DevelopmentHostClient(DEV_HOST, USER)
    try:
        dev_host.connect()
    except DevelopmentHostConnectionError as err:
        logger.error(err)
        sys.exit(1)

    # Get Slurm-web JWT for authentication on gateway and agent
    token = slurmweb_token(
        dev_host,
        slurmweb_cluster_name(GATEWAY_PREFERRED_INFRASTRUCTURE),
        GATEWAY_PREFERRED_INFRASTRUCTURE,
        dev_tmp_dir,
    )

    # Get asset filter from arguments
    asset_filter = args.assets if args.assets else None
    if asset_filter:
        logger.info("Restricting to assets: %s", asset_filter)

    try:
        _run_crawl(
            clusters_to_process,
            components_to_process,
            dev_host,
            dev_tmp_dir,
            token,
            asset_filter,
        )
    except KeyboardInterrupt:
        logger.warning("Interrupted by user (Ctrl-C)")
        sys.exit(130)


if __name__ == "__main__":
    main()
