#!/usr/libexec/platform-python
"""
45d-firmware-check

Discovers installed firmware, fetches the manifest from repo.45drives.com,
compares versions, and writes a cache file at /var/cache/45drives/firmware.json.

Usage:
    45d-firmware-check              # run full check, write cache
    45d-firmware-check --stdout     # print result to stdout (don't write cache)
    45d-firmware-check --discover   # just run discovery, no manifest comparison

Exit codes:
    0 - success
    1 - discovery script failed
    2 - manifest fetch failed (cache written with partial data)
"""

import json
import os
import re
import subprocess
import sys
import tempfile
import urllib.request
import urllib.error
from datetime import datetime, timezone

###############################################################################
# Configuration
###############################################################################

REPO_CONF = "/usr/share/45drives/firmware/repo.conf"
GPG_PUBKEY = "/usr/share/45drives/firmware/45drives-firmware.gpg"
CACHE_DIR = "/var/cache/45drives"
CACHE_FILE = os.path.join(CACHE_DIR, "firmware.json")
LOG_DIR = "/var/log/45drives"
LOG_FILE = os.path.join(LOG_DIR, "firmware-check.log")
DISCOVER_SCRIPT = "/usr/share/cockpit/45drives-system/scripts/firmware-discover"
MANIFEST_FETCH_TIMEOUT = 30  # seconds


def get_config():
    """Read firmware repo configuration from repo.conf.
    
    Returns dict with keys: repo_url, manifest_url, verify_gpg.
    Falls back to production defaults (fw.45drives.com) if config is missing.
    """
    config = {
        "repo_url": "https://fw.45drives.com",
        "manifest_url": "https://fw.45drives.com/manifest.json",
        "verify_gpg": True,
    }
    conf_path = REPO_CONF
    if os.environ.get("FIRMWARE_REPO_CONF"):
        conf_path = os.environ["FIRMWARE_REPO_CONF"]
    try:
        with open(conf_path, "r") as f:
            for line in f:
                line = line.strip()
                if line.startswith("#") or not line:
                    continue
                if "=" not in line:
                    continue
                key, val = line.split("=", 1)
                key = key.strip().lower()
                val = val.strip()
                if key == "repo_url":
                    config["repo_url"] = val
                elif key == "manifest_url":
                    config["manifest_url"] = val
                elif key == "verify_gpg":
                    config["verify_gpg"] = val.lower() in ("true", "1", "yes")
    except FileNotFoundError:
        pass
    return config

###############################################################################
# Discovery
###############################################################################

