#!/usr/libexec/platform-python
"""
firmware-flash - Flashes firmware to a specific device.

Usage:
    firmware-flash --cache-index 3   (flash device at index 3 from cache)
    firmware-flash --cache-index 3 --dry-run   (show what would be done)

Exit codes:
    0 = success
    1 = error (bad args, file not found, etc.)
    2 = flash failed (tool reported failure)
    3 = device not found or not flashable
"""

import json
import subprocess
import sys
import os
import shutil
import shlex
import argparse
from pathlib import Path
from datetime import datetime, timezone
from urllib.request import urlopen, Request
from urllib.error import URLError, HTTPError
from urllib.parse import quote
import hashlib


CACHE_PATH = "/var/cache/45drives/firmware/status.json"
FIRMWARE_DIR = "/usr/share/45drives/firmware/files"
SEACHEST_PATH = "/opt/45drives/tools/SeaChest_Firmware"
REPO_CONF = "/usr/share/45drives/firmware/repo.conf"
LOG_DIR = "/var/log/45drives"
LOG_FILE = os.path.join(LOG_DIR, "firmware-flash.log")

# Global flag: set via --allow-download CLI arg (UI passes this after user confirms)
ALLOW_DOWNLOAD = False


def log_flash_event(device_info, success, message=""):
    """Append a structured log entry for every flash attempt."""
    try:
        os.makedirs(LOG_DIR, mode=0o755, exist_ok=True)
        timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
        status = "SUCCESS" if success else "FAILED"
        model = device_info.get("model", "unknown")
        device = device_info.get("device", device_info.get("sg_device", "?"))
        dev_type = device_info.get("type", "?")
        current_fw = device_info.get("current_firmware", "?")
        target_fw = device_info.get("latest_firmware", "?")
        flash_tool = device_info.get("flash_tool", "?")
        fw_file = device_info.get("firmware_file", "?")
        line = (f"{timestamp}  {status}  {dev_type}/{model}  "
                f"{current_fw} -> {target_fw}  tool={flash_tool}  "
                f"device={device}  file={fw_file}")
        if not success and message:
            line += f"  error={message}"
        with open(LOG_FILE, "a") as f:
            f.write(line + "\n")
    except Exception:
        pass  # Never let logging break the flash process

# Environment overrides for development
if os.environ.get("FIRMWARE_CACHE_PATH"):
    CACHE_PATH = os.environ["FIRMWARE_CACHE_PATH"]
if os.environ.get("FIRMWARE_DIR"):
    FIRMWARE_DIR = os.environ["FIRMWARE_DIR"]
if os.environ.get("SEACHEST_PATH"):
    SEACHEST_PATH = os.environ["SEACHEST_PATH"]


def get_repo_url():
    """Read the firmware repo URL from config."""
    conf_path = REPO_CONF
    if os.environ.get("FIRMWARE_REPO_CONF"):
        conf_path = os.environ["FIRMWARE_REPO_CONF"]
    try:
        with open(conf_path, "r") as f:
            for line in f:
                line = line.strip()
                if line.startswith("#") or not line:
                    continue
                if line.startswith("REPO_URL="):
                    return line.split("=", 1)[1].strip()
    except FileNotFoundError:
        pass
    return None


def fetch_firmware_from_repo(firmware_file, dest_path, expected_sha256=""):
    """Download a firmware file from the repo if not available locally.
    
    If expected_sha256 is provided, verifies integrity after download.
    Returns True if file is now available at dest_path, False otherwise.
    """
    # Enforce: SHA256 must be present in manifest for integrity verification
    if not expected_sha256:
        print(f"  ✗ REFUSED: No SHA256 hash in manifest for {firmware_file}.")
        print(f"    Firmware files without integrity hashes cannot be safely downloaded.")
        print(f"    Run 'firmware-release update-hashes' to populate the manifest.")
        return False

    if os.path.isfile(dest_path):
        # File exists locally — verify SHA256
        actual = compute_sha256(dest_path)
        if actual != expected_sha256:
            print(f"  Local file SHA256 mismatch! Re-downloading...")
            os.unlink(dest_path)
        else:
            return True

    repo_url = get_repo_url()
    if not repo_url:
        print("  No firmware repo configured (repo.conf missing or no REPO_URL)")
        return False

    url = f"{repo_url}/{quote(firmware_file)}"
    print(f"  Firmware not found locally. Downloading from repo...")
    print(f"  URL: {url}")

    try:
        req = Request(url, headers={"User-Agent": "45drives-firmware-flash/1.0"})
        response = urlopen(req, timeout=60)

        # Ensure destination directory exists
        os.makedirs(os.path.dirname(dest_path), exist_ok=True)

        sha256_hash = hashlib.sha256()
        with open(dest_path, "wb") as f:
            while True:
                chunk = response.read(65536)
                if not chunk:
                    break
                f.write(chunk)
                sha256_hash.update(chunk)

        file_size = os.path.getsize(dest_path)
        actual_sha256 = sha256_hash.hexdigest()
        print(f"  Downloaded: {firmware_file} ({file_size} bytes)")
        print(f"  SHA256: {actual_sha256}")

        # Verify integrity
        if expected_sha256 and actual_sha256 != expected_sha256:
            print(f"  ✗ INTEGRITY FAILURE: SHA256 mismatch!")
            print(f"    Expected: {expected_sha256}")
            print(f"    Got:      {actual_sha256}")
            os.unlink(dest_path)
            return False

        return True

    except HTTPError as e:
        print(f"  Download failed: HTTP {e.code} - {e.reason}")
        return False
    except URLError as e:
        print(f"  Download failed: {e.reason}")
        return False
    except Exception as e:
        print(f"  Download failed: {e}")
        # Clean up partial download
        if os.path.exists(dest_path):
            os.unlink(dest_path)
        return False


def fetch_tool_from_repo(tool_name, dest_path, expected_sha256=""):
    """Download a flash tool binary from the repo if not available locally.
    
    Tools are served from the repo under tools/<tool_name>.
    If expected_sha256 is provided, verifies integrity after download.
    If no SHA256 is available, refuses to download (security policy).
    Returns True if tool is now available at dest_path, False otherwise.
    """
    if os.path.isfile(dest_path) and os.access(dest_path, os.X_OK):
        # If we have a hash, verify the existing binary
        if expected_sha256:
            actual = compute_sha256(dest_path)
            if actual != expected_sha256:
                print(f"  ⚠ Local {tool_name} SHA256 mismatch — re-downloading...")
                os.unlink(dest_path)
            else:
                return True
        else:
            return True

    # Refuse to download without a SHA256 hash — tool binaries run as root
    if not expected_sha256:
        print(f"  ✗ REFUSED: No SHA256 hash in manifest for tool '{tool_name}'.")
        print(f"    Tool binaries without integrity hashes cannot be safely downloaded.")
        print(f"    The manifest must include a sha256 in flash_tools[] for this tool.")
        return False

    repo_url = get_repo_url()
    if not repo_url:
        print(f"  No firmware repo configured — cannot download {tool_name}")
        return False

    url = f"{repo_url}/tools/{quote(tool_name)}"
    print(f"  {tool_name} not found locally. Downloading from repo...")
    print(f"  URL: {url}")

    try:
        req = Request(url, headers={"User-Agent": "45drives-firmware-flash/1.0"})
        response = urlopen(req, timeout=60)

        os.makedirs(os.path.dirname(dest_path), exist_ok=True)

        sha256_hash = hashlib.sha256()
        with open(dest_path, "wb") as f:
            while True:
                chunk = response.read(65536)
                if not chunk:
                    break
                f.write(chunk)
                sha256_hash.update(chunk)

        # Make executable
        os.chmod(dest_path, 0o755)
        file_size = os.path.getsize(dest_path)
        actual_sha256 = sha256_hash.hexdigest()
        print(f"  Downloaded: {tool_name} ({file_size} bytes)")
        print(f"  SHA256: {actual_sha256}")

        # Verify integrity if hash was provided
        if expected_sha256:
            if actual_sha256 != expected_sha256:
                print(f"  ✗ INTEGRITY FAILURE: SHA256 mismatch!")
                print(f"    Expected: {expected_sha256}")
                print(f"    Got:      {actual_sha256}")
                os.unlink(dest_path)
                return False
            print(f"  ✓ Integrity verified")

        return True

    except HTTPError as e:
        print(f"  Tool download failed: HTTP {e.code} - {e.reason}")
        return False
    except URLError as e:
        print(f"  Tool download failed: {e.reason}")
        return False
    except Exception as e:
        print(f"  Tool download failed: {e}")
        if os.path.exists(dest_path):
            os.unlink(dest_path)
        return False


