#!/usr/libexec/platform-python
"""
firmware-check - Compares discovered firmware against manifest and outputs
status for each device. Writes results to cache file.

Used by: systemd timer (45d-firmware-check.timer)
         SectionFirmware.vue (on-demand refresh)
         App.vue (badge notification)

Exit codes:
    0 = all firmware current
    1 = error
    2 = one or more devices have outdated firmware
"""

import json
import subprocess
import sys
import re
import os
import tempfile
from pathlib import Path
from datetime import datetime
from urllib.request import urlopen, Request
from urllib.error import URLError, HTTPError


MANIFEST_PATH = "/usr/share/45drives/firmware/manifest.json"
CACHE_PATH = "/var/cache/45drives/firmware/status.json"
DISCOVER_SCRIPT = "/usr/share/45drives/firmware/firmware-discover"
FIRMWARE_DIR = "/usr/share/45drives/firmware/files"
REPO_CONF = "/usr/share/45drives/firmware/repo.conf"
GPG_PUBKEY = "/usr/share/45drives/firmware/45drives-firmware.gpg"

# For development, allow override via environment
if os.environ.get("FIRMWARE_MANIFEST_PATH"):
    MANIFEST_PATH = os.environ["FIRMWARE_MANIFEST_PATH"]
if os.environ.get("FIRMWARE_CACHE_PATH"):
    CACHE_PATH = os.environ["FIRMWARE_CACHE_PATH"]
if os.environ.get("FIRMWARE_DISCOVER_PATH"):
    DISCOVER_SCRIPT = os.environ["FIRMWARE_DISCOVER_PATH"]
if os.environ.get("FIRMWARE_DIR"):
    FIRMWARE_DIR = os.environ["FIRMWARE_DIR"]


def get_repo_config():
    """Read repo.conf and return dict of key=value pairs."""
    conf_path = REPO_CONF
    if os.environ.get("FIRMWARE_REPO_CONF"):
        conf_path = os.environ["FIRMWARE_REPO_CONF"]
    config = {}
    try:
        with open(conf_path, "r") as f:
            for line in f:
                line = line.strip()
                if line.startswith("#") or not line or "=" not in line:
                    continue
                key, val = line.split("=", 1)
                config[key.strip()] = val.strip()
    except FileNotFoundError:
        pass
    return config


def has_repo_configured():
    """Check if a firmware repo URL is configured."""
    config = get_repo_config()
    return bool(config.get("REPO_URL"))


def fetch_remote_manifest():
    """Try to fetch manifest from MANIFEST_URL. Verifies GPG signature before parsing. Returns dict or None."""
    config = get_repo_config()
    manifest_url = config.get("MANIFEST_URL", "")
    if not manifest_url:
        return None

    try:
        req = Request(manifest_url, headers={"User-Agent": "45drives-firmware-check/1.0"})
        response = urlopen(req, timeout=15)
        manifest_bytes = response.read()

        # Enforce a size limit before any processing (10 MB)
        if len(manifest_bytes) > 10 * 1024 * 1024:
            print("Warning: manifest too large (>10MB) — rejecting", file=sys.stderr)
            return None

        # Check if GPG verification is required
        verify_gpg = config.get("VERIFY_GPG", "true").strip().lower()
        if verify_gpg not in ("false", "0", "no"):
            # Verify signature BEFORE JSON decode
            sig_url = manifest_url + ".sig"
            gpg_key_path = GPG_PUBKEY
            if os.environ.get("FIRMWARE_GPG_PUBKEY"):
                gpg_key_path = os.environ["FIRMWARE_GPG_PUBKEY"]

            if not os.path.isfile(gpg_key_path):
                print(f"GPG: Public key not found at {gpg_key_path} — rejecting remote manifest!", file=sys.stderr)
                print(f"GPG: The key file should ship with the cockpit-hardware package.", file=sys.stderr)
                return None

            # Fetch signature — MUST exist
            try:
                sig_req = Request(sig_url, headers={"User-Agent": "45drives-firmware-check/1.0"})
                sig_response = urlopen(sig_req, timeout=10)
                sig_bytes = sig_response.read()
            except (HTTPError, URLError) as e:
                print(f"GPG: Signature not available at {sig_url} ({e}) — rejecting remote manifest!", file=sys.stderr)
                return None

            # Verify signature — MUST pass before we trust the content
            if verify_gpg_signature(manifest_bytes, sig_bytes, gpg_key_path):
                print("GPG: Manifest signature verified ✓", file=sys.stderr)
            else:
                print("GPG: Manifest signature INVALID — rejecting remote manifest!", file=sys.stderr)
                return None
        else:
            print("GPG: Verification disabled via VERIFY_GPG=false (internal mode)", file=sys.stderr)

        # Only decode JSON after signature verification passes (or is disabled)
        data = json.loads(manifest_bytes.decode("utf-8"))
        if "components" not in data:
            return None

        return data
    except (HTTPError, URLError, json.JSONDecodeError, Exception) as e:
        print(f"Note: Could not fetch remote manifest ({e}), using local.", file=sys.stderr)
    return None


