#!/usr/libexec/platform-python
"""
45d-pre-reboot-storage-check

System-wide storage safety check before activation reboot.
Scans ALL detected HBAs/controllers, evaluates storage activity on each,
and provides a per-controller risk assessment.

Usage:
    pre-reboot-storage-check
    pre-reboot-storage-check --samples 3 --interval 1

Output: JSON with overall risk level and per-controller breakdown.

Exit codes:
    0 - check completed (see JSON for risk level)
    1 - error during check
"""

import argparse
import glob
import json
import os
import re
import subprocess
import sys
import time

###############################################################################
# Configuration
###############################################################################

DEFAULT_SAMPLES = 3
DEFAULT_INTERVAL = 1  # seconds between samples

# Risk levels (ordered)
RISK_SAFE = "safe"
RISK_INFO = "info"
RISK_WARNING = "warning"
RISK_HIGH = "high"
RISK_CRITICAL = "critical"

RISK_ORDER = [RISK_SAFE, RISK_INFO, RISK_WARNING, RISK_HIGH, RISK_CRITICAL]


def risk_max(a, b):
    """Return the higher of two risk levels."""
    return a if RISK_ORDER.index(a) >= RISK_ORDER.index(b) else b


###############################################################################
# Discover All Storage Controllers (HBAs)
###############################################################################

def discover_storage_controllers():
    """Find all PCI storage controllers (HBAs) that have SCSI hosts.
    
    Returns list of dicts: [{pci_address, scsi_hosts, description}]
    """
    controllers = []
    
    # Walk /sys/class/scsi_host/ to find all SCSI hosts, then map back to PCI
    scsi_host_dir = "/sys/class/scsi_host"
    if not os.path.isdir(scsi_host_dir):
        return controllers
    
    # Map PCI address -> list of host numbers
    pci_to_hosts = {}
    
    for host_entry in sorted(os.listdir(scsi_host_dir)):
        if not host_entry.startswith("host"):
            continue
        try:
            host_num = int(host_entry[4:])
        except ValueError:
            continue
        
        # Resolve symlink to find PCI device
        host_path = os.path.join(scsi_host_dir, host_entry)
        real_path = os.path.realpath(host_path)
        
        # Extract PCI address from path - use LAST match to get actual controller, not upstream bridge
        pci_matches = re.findall(r'(\d{4}:[0-9a-f]{2}:[0-9a-f]{2}\.\d)', real_path)
        if pci_matches:
            pci_addr = pci_matches[-1]
            if pci_addr not in pci_to_hosts:
                pci_to_hosts[pci_addr] = []
            pci_to_hosts[pci_addr].append(host_num)
    
    # Get descriptions via lspci
    pci_descriptions = {}
    try:
        result = subprocess.run(
            ["lspci", "-mm"],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=10
        )
        if result.returncode == 0:
            for line in result.stdout.strip().split('\n'):
                # Format: "41:00.0 ..."
                parts = line.split('"')
                if len(parts) >= 6:
                    slot = line.split()[0]
                    # Try to match with full domain prefix
                    for pci_addr in pci_to_hosts:
                        if pci_addr.endswith(slot):
                            pci_descriptions[pci_addr] = parts[3] if len(parts) > 3 else ""
    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass
    
    for pci_addr, hosts in sorted(pci_to_hosts.items()):
        controllers.append({
            "pci_address": pci_addr,
            "scsi_hosts": sorted(set(hosts)),
            "description": pci_descriptions.get(pci_addr, ""),
        })
    
    return controllers


###############################################################################
# Block Device Mapping
###############################################################################

def find_block_devices_for_hosts(scsi_hosts):
    """Map SCSI host numbers to block devices."""
    devices = []
    
    for host_num in scsi_hosts:
        pattern = "/sys/bus/scsi/devices/{}:*".format(host_num)
        for scsi_dev in glob.glob(pattern):
            block_dir = os.path.join(scsi_dev, "block")
            if os.path.isdir(block_dir):
                for dev_name in os.listdir(block_dir):
                    if dev_name not in devices:
                        devices.append(dev_name)
    
    return sorted(devices)


