#!/usr/libexec/platform-python
"""
45d-pre-flash-check

Lightweight pre-flash validation for a single HBA.
Confirms the HBA exists in sysfs and reports attached drives.

Usage:
    pre-flash-check --device-path 0000:41:00.0
    pre-flash-check --device-path 41:00.0

Output: JSON with HBA info and attached devices.

Exit codes:
    0 - HBA found, safe to flash
    1 - HBA not found or validation failed
"""

import argparse
import glob
import json
import os
import sys


def normalize_pci_address(addr):
    """Ensure PCI address has domain prefix (0000:)."""
    parts = addr.split(":")
    if len(parts) == 2:
        return "0000:{}".format(addr)
    return addr


def find_scsi_hosts_for_pci(pci_address):
    """Map PCI address to SCSI host numbers."""
    hosts = set()
    pci_path = "/sys/bus/pci/devices/{}".format(pci_address)
    if not os.path.exists(pci_path):
        return sorted(hosts)
    for root, dirs, files in os.walk(pci_path):
        for d in dirs:
            if d.startswith("host"):
                try:
                    hosts.add(int(d[4:]))
                except ValueError:
                    pass
    return sorted(hosts)


def find_block_devices_for_hosts(scsi_hosts):
    """Map SCSI host numbers to block devices."""
    devices = set()
    for host_num in scsi_hosts:
        pattern = "/sys/bus/scsi/devices/{}:*".format(host_num)
        for scsi_dev in glob.glob(pattern):
            block_dir = os.path.join(scsi_dev, "block")
            if os.path.isdir(block_dir):
                for dev_name in os.listdir(block_dir):
                    devices.add(dev_name)
    return sorted(devices)


def main():
    parser = argparse.ArgumentParser(description="Pre-flash check for HBA firmware update")
    parser.add_argument("--device-path", required=True, help="PCI bus address (e.g., 0000:41:00.0)")
    args = parser.parse_args()

    pci_address = normalize_pci_address(args.device_path)
    scsi_hosts = find_scsi_hosts_for_pci(pci_address)
    found = len(scsi_hosts) > 0
    devices = find_block_devices_for_hosts(scsi_hosts) if found else []

    output = {
        "pci_address": pci_address,
        "found": found,
        "safe_to_flash": found,
        "scsi_hosts": scsi_hosts,
        "block_devices": devices,
        "device_count": len(devices),
        "message": "Controller located. A reboot will be required after flashing to activate new firmware." if found else "Controller not found at {}. Verify the PCI address.".format(pci_address),
    }

    print(json.dumps(output, indent=2))
    sys.exit(0 if found else 1)


if __name__ == "__main__":
    main()