def verify_gpg_signature(data_bytes, sig_bytes, pubkey_path):
    """Verify a GPG detached signature using gpgv only (no gpg dependency). Returns True if valid."""
    try:
        # Write data and signature to temp files
        with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as data_f:
            data_f.write(data_bytes)
            data_path = data_f.name
        with tempfile.NamedTemporaryFile(suffix=".sig", delete=False) as sig_f:
            sig_f.write(sig_bytes)
            sig_path = sig_f.name

        # gpgv can use the binary keyring directly — no gpg --import needed.
        # The public key file must be in binary OpenPGP format (not ASCII-armored).
        result = subprocess.run(
            ["gpgv", "--keyring", pubkey_path, sig_path, data_path],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=10
        )

        return result.returncode == 0
    except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
        print(f"GPG: Verification error: {e}", file=sys.stderr)
        return False
    finally:
        # Clean up temp files
        for f in (locals().get("data_path"), locals().get("sig_path")):
            if not f:
                continue
            try:
                os.unlink(f)
            except OSError:
                pass


def load_manifest():
    """Load firmware manifest. Tries remote first (if configured), falls back to local."""
    # Try remote manifest
    remote = fetch_remote_manifest()
    if remote:
        return remote

    # Fall back to local
    try:
        with open(MANIFEST_PATH, "r") as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print(f"Error loading manifest: {e}", file=sys.stderr)
        return None


def run_discover():
    """Run firmware-discover and parse output."""
    try:
        result = subprocess.run(
            [sys.executable, DISCOVER_SCRIPT],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=120
        )
        if result.returncode != 0:
            print(f"firmware-discover failed: {result.stderr}", file=sys.stderr)
            return None
        return json.loads(result.stdout)
    except Exception as e:
        print(f"Error running firmware-discover: {e}", file=sys.stderr)
        return None


def match_device_to_manifest(device, manifest):
    """
    Find the matching manifest entry for a discovered device.
    Checks model_match against device model, part_number, and model_full.
    Returns the manifest entry or None.
    """
    device_type = device.get("type", "")
    components = manifest.get("components", {})

    # Map device type to manifest section
    section = components.get(device_type, [])
    if not section:
        return None

    device_model = device.get("model", "")
    part_number = device.get("part_number", "")
    model_full = device.get("model_full", "")

    # Build list of candidate strings to match against (in priority order)
    match_candidates = [s for s in [part_number, device_model, model_full] if s]

    for entry in section:
        model_pattern = entry.get("model_match", "")
        if not model_pattern:
            continue

        # Check if any candidate string matches the manifest pattern (regex)
        try:
            for candidate in match_candidates:
                if re.search(model_pattern, candidate, re.IGNORECASE):
                    return entry
        except re.error:
            # Invalid regex, try substring match
            for candidate in match_candidates:
                if model_pattern.lower() in candidate.lower():
                    return entry

    return None


def parse_version_tuple(version_str):
    """
    Parse a version string into a comparable tuple of integers.
    Handles formats like "16.35.8002", "9.56", "SN04", "24.22.0-0105".
    Extracts ALL numeric sequences from the string.
    Returns None if no numeric parts found.
    """
    import re
    # Extract all numeric sequences from the string
    parts = re.findall(r'\d+', version_str)
    if parts:
        return tuple(int(p) for p in parts)
    return None


