#!/usr/bin/env python3
#
# Copyright (c) 2023-2024 Rackslab
#
# This file is part of Slurm-web.
#
# SPDX-License-Identifier: MIT

from dataclasses import dataclass
import sshtunnel
import paramiko
import shlex
import logging
import subprocess
from pathlib import Path
import signal
import tempfile
import atexit
import shutil
import time
import argparse
import json
import socket
import sys
import getpass

import jinja2
from rfl.authentication.jwt import jwt_gen_key
from racksdb import RacksDB

logger = logging.getLogger(__name__)

DEV_HOST = "firehpc.dev.rackslab.io"
# Dict of clusters settings to control cluster names in UI, slurmrestd access
# (TCP/IP or Unix socket) and authentication.
CLUSTERS = {
    "nova": {"auth": "jwt", "scheme": "http", "policy_extra_user_group": True},
    "zenith": {"auth": "jwt", "scheme": "http"},
    "quark": {"auth": "local", "racksdb": False, "metrics": False},
    "vortex": {"auth": "jwt", "scheme": "http", "metrics": False},
    "titan": {"auth": "jwt", "mode": "static"},
}
USER = getpass.getuser()

DEBUG_FLAGS = ["slurmweb", "rfl", "werkzeug", "urllib3"]
# DEBUG_FLAGS = ["ALL"]


def runcmd(cmd: list[str]) -> subprocess.Popen:
    try:
        return subprocess.Popen(cmd)
    except FileNotFoundError:
        logger.error("Command not found: %s", shlex.join(cmd))
        sys.exit(1)


@dataclass
class PortForward:
    remote: int
    local: int


class PortAllocator:
    def __init__(self, initial=0):
        self.current = initial - 1
        self.forwards = {
            "ldap": PortForward(389, 3389),  # LDAP
            "slurmrestd": PortForward(6820, 6820),  # slurmrestd
            "redis": PortForward(6379, 6379),  # redis
            "prometheus": PortForward(9090, 9090),  # prometheus
        }

    def allocate(self):
        self.current += 1
        return {
            key: PortForward(value.remote, value.local + self.current)
            for key, value in self.forwards.items()
        }


@dataclass
class ClusterChannel:
    name: str
    connection: paramiko.client.SSHClient
    forwarder: sshtunnel.SSHTunnelForwarder
    process: subprocess.Popen

    @classmethod
    def connect(
        cls,
        host: str,
        ssh_client: paramiko.client.SSHClient,
        forwards: dict[str, PortForward],
        cluster: str,
        cluster_id: int,
    ):
        remote_bind_addresses = []
        local_bind_addresses = []

        for forward in forwards.values():
            remote_bind_addresses.append(
                (
                    f"admin.{cluster}.{ssh_client.get_transport().get_username()}",
                    forward.remote,
                )
            )
            local_bind_addresses.append(("localhost", forward.local))
        forwarder = sshtunnel.SSHTunnelForwarder(
            host,
            ssh_username=USER,
            allow_agent=True,
            remote_bind_addresses=remote_bind_addresses,
            local_bind_addresses=local_bind_addresses,
        )
        forwarder.start()

        process = None
        if CLUSTERS[cluster].get("scheme", "unix") == "unix":
            # run socat to forward slurmrestd UNIX socket
            cmd = shlex.join(
                [
                    "firehpc",
                    "ssh",
                    f"admin.{cluster}",
                    "socat",
                    f"TCP-LISTEN:{forwards['slurmrestd'].remote},fork",
                    "UNIX-CONNECT:/run/slurmrestd/slurmrestd.socket",
                ]
            )
            logger.info("Running command on development server: %s", cmd)
            ssh_client.exec_command(cmd)

            cmd = [
                "socat",
                f"UNIX-LISTEN:/tmp/slurmrestd-{cluster}.socket,fork",
                f"TCP-CONNECT:localhost:{forwards['slurmrestd'].local}",
            ]
            logger.info("Running command locally: %s", shlex.join(cmd))
            process = runcmd(cmd)

        return cls(cluster, ssh_client, forwarder, process)

    def stop(self) -> None:
        logger.info("Stopping %s remote socat command", self.name)
        cmd = shlex.join(["firehpc", "ssh", f"admin.{self.name}", "killall", "socat"])
        self.connection.exec_command(cmd)
        logger.info("Stopping %s forwarder", self.name)
        self.forwarder.stop()
        if self.process:
            logger.info("Stopping %s socat process", self.name)
            self.process.kill()


