#!/usr/libexec/platform-python
"""
45d-firmware-discover

Discovers all permanent server devices and outputs JSON:
  [{type, model, vendor, firmware, device_path}, ...]

Leverages the existing 45drives-system 'pci' script for HBA and NIC detection
(it already parses storcli, lspci, network helpers, and firmware versions).
Adds BIOS, BMC, and SAS Expander discovery on top.

Usage:
  firmware-discover            # compact JSON to stdout
  firmware-discover --pretty   # pretty-printed JSON
"""

import json
import os
import re
import shutil
import subprocess
import sys

devices = []

# Path to the existing pci script that already detects HBAs + NICs with firmware
PCI_SCRIPT_PATHS = [
    "/usr/share/cockpit/45drives-system/scripts/pci",
    os.path.join(os.path.dirname(os.path.abspath(__file__)), "pci"),
]


def run(cmd, timeout=30):
    """Run command, return stdout or empty string on failure."""
    try:
        r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                           universal_newlines=True, timeout=timeout)
        return r.stdout.strip() if r.returncode == 0 else ""
    except Exception:
        return ""


def run_any(cmd, timeout=30):
    """Run command, return stdout even on non-zero exit."""
    try:
        r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                           universal_newlines=True, timeout=timeout)
        return r.stdout.strip()
    except Exception:
        return ""


def add(dtype, model, vendor, firmware, device_path):
    if firmware and firmware.lower() not in ("", "n/a", "-"):
        devices.append({
            "type": dtype,
            "model": model,
            "vendor": vendor,
            "firmware": firmware,
            "device_path": device_path,
        })


###############################################################################
# HBA + NIC via existing pci script
###############################################################################
def discover_pci_devices():
    """
    Use the existing 45drives-system 'pci' script which already detects:
      - HBAs (via storcli64/storcli2 + lspci + server_info.json)
      - Network Cards (via network helper)
      - Firmware versions for HBAs
    Output format: [{slot, type, availibility, busAddress, cardType, cardModel, firmwareVersion}]
    """
    pci_script = None
    for p in PCI_SCRIPT_PATHS:
        if os.path.isfile(p):
            pci_script = p
            break

    if not pci_script:
        # Fallback: do manual HBA discovery
        discover_hba_fallback()
        return

    output = run(["python3", pci_script], timeout=60)
    if not output:
        discover_hba_fallback()
        return

    try:
        pci_data = json.loads(output)
    except (json.JSONDecodeError, ValueError):
        discover_hba_fallback()
        return

    for slot in pci_data:
        card_type = slot.get("cardType", "-")
        model = slot.get("cardModel", "-")
        fw = slot.get("firmwareVersion", "-")
        bus = slot.get("busAddress", "")

        if card_type == "HBA" and fw and fw != "-":
            # Determine vendor from model name
            vendor = "Broadcom"
            if "intel" in model.lower():
                vendor = "Intel"
            add("hba", model, vendor, fw, bus)

        elif card_type == "Network Card" and fw and fw != "-":
            vendor = "Unknown"
            if "mellanox" in model.lower() or "connectx" in model.lower():
                vendor = "NVIDIA/Mellanox"
            elif "intel" in model.lower() or "x710" in model.lower() or "e810" in model.lower():
                vendor = "Intel"
            elif "broadcom" in model.lower() or "bcm" in model.lower():
                vendor = "Broadcom"
            add("nic", model, vendor, fw, bus)