###############################################################################
# I/O Activity Detection
###############################################################################

def read_inflight(device):
    """Read /sys/block/<device>/inflight for in-flight I/O counts."""
    path = "/sys/block/{}/inflight".format(device)
    try:
        with open(path) as f:
            parts = f.read().strip().split()
            if len(parts) >= 2:
                return int(parts[0]), int(parts[1])
    except (IOError, ValueError):
        pass
    return 0, 0


def read_diskstats(device):
    """Read /proc/diskstats for cumulative I/O counters."""
    try:
        with open("/proc/diskstats") as f:
            for line in f:
                parts = line.split()
                if len(parts) >= 14 and parts[2] == device:
                    return {
                        "reads_completed": int(parts[3]),
                        "writes_completed": int(parts[7]),
                    }
    except (IOError, ValueError):
        pass
    return None


def sample_io_activity(devices, samples=DEFAULT_SAMPLES, interval=DEFAULT_INTERVAL):
    """Sample I/O activity and return per-device summary."""
    if not devices:
        return {}
    
    device_activity = {}
    for dev in devices:
        device_activity[dev] = {
            "max_reads_inflight": 0,
            "max_writes_inflight": 0,
            "total_reads_during_sample": 0,
            "total_writes_during_sample": 0,
            "active": False,
            "read_active": False,
            "write_active": False,
        }
    
    # Initial snapshot
    start_stats = {}
    for dev in devices:
        start_stats[dev] = read_diskstats(dev)
    
    # Sample inflight
    for i in range(samples):
        for dev in devices:
            reads, writes = read_inflight(dev)
            if reads > device_activity[dev]["max_reads_inflight"]:
                device_activity[dev]["max_reads_inflight"] = reads
            if writes > device_activity[dev]["max_writes_inflight"]:
                device_activity[dev]["max_writes_inflight"] = writes
        if i < samples - 1:
            time.sleep(interval)
    
    # Final snapshot
    for dev in devices:
        end = read_diskstats(dev)
        start = start_stats[dev]
        if start and end:
            reads_delta = end["reads_completed"] - start["reads_completed"]
            writes_delta = end["writes_completed"] - start["writes_completed"]
            device_activity[dev]["total_reads_during_sample"] = reads_delta
            device_activity[dev]["total_writes_during_sample"] = writes_delta
            device_activity[dev]["read_active"] = reads_delta > 0
            device_activity[dev]["write_active"] = writes_delta > 0
            device_activity[dev]["active"] = (reads_delta > 0 or writes_delta > 0)
    
    return device_activity


def get_activity_status(io_activity):
    """Summarize activity across all devices for a controller."""
    if not io_activity:
        return "no_devices"
    
    any_writes = any(a["write_active"] for a in io_activity.values())
    any_reads = any(a["read_active"] for a in io_activity.values())
    total_writes = sum(a["total_writes_during_sample"] for a in io_activity.values())
    total_reads = sum(a["total_reads_during_sample"] for a in io_activity.values())
    
    if total_writes > 100 or total_reads > 1000:
        return "heavy_activity"
    elif any_writes:
        return "write_active"
    elif any_reads:
        return "read_active"
    else:
        return "idle"


###############################################################################
# Storage Layer Detection
###############################################################################

def resolve_all_descendants(devices):
    """Use lsblk to map physical disks to all related device names (partitions, dm, holders)."""
    related = set(devices)
    try:
        result = subprocess.run(
            ["lsblk", "-J", "-o", "NAME,PKNAME,TYPE,MOUNTPOINT"],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=10
        )
        if result.returncode == 0:
            import json as _json
            data = _json.loads(result.stdout)
            def walk(node, parent_match):
                name = node.get("name", "")
                pkname = node.get("pkname", "")
                is_match = parent_match or name in devices or pkname in devices
                if is_match:
                    related.add(name)
                for child in node.get("children", []):
                    walk(child, is_match)
            for dev_node in data.get("blockdevices", []):
                walk(dev_node, dev_node.get("name", "") in devices)
    except (FileNotFoundError, subprocess.TimeoutExpired, ValueError):
        pass
    return related