@dataclass
class SlurmwebAgent:
    name: str
    ui_name: str
    channel: ClusterChannel
    forwards: dict[str, PortForward]
    service_port: int
    conf_path: Path = None
    policy_path: Path = None
    process: subprocess.Popen = None

    @classmethod
    def init(
        cls,
        host: str,
        port_allocator: PortAllocator,
        ssh_client: paramiko.client.SSHClient,
        name: str,
        ui_name: str,
        cluster_id: int,
    ):
        forwards = port_allocator.allocate()
        channel = ClusterChannel.connect(host, ssh_client, forwards, name, cluster_id)
        return cls(name, ui_name, channel, forwards, 5013 + cluster_id)

    def render_policy(self, tmpdir: Path, anonymous: bool) -> None:
        environment = jinja2.Environment(loader=jinja2.FileSystemLoader("dev/conf/"))
        template = environment.get_template("policy.ini.j2")
        self.policy_path = tmpdir / f"policy-{self.name}.ini"
        _, stdout, _ = self.channel.connection.exec_command(
            shlex.join(["firehpc", "status", "--cluster", self.name, "--json"])
        )
        try:
            self.cluster_status = json.loads(stdout.read())
        except json.decoder.JSONDecodeError:
            logger.critical(f"Unable to load {self.name} cluster status")
            raise
        with open(self.policy_path, "w+") as fh:
            fh.write(
                template.render(
                    cluster=self.name,
                    users=self.cluster_status["users"],
                    groups=self.cluster_status["groups"],
                    anonymous=anonymous,
                    policy_extra_user_group=CLUSTERS[self.name].get(
                        "policy_extra_user_group", False
                    ),
                )
            )

    def render_conf(
        self,
        tmpdir: Path,
        jwt_key_path: Path,
        cache_enabled: bool,
        racksdb_enabled: bool,
        metrics_enabled: bool,
    ) -> None:
        environment = jinja2.Environment(loader=jinja2.FileSystemLoader("dev/conf/"))
        template = environment.get_template("agent.ini.j2")
        sftp = self.channel.connection.open_sftp()
        logger.info("Reading remote redis password for cluster %s", self.name)
        redis_password = None
        if cache_enabled:
            try:
                with sftp.open(
                    f".local/state/firehpc/clusters/{self.name}/redis/redis.password"
                ) as fh:
                    redis_password = fh.read().decode()
            except FileNotFoundError:
                logger.warning(
                    "Remote Redis password file not found for cluster %s, force "
                    "disabling cache",
                    self.name,
                )
                cache_enabled = False

        slurmrestd_token = None
        slurmrestd_jwt_key = None
        if CLUSTERS[self.name]["auth"] == "jwt":
            if CLUSTERS[self.name].get("mode") == "static":
                logger.info("Generating static JWT for cluster %s", self.name)
                # get infinite token for slurm user
                cmd = shlex.join(
                    [
                        "firehpc",
                        "ssh",
                        f"admin.{self.name}",
                        "scontrol",
                        "token",
                        "lifespan=infinite",
                        "username=slurm",
                    ]
                )
                _, stdout, _ = self.channel.connection.exec_command(cmd)
                try:
                    slurmrestd_token = stdout.read().decode().split("=")[1]
                except IndexError:
                    logger.critical(f"Unable to get static JWT on cluster {self.name}")
                    raise
            else:
                logger.info(
                    "Downloading remote Slurm JWT signing key for cluster %s", self.name
                )
                try:
                    with sftp.open(
                        f".local/state/firehpc/clusters/{self.name}/slurm/jwt_hs256.key"
                    ) as fh:
                        slurm_jwt_key = fh.read()
                    slurmrestd_jwt_key = tmpdir / f"{self.name}_jwt_hs256.key"
                    with open(slurmrestd_jwt_key, "wb+") as fh:
                        fh.write(slurm_jwt_key)
                except FileNotFoundError:
                    logger.warning(
                        "Remote Slurm JWT signature key not found for cluster %s",
                        self.name,
                    )
        if CLUSTERS[self.name].get("scheme", "unix") == "unix":
            slurmrestd_uri = f"unix:///tmp/slurmrestd-{self.name}.socket"
        else:
            slurmrestd_uri = f"http://localhost:{self.forwards['slurmrestd'].local}"

        self.conf_path = tmpdir / f"agent-{self.name}.ini"
        with open(self.conf_path, "w+") as fh:
            fh.write(
                template.render(
                    cluster_name=self.ui_name,
                    service_port=self.service_port,
                    slurmrestd_auth=CLUSTERS[self.name]["auth"],
                    slurmrestd_uri=slurmrestd_uri,
                    slurmrestd_jwt_mode=CLUSTERS[self.name].get("mode"),
                    slurmrestd_jwt_key=slurmrestd_jwt_key,
                    slurmrestd_token=slurmrestd_token,
                    jwt_key=jwt_key_path,
                    cache_enabled=cache_enabled,
                    racksdb_enabled=racksdb_enabled,
                    policy_path=self.policy_path,
                    redis_port=self.forwards["redis"].local,
                    redis_password=redis_password,
                    prometheus_port=self.forwards["prometheus"].local,
                    infrastructure=self.name if self.ui_name != self.name else None,
                    metrics_enabled=metrics_enabled,
                )
            )

    def launch(self) -> None:
        cmd = (
            [
                "slurm-web",
                "agent",
                "--log-component",
                f"agent-{self.name}",
                "--debug",
                "--debug-flags",
            ]
            + DEBUG_FLAGS
            + [
                "--conf-defs",
                "conf/vendor/agent.yml",
                "--conf",
                str(self.conf_path),
            ]
        )
        logging.info("Launching Slurm-web agent %s: %s", self.name, shlex.join(cmd))
        self.process = runcmd(cmd)

    def stop(self) -> None:
        logging.info("Stopping slurm-web agent %s", self.name)
        self.process.kill()
        self.channel.stop()