def compute_sha256(filepath):
    """Compute SHA256 of a file."""
    sha256_hash = hashlib.sha256()
    with open(filepath, "rb") as f:
        for chunk in iter(lambda: f.read(65536), b""):
            sha256_hash.update(chunk)
    return sha256_hash.hexdigest()


def confirm_download(item_description):
    """Ask user for permission before downloading from repo.
    
    - If --allow-download was passed (user already confirmed in UI), proceeds.
    - If stdin is a TTY, prompts the user interactively.
    - Otherwise, refuses and informs the user what's needed.
    """
    if ALLOW_DOWNLOAD:
        print(f"  ℹ {item_description} — downloading from repo (user approved)...")
        return True

    if sys.stdin.isatty():
        try:
            answer = input(f"  {item_description} not found locally. Download from repo? [y/N]: ").strip().lower()
            return answer in ("y", "yes")
        except (EOFError, KeyboardInterrupt):
            print()
            return False
    else:
        print(f"  ✗ {item_description} not found locally.")
        print(f"    Approve the download and retry.")
        return False


# SeaChest exit codes
SEACHEST_SUCCESS_CODES = {0, 32, 33, 34}  # 32=FW complete, 33=deferred complete, 34=activated
SEACHEST_FAILURE_CODES = {3, 35, 36, 37}  # 3=generic, 35=no match, 36=model mismatch, 37=hash error


def load_cache():
    """Load the firmware status cache."""
    try:
        with open(CACHE_PATH, "r") as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        print(f"Error loading cache: {e}", file=sys.stderr)
        return None


def flash_seachest(sg_device, firmware_path, tool_sha256=""):
    """Flash using SeaChest_Firmware."""
    if not sg_device:
        print("Error: No sg device path available for SeaChest", file=sys.stderr)
        return False, "No sg device path"

    # SeaChest has issues with long paths/filenames — copy to simple temp path
    tmp_path = f"/root/45d_fw_flash_{os.getpid()}.LOD"
    try:
        shutil.copy2(firmware_path, tmp_path)
    except OSError as e:
        print(f"Error copying firmware file: {e}", file=sys.stderr)
        return False, str(e)

    # Find SeaChest binary
    seachest_bin = SEACHEST_PATH
    if not os.path.isfile(seachest_bin):
        # Try common locations
        for path in [
            "/usr/local/bin/SeaChest_Firmware",
            "/opt/45drives/tools/SeaChest_Firmware",
            "/usr/bin/SeaChest_Firmware",
        ]:
            if os.path.isfile(path):
                seachest_bin = path
                break
        else:
            # Try downloading from repo (with user permission)
            if not confirm_download("SeaChest_Firmware"):
                return False, "SeaChest_Firmware binary not found and download was declined."
            print("  SeaChest_Firmware not found locally, downloading from repo...")
            if fetch_tool_from_repo("SeaChest_Firmware", SEACHEST_PATH, tool_sha256):
                seachest_bin = SEACHEST_PATH
            else:
                return False, "SeaChest_Firmware binary not found and could not be downloaded from repo"

    # Verify local binary integrity against manifest hash (unconditional)
    if not tool_sha256:
        if os.path.exists(tmp_path):
            os.unlink(tmp_path)
        print(f"  ✗ REFUSED: No tool_binary_sha256 in manifest for 'SeaChest_Firmware'.")
        print(f"    Cannot execute flash tools without integrity verification.")
        return False, "No tool_binary_sha256 for 'SeaChest_Firmware' — manifest flash_tools[] entry missing or incomplete"
    actual = compute_sha256(seachest_bin)
    if actual != tool_sha256:
        # Local binary is stale/wrong — try re-downloading the correct version
        print(f"  ⚠ Local SeaChest_Firmware SHA256 mismatch, re-downloading from repo...")
        print(f"    Path:     {seachest_bin}")
        print(f"    Expected: {tool_sha256}")
        print(f"    Got:      {actual}")
        dest = SEACHEST_PATH
        if not confirm_download("SeaChest_Firmware (integrity mismatch — re-download)"):
            if os.path.exists(tmp_path):
                os.unlink(tmp_path)
            return False, "SeaChest_Firmware integrity check failed and re-download was declined."
        # Only remove the binary if it's in our managed directory
        if seachest_bin.startswith("/opt/45drives/tools/") and os.path.isfile(seachest_bin):
            os.unlink(seachest_bin)
        if fetch_tool_from_repo("SeaChest_Firmware", dest, tool_sha256):
            seachest_bin = dest
        else:
            if os.path.exists(tmp_path):
                os.unlink(tmp_path)
            return False, f"SeaChest_Firmware integrity check failed and re-download from repo also failed"

    cmd = [seachest_bin, "-d", sg_device, "--downloadFW", tmp_path]
    print(f"Running: {' '.join(shlex.quote(a) for a in cmd)}")

    try:
        result = subprocess.run(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=300
        )

        # Clean up temp file
        os.unlink(tmp_path)

        output = result.stdout + result.stderr
        print(output)

        if result.returncode in SEACHEST_SUCCESS_CODES:
            return True, output
        else:
            return False, f"Exit code {result.returncode}: {output}"

    except subprocess.TimeoutExpired:
        os.unlink(tmp_path)
        return False, "Flash command timed out (300s)"
    except Exception as e:
        if os.path.exists(tmp_path):
            os.unlink(tmp_path)
        return False, str(e)


