#!/usr/libexec/platform-python
"""
45d-firmware-flash

Downloads and flashes firmware for a specific device.
Called from Cockpit UI with device info from the firmware cache.

Usage:
    firmware-flash --type hba --model "9400-16i" --device-path "0000:01:00.0"
    firmware-flash --type nic --model "ConnectX-6" --device-path "0000:41:00.0"

Output: JSON with status, stdout, stderr, reboot_needed

Exit codes:
    0 - flash successful
    1 - flash failed
    2 - invalid arguments
    3 - tool not found
    4 - firmware download failed
"""

import argparse
import hashlib
import json
import os
import re
import subprocess
import sys
import urllib.request
import urllib.error

###############################################################################
# Configuration
###############################################################################

REPO_CONF = "/usr/share/45drives/firmware/repo.conf"
# Use the packaged on-disk cache so downloads persist and match firmware-check/docs
FIRMWARE_DOWNLOAD_DIR = "/usr/share/45drives/firmware/files"
# Primary cache: written by cockpit firmware-check
CACHE_FILE = "/var/cache/45drives/firmware.json"
# Fallback cache: written by system firmware-check (installed version)
CACHE_FILE_SYSTEM = "/var/cache/45drives/firmware/status.json"

# Tool paths
TOOL_PATHS = {
    "storcli2": ["/opt/45drives/tools/storcli2", "/usr/local/bin/storcli2", "/opt/MegaRAID/storcli/storcli2"],
    "storcli64": ["/opt/45drives/tools/storcli64", "/opt/45drives/bin/storcli64", "/opt/MegaRAID/storcli/storcli64"],
    "mlxup": ["/opt/45drives/tools/mlxup", "/root/nic-firmware-tools/mlxup", "/usr/local/bin/mlxup"],
    "niccli": ["/opt/45drives/niccli/niccli", "/usr/local/bin/niccli"],
    "sas3flash": ["/opt/45drives/bin/sas3flash", "/usr/local/bin/sas3flash"],
    "nvme-cli": ["/usr/sbin/nvme", "/usr/bin/nvme"],
}

# Where to install downloaded tools
TOOL_INSTALL_DIR = "/opt/45drives/tools"

###############################################################################
# Helpers
###############################################################################

def compute_sha256(filepath):
    """Compute SHA256 of a file."""
    sha256_hash = hashlib.sha256()
    with open(filepath, "rb") as f:
        for chunk in iter(lambda: f.read(65536), b""):
            sha256_hash.update(chunk)
    return sha256_hash.hexdigest()


def get_repo_url():
    """Read the firmware repo URL from repo.conf."""
    try:
        with open(REPO_CONF, "r") as f:
            for line in f:
                line = line.strip()
                if line.startswith("#") or not line:
                    continue
                if line.startswith("REPO_URL="):
                    return line.split("=", 1)[1].strip()
    except FileNotFoundError:
        pass
    return None


