#!/usr/libexec/platform-python
"""
firmware-discover - Scans the system for all hardware components and reports
their current firmware versions. Outputs JSON to stdout.

Used by: 45d-firmware-check (comparison against manifest)
         SectionFirmware.vue (display in UI)

Exit codes:
    0 = success
    1 = error
"""

import json
import subprocess
import sys
import re
import os
import shutil
from pathlib import Path


def run_cmd(cmd, timeout=30):
    """Run a command and return stdout, stderr, returncode.

    If cmd is a list, runs with shell=False (safe from injection).
    If cmd is a string, runs with shell=True (needed for pipes/redirections).
    Prefer passing a list whenever possible.
    """
    try:
        use_shell = isinstance(cmd, str)
        result = subprocess.run(
            cmd, shell=use_shell, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
            universal_newlines=True, timeout=timeout
        )
        return result.stdout.strip(), result.stderr.strip(), result.returncode
    except subprocess.TimeoutExpired:
        return "", "timeout", -1
    except Exception as e:
        return "", str(e), -1


def discover_hdds():
    """Discover HDD firmware using smartctl."""
    devices = []
    
    # Get list of block devices
    stdout, _, rc = run_cmd("lsblk -d -n -o NAME,TYPE | grep disk")
    if rc != 0 or not stdout:
        return devices

    for line in stdout.splitlines():
        parts = line.split()
        if not parts:
            continue
        name = parts[0]
        dev_path = f"/dev/{name}"

        # Skip nvme devices (handled separately)
        if name.startswith("nvme"):
            continue

        # Get smartctl info
        stdout, _, rc = run_cmd(["smartctl", "-i", dev_path])
        if rc not in (0, 4) or not stdout:  # rc=4 means SMART disabled but info available
            continue

        info = {}
        for field_line in stdout.splitlines():
            if ":" in field_line:
                key, _, val = field_line.partition(":")
                info[key.strip()] = val.strip()

        # SATA drives use "Device Model" / "Firmware Version"
        # SAS drives use "Product" / "Revision"
        model = info.get("Device Model", "") or info.get("Product", "")
        serial = info.get("Serial Number", "") or info.get("Serial number", "")
        firmware = info.get("Firmware Version", "") or info.get("Revision", "")
        family = info.get("Model Family", "")
        capacity = info.get("User Capacity", "")
        sata_ver = info.get("SATA Version is", "")

        if not model or not firmware:
            continue

        # Determine interface
        interface = "SATA"
        if "SAS" in info.get("Transport protocol", ""):
            interface = "SAS"
        elif not sata_ver:
            # Check if it's SAS via other means
            if "SAS" in stdout:
                interface = "SAS"

        # Extract part number from model (e.g., ST16000NM000J-2TW103 -> 2TW103)
        part_match = ""
        if "-" in model:
            part_match = model.split("-", 1)[1]

        # Determine sector format from Sector Sizes field
        # "512 bytes logical, 4096 bytes physical" = 512E (emulated)
        # "4096 bytes logical, 4096 bytes physical" = 4kN (native 4K)
        # "512 bytes logical, 512 bytes physical"   = 512N (native 512)
        sector_format = ""
        sector_info = info.get("Sector Sizes", "") or info.get("Sector Size", "")
        if sector_info:
            logical = 0
            physical = 0
            # Parse "512 bytes logical, 4096 bytes physical"
            parts = sector_info.lower().split(",")
            for part in parts:
                part = part.strip()
                if "logical" in part:
                    try:
                        logical = int(part.split()[0])
                    except (ValueError, IndexError):
                        pass
                elif "physical" in part:
                    try:
                        physical = int(part.split()[0])
                    except (ValueError, IndexError):
                        pass
            # Single value format: "512 bytes logical/physical"
            if not physical and logical:
                physical = logical
            elif not logical and not physical:
                # Try single value: "512 bytes"
                try:
                    logical = physical = int(sector_info.split()[0])
                except (ValueError, IndexError):
                    pass

            if logical == 512 and physical == 4096:
                sector_format = "512E"
            elif logical == 4096 and physical == 4096:
                sector_format = "4kN"
            elif logical == 512 and physical == 512:
                sector_format = "512N"

        # Find sg device for SeaChest
        sg_device = ""
        try:
            sg_list = os.listdir("/sys/class/scsi_generic/")
        except OSError:
            sg_list = []
        if sg_list:
            for sg in sg_list:
                sg_path = f"/sys/class/scsi_generic/{sg}/device/block"
                if os.path.exists(sg_path):
                    blocks = os.listdir(sg_path)
                    if name in blocks:
                        sg_device = f"/dev/{sg}"
                        break

        # Get slot from vdev mapping
        slot = ""
        vdev_path = f"/dev/disk/by-vdev/"
        if os.path.isdir(vdev_path):
            try:
                for entry in os.listdir(vdev_path):
                    link = os.readlink(os.path.join(vdev_path, entry))
                    if name in link and "-part" not in entry:
                        slot = entry
                        break
            except OSError:
                pass

        devices.append({
            "type": "hdd",
            "device": dev_path,
            "sg_device": sg_device,
            "slot": slot,
            "model": model.split("-")[0] if "-" in model else model,
            "model_full": model,
            "part_number": part_match,
            "serial": serial,
            "firmware": firmware,
            "family": family,
            "interface": interface,
            "sector_format": sector_format,
            "capacity": capacity,
        })

    return devices