def flash_storcli(device_info, firmware_path):
    """Flash HBA using storcli64."""
    tool = device_info.get("flash_tool", "storcli64")
    ctrl_idx = device_info.get("controller_index", 0)
    tool_sha256 = device_info.get("tool_binary_sha256", "")

    # Resolve full path for storcli
    tool_path = tool
    for candidate in [
        f"/opt/45drives/tools/{tool}",
        f"/opt/MegaRAID/storcli/{tool}",
        f"/usr/local/bin/{tool}",
        f"/usr/bin/{tool}",
    ]:
        if os.path.isfile(candidate):
            tool_path = candidate
            break
    else:
        # Not found locally — try downloading from repo (with user permission)
        dest = f"/opt/45drives/tools/{tool}"
        if not confirm_download(tool):
            return False, f"{tool} binary not found and download was declined."
        print(f"  {tool} not found locally, downloading from repo...")
        if fetch_tool_from_repo(tool, dest, tool_sha256):
            tool_path = dest
        else:
            return False, f"{tool} binary not found and could not be downloaded from repo"

    # Verify local binary integrity against manifest hash (unconditional)
    if not tool_sha256:
        print(f"  ✗ REFUSED: No tool_binary_sha256 in manifest for '{tool}'.")
        print(f"    Cannot execute flash tools without integrity verification.")
        return False, f"No tool_binary_sha256 for '{tool}' — manifest flash_tools[] entry missing or incomplete"
    actual = compute_sha256(tool_path)
    if actual != tool_sha256:
        # Local binary is stale/wrong — try re-downloading the correct version
        print(f"  ⚠ Local {tool} SHA256 mismatch, re-downloading from repo...")
        print(f"    Path:     {tool_path}")
        print(f"    Expected: {tool_sha256}")
        print(f"    Got:      {actual}")
        dest = f"/opt/45drives/tools/{tool}"
        if not confirm_download(f"{tool} (integrity mismatch — re-download)"):
            return False, f"{tool} integrity check failed and re-download was declined."
        # Only remove the binary if it's in our managed directory
        if tool_path.startswith("/opt/45drives/tools/") and os.path.isfile(tool_path):
            os.unlink(tool_path)
        if fetch_tool_from_repo(tool, dest, tool_sha256):
            tool_path = dest
        else:
            return False, f"{tool} integrity check failed and re-download from repo also failed"

    # Build command — replace placeholders
    cmd_template = device_info.get("flash_command", "")
    if not cmd_template:
        cmd_template = f"{tool_path} /c{ctrl_idx} download file={shlex.quote(firmware_path)}"
    else:
        cmd_template = cmd_template.replace(tool, tool_path)
        cmd_template = cmd_template.replace("<firmware_path>", shlex.quote(firmware_path))
        # Replace controller index placeholder with actual index
        cmd_template = cmd_template.replace("/cX", f"/c{ctrl_idx}")

    print(f"Running: {cmd_template}")

    try:
        result = subprocess.run(
            cmd_template, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=300
        )
        output = result.stdout + result.stderr
        print(output)

        if result.returncode == 0:
            return True, output
        else:
            return False, f"Exit code {result.returncode}: {output}"
    except Exception as e:
        return False, str(e)


def flash_nvme(device_path, firmware_path):
    """Flash NVMe device using nvme-cli."""
    # Download firmware
    cmd_dl = f"nvme fw-download {device_path} --fw={firmware_path}"
    print(f"Running: {cmd_dl}")

    try:
        result = subprocess.run(
            cmd_dl, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=120
        )
        if result.returncode != 0:
            return False, f"fw-download failed: {result.stdout}{result.stderr}"

        # Activate firmware
        cmd_act = f"nvme fw-activate {device_path} --slot=1 --action=1"
        print(f"Running: {cmd_act}")
        result = subprocess.run(
            cmd_act, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=60
        )
        output = result.stdout + result.stderr
        print(output)

        # fw-activate may return non-zero for "activate on next reset"
        if result.returncode in (0, 11):  # 11 = activate on next reset
            return True, output
        else:
            return False, f"fw-activate failed (code {result.returncode}): {output}"
    except Exception as e:
        return False, str(e)


def flash_niccli(device_info, firmware_path):
    """Flash Broadcom NIC using niccli.
    
    niccli uses its own adapter IDs (from 'niccli --list'), so we need to
    resolve the PCI bus address to the niccli adapter ID at flash time.
    """
    bus_info = device_info.get("bus_info", "")
    if not bus_info:
        return False, "No bus_info available for niccli flash"

    # Find niccli binary
    niccli_bin = None
    for path in ["/opt/niccli/niccli", "/usr/local/bin/niccli", "/usr/bin/niccli", "/opt/45drives/tools/niccli"]:
        if os.path.isfile(path):
            niccli_bin = path
            break
    if not niccli_bin:
        niccli_bin = shutil.which("niccli")
    if not niccli_bin:
        # Try downloading from repo
        dest = "/opt/45drives/tools/niccli"
        if not confirm_download("niccli"):
            return False, "niccli not found and download was declined."
        tool_sha256 = device_info.get("tool_binary_sha256", "")
        if fetch_tool_from_repo("niccli", dest, tool_sha256):
            niccli_bin = dest
        else:
            return False, "niccli not found and could not be downloaded from repo"

    # Verify local binary integrity against manifest hash (unconditional)
    tool_sha256 = device_info.get("tool_binary_sha256", "")
    if not tool_sha256:
        print(f"  ✗ REFUSED: No tool_binary_sha256 in manifest for 'niccli'.")
        print(f"    Cannot execute flash tools without integrity verification.")
        return False, "No tool_binary_sha256 for 'niccli' — manifest flash_tools[] entry missing or incomplete"
    actual = compute_sha256(niccli_bin)
    if actual != tool_sha256:
        # Local binary is stale/wrong — try re-downloading the correct version
        print(f"  ⚠ Local niccli SHA256 mismatch, re-downloading from repo...")
        print(f"    Path:     {niccli_bin}")
        print(f"    Expected: {tool_sha256}")
        print(f"    Got:      {actual}")
        dest = "/opt/45drives/tools/niccli"
        if not confirm_download("niccli (integrity mismatch — re-download)"):
            return False, "niccli integrity check failed and re-download was declined."
        # Don't delete the old binary if it's not ours (e.g. /opt/niccli/niccli from vendor package)
        if fetch_tool_from_repo("niccli", dest, tool_sha256):
            niccli_bin = dest
        else:
            return False, "niccli integrity check failed and re-download from repo also failed"

    # Get adapter list from niccli
    print(f"Discovering Broadcom adapter ID for PCI {bus_info}...")
    try:
        result = subprocess.run(
            [niccli_bin, "--list"], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=30
        )
        if result.returncode != 0:
            # niccli returns non-zero when no devices found
            if "not found" in result.stdout.lower() or "not found" in result.stderr.lower():
                return False, "No Broadcom NICs found by niccli"
            return False, f"niccli --list failed: {result.stdout}{result.stderr}"

        nic_list = result.stdout
        print(f"niccli --list output:\n{nic_list}")
    except Exception as e:
        return False, f"Failed to run niccli --list: {e}"

    # Find adapter ID matching our PCI bus address
    # niccli --list output format: <id> <model> <mac> <fw_ver> <pci_addr> ...
    adapter_id = None
    # Strip domain prefix for matching (0000:42:00.0 → 42:00.0)
    short_bus = bus_info
    if short_bus.startswith("0000:"):
        short_bus = short_bus[5:]
    # Also match without function for multi-port (42:00.0 → 42:00)
    bus_slot = short_bus.rsplit(".", 1)[0] if "." in short_bus else short_bus

    for line in nic_list.splitlines():
        # Match by PCI address in the line
        if short_bus in line or bus_info in line or bus_slot in line:
            # Adapter ID is the first token, strip trailing ')' (e.g. "1)" → "1")
            parts = line.split()
            if parts:
                adapter_id = parts[0].strip().rstrip(")")
                break

    if not adapter_id:
        return False, f"Could not find niccli adapter ID for PCI bus {bus_info}"

    # Flash using niccli
    cmd = f"{niccli_bin} -i {adapter_id} fw --update -f {shlex.quote(firmware_path)} --yes"
    print(f"Running: {cmd}")

    try:
        result = subprocess.run(
            cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=300, input="y\n"
        )
        output = result.stdout + result.stderr
        print(output)

        if "already up-to-date" in output.lower():
            return True, "Firmware already up-to-date. Reboot to activate."
        elif "update is in progress" in output.lower() or \
             "updated to package version" in output.lower() or \
             "update successful" in output.lower():
            return True, output
        elif result.returncode == 0:
            return True, output
        else:
            return False, f"niccli exit code {result.returncode}: {output}"
    except subprocess.TimeoutExpired:
        return False, "niccli firmware update timed out (300s)"
    except Exception as e:
        return False, str(e)