@dataclass
class SlurmwebGateway:
    conf_path: Path = None
    gateway_prefix: str = "/"

    def render_conf(
        self,
        tmpdir: Path,
        jwt_key_path: Path,
        agents: list[SlurmwebAgent],
        ui: Path | None,
        anonymous: bool,
        gateway_prefix: str = "/",
    ):
        self.gateway_prefix = gateway_prefix
        environment = jinja2.Environment(loader=jinja2.FileSystemLoader("dev/conf/"))
        template = environment.get_template("gateway.ini.j2")
        self.conf_path = tmpdir / "gateway.ini"
        self.message_path = tmpdir / "message.md"
        with open(self.conf_path, "w+") as fh:
            fh.write(
                template.render(
                    service_port=5012,
                    jwt_key=jwt_key_path,
                    ldap_port=agents[0].forwards["ldap"].local,
                    ldap_base=agents[0].name,
                    agents=agents,
                    ui=ui,
                    message=self.message_path,
                    anonymous=anonymous,
                    gateway_prefix=gateway_prefix,
                )
            )

    def render_message(
        self,
        tmpdir: Path,
        users,
        groups,
    ):
        environment = jinja2.Environment(loader=jinja2.FileSystemLoader("dev/conf/"))
        template = environment.get_template("message.md.j2")
        with open(self.message_path, "w+") as fh:
            fh.write(template.render(users=users, groups=groups))

    def launch(self):
        cmd = (
            [
                "slurm-web",
                "gateway",
                "--log-component",
                "gateway",
                "--debug",
                "--debug-flags",
            ]
            + DEBUG_FLAGS
            + [
                "--conf-defs",
                "conf/vendor/gateway.yml",
                "--conf",
                str(self.conf_path),
            ]
        )
        logging.info("Launching Slurm-web gateway: %s", shlex.join(cmd))
        self.process = runcmd(cmd)

    def stop(self):
        logging.info("Stopping slurm-web gateway")
        self.process.kill()


def delete_tmpdir(path: Path):
    logging.info("Removing temporary directory %s", path)
    shutil.rmtree(path)


