#!/usr/bin/python3
"""Flask WSGI dashboard for slurm-quota stats."""

from __future__ import annotations

import os
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode, urljoin
from urllib.request import Request, urlopen

from flask import Flask, render_template, request


def _assets_root() -> Path:
    env_root = os.environ.get("SLURM_QUOTA_WEB_ASSETS_DIR")
    if env_root:
        return Path(env_root)
    local_root = Path(__file__).resolve().parent / "webapp"
    if local_root.exists():
        return local_root
    return Path("/usr/share/slurm-quota-web")


_WEB_ROOT = _assets_root()
app = Flask(
    __name__,
    template_folder=str(_WEB_ROOT / "templates"),
    static_folder=str(_WEB_ROOT / "static"),
)
# WSGI servers (uWSGI, Gunicorn, etc.) default to a callable named "application".
application = app


def _stats_url(username: Optional[str], account: Optional[str]) -> str:
    base_url = os.environ.get("SLURM_QUOTA_URL", "http://127.0.0.1:9911/")
    params: Dict[str, str] = {}
    if username:
        params["username"] = username
    if account:
        params["account"] = account
    stats_url = urljoin(base_url, "stats")
    if params:
        return f"{stats_url}?{urlencode(params)}"
    return stats_url


def _parse_int(value: Any, default: int = 0) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return default


def _format_minutes(value_minutes: int, display_hours: bool) -> str:
    if display_hours:
        return f"{value_minutes / 60:.2f}"
    return str(value_minutes)


def _quota_label(quota: int, display_hours: bool) -> str:
    if quota < 0:
        return "∞"
    return _format_minutes(quota, display_hours)


def _usage_percent(consumed: int, preallocated: int, quota: int) -> Optional[float]:
    if quota <= 0:
        return None
    return min(((consumed + preallocated) / quota) * 100.0, 100.0)


def _status_class(percent: Optional[float]) -> str:
    if percent is None:
        return "bar-unlimited"
    if percent >= 95.0:
        return "bar-danger"
    if percent >= 80.0:
        return "bar-warning"
    return "bar-ok"


def _decorate_rows(
    rows: List[Dict[str, Any]], name_key: str, display_hours: bool
) -> List[Dict[str, Any]]:
    decorated: List[Dict[str, Any]] = []
    for item in rows:
        cpu_consumed = _parse_int(item.get("total_consumed_cpu_minutes"))
        cpu_preallocated = _parse_int(item.get("total_preallocated_cpu_minutes"))
        cpu_quota = _parse_int(item.get("quota_cpu_minutes"), -1)
        cpu_percent = _usage_percent(cpu_consumed, cpu_preallocated, cpu_quota)

        gpu_consumed = _parse_int(item.get("total_consumed_gpu_minutes"))
        gpu_preallocated = _parse_int(item.get("total_preallocated_gpu_minutes"))
        gpu_quota = _parse_int(item.get("quota_gpu_minutes"), -1)
        gpu_percent = _usage_percent(gpu_consumed, gpu_preallocated, gpu_quota)

        decorated.append(
            {
                "name": str(item.get(name_key, "?")),
                "job_count": _parse_int(item.get("job_count")),
                "last_updated": item.get("last_updated") or "n/a",
                "cpu": {
                    "consumed": _format_minutes(cpu_consumed, display_hours),
                    "preallocated": _format_minutes(cpu_preallocated, display_hours),
                    "quota": _quota_label(cpu_quota, display_hours),
                    "percent": cpu_percent,
                    "status_class": _status_class(cpu_percent),
                },
                "gpu": {
                    "consumed": _format_minutes(gpu_consumed, display_hours),
                    "preallocated": _format_minutes(gpu_preallocated, display_hours),
                    "quota": _quota_label(gpu_quota, display_hours),
                    "percent": gpu_percent,
                    "status_class": _status_class(gpu_percent),
                },
            }
        )
    return decorated


def fetch_stats(
    username: Optional[str], account: Optional[str]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    url = _stats_url(username, account)
    with urlopen(
        Request(url, headers={"Accept": "application/json"}), timeout=5
    ) as resp:
        if getattr(resp, "status", 200) != 200:
            raise RuntimeError(f"HTTP {resp.status}")
        payload = json.load(resp)
    return list(payload.get("users", [])), list(payload.get("accounts", []))


@app.get("/")
def dashboard() -> str:
    username = (request.args.get("username") or "").strip() or None
    account = (request.args.get("account") or "").strip() or None
    unit = (request.args.get("unit") or "").strip().lower()
    if unit in ("hours", "h"):
        display_hours = True
    elif unit in ("minutes", "m", "min"):
        display_hours = False
    elif (request.args.get("hours") or "").lower() in {"1", "true", "yes", "on"}:
        # Legacy query param from older dashboard links.
        display_hours = True
    else:
        display_hours = False

    error: Optional[str] = None
    users: List[Dict[str, Any]] = []
    accounts: List[Dict[str, Any]] = []

    if username and account:
        error = "username and account filters are mutually exclusive."
    else:
        try:
            users_raw, accounts_raw = fetch_stats(username, account)
            users = _decorate_rows(users_raw, "username", display_hours)
            accounts = _decorate_rows(accounts_raw, "account", display_hours)
        except (URLError, HTTPError, RuntimeError) as exc:
            error = f"Failed to retrieve stats from service: {exc}"

    return render_template(
        "dashboard.html",
        error=error,
        users=users,
        accounts=accounts,
        selected_username=username or "",
        selected_account=account or "",
        display_hours=display_hours,
        unit_label="hours" if display_hours else "minutes",
    )


if __name__ == "__main__":
    host = os.environ.get("SLURM_QUOTA_WEB_HOST", "127.0.0.1")
    port = _parse_int(os.environ.get("SLURM_QUOTA_WEB_PORT"), 5000)
    debug = os.environ.get("SLURM_QUOTA_WEB_DEBUG") == "1"
    app.run(host=host, port=port, debug=debug)