def flash_mlxup(device_info, firmware_path=None):
    """Flash Mellanox/NVIDIA NIC using mlxup with a pinned firmware .bin file."""
    bus_info = device_info.get("bus_info", "")

    # Find mlxup binary
    mlxup_bin = None
    for path in ["/usr/bin/mlxup", "/usr/local/bin/mlxup", "/opt/45drives/tools/mlxup", "/root/nic-firmware-tools/mlxup"]:
        if os.path.isfile(path):
            mlxup_bin = path
            break
    if not mlxup_bin:
        mlxup_bin = shutil.which("mlxup")
    if not mlxup_bin:
        # Try downloading from repo
        dest = "/opt/45drives/tools/mlxup"
        if not confirm_download("mlxup"):
            return False, "mlxup not found and download was declined."
        tool_sha256 = device_info.get("tool_binary_sha256", "")
        if fetch_tool_from_repo("mlxup", dest, tool_sha256):
            mlxup_bin = dest
        else:
            return False, "mlxup not found and could not be downloaded from repo"

    # Verify local binary integrity against manifest hash (unconditional)
    tool_sha256 = device_info.get("tool_binary_sha256", "")
    if not tool_sha256:
        print(f"  ✗ REFUSED: No tool_binary_sha256 in manifest for 'mlxup'.")
        print(f"    Cannot execute flash tools without integrity verification.")
        return False, "No tool_binary_sha256 for 'mlxup' — manifest flash_tools[] entry missing or incomplete"
    actual = compute_sha256(mlxup_bin)
    if actual != tool_sha256:
        # Local binary is stale/wrong — try re-downloading the correct version
        print(f"  ⚠ Local mlxup SHA256 mismatch, re-downloading from repo...")
        print(f"    Path:     {mlxup_bin}")
        print(f"    Expected: {tool_sha256}")
        print(f"    Got:      {actual}")
        dest = "/opt/45drives/tools/mlxup"
        if not confirm_download("mlxup (integrity mismatch — re-download)"):
            return False, "mlxup integrity check failed and re-download was declined."
        # Only remove the binary if it's in our managed directory
        if mlxup_bin.startswith("/opt/45drives/tools/") and os.path.isfile(mlxup_bin):
            os.unlink(mlxup_bin)
        if fetch_tool_from_repo("mlxup", dest, tool_sha256):
            mlxup_bin = dest
        else:
            return False, "mlxup integrity check failed and re-download from repo also failed"

    # Build command — use --image-file if we have a pinned firmware file
    if firmware_path and os.path.isfile(firmware_path):
        if bus_info:
            cmd = f"{mlxup_bin} --update --dev {bus_info} --image-file {shlex.quote(firmware_path)} --force --yes"
        else:
            cmd = f"{mlxup_bin} --update --image-file {shlex.quote(firmware_path)} --force --yes"
        print(f"Running: {cmd}")
        print(f"  (using pinned firmware image from 45Drives repo)")
    else:
        # Fallback to auto-update from NVIDIA (legacy behavior)
        if bus_info:
            cmd = f"{mlxup_bin} --update --dev {bus_info} --yes"
        else:
            cmd = f"{mlxup_bin} --update --yes"
        print(f"Running: {cmd}")
        print(f"  (mlxup downloads firmware from NVIDIA servers automatically)")

    try:
        result = subprocess.run(
            cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=300
        )
        output = result.stdout + result.stderr
        print(output)

        if result.returncode == 0:
            return True, output
        elif "already updated" in output.lower() or "up to date" in output.lower():
            return True, "Firmware already up-to-date."
        else:
            return False, f"mlxup exit code {result.returncode}: {output}"
    except subprocess.TimeoutExpired:
        return False, "mlxup firmware update timed out (300s)"
    except Exception as e:
        return False, str(e)