def get_mounted_filesystems(devices):
    """Check which devices have mounted filesystems, resolving aliases via lsblk."""
    mounts = []
    all_related = resolve_all_descendants(devices)
    try:
        with open("/proc/mounts") as f:
            mount_lines = f.readlines()
    except IOError:
        return mounts

    for line in mount_lines:
        parts = line.split()
        if len(parts) < 3:
            continue
        mount_dev = parts[0]
        # Check direct /dev/<name> match against all related devices
        for name in all_related:
            if "/dev/{}".format(name) in mount_dev or "/dev/mapper/{}".format(name) in mount_dev:
                mounts.append({
                    "device": mount_dev,
                    "mount_point": parts[1],
                    "filesystem": parts[2],
                })
                break
        else:
            # Also resolve UUID/LABEL symlinks via realpath
            if mount_dev.startswith("/dev/"):
                try:
                    real = os.path.realpath(mount_dev)
                    base = os.path.basename(real)
                    if base in all_related:
                        mounts.append({
                            "device": mount_dev,
                            "mount_point": parts[1],
                            "filesystem": parts[2],
                        })
                except OSError:
                    pass
    return mounts


def get_zfs_pools(devices):
    """Check if any ZFS pools use these devices."""
    pools = []
    try:
        result = subprocess.run(
            ["zpool", "status"],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=10
        )
        if result.returncode != 0:
            return pools
        
        current_pool = None
        pool_state = None
        pool_scan = None
        pool_devices = []
        
        for line in result.stdout.split('\n'):
            pool_match = re.match(r'\s*pool:\s+(.+)', line)
            if pool_match:
                if current_pool and pool_devices:
                    pools.append({
                        "name": current_pool,
                        "state": pool_state,
                        "scan": pool_scan,
                        "devices_on_controller": pool_devices,
                    })
                current_pool = pool_match.group(1)
                pool_state = None
                pool_scan = None
                pool_devices = []
            
            state_match = re.match(r'\s*state:\s+(.+)', line)
            if state_match:
                pool_state = state_match.group(1)
            
            scan_match = re.match(r'\s*scan:\s+(.+)', line)
            if scan_match:
                pool_scan = scan_match.group(1).strip()
            
            for dev in devices:
                if dev in line and not line.strip().startswith("pool:"):
                    if dev not in pool_devices:
                        pool_devices.append(dev)
        
        if current_pool and pool_devices:
            pools.append({
                "name": current_pool,
                "state": pool_state,
                "scan": pool_scan,
                "devices_on_controller": pool_devices,
            })
    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass
    return pools


def get_mdraid_arrays(devices):
    """Check if any md RAID arrays use these devices."""
    arrays = []
    try:
        with open("/proc/mdstat") as f:
            content = f.read()
    except IOError:
        return arrays
    
    for line in content.split('\n'):
        for dev in devices:
            if dev in line:
                md_match = re.match(r'^(md\d+)\s+:', line)
                if md_match:
                    array_info = {"name": md_match.group(1), "devices_on_controller": [dev]}
                    if "recovery" in content or "resync" in content:
                        array_info["rebuilding"] = True
                    arrays.append(array_info)
    return arrays