def run_discover():
    """Run the discovery script and return parsed JSON list."""
    if not os.path.isfile(DISCOVER_SCRIPT):
        # Fallback: try relative path (dev mode)
        script_dir = os.path.dirname(os.path.abspath(__file__))
        alt = os.path.join(script_dir, "firmware-discover")
        if os.path.isfile(alt):
            script = alt
        else:
            print(f"Error: discovery script not found at {DISCOVER_SCRIPT}", file=sys.stderr)
            sys.exit(1)
    else:
        script = DISCOVER_SCRIPT

    try:
        result = subprocess.run(
            ["python3", script],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=120
        )
    except subprocess.TimeoutExpired:
        print("Error: discovery script timed out", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Error running discovery: {e}", file=sys.stderr)
        sys.exit(1)

    if result.returncode != 0:
        print(f"Error: discovery script failed (exit {result.returncode})", file=sys.stderr)
        if result.stderr:
            print(result.stderr, file=sys.stderr)
        sys.exit(1)

    try:
        devices = json.loads(result.stdout)
    except json.JSONDecodeError as e:
        print(f"Error: discovery output is not valid JSON: {e}", file=sys.stderr)
        print(f"Output was: {result.stdout[:500]}", file=sys.stderr)
        sys.exit(1)

    return devices

###############################################################################
# GPG signature verification
###############################################################################

def verify_gpg_signature(data_bytes, sig_bytes, pubkey_path):
    """Verify a GPG detached signature using gpgv only (no gpg dependency). Returns True if valid."""
    data_path = None
    sig_path = None
    try:
        with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as data_f:
            data_f.write(data_bytes)
            data_path = data_f.name
        with tempfile.NamedTemporaryFile(suffix=".sig", delete=False) as sig_f:
            sig_f.write(sig_bytes)
            sig_path = sig_f.name

        result = subprocess.run(
            ["gpgv", "--keyring", pubkey_path, sig_path, data_path],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=10
        )
        return result.returncode == 0
    except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
        print(f"GPG: Verification error: {e}", file=sys.stderr)
        return False
    finally:
        for p in (data_path, sig_path):
            if p is None:
                continue
            try:
                os.unlink(p)
            except OSError:
                pass


###############################################################################
# Manifest fetch
###############################################################################

def fetch_manifest(config):
    """Fetch the firmware manifest from the configured URL with GPG verification. Returns dict or None."""
    manifest_url = config["manifest_url"]
    verify_gpg = config["verify_gpg"]

    try:
        req = urllib.request.Request(manifest_url, headers={"User-Agent": "45d-firmware-check/1.0"})
        with urllib.request.urlopen(req, timeout=MANIFEST_FETCH_TIMEOUT) as resp:
            manifest_bytes = resp.read()

        # Enforce a size limit before any processing (10 MB)
        if len(manifest_bytes) > 10 * 1024 * 1024:
            print("Warning: manifest too large (>10MB) — rejecting", file=sys.stderr)
            return None

        # GPG verification BEFORE JSON decode
        if not verify_gpg:
            print("GPG: Verification disabled via VERIFY_GPG=false", file=sys.stderr)
        else:
            gpg_key_path = GPG_PUBKEY
            if os.environ.get("FIRMWARE_GPG_PUBKEY"):
                gpg_key_path = os.environ["FIRMWARE_GPG_PUBKEY"]

            if not os.path.isfile(gpg_key_path):
                print(f"GPG: Public key not found at {gpg_key_path} — rejecting remote manifest!", file=sys.stderr)
                return None

            # Fetch signature
            sig_url = manifest_url + ".sig"
            try:
                sig_req = urllib.request.Request(sig_url, headers={"User-Agent": "45d-firmware-check/1.0"})
                with urllib.request.urlopen(sig_req, timeout=10) as sig_resp:
                    sig_bytes = sig_resp.read()
            except (urllib.error.URLError, urllib.error.HTTPError) as e:
                print(f"GPG: Signature not available at {sig_url} ({e}) — rejecting remote manifest!", file=sys.stderr)
                return None

            # Verify — MUST pass before we trust the content
            if verify_gpg_signature(manifest_bytes, sig_bytes, gpg_key_path):
                print("GPG: Manifest signature verified ✓", file=sys.stderr)
            else:
                print("GPG: Manifest signature INVALID — rejecting remote manifest!", file=sys.stderr)
                return None

        # Only decode JSON after signature verification passes (or is disabled)
        data = json.loads(manifest_bytes.decode("utf-8"))
        if "components" not in data:
            print("Warning: manifest missing 'components' key", file=sys.stderr)
            return None

        return data

    except urllib.error.URLError as e:
        print(f"Warning: could not fetch manifest: {e}", file=sys.stderr)
        return None
    except json.JSONDecodeError as e:
        print(f"Warning: manifest is not valid JSON: {e}", file=sys.stderr)
        return None
    except Exception as e:
        print(f"Warning: manifest fetch error: {e}", file=sys.stderr)
        return None

###############################################################################
# Version comparison
###############################################################################

def normalize_version(ver_str):
    """Strip non-numeric prefixes and split into comparable tuples."""
    if not ver_str or ver_str.lower() in ("unknown", "n/a", "-"):
        return None
    # Remove common prefixes like 'v', 'V', 'FW'
    ver_str = re.sub(r'^[vV]|^FW[: ]*', '', ver_str).strip()
    # Split on dots, dashes, underscores
    parts = re.split(r'[.\-_]', ver_str)
    result = []
    for p in parts:
        # Extract numeric portion
        num = re.sub(r'[^0-9]', '', p)
        result.append(int(num) if num else 0)
    return tuple(result) if result else None


def version_compare(current, latest):
    """
    Compare two version strings.
    Returns:
        'current'    - current >= latest
        'outdated'   - current < latest
        'unknown'    - can't determine
    """
    cur = normalize_version(current)
    lat = normalize_version(latest)

    if cur is None or lat is None:
        return "unknown"

    # Pad to same length
    max_len = max(len(cur), len(lat))
    cur = cur + (0,) * (max_len - len(cur))
    lat = lat + (0,) * (max_len - len(lat))

    if cur >= lat:
        return "current"
    else:
        return "outdated"

###############################################################################
# Matching
###############################################################################

def find_manifest_entry(device, manifest):
    """Find the matching manifest entry for a discovered device."""
    if not manifest or "components" not in manifest:
        return None

    device_type = device.get("type", "")
    model = device.get("model", "")

    # Look in the matching component category
    entries = manifest["components"].get(device_type, [])

    for entry in entries:
        pattern = entry.get("model_match", "")
        if pattern and re.search(pattern, model, re.IGNORECASE):
            return entry

    return None

###############################################################################
# Main logic
###############################################################################

def build_cache(devices, manifest, manifest_url=""):
    """Build the cache structure with comparison results."""
    results = []

    for device in devices:
        entry = find_manifest_entry(device, manifest)

        result = {
            "type": device.get("type", ""),
            "model": device.get("model", ""),
            "vendor": device.get("vendor", ""),
            "current_firmware": device.get("firmware", ""),
            "device_path": device.get("device_path", ""),
            "latest_firmware": None,
            "update_available": "unknown",
            "flash_tool": None,
            "flashable": False,
            "requires_reboot": False,
            "release_notes": "",
            "release_date": "",
        }

        if entry:
            result["latest_firmware"] = entry.get("latest_firmware")
            result["flash_tool"] = entry.get("flash_tool")
            result["flash_command"] = entry.get("flash_command", "")
            result["firmware_file"] = entry.get("firmware_file", "")
            result["sha256"] = entry.get("sha256", "")
            result["flashable"] = entry.get("flashable", False)
            result["requires_reboot"] = entry.get("requires_reboot", False)
            result["release_notes"] = entry.get("release_notes", "")
            result["release_date"] = entry.get("release_date", "")
            result["update_available"] = version_compare(
                device.get("firmware", ""),
                entry.get("latest_firmware", "")
            )

            # Look up flash tool binary SHA256 from manifest.flash_tools[]
            flash_tool = entry.get("flash_tool", "")
            tool_binary_sha256 = ""
            if flash_tool and manifest:
                for ft in manifest.get("flash_tools", []):
                    if ft.get("name", "").lower() == flash_tool.lower():
                        tool_binary_sha256 = ft.get("sha256", "")
                        break
            result["tool_binary_sha256"] = tool_binary_sha256

        results.append(result)

    cache = {
        "schema_version": "1.0",
        "last_checked": datetime.now(timezone.utc).isoformat(),
        "manifest_url": manifest_url,
        "manifest_available": manifest is not None,
        "device_count": len(results),
        "updates_available": sum(1 for r in results if r["update_available"] == "outdated"),
        "devices": results,
    }

    return cache


def write_cache(cache):
    """Write cache to disk."""
    os.makedirs(CACHE_DIR, exist_ok=True)
    tmp_file = CACHE_FILE + ".tmp"
    try:
        with open(tmp_file, "w") as f:
            json.dump(cache, f, indent=2)
        os.replace(tmp_file, CACHE_FILE)
    except OSError as e:
        # Try to clean up
        try:
            os.unlink(tmp_file)
        except OSError:
            pass
        raise RuntimeError(f"Error writing cache: {e}") from e


def log_check(message):
    """Append a log entry for firmware-check runs."""
    try:
        os.makedirs(LOG_DIR, mode=0o755, exist_ok=True)
        timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
        with open(LOG_FILE, "a") as f:
            f.write(f"{timestamp}  {message}\n")
    except Exception:
        pass


def main():
    stdout_only = "--stdout" in sys.argv
    discover_only = "--discover" in sys.argv

    # Load configuration from repo.conf
    config = get_config()
    manifest_url = config["manifest_url"]

    # Step 1: Discover devices
    devices = run_discover()

    if discover_only:
        print(json.dumps(devices, indent=2))
        return

    # Step 2: Fetch manifest (with GPG verification)
    manifest = fetch_manifest(config)

    # Step 3: Compare and build cache
    cache = build_cache(devices, manifest, manifest_url)

    # Step 4: Output
    if stdout_only:
        print(json.dumps(cache, indent=2))
    else:
        try:
            write_cache(cache)
        except RuntimeError as e:
            print(e, file=sys.stderr)
            log_check(f"FAILED  error={e}")
            sys.exit(1)
        log_check(f"OK  devices={cache['device_count']}  updates={cache['updates_available']}  manifest={'yes' if manifest else 'no'}  url={manifest_url}")
        print(f"Cache written: {CACHE_FILE}")
        print(f"Devices found: {cache['device_count']}")
        print(f"Updates available: {cache['updates_available']}")


if __name__ == "__main__":
    main()