def flash_tool_package(device_info):
    """Generic handler for tools distributed as .tar.gz packages.

    Works entirely from manifest fields — no hardcoded tool names:
      - tool_package: tarball filename (e.g. "nvmupdate_E810.tar.gz")
      - tool_package_sha256: integrity hash
      - flash_tool: binary name inside the tarball (e.g. "nvmupdate64e")
      - flash_command: command template to run from the extraction dir
        Supports placeholders: <extract_dir>, <device>, <bus_info>, <firmware_path>

    This means adding a new tarball-based tool only requires filling in
    the admin UI — no code changes.
    """
    import tarfile as tarfile_mod

    tool_package = device_info.get("tool_package", "")
    if not tool_package:
        return False, "No tool_package specified in manifest"

    flash_tool = device_info.get("flash_tool", "")
    flash_command = device_info.get("flash_command", "")
    expected_sha256 = device_info.get("tool_package_sha256", "")

    # Fail-closed: refuse to proceed without an integrity hash — tool packages run as root
    if not expected_sha256:
        print(f"  ✗ REFUSED: No tool_package_sha256 in manifest for '{tool_package}'.")
        print(f"    Tool packages without integrity hashes cannot be safely used.")
        print(f"    The manifest must include tool_package_sha256 for this entry.")
        return False, f"No tool_package_sha256 in manifest for '{tool_package}' — cannot safely proceed"

    if not flash_command:
        return False, f"No flash_command specified in manifest for tool_package '{tool_package}'"

    # Derive a stable extraction directory name from the package
    pkg_name = tool_package.replace(".tar.gz", "").replace(".tgz", "")
    extract_dir = f"/tmp/45drives-toolpkg-{pkg_name}"

    # Check if already extracted (re-use if the flash_tool binary exists there)
    tool_bin = os.path.join(extract_dir, flash_tool) if flash_tool else ""
    tarball_path = os.path.join("/tmp", tool_package)

    needs_extract = True
    if tool_bin and os.path.isfile(tool_bin):
        needs_extract = False
    elif not flash_tool:
        # No specific binary to check — just see if dir exists with files
        if os.path.isdir(extract_dir) and os.listdir(extract_dir):
            needs_extract = False

    # Even if extraction exists, verify tarball hash hasn't changed (prevents
    # reuse of stale extractions after a firmware/tool update on the repo).
    if not needs_extract and expected_sha256:
        if os.path.isfile(tarball_path):
            actual = compute_sha256(tarball_path)
            if actual != expected_sha256:
                print(f"  Cached extraction stale (tarball SHA256 changed) — re-extracting...")
                shutil.rmtree(extract_dir, ignore_errors=True)
                needs_extract = True
        else:
            # No tarball on disk to verify — cannot trust the extraction
            print(f"  No tarball on disk to verify extraction integrity — re-downloading...")
            shutil.rmtree(extract_dir, ignore_errors=True)
            needs_extract = True

    if needs_extract:
        # Need to download and extract

        if not os.path.isfile(tarball_path):
            if not confirm_download(f"Tool package '{tool_package}'"):
                return False, f"Tool package download was declined: {tool_package}"

            repo_url = get_repo_url()
            if not repo_url:
                return False, "No firmware repo configured — cannot download tool package"

            url = f"{repo_url}/tools/{quote(tool_package)}"
            print(f"  Downloading {tool_package} from repo...")
            print(f"  URL: {url}")

            try:
                req = Request(url, headers={"User-Agent": "45drives-firmware-flash/1.0"})
                response = urlopen(req, timeout=120)

                sha256_hash = hashlib.sha256()
                with open(tarball_path, "wb") as f:
                    while True:
                        chunk = response.read(65536)
                        if not chunk:
                            break
                        f.write(chunk)
                        sha256_hash.update(chunk)

                file_size = os.path.getsize(tarball_path)
                actual_sha256 = sha256_hash.hexdigest()
                print(f"  Downloaded: {tool_package} ({file_size / 1024 / 1024:.1f} MB)")
                print(f"  SHA256: {actual_sha256}")

                # Verify integrity against manifest hash
                if actual_sha256 != expected_sha256:
                    print(f"  ✗ INTEGRITY FAILURE: SHA256 mismatch!")
                    print(f"    Expected: {expected_sha256}")
                    print(f"    Got:      {actual_sha256}")
                    os.unlink(tarball_path)
                    return False, f"Integrity check failed for {tool_package}: SHA256 mismatch"
                print(f"  ✓ Integrity verified")

            except (HTTPError, URLError) as e:
                if os.path.exists(tarball_path):
                    os.unlink(tarball_path)
                return False, f"Failed to download {tool_package}: {e}"
            except Exception as e:
                if os.path.exists(tarball_path):
                    os.unlink(tarball_path)
                return False, f"Failed to download {tool_package}: {e}"
        else:
            # File exists locally — verify SHA256 if available
            if expected_sha256:
                actual = compute_sha256(tarball_path)
                if actual != expected_sha256:
                    print(f"  Local file SHA256 mismatch! Re-downloading...")
                    os.unlink(tarball_path)
                    return flash_tool_package(device_info)  # Retry with fresh download

        # Extract tarball with safe path validation
        print(f"  Extracting {tool_package} to {extract_dir}...")
        if os.path.exists(extract_dir):
            shutil.rmtree(extract_dir)
        os.makedirs(extract_dir)

        try:
            with tarfile_mod.open(tarball_path, "r:gz") as tar:
                # Validate all members before extraction
                for member in tar.getmembers():
                    member_path = os.path.normpath(member.name)
                    # Reject absolute paths
                    if member.name.startswith("/") or member.name.startswith("\\"):
                        shutil.rmtree(extract_dir, ignore_errors=True)
                        return False, f"Unsafe archive: absolute path '{member.name}'"
                    # Reject path traversal
                    if member_path.startswith("..") or "/../" in member.name:
                        shutil.rmtree(extract_dir, ignore_errors=True)
                        return False, f"Unsafe archive: path traversal in '{member.name}'"
                    # Reject symlinks/hardlinks escaping extraction dir
                    if member.issym() or member.islnk():
                        link_target = member.linkname
                        if link_target.startswith("/") or link_target.startswith(".."):
                            shutil.rmtree(extract_dir, ignore_errors=True)
                            return False, f"Unsafe archive: symlink escape '{member.name}' -> '{link_target}'"
                    # Ensure resolved path stays within extract_dir
                    abs_path = os.path.realpath(os.path.join(extract_dir, member_path))
                    if not abs_path.startswith(os.path.realpath(extract_dir)):
                        shutil.rmtree(extract_dir, ignore_errors=True)
                        return False, f"Unsafe archive: member '{member.name}' escapes extraction directory"

                # Safe to extract — only regular files and directories
                safe_members = [m for m in tar.getmembers() if m.isreg() or m.isdir()]
                tar.extractall(path=extract_dir, members=safe_members)
        except tarfile_mod.TarError as e:
            shutil.rmtree(extract_dir, ignore_errors=True)
            return False, f"Failed to extract {tool_package}: {e}"

        # Make all binaries in extraction dir executable
        for root, dirs, files in os.walk(extract_dir):
            for fname in files:
                fpath = os.path.join(root, fname)
                # Make executable if it looks like a binary (no extension or known tool)
                if not os.path.splitext(fname)[1] or fname == flash_tool:
                    os.chmod(fpath, 0o755)

    # Verify the flash tool binary exists (if specified)
    if flash_tool and tool_bin and not os.path.isfile(tool_bin):
        return False, f"Flash tool '{flash_tool}' not found in extracted package at {extract_dir}"

    # Build the command from flash_command template with placeholder substitution
    cmd = flash_command
    cmd = cmd.replace("<extract_dir>", shlex.quote(extract_dir))
    cmd = cmd.replace("<device>", device_info.get("device", ""))
    cmd = cmd.replace("<bus_info>", device_info.get("bus_info", ""))
    cmd = cmd.replace("<firmware_path>", shlex.quote(device_info.get("firmware_file", "")))
    cmd = cmd.replace("<sg_device>", device_info.get("sg_device", ""))

    # If the command starts with the tool name (no path prefix), add ./
    if flash_tool and cmd.strip().startswith(flash_tool) and not cmd.strip().startswith("./"):
        cmd = "./" + cmd.strip()

    # If the command doesn't have an explicit cd, prepend one to the extraction dir
    if not cmd.strip().startswith("cd "):
        cmd = f"cd {extract_dir} && {cmd}"

    print(f"  Running: {cmd}")

    try:
        result = subprocess.run(
            cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=600, cwd=extract_dir
        )
        output = result.stdout + result.stderr
        print(output)

        if result.returncode == 0:
            return True, output
        elif "no update available" in output.lower() or "up to date" in output.lower():
            return True, "Firmware already up-to-date."
        else:
            return False, f"{flash_tool or tool_package} exit code {result.returncode}: {output}"
    except subprocess.TimeoutExpired:
        return False, f"{flash_tool or tool_package} timed out (600s)"
    except Exception as e:
        return False, str(e)


# Keep backward-compat alias
flash_nvmupdate = flash_tool_package


def flash_device(device_info):
    """Flash a device based on its type and tool."""
    flash_tool = device_info.get("flash_tool", "")
    firmware_file = device_info.get("firmware_file", "")
    device_type = device_info.get("type", "")
    expected_sha256 = device_info.get("sha256", "")
    tool_package = device_info.get("tool_package", "")

    # If a tool_package is specified, use the generic tool-package handler
    # This is fully data-driven — works for any tarball-based tool without code changes
    if tool_package:
        return flash_tool_package(device_info)

    if not firmware_file:
        return False, "No firmware file specified in manifest"

    # Locate firmware file (reject absolute paths / traversal)
    firmware_file_norm = os.path.normpath(firmware_file).lstrip(os.sep)
    if firmware_file_norm.startswith("..") or firmware_file_norm in ("", "."):
        return False, f"Unsafe firmware_file path in manifest: {firmware_file}"
    firmware_file = firmware_file_norm

    firmware_path = os.path.join(FIRMWARE_DIR, firmware_file)
    if not os.path.isfile(firmware_path):
        # Try subdirectories by type
        for subdir in [device_type, ""]:
            candidate = os.path.join(FIRMWARE_DIR, subdir, firmware_file) if subdir else firmware_path
            if os.path.isfile(candidate):
                firmware_path = candidate
                break
        else:
            # Not found locally — ask permission then fetch from repo
            if not confirm_download(f"Firmware file '{firmware_file}'"):
                return False, f"Firmware file not found locally and download was declined: {firmware_file}"
            if not fetch_firmware_from_repo(firmware_file, firmware_path, expected_sha256):
                return False, f"Firmware file not found locally and could not be downloaded: {firmware_file}"
    else:
        # File exists locally — verify SHA256 if available
        if expected_sha256:
            actual = compute_sha256(firmware_path)
            if actual != expected_sha256:
                print(f"  ✗ Local file SHA256 mismatch! Expected {expected_sha256}, got {actual}")
                return False, f"Firmware file integrity check failed (SHA256 mismatch)"

    # Route to appropriate flash function
    if "SeaChest" in flash_tool:
        sg_device = device_info.get("sg_device", "")
        tool_sha256 = device_info.get("tool_binary_sha256", "")
        return flash_seachest(sg_device, firmware_path, tool_sha256)
    elif "storcli" in flash_tool:
        return flash_storcli(device_info, firmware_path)
    elif "niccli" in flash_tool:
        return flash_niccli(device_info, firmware_path)
    elif "mlxup" in flash_tool:
        return flash_mlxup(device_info, firmware_path)
    elif "nvme" in flash_tool:
        device_path = device_info.get("device", "")
        return flash_nvme(device_path, firmware_path)
    else:
        # Last resort: if there's a flash_command template, try running it generically
        flash_command = device_info.get("flash_command", "")
        if flash_command:
            cmd = flash_command.replace("<firmware_path>", firmware_path)
            cmd = cmd.replace("<device>", device_info.get("device", ""))
            cmd = cmd.replace("<bus_info>", device_info.get("bus_info", ""))
            cmd = cmd.replace("<sg_device>", device_info.get("sg_device", ""))
            print(f"  Running (generic): {cmd}")
            try:
                result = subprocess.run(
                    cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                    universal_newlines=True, timeout=600
                )
                output = result.stdout + result.stderr
                print(output)
                if result.returncode == 0:
                    return True, output
                return False, f"{flash_tool} exit code {result.returncode}: {output}"
            except subprocess.TimeoutExpired:
                return False, f"{flash_tool} timed out (600s)"
            except Exception as e:
                return False, str(e)
        return False, f"Unsupported flash tool: {flash_tool}"