def check_is_boot_disk(devices):
    """Check if any device is the OS/boot disk, resolving through LVM/dm/partitions."""
    boot_devices = []
    all_related = resolve_all_descendants(devices)
    try:
        with open("/proc/mounts") as f:
            for line in f:
                parts = line.split()
                if len(parts) >= 2 and parts[1] in ("/", "/boot", "/boot/efi"):
                    mount_dev = parts[0]
                    # Direct match
                    for name in all_related:
                        if "/dev/{}".format(name) in mount_dev or "/dev/mapper/{}".format(name) in mount_dev:
                            # Find the parent physical disk
                            for dev in devices:
                                if dev not in boot_devices:
                                    boot_devices.append(dev)
                            break
                    else:
                        # Resolve symlink (UUID=, LABEL=, /dev/mapper/*)
                        if mount_dev.startswith("/dev/"):
                            try:
                                real = os.path.realpath(mount_dev)
                                base = os.path.basename(real)
                                if base in all_related:
                                    for dev in devices:
                                        if dev not in boot_devices:
                                            boot_devices.append(dev)
                            except OSError:
                                pass
    except IOError:
        pass
    return boot_devices


def check_lvm_usage(devices):
    """Check if devices are used as LVM PVs, resolving partitions and holders."""
    lvm_info = []
    all_related = resolve_all_descendants(devices)
    try:
        result = subprocess.run(
            ["pvs", "--noheadings", "-o", "pv_name,vg_name"],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, timeout=10
        )
        if result.returncode == 0:
            for line in result.stdout.strip().split('\n'):
                parts = line.split()
                if len(parts) >= 2:
                    pv = parts[0]
                    vg = parts[1]
                    pv_base = os.path.basename(pv)
                    if pv_base in all_related:
                        # Find the parent physical disk
                        matched_dev = pv_base if pv_base in devices else (devices[0] if devices else pv_base)
                        for dev in devices:
                            if dev in pv or pv_base.startswith(dev):
                                matched_dev = dev
                                break
                        lvm_info.append({"pv": pv, "vg": vg, "device": matched_dev})
    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass
    return lvm_info


###############################################################################
# Per-Controller Risk Assessment
###############################################################################

def assess_controller_risk(devices, io_activity, mounts, zfs_pools, md_arrays, boot_disks, lvm_info):
    """Assess risk for a single controller.
    
    This is purely informational — shows what's present on each controller.
    ZFS, mdraid, and normal I/O all handle reboots gracefully.
    Only an mdraid rebuild is flagged as a concern.
    """
    
    risk_level = RISK_SAFE
    warnings = []
    recommendations = []

    # Treat active storage I/O as a reboot risk signal.
    any_writes = any(a.get("write_active") or a.get("max_writes_inflight", 0) > 0 for a in (io_activity or {}).values())
    any_reads = any(a.get("read_active") or a.get("max_reads_inflight", 0) > 0 for a in (io_activity or {}).values())
    if any_writes:
        risk_level = risk_max(risk_level, RISK_HIGH)
        warnings.append("Active write I/O detected")
        recommendations.append("Wait for write activity to finish before rebooting")
    elif any_reads:
        risk_level = risk_max(risk_level, RISK_WARNING)
        warnings.append("Active read I/O detected")

    # No drives attached
    if not devices:
        return {
            "risk_level": RISK_SAFE,
            "warnings": [],
            "recommendations": [],
        }
    
    # Boot disk (informational)
    if boot_disks:
        warnings.append("OS/boot disk ({}) is on this controller".format(", ".join(boot_disks)))
    
    # ZFS pools (informational — ZFS handles reboots gracefully)
    if zfs_pools:
        pool_names = [p["name"] for p in zfs_pools]
        warnings.append("ZFS pools: {}".format(", ".join(pool_names)))
    
    # mdraid — only flag active rebuild as a concern
    if md_arrays:
        for arr in md_arrays:
            if arr.get("rebuilding"):
                risk_level = risk_max(risk_level, RISK_WARNING)
                warnings.append("mdraid rebuild in progress on '{}'".format(arr["name"]))
                recommendations.append("Consider waiting for RAID rebuild to complete on '{}'".format(arr["name"]))
    
    return {
        "risk_level": risk_level,
        "warnings": warnings,
        "recommendations": recommendations,
    }


###############################################################################
# Main
###############################################################################