def discover_nvme():
    """Discover NVMe firmware using nvme-cli."""
    devices = []

    stdout, _, rc = run_cmd("nvme list -o json 2>/dev/null")
    if rc != 0 or not stdout:
        # Fallback: try lsblk for nvme devices
        stdout, _, rc = run_cmd("lsblk -d -n -o NAME | grep nvme")
        if rc != 0 or not stdout:
            return devices
        for line in stdout.splitlines():
            name = line.strip()
            dev_path = f"/dev/{name}"
            info_out, _, _ = run_cmd(["smartctl", "-i", dev_path])
            if not info_out:
                continue
            info = {}
            for field_line in info_out.splitlines():
                if ":" in field_line:
                    key, _, val = field_line.partition(":")
                    info[key.strip()] = val.strip()
            model = info.get("Model Number", "")
            serial = info.get("Serial Number", "")
            firmware = info.get("Firmware Version", "")
            if model and firmware:
                devices.append({
                    "type": "nvme",
                    "device": dev_path,
                    "model": model,
                    "serial": serial,
                    "firmware": firmware,
                })
        return devices

    try:
        nvme_data = json.loads(stdout)
        nvme_list = nvme_data.get("Devices", [])
        for dev in nvme_list:
            devices.append({
                "type": "nvme",
                "device": dev.get("DevicePath", ""),
                "model": dev.get("ModelNumber", "").strip(),
                "serial": dev.get("SerialNumber", "").strip(),
                "firmware": dev.get("Firmware", "").strip(),
            })
    except json.JSONDecodeError:
        pass

    return devices