def check_raid_membership(device_info):
    """Check if an HDD is part of a RAID virtual disk (not JBOD).
    
    Uses storcli to query all controllers and matches the target drive
    by serial number for unambiguous identification.
    
    Returns:
        "jbod"       — drive is in JBOD/passthrough mode (safe to flash)
        "raid"       — drive is a member of a RAID virtual disk (DO NOT flash)
        "rebuilding" — drive is rebuilding (absolutely DO NOT flash)
        "unknown"    — could not determine (storcli not available or drive not behind HBA)
    """
    serial = (device_info.get("serial") or "").strip()
    if not serial:
        # Can't identify the drive without a serial — fall through to unknown
        return "unknown"

    # Find storcli
    storcli_bin = None
    for candidate in ["/opt/45drives/tools/storcli64", "/opt/MegaRAID/storcli/storcli64",
                      "/usr/local/bin/storcli64", "/usr/bin/storcli64"]:
        if os.path.isfile(candidate):
            storcli_bin = candidate
            break
    if not storcli_bin:
        storcli_bin = shutil.which("storcli64")
    if not storcli_bin:
        # No storcli = not behind a RAID controller (direct SATA/AHCI)
        return "unknown"

    # Strategy: get all physical drives with their serial numbers and states
    # storcli /call/eall/sall show all J — gives per-drive detail including SN
    try:
        result = subprocess.run(
            [storcli_bin, "/call/eall/sall", "show", "all", "J"],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=60
        )
        if result.returncode != 0:
            # Fall back to text-based approach
            return _check_raid_text_fallback(storcli_bin, device_info)

        import json as _json
        data = _json.loads(result.stdout)

        # JSON structure (confirmed on live system):
        # Controllers[].Response Data contains pairs of keys per drive:
        #   "Drive /c0/e0/s0"  -> [{"EID:Slt":"0:0", "State":"JBOD", "DG":"-", ...}]
        #   "Drive /c0/e0/s0 - Detailed Information" -> {
        #       "Drive /c0/e0/s0 Device attributes": {"SN": "ZHZ...", "Model Number": ..., ...},
        #       ...
        #   }
        controllers = data.get("Controllers", [])
        for ctrl in controllers:
            resp_data = ctrl.get("Response Data", {})

            # Iterate drive summary keys (without "- Detailed Information")
            for key in list(resp_data.keys()):
                if not key.startswith("Drive /c") or "Detailed" in key:
                    continue

                # Get summary row: State, DG
                summary_list = resp_data.get(key, [])
                if not summary_list or not isinstance(summary_list, list):
                    continue
                summary = summary_list[0] if isinstance(summary_list[0], dict) else {}
                drive_state = str(summary.get("State", "")).strip()
                drive_dg = str(summary.get("DG", "")).strip()

                # Get serial from detailed info
                detail_key = f"{key} - Detailed Information"
                detail = resp_data.get(detail_key, {})
                if not isinstance(detail, dict):
                    continue

                # Serial is in the "Device attributes" sub-dict
                drive_serial = ""
                for dk, dv in detail.items():
                    if "Device attributes" in dk and isinstance(dv, dict):
                        drive_serial = str(dv.get("SN", "")).strip()
                        if not drive_serial:
                            drive_serial = str(dv.get("Serial Number", "")).strip()
                        break

                if drive_serial and drive_serial == serial:
                    return _interpret_drive_state(drive_state, drive_dg)

        # JSON didn't find it — try text fallback
        return _check_raid_text_fallback(storcli_bin, device_info)

    except (subprocess.TimeoutExpired, ValueError):
        return _check_raid_text_fallback(storcli_bin, device_info)
    except Exception:
        return "unknown"


def _check_raid_text_fallback(storcli_bin, device_info):
    """Fallback: use storcli /call/eall/sall show to find drive by serial in text output.
    
    The text table format includes serial in some storcli versions.
    If serial matching fails, try EID:Slt mapping via /call show + smartctl SAS address.
    """
    serial = (device_info.get("serial") or "").strip()

    try:
        # Try per-drive detail query to get serial + state
        result = subprocess.run(
            f"{storcli_bin} /call/eall/sall show",
            shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=30
        )
        if result.returncode != 0:
            return "unknown"

        # Parse the PD LIST table:
        # EID:Slt DID State DG Size Intf Med SED PI SeSz Model Sp
        # We need to find EID:Slt for our drive first, then query its serial
        # Build a list of candidate rows matching our model
        model = device_info.get("model", "")
        candidates = []  # [(eid_slt, state, dg, line)]

        for line in result.stdout.splitlines():
            # Skip non-data lines
            parts = line.split()
            if len(parts) < 8:
                continue
            # EID:Slt is like "64:0", first column
            if ":" not in parts[0]:
                continue
            try:
                eid_s, slt_s = parts[0].split(":", 1)
                int(eid_s)
                int(slt_s)
            except (ValueError, IndexError):
                continue

            state = parts[2]
            dg = parts[3]
            candidates.append((parts[0], state, dg))

        # For each candidate slot, query its serial to find the exact drive
        for eid_slt, state, dg in candidates:
            try:
                detail = subprocess.run(
                    f"{storcli_bin} /call/e{eid_slt.split(':')[0]}/s{eid_slt.split(':')[1]} show all",
                    shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                    universal_newlines=True, timeout=10
                )
                if serial and serial in detail.stdout:
                    return _interpret_drive_state(state, dg)
            except (subprocess.TimeoutExpired, Exception):
                continue

        return "unknown"

    except subprocess.TimeoutExpired:
        return "unknown"
    except Exception:
        return "unknown"


def _interpret_drive_state(state, dg):
    """Interpret storcli State and DG columns into a flash-safety verdict."""
    state = state.strip()
    dg = str(dg).strip()

    if state == "JBOD" or dg == "-":
        return "jbod"
    elif state in ("Rbld", "Rebuild"):
        return "rebuilding"
    elif state in ("Onln", "DHS", "GHS"):
        return "raid"
    elif dg != "-":
        # Any drive group assignment means RAID membership
        return "raid"
    return "unknown"