def main():
    parser = argparse.ArgumentParser(description="Pre-reboot system-wide storage safety check")
    parser.add_argument("--samples", type=int, default=DEFAULT_SAMPLES, help="Number of I/O samples (default: 3)")
    parser.add_argument("--interval", type=int, default=DEFAULT_INTERVAL, help="Seconds between samples (default: 1)")
    args = parser.parse_args()

    # Step 1: Discover all storage controllers
    controllers = discover_storage_controllers()
    
    # Step 2: For each controller, gather storage info and assess risk
    controller_results = []
    overall_risk = RISK_SAFE
    
    for ctrl in controllers:
        pci_addr = ctrl["pci_address"]
        scsi_hosts = ctrl["scsi_hosts"]
        
        # Find block devices
        devices = find_block_devices_for_hosts(scsi_hosts)
        
        # Sample I/O
        io_activity = sample_io_activity(devices, args.samples, args.interval)
        
        # Storage layers
        mounts = get_mounted_filesystems(devices)
        zfs_pools = get_zfs_pools(devices)
        md_arrays = get_mdraid_arrays(devices)
        boot_disks = check_is_boot_disk(devices)
        lvm_info = check_lvm_usage(devices)
        
        # Assess
        risk = assess_controller_risk(devices, io_activity, mounts, zfs_pools, md_arrays, boot_disks, lvm_info)
        
        activity_status = get_activity_status(io_activity)
        
        controller_results.append({
            "pci_address": pci_addr,
            "description": ctrl["description"],
            "scsi_hosts": scsi_hosts,
            "block_devices": devices,
            "device_count": len(devices),
            "activity_status": activity_status,
            "mounted_filesystems": mounts,
            "zfs_pools": zfs_pools,
            "mdraid_arrays": md_arrays,
            "boot_disks": boot_disks,
            "lvm_volumes": lvm_info,
            "risk_level": risk["risk_level"],
            "warnings": risk["warnings"],
            "recommendations": risk["recommendations"],
        })
        
        overall_risk = risk_max(overall_risk, risk["risk_level"])
    
    # Step 3: Build summary
    total_drives = sum(c["device_count"] for c in controller_results)
    total_zfs = []
    total_mounts = []
    active_writes = False
    active_reads = False
    resilver_rebuild = []
    
    for c in controller_results:
        for p in c.get("zfs_pools", []):
            total_zfs.append(p["name"])
            if p.get("scan") and ("scrub in progress" in (p["scan"] or "") or "resilver in progress" in (p["scan"] or "")):
                resilver_rebuild.append(p["name"])
        for m in c.get("mounted_filesystems", []):
            total_mounts.append(m["mount_point"])
        for arr in c.get("mdraid_arrays", []):
            if arr.get("rebuilding"):
                resilver_rebuild.append(arr["name"])
        if c["activity_status"] in ("write_active", "heavy_activity"):
            active_writes = True
        elif c["activity_status"] == "read_active":
            active_reads = True
    
    summary = {
        "total_drives": total_drives,
        "active_writes": active_writes,
        "active_reads": active_reads,
        "zfs_pools": total_zfs,
        "mounted_filesystems": total_mounts,
        "resilver_rebuild_in_progress": resilver_rebuild,
    }

    # Step 4: Global recommendations
    global_recommendations = []
    if overall_risk == RISK_WARNING:
        global_recommendations.append("An mdraid rebuild is in progress — consider waiting for it to finish")
    
    # Step 5: Output
    output = {
        "overall_risk_level": overall_risk,
        "safe_to_reboot": overall_risk in (RISK_SAFE, RISK_INFO),
        "summary": summary,
        "controller_count": len(controller_results),
        "controllers": controller_results,
        "global_recommendations": global_recommendations,
        "samples_taken": args.samples,
        "sample_interval_sec": args.interval,
    }
    
    print(json.dumps(output, indent=2))


if __name__ == "__main__":
    main()