def discover_hba():
    """Discover HBA/RAID controller firmware using storcli64.
    Falls back to lspci detection if storcli64 is not available."""
    devices = []

    # Try storcli64 at known paths
    for tool in ["/opt/45drives/tools/storcli64", "storcli64", "/opt/MegaRAID/storcli/storcli64", "/usr/local/bin/storcli64"]:
        stdout, _, rc = run_cmd([tool, "/call", "show", "all", "J"])
        if rc == 0 and stdout:
            try:
                data = json.loads(stdout)
                controllers = data.get("Controllers", [])
                for idx, ctrl in enumerate(controllers):
                    resp = ctrl.get("Response Data", {})
                    if not resp:
                        continue
                    basics = resp.get("Basics", {})
                    model = basics.get("Model", "")
                    serial = basics.get("Serial Number", "")
                    ctrl_idx = basics.get("Controller", idx)
                    fw = resp.get("Version", {}).get("Firmware Package Build", "")
                    if not fw:
                        fw = resp.get("Version", {}).get("Firmware Version", "")
                    # Extract PCI bus address (storcli format "DD:BB:DD:FF" → "DDDD:BB:DD.F")
                    bus_info = ""
                    pci_raw = basics.get("PCI Address", "")
                    if pci_raw:
                        parts = pci_raw.split(":")
                        if len(parts) == 4:
                            domain, bus, dev, func = parts
                            bus_info = f"{int(domain, 16):04x}:{bus}:{dev}.{int(func, 16)}"
                    # IT-mode HBAs (9305, 9400 in IT mode) report 00.00.00.00 via storcli.
                    # Fall back to sysfs scsi_host version_fw matched by PCI address.
                    if (not fw or fw == "00.00.00.00") and bus_info:
                        try:
                            import glob
                            for host_dir in glob.glob("/sys/class/scsi_host/host*"):
                                host_pci = os.path.realpath(os.path.join(host_dir, "device"))
                                if bus_info in host_pci:
                                    fw_path = os.path.join(host_dir, "version_fw")
                                    if os.path.isfile(fw_path):
                                        sysfs_fw = open(fw_path).read().strip()
                                        if sysfs_fw and sysfs_fw != "00.00.00.00":
                                            fw = sysfs_fw
                                            break
                        except (OSError, IOError):
                            pass

                    if model:
                        devices.append({
                            "type": "hba",
                            "model": model,
                            "serial": serial,
                            "firmware": fw,
                            "tool": tool,
                            "controller_index": ctrl_idx,
                            "bus_info": bus_info,
                        })
            except (json.JSONDecodeError, KeyError):
                pass
            if devices:
                break

    # Fallback: use lspci to detect RAID/SAS controllers if storcli not available
    if not devices:
        stdout, _, rc = run_cmd("lspci -D")
        if rc == 0 and stdout:
            for line in stdout.strip().split("\n"):
                if any(kw in line for kw in ["RAID", "SAS", "MegaRAID"]):
                    # Extract PCI address and description
                    parts = line.split(" ", 1)
                    pci_addr = parts[0] if parts else ""
                    desc = parts[1] if len(parts) > 1 else ""
                    # Try to extract model name (after the colon)
                    model = ""
                    if ": " in desc:
                        model = desc.split(": ", 1)[1].strip()
                    # Skip if it looks like an expander or non-HBA device
                    if "Expander" in model:
                        continue
                    if model:
                        devices.append({
                            "type": "hba",
                            "model": model,
                            "serial": "",
                            "firmware": "",
                            "tool": "",
                            "controller_index": 0,
                            "bus_info": pci_addr,
                            "note": "storcli64 not available — firmware version unknown",
                        })

    return devices