def get_version_prefix(version_str):
    """
    Extract the leading non-numeric prefix from a version string.
    E.g. "SN04" → "SN", "TN04" → "TN", "16.35.8002" → ""
    Used to detect differing firmware tracks (e.g. Seagate SN vs TN).
    """
    import re
    m = re.match(r'^([A-Za-z]+)', version_str.strip())
    return m.group(1).upper() if m else ""


def compare_firmware(device_fw, manifest_fw):
    """
    Compare firmware versions.
    Returns: 'current', 'outdated', or 'unknown'
    
    Logic:
    - If device firmware >= manifest firmware → 'current' (meets or exceeds target)
    - If device firmware < manifest firmware → 'outdated'
    - If versions can't be parsed/compared numerically, treat a mismatch as 'outdated' (conservative)
    
    Handles multi-part version strings (e.g. Intel NICs report
    "9.56 0x80010136 1.3909.0" where the first token is the NVM version).
    """
    if not manifest_fw:
        return "unknown"
    if not device_fw:
        return "unknown"

    # Normalize for comparison
    device_fw_norm = device_fw.strip().upper()
    manifest_fw_norm = manifest_fw.strip().upper()

    # Exact match (fast path)
    if device_fw_norm == manifest_fw_norm:
        return "current"

    # Extract first token for multi-part strings
    # (Intel NICs: "9.56 0x80010136 1.3909.0" → first token is "9.56")
    device_fw_first_token = device_fw_norm.split()[0] if device_fw_norm else ""
    manifest_fw_first_token = manifest_fw_norm.split()[0] if manifest_fw_norm else ""

    # Check first-token exact match
    if device_fw_first_token and device_fw_first_token == manifest_fw_first_token:
        return "current"

    # Check for differing firmware track prefixes (e.g. SN04 vs TN04).
    # These are different firmware trains and should never compare as equal,
    # even if their numeric portions match.
    device_prefix = get_version_prefix(device_fw_first_token)
    manifest_prefix = get_version_prefix(manifest_fw_first_token)
    if device_prefix and manifest_prefix and device_prefix != manifest_prefix:
        return "outdated"

    # Try numeric version comparison (device >= manifest means current)
    device_ver = parse_version_tuple(device_fw_first_token)
    manifest_ver = parse_version_tuple(manifest_fw_first_token)

    if device_ver and manifest_ver:
        # Only do numeric comparison if tuples are similar length (same format)
        # e.g. (16, 35, 8002) vs (16, 35, 4030) — same structure, compare OK
        # but (24, 22, 0, 105) vs (3324,) — different formats, skip numeric compare
        if len(device_ver) == len(manifest_ver) or \
           (len(device_ver) > 1 and len(manifest_ver) > 1):
            if device_ver >= manifest_ver:
                return "current"
            else:
                return "outdated"

    # Fallback: string mismatch with no numeric parsing = outdated
    return "outdated"

def write_cache(results):
    """Write results to cache file atomically (temp file + rename)."""
    cache_dir = os.path.dirname(CACHE_PATH)
    os.makedirs(cache_dir, exist_ok=True)

    tmp_path = CACHE_PATH + ".tmp"
    try:
        with open(tmp_path, "w") as f:
            json.dump(results, f, indent=2)
        os.replace(tmp_path, CACHE_PATH)
    except OSError as e:
        print(f"Error writing cache: {e}", file=sys.stderr)
        try:
            os.unlink(tmp_path)
        except OSError:
            pass