def discover_hba_fallback():
    """Fallback HBA discovery if pci script is not available."""
    storcli_paths = [
        "/opt/45drives/tools/storcli64",
        "/opt/45drives/tools/storcli2",
        "/opt/45drives/bin/storcli64",
    ]
    for name in ("storcli64", "storcli2"):
        p = shutil.which(name)
        if p and p not in storcli_paths:
            storcli_paths.append(p)

    found = False
    for storcli in storcli_paths:
        if not os.path.isfile(storcli) or not os.access(storcli, os.X_OK):
            continue

        # Try JSON mode first for PCI address extraction
        json_output = run_any([storcli, "/call", "show", "all", "J"], timeout=60)
        if json_output:
            try:
                data = json.loads(json_output)
                controllers = data.get("Controllers", [])
                for idx, ctrl in enumerate(controllers):
                    resp = ctrl.get("Response Data", {})
                    if not resp:
                        continue
                    basics = resp.get("Basics", {})
                    model = basics.get("Model", "")
                    fw = resp.get("Version", {}).get("Firmware Package Build", "")
                    if not fw:
                        fw = resp.get("Version", {}).get("Firmware Version", "")
                    # Extract PCI address (storcli format "DD:BB:DD:FF" → "DDDD:BB:DD.F")
                    bus_info = ""
                    pci_raw = basics.get("PCI Address", "")
                    if pci_raw:
                        parts = pci_raw.split(":")
                        if len(parts) == 4:
                            domain, bus, dev, func = parts
                            bus_info = "{:04x}:{}:{}.{}".format(
                                int(domain, 16), bus, dev, int(func, 16))
                    if model and fw:
                        add("hba", model, "Broadcom", fw, bus_info or "/c{}".format(idx))
                        found = True
                if found:
                    break
            except (json.JSONDecodeError, ValueError, KeyError):
                pass

        if found:
            break

        # Fallback to text parsing if JSON mode didn't work
        output = run_any([storcli, "/call", "show"], timeout=60)
        if not output or "No Controller found" in output:
            continue

        sections = re.split(r'^Controller\s*=\s*(\d+)', output, flags=re.MULTILINE)
        i = 1
        while i < len(sections):
            ctrl_id = sections[i].strip()
            ctrl_data = sections[i + 1] if i + 1 < len(sections) else ""
            i += 2

            model = ""
            fw = ""
            for line in ctrl_data.splitlines():
                if "Product Name" in line and "=" in line:
                    model = line.split("=", 1)[-1].strip()
                elif "FW Version" in line and "=" in line and not fw:
                    fw = line.split("=", 1)[-1].strip()
                elif "FW Package Build" in line and "=" in line and not fw:
                    fw = line.split("=", 1)[-1].strip()

            if model and fw:
                add("hba", model, "Broadcom", fw, "/c{}".format(ctrl_id))
                found = True

        if found:
            break


###############################################################################
# NIC firmware via ethtool (fallback if pci script doesn't report NICs)
###############################################################################
def discover_nic_fallback():
    """Only runs if no NICs were found via pci script."""
    if any(d["type"] == "nic" for d in devices):
        return  # Already found NICs via pci script

    # Mellanox
    lspci_out = run(["lspci"])
    if re.search(r"Mellanox|ConnectX|NVIDIA.*Ethernet", lspci_out, re.IGNORECASE):
        mlxup = None
        for p in ["/root/nic-firmware-tools/mlxup"]:
            if os.path.isfile(p) and os.access(p, os.X_OK):
                mlxup = p
                break
        if not mlxup:
            mlxup = shutil.which("mlxup") or shutil.which("mlxfwmanager")

        if mlxup:
            output = run_any([mlxup, "--query"], timeout=60)
            if output:
                blocks = re.split(r"Device #\d+", output)
                for block in blocks[1:]:
                    model = ""
                    fw = ""
                    pci = ""
                    for line in block.splitlines():
                        if "Part Number" in line:
                            model = line.split(":")[-1].strip().rstrip("_Ax")
                        elif "Description" in line and not model:
                            model = line.split(":")[-1].strip()
                        elif re.search(r"Current FW|FW Version", line):
                            fw = line.split(":")[-1].strip()
                        elif re.search(r"PCI Device Name|Device Name", line):
                            pci = line.split(":", 1)[-1].strip()
                    if model and fw:
                        add("nic", model, "NVIDIA/Mellanox", fw, pci)
        else:
            add("nic", "Mellanox NIC (tool not found)", "NVIDIA/Mellanox", "unknown", "mlxup missing")

    # Broadcom (via niccli or ethtool)
    if re.search(r"Broadcom.*Ethernet", lspci_out, re.IGNORECASE):
        if shutil.which("niccli"):
            output = run_any(["niccli", "--list"], timeout=30)
            if output and "BCM" in output.upper():
                seen_slots = set()
                for line in output.splitlines():
                    if "BCM" not in line.upper():
                        continue
                    parts = line.split()
                    if len(parts) < 5:
                        continue
                    model_match = re.search(r"BCM\d+", line)
                    model = model_match.group(0) if model_match else ""
                    fw = parts[3] if len(parts) > 3 else ""
                    pci_addr = parts[4] if len(parts) > 4 else ""
                    slot_key = pci_addr.rsplit(".", 1)[0] if pci_addr else ""
                    if not model or slot_key in seen_slots:
                        continue
                    seen_slots.add(slot_key)
                    add("nic", model, "Broadcom", fw, pci_addr)
        else:
            add("nic", "Broadcom NIC (tool not found)", "Broadcom", "unknown", "niccli missing")

    # Intel via ethtool
    if re.search(r"Intel.*Ethernet", lspci_out, re.IGNORECASE) and shutil.which("ethtool"):
        seen_slots = set()
        net_dir = "/sys/class/net"
        if os.path.isdir(net_dir):
            for iface in sorted(os.listdir(net_dir)):
                device_link = os.path.join(net_dir, iface, "device")
                if not os.path.islink(device_link):
                    continue
                if os.path.isdir("/sys/devices/virtual/net/{}".format(iface)):
                    continue
                vendor_file = os.path.join(net_dir, iface, "device", "vendor")
                try:
                    with open(vendor_file) as f:
                        vendor_id = f.read().strip()
                except (OSError, IOError):
                    continue
                if vendor_id != "0x8086":
                    continue
                pci_addr = os.path.basename(os.path.realpath(device_link))
                slot_key = pci_addr.rsplit(".", 1)[0]
                if slot_key in seen_slots:
                    continue
                seen_slots.add(slot_key)
                ethtool_out = run(["ethtool", "-i", iface])
                fw = ""
                for line in ethtool_out.splitlines():
                    if line.startswith("firmware-version:"):
                        fw = line.split(":", 1)[-1].strip()
                        break
                if fw and fw != "N/A":
                    model_out = run(["lspci", "-s", pci_addr])
                    model = "Intel NIC"
                    if model_out:
                        m = re.sub(r".*Ethernet controller:\s*Intel Corporation\s*", "", model_out)
                        m = re.sub(r"\s*\(rev.*", "", m)
                        if m:
                            model = m
                    add("nic", model, "Intel", fw, pci_addr)