def main():
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(
        description="Launch Slurm-web development environment"
    )
    parser.add_argument(
        "--no-cache",
        dest="cache",
        action="store_false",
        help="Disable redis cache",
    )
    parser.add_argument(
        "--with-ui",
        dest="ui",
        type=Path,
        help="Path to frontend application",
    )
    parser.add_argument(
        "--no-service-message",
        dest="service_message",
        action="store_false",
        help="Disable service message",
    )
    parser.add_argument(
        "--anonymous",
        action="store_true",
        help="Enable anonymous mode and disable authentication",
    )
    parser.add_argument(
        "--clusters",
        nargs="+",
        choices=list(CLUSTERS.keys()),
        default=list(CLUSTERS.keys()),
        help="Select specific clusters to launch (default: all clusters)",
    )
    parser.add_argument(
        "--no-metrics",
        dest="metrics",
        action="store_false",
        help="Disable metrics collection globally",
    )
    parser.add_argument(
        "--gateway-prefix",
        default="/",
        help=(
            "URL path prefix under which to serve the gateway application "
            "(default: %(default)s)"
        ),
    )
    args = parser.parse_args()

    # Load cluster list from RacksDB database
    db = RacksDB.load(db="dev/firehpc/db", schema="../RacksDB/schemas/racksdb.yml")

    # Connect to development host
    ssh_client = paramiko.SSHClient()
    ssh_client.load_host_keys(Path("~/.ssh/known_hosts").expanduser())
    logger.info("Connecting to development host %s", DEV_HOST)
    try:
        ssh_client.connect(DEV_HOST, username=USER)
    except socket.gaierror as err:
        logger.error("Unable to get address of %s: %s", DEV_HOST, str(err))
        sys.exit(1)
    except paramiko.ssh_exception.PasswordRequiredException as err:
        logger.error("Unable to connect on %s@%s: %s", USER, DEV_HOST, str(err))
        sys.exit(1)
    # Create temporary directory and register its automatic deletion at exit
    tmpdir = Path(tempfile.mkdtemp(prefix="slurm-web-dev"))
    atexit.register(delete_tmpdir, tmpdir)

    # Generate random JWT key
    jwt_key_path = tmpdir / "jwt.key"
    jwt_gen_key(jwt_key_path)

    # Launch all clusters agents
    cluster_id = 0
    agents = []
    port_allocator = PortAllocator()
    for infrastructure in db.infrastructures:
        # Skip infrastructure if not in selected clusters
        if infrastructure.name not in args.clusters:
            continue

        if "name" in CLUSTERS[infrastructure.name]:
            ui_name = CLUSTERS[infrastructure.name]["name"]
        else:
            ui_name = infrastructure.name
        agent = SlurmwebAgent.init(
            DEV_HOST,
            port_allocator,
            ssh_client,
            infrastructure.name,
            ui_name,
            cluster_id,
        )
        # Generate agent configuration file
        agent.render_policy(tmpdir, args.anonymous)
        cluster_cfg = CLUSTERS[infrastructure.name]
        cluster_racksdb_enabled = cluster_cfg.get("racksdb", True)
        cluster_metrics_enabled = args.metrics and cluster_cfg.get("metrics", True)
        agent.render_conf(
            tmpdir,
            jwt_key_path,
            args.cache,
            cluster_racksdb_enabled,
            cluster_metrics_enabled,
        )
        # Launch agent in backgroup
        agent.launch()
        agents.append(agent)
        cluster_id += 1

    # Wait a second for all agents to start properly
    time.sleep(1)

    # Launch gateway
    gateway = SlurmwebGateway()
    gateway.render_conf(
        tmpdir, jwt_key_path, agents, args.ui, args.anonymous, args.gateway_prefix
    )
    if args.service_message:
        gateway.render_message(
            tmpdir,
            agents[0].cluster_status["users"],
            agents[0].cluster_status["groups"],
        )
    gateway.launch()

    logger.info("Development environment is ready, type ^c to stop")
    try:
        signal.pause()
    except KeyboardInterrupt:
        # Stop gateway
        gateway.stop()
        # Stop all cluster agents
        for agent in agents:
            agent.stop()
        logger.info("Closing SSH connection")
        ssh_client.close()
        logger.info("Stopping development environment tunnals")


if __name__ == "__main__":
    main()