def main():
    parser = argparse.ArgumentParser(description="Flash firmware to a device")
    parser.add_argument("--cache-index", type=int, required=True, help="Index of device in cache to flash")
    parser.add_argument("--dry-run", action="store_true", help="Show what would be done without flashing")
    parser.add_argument("--preflight", action="store_true", help="Output JSON describing what will happen (downloads needed, etc.)")
    parser.add_argument("--allow-download", action="store_true", help="Allow downloading missing tools/firmware from repo (UI passes this after user confirms)")
    args = parser.parse_args()

    global ALLOW_DOWNLOAD
    ALLOW_DOWNLOAD = args.allow_download

    cache = load_cache()
    if not cache:
        return 1

    devices = cache.get("devices", [])
    if args.cache_index < 0 or args.cache_index >= len(devices):
        print(f"Error: Invalid cache index {args.cache_index} (have {len(devices)} devices)", file=sys.stderr)
        return 1

    device_info = devices[args.cache_index]

    if not device_info.get("flashable", False):
        print(f"Error: Device is not flashable: {device_info.get('model', 'unknown')}", file=sys.stderr)
        return 3

    if device_info.get("update_available") != "outdated":
        print(f"Warning: Device firmware is already current or unknown", file=sys.stderr)

    # Pre-flight: RAID membership check for HDDs behind HBA controllers
    if device_info.get("type") == "hdd" and "SeaChest" in device_info.get("flash_tool", ""):
        raid_status = check_raid_membership(device_info)
        if raid_status == "raid":
            print(f"✗ REFUSED: Drive is part of a RAID virtual disk.", file=sys.stderr)
            print(f"  Seagate strongly warns against firmware updates on RAID-member drives.", file=sys.stderr)
            print(f"  Remove the drive from the RAID array first, or use your RAID controller's tools.", file=sys.stderr)
            log_flash_event(device_info, False, "REFUSED: drive is RAID member")
            return 4
        elif raid_status == "rebuilding":
            print(f"✗ REFUSED: Drive is currently rebuilding in a RAID array.", file=sys.stderr)
            log_flash_event(device_info, False, "REFUSED: drive is rebuilding in RAID")
            return 4
        elif raid_status == "jbod":
            pass  # Safe to flash
        # "unknown" = couldn't determine, proceed with caution

    # Pre-flight: upgrade_from compatibility check (Seagate chain-flash protection)
    upgrade_from = device_info.get("upgrade_from", "")
    if upgrade_from:
        current_fw = device_info.get("current_firmware", "")
        compatible_prefixes = [p.strip() for p in upgrade_from.split(",") if p.strip()]
        if current_fw and not any(current_fw.startswith(prefix) for prefix in compatible_prefixes):
            print(f"✗ REFUSED: Current firmware '{current_fw}' is not compatible with this upgrade.", file=sys.stderr)
            print(f"  This firmware can only be applied to drives running: {', '.join(compatible_prefixes)}", file=sys.stderr)
            print(f"  Your drive may have an OEM firmware variant. Contact your vendor.", file=sys.stderr)
            log_flash_event(device_info, False, f"REFUSED: current FW '{current_fw}' not in upgrade_from list")
            return 5

    if args.dry_run:
        print(f"Would flash: {device_info.get('model', '')} ({device_info.get('device', '')})")
        print(f"  Current FW: {device_info.get('current_firmware', '')}")
        print(f"  Target FW:  {device_info.get('latest_firmware', '')}")
        print(f"  Tool: {device_info.get('flash_tool', '')}")
        print(f"  File: {device_info.get('firmware_file', '')}")
        return 0

    if args.preflight:
        # Output JSON describing what will happen
        preflight_info = {
            "model": device_info.get("model", ""),
            "device": device_info.get("device", ""),
            "type": device_info.get("type", ""),
            "current_firmware": device_info.get("current_firmware", ""),
            "latest_firmware": device_info.get("latest_firmware", ""),
            "flash_tool": device_info.get("flash_tool", ""),
            "firmware_file": device_info.get("firmware_file", ""),
            "downloads": [],
            "actions": [],
        }

        flash_tool = device_info.get("flash_tool", "")
        firmware_file = device_info.get("firmware_file", "")

        # Check if flash tool needs downloading
        if "SeaChest" in flash_tool:
            if not os.path.isfile(SEACHEST_PATH):
                found = False
                for path in ["/usr/local/bin/SeaChest_Firmware", "/opt/45drives/tools/SeaChest_Firmware", "/usr/bin/SeaChest_Firmware"]:
                    if os.path.isfile(path):
                        found = True
                        break
                if not found:
                    preflight_info["downloads"].append({"type": "tool", "name": "SeaChest_Firmware", "reason": "Flash tool not installed locally"})
        elif "storcli" in flash_tool:
            tool = flash_tool if flash_tool else "storcli64"
            for base in ("storcli64", "storcli"):
                if base in tool:
                    tool = base
                    break
            found = False
            for candidate in [f"/opt/45drives/tools/{tool}", f"/opt/MegaRAID/storcli/{tool}", f"/usr/local/bin/{tool}", f"/usr/bin/{tool}"]:
                if os.path.isfile(candidate):
                    found = True
                    break
            if not found:
                preflight_info["downloads"].append({"type": "tool", "name": tool, "reason": "Flash tool not installed locally"})
        elif "niccli" in flash_tool:
            found = False
            for path in ["/opt/niccli/niccli", "/usr/local/bin/niccli", "/usr/bin/niccli", "/opt/45drives/tools/niccli"]:
                if os.path.isfile(path):
                    found = True
                    break
            if not found and not shutil.which("niccli"):
                preflight_info["downloads"].append({"type": "tool", "name": "niccli", "reason": "Broadcom NIC flash tool not installed locally"})
        elif "mlxup" in flash_tool:
            found = False
            for path in ["/usr/bin/mlxup", "/usr/local/bin/mlxup", "/opt/45drives/tools/mlxup", "/root/nic-firmware-tools/mlxup"]:
                if os.path.isfile(path):
                    found = True
                    break
            if not found and not shutil.which("mlxup"):
                preflight_info["downloads"].append({"type": "tool", "name": "mlxup", "reason": "Mellanox/NVIDIA NIC flash tool not installed locally"})
        elif "nvmupdate" in flash_tool:
            tool_package = device_info.get("tool_package", "")
            if tool_package:
                family_name = tool_package.replace("nvmupdate_", "").replace(".tar.gz", "")
                extract_dir = f"/tmp/45drives-nvmupdate-{family_name}"
                nvmupdate_bin = os.path.join(extract_dir, "nvmupdate64e")
                nvmupdate_cfg = os.path.join(extract_dir, "nvmupdate.cfg")
                if not (os.path.isfile(nvmupdate_bin) and os.path.isfile(nvmupdate_cfg)):
                    preflight_info["downloads"].append({"type": "tool", "name": tool_package, "reason": "Intel NVM update package not cached locally"})

        # Check if firmware file needs downloading
        SELF_UPDATING_TOOLS = ["nvmupdate64e"]
        is_self_updating = any(t in flash_tool for t in SELF_UPDATING_TOOLS)
        if not is_self_updating and firmware_file:
            firmware_path = os.path.join(FIRMWARE_DIR, firmware_file)
            if not os.path.isfile(firmware_path):
                # Check subdirectories
                device_type = device_info.get("type", "")
                found = False
                for subdir in [device_type, ""]:
                    candidate = os.path.join(FIRMWARE_DIR, subdir, firmware_file) if subdir else firmware_path
                    if os.path.isfile(candidate):
                        found = True
                        break
                if not found:
                    preflight_info["downloads"].append({"type": "firmware", "name": firmware_file, "reason": "Firmware file not cached locally"})

        # Describe actions
        if preflight_info["downloads"]:
            for dl in preflight_info["downloads"]:
                preflight_info["actions"].append(f"Download {dl['name']} from firmware repo")

        if "mlxup" in flash_tool:
            preflight_info["actions"].append(f"Flash {device_info.get('model', '')} using mlxup (pinned firmware image from 45Drives repo)")
        elif "nvmupdate" in flash_tool:
            preflight_info["actions"].append(f"Flash {device_info.get('model', '')} using nvmupdate64e (auto-detects NIC and applies correct firmware)")
        else:
            preflight_info["actions"].append(f"Flash {device_info.get('model', '')} using {flash_tool}")
        if device_info.get("requires_reboot"):
            preflight_info["actions"].append("Reboot required after flash to activate new firmware")

        print(json.dumps(preflight_info))
        return 0

    print(f"Flashing: {device_info.get('model', '')} ({device_info.get('device', '')})")
    print(f"  Current: {device_info.get('current_firmware', '')} -> Target: {device_info.get('latest_firmware', '')}")

    success, message = flash_device(device_info)
    log_flash_event(device_info, success, message)

    if success:
        print(f"\n✓ Flash successful")
        # Post-flash verification: re-read firmware version from device
        verify_firmware_after_flash(device_info)
        return 0
    else:
        print(f"\n✗ Flash failed: {message}", file=sys.stderr)
        return 2