###############################################################################
# BIOS
###############################################################################
def discover_bios():
    if not shutil.which("dmidecode"):
        return
    vendor = run(["dmidecode", "-s", "bios-vendor"])
    version = run(["dmidecode", "-s", "bios-version"])
    if version:
        add("bios", "System BIOS", vendor or "Unknown", version, "dmidecode")


###############################################################################
# BMC
###############################################################################
def discover_bmc():
    if not shutil.which("ipmitool"):
        return
    output = run(["ipmitool", "mc", "info"])
    if not output:
        return
    fw = ""
    mfg = "Unknown"
    for line in output.splitlines():
        if "Firmware Revision" in line:
            fw = line.split(":")[-1].strip()
        elif "Manufacturer Name" in line:
            mfg = line.split(":")[-1].strip()
    if fw:
        add("bmc", "BMC/IPMI", mfg, fw, "ipmitool")


###############################################################################
# SAS Expander
###############################################################################
def discover_expander():
    sg_inq = shutil.which("sg_inq")
    if not sg_inq:
        return

    seen_models = set()
    dev_dir = "/dev"
    if not os.path.isdir(dev_dir):
        return

    for entry in sorted(os.listdir(dev_dir)):
        if not re.match(r"sg\d+$", entry):
            continue
        dev = "/dev/{}".format(entry)

        output = run([sg_inq, dev])
        if not output or "enclosure" not in output.lower():
            continue

        fw = ""
        model = "SAS Expander"
        vendor = "Unknown"
        for line in output.splitlines():
            if "Product rev" in line:
                fw = line.split()[-1] if line.split() else ""
            elif "Product identification" in line:
                model = line.split(":")[-1].strip() if ":" in line else model
            elif "Vendor identification" in line:
                vendor = line.split(":")[-1].strip() if ":" in line else vendor

        key = "{}:{}".format(model, fw)
        if key in seen_models:
            continue
        seen_models.add(key)

        if fw:
            add("expander", model, vendor, fw, dev)


###############################################################################
# Main
###############################################################################
def main():
    discover_bios()
    discover_bmc()
    discover_pci_devices()
    discover_nic_fallback()
    discover_expander()

    if "--pretty" in sys.argv:
        print(json.dumps(devices, indent=2))
    else:
        print(json.dumps(devices))


if __name__ == "__main__":
    main()