def discover_nic():
    """Discover NIC firmware using ethtool + lspci for model info."""
    devices = []

    stdout, _, rc = run_cmd("ls /sys/class/net/ | grep -v lo")
    if rc != 0 or not stdout:
        return devices

    seen_slots = set()  # Deduplicate by PCI slot (base address without function)
    for iface in stdout.splitlines():
        iface = iface.strip()
        if not iface:
            continue

        # Skip virtual/bridge/vlan interfaces
        if os.path.exists(f"/sys/class/net/{iface}/bridge") or \
           os.path.exists(f"/sys/class/net/{iface}/bonding") or \
           not os.path.exists(f"/sys/class/net/{iface}/device"):
            continue

        # Get driver info
        drv_out, _, _ = run_cmd(["ethtool", "-i", iface])
        if not drv_out:
            continue

        info = {}
        for field_line in drv_out.splitlines():
            if ":" in field_line:
                key, _, val = field_line.partition(":")
                info[key.strip()] = val.strip()

        firmware = info.get("firmware-version", "")
        driver = info.get("driver", "")
        bus = info.get("bus-info", "")

        if not firmware or not driver or not bus or bus == "N/A":
            continue

        # Deduplicate multi-port NICs by PCI slot (strip function .X suffix)
        slot_base = bus.rsplit(".", 1)[0] if "." in bus else bus
        if slot_base in seen_slots:
            continue
        seen_slots.add(slot_base)

        # Get model from lspci
        model = ""
        lspci_out, _, _ = run_cmd(["lspci", "-s", bus])
        if lspci_out:
            # Format: "42:00.0 Ethernet controller: Broadcom Inc. ... BCM57416 ..."
            parts = lspci_out.split(":", 2)
            if len(parts) >= 3:
                model = parts[2].strip()

        # Also try to get part number from lspci VPD
        part_number = ""
        vpd_out, _, _ = run_cmd(["lspci", "-s", bus, "-vv"])
        if vpd_out:
            for vpd_line in vpd_out.splitlines():
                if "part number" in vpd_line.lower():
                    part_number = vpd_line.split(":", 1)[-1].strip() if ":" in vpd_line else ""
                    break

        # Extract short model identifier (e.g., BCM57416, ConnectX-6)
        short_model = model
        # Try to find common chip identifiers
        chip_match = re.search(r'(BCM\d+|ConnectX-\d+|MCX\d+|X[57]10|E810|XXV710|XL710|I350|I210)', model)
        if chip_match:
            short_model = chip_match.group(1)

        # Normalize firmware version for Broadcom NICs
        # ethtool reports "224.0.158.0/pkg 224.1.102.0" — extract package version
        fw_normalized = firmware
        if driver == "bnxt_en" and "/pkg " in firmware:
            fw_normalized = firmware.split("/pkg ")[-1].strip()
        # Normalize firmware version for Mellanox/NVIDIA NICs
        # ethtool reports "20.31.1014 (MT_0000000225)" — extract version before parenthetical
        elif driver in ("mlx5_core", "mlx4_core") and "(" in firmware:
            fw_normalized = firmware.split("(")[0].strip()

        devices.append({
            "type": "nic",
            "device": iface,
            "interface": iface,
            "model": short_model,
            "model_full": model,
            "part_number": part_number,
            "serial": "",
            "driver": driver,
            "firmware": fw_normalized,
            "firmware_raw": firmware,
            "bus_info": bus,
        })

    return devices


def discover_bios():
    """Discover BIOS/UEFI version using dmidecode."""
    devices = []

    stdout, _, rc = run_cmd("dmidecode -t bios 2>/dev/null")
    if rc != 0 or not stdout:
        return devices

    info = {}
    for line in stdout.splitlines():
        line = line.strip()
        if ":" in line:
            key, _, val = line.partition(":")
            info[key.strip()] = val.strip()

    vendor = info.get("Vendor", "")
    version = info.get("Version", "")
    release_date = info.get("Release Date", "")

    if version:
        devices.append({
            "type": "bios",
            "vendor": vendor,
            "firmware": version,
            "release_date": release_date,
        })

    return devices


def discover_bmc():
    """Discover BMC/IPMI firmware version using ipmitool."""
    devices = []

    ipmitool = shutil.which("ipmitool")
    if not ipmitool:
        return devices

    stdout, _, rc = run_cmd("ipmitool mc info 2>/dev/null")
    if rc != 0 or not stdout:
        return devices

    fw = ""
    mfg = "Unknown"
    for line in stdout.splitlines():
        line = line.strip()
        if "Firmware Revision" in line:
            fw = line.split(":", 1)[-1].strip()
        elif "Manufacturer Name" in line:
            mfg = line.split(":", 1)[-1].strip()

    if fw:
        devices.append({
            "type": "bmc",
            "model": "BMC/IPMI",
            "vendor": mfg,
            "firmware": fw,
        })

    return devices


def main():
    results = {
        "timestamp": "",
        "hostname": "",
        "devices": [],
    }

    # Get timestamp
    from datetime import datetime
    results["timestamp"] = datetime.utcnow().isoformat() + "Z"

    # Get hostname
    stdout, _, _ = run_cmd("hostname")
    results["hostname"] = stdout

    # Discover all components
    results["devices"].extend(discover_hdds())
    results["devices"].extend(discover_nvme())
    results["devices"].extend(discover_hba())
    results["devices"].extend(discover_nic())
    results["devices"].extend(discover_bios())
    results["devices"].extend(discover_bmc())

    print(json.dumps(results, indent=2))
    return 0


if __name__ == "__main__":
    sys.exit(main())