def find_tool(tool_name, expected_sha256=""):
    """Find a flash tool on the system. Downloads it if not found."""
    # Check known paths
    for path in TOOL_PATHS.get(tool_name, []):
        if os.path.isfile(path) and os.access(path, os.X_OK):
            # Verify existing binary if we have a hash
            if expected_sha256:
                actual = compute_sha256(path)
                if actual != expected_sha256:
                    print(f"WARNING: {path} SHA256 mismatch (expected {expected_sha256[:12]}..., got {actual[:12]}...). Re-downloading.", file=sys.stderr)
                    os.unlink(path)
                    break
            return path
    # Try which
    result = subprocess.run(["which", tool_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    if result.returncode == 0:
        found_path = result.stdout.strip()
        if expected_sha256:
            actual = compute_sha256(found_path)
            if actual != expected_sha256:
                print(f"WARNING: {found_path} SHA256 mismatch. Will download verified copy.", file=sys.stderr)
            else:
                return found_path
        else:
            return found_path

    # For nvme-cli, try installing via package manager
    if tool_name == "nvme-cli":
        return install_nvme_cli()

    # Try downloading from asset server
    return download_tool(tool_name, expected_sha256)


def download_tool(tool_name, expected_sha256=""):
    """Download a tool binary from the firmware repo with SHA256 verification."""
    repo_url = get_repo_url()
    if not repo_url:
        print(f"No firmware repo configured (repo.conf missing) — cannot download {tool_name}", file=sys.stderr)
        return None

    # Refuse to download tool binaries without a SHA256 hash — they run as root
    if not expected_sha256:
        print(f"REFUSED: No SHA256 hash available for tool '{tool_name}'. "
              f"Cannot safely download and execute unverified binaries.", file=sys.stderr)
        return None

    # Sanitize tool_name to prevent path traversal
    safe_name = os.path.basename(tool_name)
    if safe_name != tool_name or safe_name in ("", ".", ".."):
        print(f"REFUSED: Invalid tool name '{tool_name}'", file=sys.stderr)
        return None

    url = f"{repo_url}/tools/{safe_name}"
    os.makedirs(TOOL_INSTALL_DIR, exist_ok=True)
    dest = os.path.join(TOOL_INSTALL_DIR, safe_name)

    try:
        print(f"Downloading {tool_name} from {url}...", file=sys.stderr)
        req = urllib.request.Request(url, headers={"User-Agent": "45d-firmware-flash/1.0"})
        sha256_hash = hashlib.sha256()
        with urllib.request.urlopen(req, timeout=60) as resp:
            with open(dest, "wb") as f:
                while True:
                    chunk = resp.read(65536)
                    if not chunk:
                        break
                    f.write(chunk)
                    sha256_hash.update(chunk)

        actual_sha256 = sha256_hash.hexdigest()
        if actual_sha256 != expected_sha256:
            print(f"INTEGRITY FAILURE: {tool_name} SHA256 mismatch!", file=sys.stderr)
            print(f"  Expected: {expected_sha256}", file=sys.stderr)
            print(f"  Got:      {actual_sha256}", file=sys.stderr)
            os.unlink(dest)
            return None

        os.chmod(dest, 0o755)
        print(f"Installed {tool_name} to {dest} (SHA256 verified)", file=sys.stderr)
        return dest
    except Exception as e:
        print(f"Failed to download {tool_name}: {e}", file=sys.stderr)
        if os.path.exists(dest):
            os.unlink(dest)
        return None


def install_nvme_cli():
    """nvme-cli must be installed via the OS package manager (no auto-install here)."""
    print("nvme-cli is not installed. Please install the 'nvme-cli' package and retry.", file=sys.stderr)
    return None


def download_firmware(firmware_file, device_type, model, expected_sha256=""):
    """Download firmware binary from repo with SHA256 verification. Returns local path or None."""
    rel_path = os.path.normpath(firmware_file).lstrip(os.sep)
    if rel_path.startswith(".."):  # path traversal guard
        return None
    local_path = os.path.join(FIRMWARE_DOWNLOAD_DIR, rel_path)
    os.makedirs(os.path.dirname(local_path), exist_ok=True)

    # If already downloaded, verify integrity if hash is available
    if os.path.isfile(local_path):
        if expected_sha256:
            actual = compute_sha256(local_path)
            if actual != expected_sha256:
                print(f"Cached firmware SHA256 mismatch, re-downloading...", file=sys.stderr)
                os.unlink(local_path)
            else:
                return local_path
        else:
            return local_path

    # Refuse to download firmware without a SHA256 hash
    if not expected_sha256:
        print(f"REFUSED: No SHA256 hash available for firmware '{firmware_file}'. "
              f"Cannot safely download unverified firmware.", file=sys.stderr)
        return None

    # Get repo URL from repo.conf
    repo_url = get_repo_url()
    if not repo_url:
        print(f"No firmware repo configured (repo.conf missing) — cannot download firmware", file=sys.stderr)
        return None

    # The manifest's firmware_file field contains the relative path within the repo
    # e.g. "hba/9400-16i/HBA_9400-16i_SAS_SATA_Profile.bin"
    url = f"{repo_url}/{rel_path}"

    try:
        print(f"Downloading firmware from {url}...", file=sys.stderr)
        req = urllib.request.Request(url, headers={"User-Agent": "45d-firmware-flash/1.0"})
        sha256_hash = hashlib.sha256()
        with urllib.request.urlopen(req, timeout=120) as resp:
            with open(local_path, "wb") as f:
                while True:
                    chunk = resp.read(65536)
                    if not chunk:
                        break
                    f.write(chunk)
                    sha256_hash.update(chunk)

        actual_sha256 = sha256_hash.hexdigest()
        if actual_sha256 != expected_sha256:
            print(f"INTEGRITY FAILURE: firmware SHA256 mismatch!", file=sys.stderr)
            print(f"  Expected: {expected_sha256}", file=sys.stderr)
            print(f"  Got:      {actual_sha256}", file=sys.stderr)
            os.unlink(local_path)
            return None

        print(f"Downloaded firmware to {local_path} (SHA256 verified)", file=sys.stderr)
        return local_path
    except Exception as e:
        print(f"  Failed: {e}", file=sys.stderr)
        if os.path.exists(local_path):
            os.unlink(local_path)
        return None


def get_device_from_cache(device_type, device_path):
    """Load the matching device entry from firmware cache.
    
    Tries the cockpit cache first, then falls back to the system cache.
    """
    for cache_path in [CACHE_FILE, CACHE_FILE_SYSTEM]:
        try:
            with open(cache_path) as f:
                cache = json.load(f)
            for dev in cache.get("devices", []):
                dev_path = dev.get("device_path") or dev.get("bus_info") or dev.get("device")
                if dev.get("type") == device_type and dev_path == device_path:
                    return dev
        except (FileNotFoundError, json.JSONDecodeError):
            continue
    return None


def is_hba_noop_flash(output):
    """Return True only when storcli says no firmware was written."""
    noop_patterns = [
        r"\balready\s+up[-\s]?to[-\s]?date\b",
        r"\balready\s+current\b",
        r"\bno\s+(?:firmware\s+)?(?:update|flash)\s+(?:is\s+)?(?:required|needed|necessary|performed)\b",
        r"\bnot\s+flashed\b",
        r"\bno\s+changes?\s+(?:made|were\s+made|required)\b",
    ]
    return any(re.search(pattern, output, re.IGNORECASE) for pattern in noop_patterns)

###############################################################################
# Flash handlers
###############################################################################

def flash_hba(device, firmware_path, tool_sha256=""):
    """Flash HBA firmware using storcli64 or storcli2."""
    model = device.get("model", "")
    flash_tool = device.get("flash_tool", "storcli2")
    tool_path = find_tool(flash_tool, tool_sha256)

    # Build list of tools to try (primary + fallback)
    tools_to_try = []
    if tool_path:
        tools_to_try.append(tool_path)
    fallback = "storcli64" if flash_tool == "storcli2" else "storcli2"
    fallback_path = find_tool(fallback)
    if fallback_path and fallback_path not in tools_to_try:
        tools_to_try.append(fallback_path)

    if not tools_to_try:
        return {"success": False, "error": f"Flash tool '{flash_tool}' not found on system (tried both storcli2 and storcli64)", "reboot_needed": False}

    device_path = device.get("device_path") or device.get("bus_info") or ""

    # Try each tool until one works
    last_error = ""
    for try_tool in tools_to_try:
        print(f"Trying flash with {try_tool}...", file=sys.stderr)

        # Determine controller index from device_path
        controller_idx = get_controller_index(try_tool, device_path)

        if controller_idx is None:
            print(f"Could not map PCI address to controller with {try_tool}, trying /c0...", file=sys.stderr)
            controller_idx = 0

        cmd = [try_tool, f"/c{controller_idx}", "download", f"file={firmware_path}"]

        try:
            result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=300)
            output = result.stdout + result.stderr

            # Check for "not found" — means this tool doesn't see the controller, try next
            if "not found" in output.lower() and fallback_path and try_tool != tools_to_try[-1]:
                print(f"{os.path.basename(try_tool)} can't see controller, trying next tool...", file=sys.stderr)
                last_error = output
                continue

            # A no-op can still return 0. It should not inherit the manifest's
            # requires_reboot flag because no firmware was actually written.
            if is_hba_noop_flash(output):
                return {"success": True, "output": "Firmware already up to date.", "reboot_needed": False}

            # Check for success indicators
            if result.returncode == 0 or "successfully" in output.lower():
                reboot = bool(re.search(r"reboot|power cycle|activation.*pending|reset required", output, re.IGNORECASE))
                # Also use manifest's requires_reboot flag
                reboot = reboot or device.get("requires_reboot", False)
                return {"success": True, "output": output, "reboot_needed": reboot}
            else:
                last_error = output
        except subprocess.TimeoutExpired:
            last_error = "Flash command timed out (300s)"
        except Exception as e:
            last_error = str(e)

    return {"success": False, "error": last_error, "reboot_needed": False}


def get_controller_index(tool_path, pci_address):
    """Map PCI address to storcli controller index."""
    try:
        result = subprocess.run([tool_path, "/call", "show"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=30)
        if result.returncode != 0 or "No Controller found" in result.stdout:
            # storcli2 might not see them, try returning 0 if only one expected
            return None

        # Parse output - look for Bus Number matching our PCI address
        # PCI address format: 0000:XX:00.0 where XX is hex bus number
        try:
            target_bus = int(pci_address.split(':')[1], 16)
        except (IndexError, ValueError):
            return None

        lines = result.stdout.split('\n')
        current_controller = None
        for line in lines:
            ctrl_match = re.search(r'Controller\s*=\s*(\d+)', line)
            if ctrl_match:
                current_controller = int(ctrl_match.group(1))
            bus_match = re.search(r'Bus Number\s*=\s*(\d+)', line)
            if bus_match and current_controller is not None:
                if int(bus_match.group(1)) == target_bus:
                    return current_controller

        # Fallback: if we found controllers but couldn't match, try 0
        if current_controller is not None:
            return 0

    except Exception as e:
        print(f"Error detecting controller index: {e}", file=sys.stderr)

    return None


def flash_mellanox(device, tool_sha256=""):
    """Flash Mellanox NIC using mlxup (self-updating from NVIDIA servers)."""
    tool_path = find_tool("mlxup", tool_sha256)
    if not tool_path:
        return {"success": False, "error": "mlxup not found on system", "reboot_needed": False}

    device_path = device.get("device_path", "")

    # mlxup auto-downloads from NVIDIA, no local firmware file needed
    cmd = [tool_path, "--update", "--yes", "--dev", device_path]

    try:
        result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=600)
        output = result.stdout + result.stderr

        if result.returncode == 0:
            reboot = bool(re.search(r"power cycle|reboot|restart required", output, re.IGNORECASE))
            if re.search(r"up to date|same version|no update", output, re.IGNORECASE):
                return {"success": True, "output": "Firmware already up to date.", "reboot_needed": False}
            return {"success": True, "output": output, "reboot_needed": reboot}
        else:
            return {"success": False, "error": output, "reboot_needed": False}
    except subprocess.TimeoutExpired:
        return {"success": False, "error": "mlxup timed out (600s)", "reboot_needed": False}
    except Exception as e:
        return {"success": False, "error": str(e), "reboot_needed": False}


def flash_broadcom_nic(device, firmware_path, tool_sha256=""):
    """Flash Broadcom NIC using niccli."""
    tool_path = find_tool("niccli", tool_sha256)
    if not tool_path:
        return {"success": False, "error": "niccli not found on system", "reboot_needed": False}

    device_path = device.get("device_path", "")
    # Use --pci for PCI BDF addresses, --dev for index/MAC
    if re.match(r"[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.\d", device_path):
        cmd = [tool_path, "--pci", device_path, "fw", "--update", "-f", firmware_path]
    else:
        cmd = [tool_path, "--dev", device_path, "fw", "--update", "-f", firmware_path]

    try:
        result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=300)
        output = result.stdout + result.stderr

        if result.returncode == 0:
            return {"success": True, "output": output, "reboot_needed": True}
        elif re.search(r"already\s+up[-\s]?to[-\s]?date", output, re.IGNORECASE):
            return {"success": True, "output": output, "reboot_needed": False}
        else:
            return {"success": False, "error": output, "reboot_needed": False}
    except subprocess.TimeoutExpired:
        return {"success": False, "error": "niccli timed out (300s)", "reboot_needed": False}
    except Exception as e:
        return {"success": False, "error": str(e), "reboot_needed": False}