def main():
    # Load manifest
    manifest = load_manifest()
    if not manifest:
        return 1

    # Run discovery
    discovered = run_discover()
    if not discovered:
        return 1

    # Compare each device
    results = {
        "timestamp": datetime.utcnow().isoformat() + "Z",
        "hostname": discovered.get("hostname", ""),
        "manifest_version": manifest.get("schema_version", ""),
        "summary": {
            "total": 0,
            "current": 0,
            "outdated": 0,
            "unknown": 0,
        },
        "devices": [],
    }

    has_outdated = False

    for idx, device in enumerate(discovered.get("devices", [])):
        entry = match_device_to_manifest(device, manifest)

        if entry:
            flash_tool = entry.get("flash_tool", "")

            # Standard comparison: device firmware vs manifest latest_firmware
            status = compare_firmware(device.get("firmware", ""), entry.get("latest_firmware", ""))
            latest_fw = entry.get("latest_firmware", "")

            device_result = {
                "cache_index": idx,
                "type": device.get("type", ""),
                "model": device.get("model", ""),
                "model_full": device.get("model_full", device.get("model", "")),
                "serial": device.get("serial", ""),
                "device": device.get("device", ""),
                "slot": device.get("slot", ""),
                "sg_device": device.get("sg_device", ""),
                "bus_info": device.get("bus_info", ""),
                "controller_index": device.get("controller_index", ""),
                "current_firmware": device.get("firmware", ""),
                "latest_firmware": latest_fw,
                "update_available": status,
                "flashable": entry.get("flashable", False),
                "requires_reboot": entry.get("requires_reboot", False),
                "flash_tool": flash_tool,
                "flash_command": entry.get("flash_command", ""),
                "tool_package": entry.get("tool_package", ""),
                "tool_package_sha256": entry.get("tool_package_sha256", ""),
                "firmware_file": entry.get("firmware_file", ""),
                "sha256": entry.get("sha256", ""),
                "family": entry.get("family", ""),
                "release_notes": entry.get("release_notes", ""),
                "upgrade_from": entry.get("upgrade_from", ""),
            }

            # Look up flash tool binary SHA256 from manifest.flash_tools[]
            tool_binary_sha256 = ""
            if flash_tool:
                for ft in manifest.get("flash_tools", []):
                    if ft.get("name", "").lower() == flash_tool.lower():
                        tool_binary_sha256 = ft.get("sha256", "")
                        break
            device_result["tool_binary_sha256"] = tool_binary_sha256

            # Verify firmware payload exists on disk if device is flashable
            fw_file = device_result.get("firmware_file", "")
            expected_sha256 = device_result.get("sha256", "")
            if device_result["flashable"] and fw_file:
                fw_path = os.path.join(FIRMWARE_DIR, fw_file)
                if not os.path.isfile(fw_path):
                    # File not local — check if repo is configured for on-demand download
                    # Also require SHA256 hash — firmware-flash refuses downloads without it
                    if has_repo_configured() and expected_sha256:
                        # Keep flashable=True; firmware-flash will download it
                        device_result["payload_missing"] = False
                        device_result["payload_remote"] = True
                    else:
                        device_result["flashable"] = False
                        device_result["payload_missing"] = True
                        reason = f" [Firmware file not deployed: {fw_file}]"
                        if not expected_sha256:
                            reason = f" [No SHA256 hash — cannot safely download: {fw_file}]"
                        device_result["release_notes"] = (
                            device_result.get("release_notes", "") + reason
                        ).strip()
        else:
            status = "unknown"
            device_result = {
                "cache_index": idx,
                "type": device.get("type", ""),
                "model": device.get("model", ""),
                "model_full": device.get("model_full", device.get("model", "")),
                "serial": device.get("serial", ""),
                "device": device.get("device", ""),
                "slot": device.get("slot", ""),
                "sg_device": device.get("sg_device", ""),
                "bus_info": device.get("bus_info", ""),
                "current_firmware": device.get("firmware", ""),
                "latest_firmware": "",
                "update_available": status,
                "flashable": False,
                "requires_reboot": False,
                "flash_tool": "",
                "flash_command": "",
                "firmware_file": "",
                "family": device.get("family", ""),
                "release_notes": "No manifest entry found for this device",
            }

        results["devices"].append(device_result)
        results["summary"]["total"] += 1
        results["summary"][status] += 1

        if status == "outdated":
            has_outdated = True

    # Write cache
    write_cache(results)

    # Print summary
    s = results["summary"]
    print(f"Firmware check complete: {s['total']} devices, "
          f"{s['current']} current, {s['outdated']} outdated, {s['unknown']} unknown")

    return 2 if has_outdated else 0


if __name__ == "__main__":
    sys.exit(main())