def resolve_storcli_path(flash_tool):
    """Resolve the full path for storcli/storcli64, matching flash logic."""
    tool = flash_tool if flash_tool else "storcli64"
    # Extract the base tool name (e.g. "storcli64" from "storcli64")
    for base in ("storcli64", "storcli"):
        if base in tool:
            tool = base
            break
    for candidate in [
        f"/opt/45drives/tools/{tool}",
        f"/opt/MegaRAID/storcli/{tool}",
        f"/usr/local/bin/{tool}",
        f"/usr/bin/{tool}",
    ]:
        if os.path.isfile(candidate):
            return candidate
    # Fallback to PATH lookup
    found = shutil.which(tool)
    return found if found else tool


def verify_firmware_after_flash(device_info):
    """Re-read firmware version after flash to confirm update took effect.
    
    This is best-effort logging only — the return value does NOT affect the
    exit code of the script. A mismatch here does NOT mean the flash failed;
    most devices need a reboot/power-cycle before the new firmware is reported.
    
    Returns:
        "verified"  — new firmware version confirmed active
        "pending"   — device still reports old version (needs power cycle)
        "skipped"   — could not read version (unsupported device type or error)
    """
    device_type = device_info.get("type", "")
    expected_fw = device_info.get("latest_firmware", "")
    device_path = device_info.get("device", "")
    flash_tool = device_info.get("flash_tool", "")

    if not expected_fw:
        return "skipped"

    print(f"\n── Post-flash verification (best-effort) ──")
    actual_fw = None

    try:
        if device_type == "hdd" and device_path:
            # Use smartctl to re-read firmware
            result = subprocess.run(
                f"smartctl -i {device_path}",
                shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                universal_newlines=True, timeout=30
            )
            if result.returncode in (0, 4):
                for line in result.stdout.splitlines():
                    if ":" in line:
                        key, _, val = line.partition(":")
                        key = key.strip()
                        if key in ("Firmware Version", "Revision"):
                            actual_fw = val.strip()
                            break

        elif device_type == "hba" and "storcli" in flash_tool:
            # Use the same key that firmware-check writes and flash_storcli reads
            ctrl_idx = device_info.get("controller_index", 0)
            # Resolve the actual storcli binary path (same logic as flash_storcli)
            storcli_path = resolve_storcli_path(flash_tool)
            result = subprocess.run(
                f"{storcli_path} /c{ctrl_idx} show all",
                shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                universal_newlines=True, timeout=30
            )
            if result.returncode == 0:
                for line in result.stdout.splitlines():
                    if "FW Package Build" in line or "Firmware Version" in line:
                        parts = line.split("=")
                        if len(parts) >= 2:
                            actual_fw = parts[-1].strip()
                            break

        elif device_type == "nvme" and device_path:
            # Use nvme id-ctrl to re-read firmware
            result = subprocess.run(
                f"nvme id-ctrl {device_path} -o json",
                shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                universal_newlines=True, timeout=30
            )
            if result.returncode == 0:
                try:
                    ctrl_info = json.loads(result.stdout)
                    actual_fw = ctrl_info.get("fr", "").strip()
                except json.JSONDecodeError:
                    pass

        elif device_type == "nic" and "niccli" in flash_tool:
            # niccli verification — read adapter firmware version
            # Use same binary search order as flash_niccli()
            bus_info = device_info.get("bus_info", "")
            if bus_info:
                niccli_bin = None
                for path in ["/opt/niccli/niccli", "/usr/local/bin/niccli", "/usr/bin/niccli", "/opt/45drives/tools/niccli"]:
                    if os.path.isfile(path):
                        niccli_bin = path
                        break
                if not niccli_bin:
                    niccli_bin = shutil.which("niccli")
                if niccli_bin:
                    result = subprocess.run(
                        f"{niccli_bin} -i all info",
                        shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                        universal_newlines=True, timeout=30
                    )
                    if result.returncode == 0:
                        # Normalize PCI bus for matching (same logic as flash_niccli)
                        short_bus = bus_info
                        if short_bus.startswith("0000:"):
                            short_bus = short_bus[5:]
                        bus_slot = short_bus.rsplit(".", 1)[0] if "." in short_bus else short_bus

                        # Look for the adapter matching our bus_info
                        in_our_adapter = False
                        for line in result.stdout.splitlines():
                            line_lower = line.lower()
                            if (bus_info.lower() in line_lower or
                                short_bus.lower() in line_lower or
                                bus_slot.lower() in line_lower):
                                in_our_adapter = True
                            elif in_our_adapter and "firmware version" in line_lower:
                                parts = line.split(":")
                                if len(parts) >= 2:
                                    actual_fw = parts[-1].strip()
                                break

        elif "mlxup" in flash_tool:
            # mlxup self-updating tools — read version via mlxup query
            bus_info = device_info.get("bus_info", "")
            mlxup_bin = None
            for path in ["/usr/bin/mlxup", "/usr/local/bin/mlxup", "/opt/45drives/tools/mlxup"]:
                if os.path.isfile(path):
                    mlxup_bin = path
                    break
            if not mlxup_bin:
                mlxup_bin = shutil.which("mlxup")
            if mlxup_bin:
                cmd = f"{mlxup_bin} --query" + (f" --dev {bus_info}" if bus_info else "")
                result = subprocess.run(
                    cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                    universal_newlines=True, timeout=30
                )
                if result.returncode == 0:
                    for line in result.stdout.splitlines():
                        if "FW Version" in line or "fw_ver" in line.lower():
                            parts = line.split()
                            if parts:
                                actual_fw = parts[-1].strip()
                            break

        else:
            print(f"  ⚠ No verification method for type={device_type}, tool={flash_tool}")
            return "skipped"

        # Report result
        if actual_fw is None:
            print(f"  ⚠ Could not re-read firmware version (device may need power cycle)")
            return "skipped"
        elif actual_fw == expected_fw:
            print(f"  ✓ Verified: firmware is now {actual_fw}")
            return "verified"
        else:
            print(f"  ⚠ Device still reports: {actual_fw} (expected: {expected_fw})")
            print(f"    This is normal — most devices need a reboot/power-cycle to activate new firmware.")
            return "pending"

    except Exception as e:
        print(f"  ⚠ Verification skipped: {e}")
        return "skipped"


if __name__ == "__main__":
    sys.exit(main())