def flash_nvme(device, firmware_path):
    """Flash NVMe drive using nvme-cli."""
    tool_path = find_tool("nvme-cli")
    if not tool_path:
        return {"success": False, "error": "nvme-cli not found on system", "reboot_needed": False}

    device_path = device.get("device_path", "")

    # Step 1: Download firmware to device
    cmd_download = [tool_path, "fw-download", device_path, f"--fw={firmware_path}"]
    # Step 2: Activate firmware
    cmd_activate = [tool_path, "fw-activate", device_path, "--slot=1", "--action=1"]

    try:
        result = subprocess.run(cmd_download, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=120)
        if result.returncode != 0:
            return {"success": False, "error": f"fw-download failed: {result.stdout + result.stderr}", "reboot_needed": False}

        result = subprocess.run(cmd_activate, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=60)
        output = result.stdout + result.stderr

        if result.returncode == 0 or "successfully" in output.lower():
            return {"success": True, "output": output, "reboot_needed": True}
        else:
            return {"success": False, "error": f"fw-activate failed: {output}", "reboot_needed": False}
    except subprocess.TimeoutExpired:
        return {"success": False, "error": "NVMe flash timed out", "reboot_needed": False}
    except Exception as e:
        return {"success": False, "error": str(e), "reboot_needed": False}

###############################################################################
# Main
###############################################################################

def main():
    parser = argparse.ArgumentParser(description="Flash firmware for a specific device")
    parser.add_argument("--type", required=True, help="Device type (hba, nic, nvme)")
    parser.add_argument("--model", required=True, help="Device model string")
    parser.add_argument("--device-path", required=True, help="PCI bus address or device path")
    args = parser.parse_args()

    if os.geteuid() != 0:
        output = {"success": False, "error": "Must be run as root", "reboot_needed": False}
        print(json.dumps(output))
        sys.exit(1)

    # Get device info from cache
    device = get_device_from_cache(args.type, args.device_path)
    if not device:
        output = {"success": False, "error": f"Device not found in cache: {args.type} at {args.device_path}", "reboot_needed": False}
        print(json.dumps(output))
        sys.exit(2)

    # Ensure device_path is set (cache entry may have None if discovered without PCI info)
    if not device.get("device_path"):
        device["device_path"] = args.device_path

    if not device.get("flashable", False):
        output = {"success": False, "error": "Device is not marked as flashable in manifest", "reboot_needed": False}
        print(json.dumps(output))
        sys.exit(2)

    flash_tool = device.get("flash_tool", "")
    firmware_file = device.get("firmware_file", "")
    model = device.get("model", args.model)
    firmware_sha256 = device.get("sha256", "")
    tool_sha256 = device.get("tool_binary_sha256", "")

    # Mellanox NICs are self-updating (mlxup downloads from NVIDIA)
    if flash_tool == "mlxup":
        result = flash_mellanox(device, tool_sha256)
        print(json.dumps(result))
        sys.exit(0 if result["success"] else 1)

    # For other devices, download firmware first
    firmware_path = None
    if firmware_file:
        firmware_path = download_firmware(firmware_file, args.type, model, firmware_sha256)
        if not firmware_path:
            # Try local path (firmware might be pre-staged)
            local_candidates = [
                f"/opt/45drives/firmware/{firmware_file}",
                f"/opt/45drives/firmware/{args.type}/{firmware_file}",
            ]
            for candidate in local_candidates:
                if os.path.isfile(candidate):
                    # Verify pre-staged file if hash is available
                    if firmware_sha256:
                        actual = compute_sha256(candidate)
                        if actual != firmware_sha256:
                            print(f"WARNING: Pre-staged {candidate} SHA256 mismatch, skipping.", file=sys.stderr)
                            continue
                    firmware_path = candidate
                    break

        if not firmware_path:
            output = {"success": False, "error": f"Could not download firmware file: {firmware_file}", "reboot_needed": False}
            print(json.dumps(output))
            sys.exit(4)

    # Route to appropriate flash handler
    if args.type == "hba":
        if not firmware_path:
            output = {"success": False, "error": "No firmware file specified for HBA", "reboot_needed": False}
            print(json.dumps(output))
            sys.exit(2)
        result = flash_hba(device, firmware_path, tool_sha256)
    elif args.type == "nic":
        is_broadcom = ("broadcom" in (device.get("vendor", "")).lower()
                       or "P210p" in model
                       or "BCM" in model.upper())
        if is_broadcom:
            if not firmware_path:
                output = {"success": False, "error": "No firmware file specified for Broadcom NIC", "reboot_needed": False}
                print(json.dumps(output))
                sys.exit(2)
            result = flash_broadcom_nic(device, firmware_path, tool_sha256)
        else:
            result = {"success": False, "error": f"Unsupported NIC type: {model}", "reboot_needed": False}
    elif args.type == "nvme":
        if not firmware_path:
            output = {"success": False, "error": "No firmware file specified for NVMe", "reboot_needed": False}
            print(json.dumps(output))
            sys.exit(2)
        result = flash_nvme(device, firmware_path)
    else:
        result = {"success": False, "error": f"Unsupported device type for flashing: {args.type}", "reboot_needed": False}

    print(json.dumps(result))
    sys.exit(0 if result["success"] else 1)


if __name__ == "__main__":
    main()
