#!/usr/bin/python3
import os
import sys
import subprocess
import re
import json

json_zfs = {
    "zfs_installed": False
}


def get_zfs_list():
    try:
        zfs_list_result = subprocess.Popen(
            ["zfs", "list", "-H"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True).stdout
    except:
        return False

    zpools = []
    for line in zfs_list_result:
        parsed_line = line.rstrip("\n").split("\t")
        if parsed_line != "" and not any(x in parsed_line[0] for x in ["/","@"]):
            zpools.append(
                {
                    "name": parsed_line[0],
                    "used": parsed_line[1],
                    "avail": parsed_line[2],
                    "refer": parsed_line[3],
                    "mountpoint": parsed_line[4]
                }
            )
    return zpools


def get_zpool_list():
    try:
        zpool_list_result = subprocess.Popen(
            ["zpool", "list", "-H"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True).stdout
    except:
        return False

    zpools = []
    for line in zpool_list_result:
        parsed_line = line.rstrip("\n").split("\t")
        if parsed_line != "":
            zpools.append(
                {
                    "name": parsed_line[0],
                    "raw_size": parsed_line[1],
                    "raw_alloc": parsed_line[2],
                    "raw_free": parsed_line[3],
                    "ckpoint": parsed_line[4],
                    "expandsz": parsed_line[5],
                    "frag": parsed_line[6],
                    "cap": parsed_line[7],
                    "dedup": parsed_line[8],
                    "health": parsed_line[9],
                    "altroot": parsed_line[10]
                }
            )
    zfs_list = get_zfs_list()
    for pool in zpools:
        for entry in zfs_list:
            if entry["name"] == pool["name"]:
                pool["used"] = entry["used"]
                pool["avail"] = entry["avail"]
                pool["refer"] = entry["refer"]
                pool["mountpoint"] = entry["mountpoint"]
        if not all(key in pool for key in ("used", "avail", "mountpoint")):
            pool["used"] = "-"
            pool["avail"] = "-"
            pool["refer"] = "-"
            pool["mountpoint"] = "-"
    return zpools


def zpool_status_flags(pool_name):
    """Returns { pool_name: <combined multi-line vdev text>, 'state': <str> }."""
    try:
        zpool_status_result = subprocess.Popen(
            ["zpool", "status", pool_name, "-P"],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            universal_newlines=True
        ).stdout.read()
    except:
        print(f"failed to run 'zpool status {pool_name}'")
        exit(1)

    # Collect all matches in one big string
    vdev_lines = []
    for match in re.finditer(
        r"^\t{1}(\S+).*$\n(?:^\t{1} +.*$\n)+"
        r"|^\t{1}(\S+).*$\n(?:^\t{1} +.*$\n)+",
        zpool_status_result,
        flags=re.MULTILINE
    ):
        vdev_lines.append(match.group(0))

    combined_vdevs = "\n".join(vdev_lines)

    # Grab the zpool state
    state_matches = re.findall(r"^.*state:\s+(\S+)", zpool_status_result, flags=re.MULTILINE)
    if state_matches:
        state_value = state_matches[0]
    else:
        state_value = "UNKNOWN"

    return {
        pool_name: combined_vdevs,
        "state": state_value
    }


def zpool_iostat_flags(pool_name):
    """Returns { pool_name: <combined multi-line vdev text> }."""
    try:
        zpool_status_result = subprocess.Popen(
            ["zpool", "iostat", "-vP", pool_name],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            universal_newlines=True
        ).stdout.read()
    except:
        print(f"failed to run 'zpool iostat -v {pool_name}'")
        exit(1)

    vdev_lines = []
    for match in re.finditer(
        r"^(\S+).*$\n(?:^ +.*$\n)+|^(\S+).*$\n(?:^ +.*$\n)+",
        zpool_status_result,
        flags=re.MULTILINE
    ):
        vdev_lines.append(match.group(0))

    combined_vdevs = "\n".join(vdev_lines)

    return { pool_name: combined_vdevs }


def zpool_status(pool_name):
    """Returns { pool_name: <combined multi-line vdev text>, 'state': <str> }."""
    try:
        zpool_status_result = subprocess.Popen(
            ["zpool", "status", pool_name],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            universal_newlines=True
        ).stdout.read()

    except:
        print(f"failed to run 'zpool status {pool_name}'")
        exit(1)

    vdev_lines = []
    for match in re.finditer(
        r"^\t{1}(\S+).*$\n(?:^\t{1} +.*$\n)+"
        r"|^\t{1}(\S+).*$\n(?:^\t{1} +.*$\n)+",
        zpool_status_result,
        flags=re.MULTILINE
    ):
        vdev_lines.append(match.group(0))

    combined_vdevs = "\n".join(vdev_lines)

    # Grab the zpool state
    state_matches = re.findall(r"^.*state:\s+(\S+)", zpool_status_result, flags=re.MULTILINE)
    if state_matches:
        state_value = state_matches[0]
    else:
        state_value = "UNKNOWN"

    return {
        pool_name: combined_vdevs,
        "state": state_value
    }


def zpool_iostat(pool_name):
    """Returns { pool_name: <combined multi-line vdev text> }."""
    try:
        zpool_status_result = subprocess.Popen(
            ["zpool", "iostat","-v", pool_name],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            universal_newlines=True
        ).stdout.read()

    except:
        print(f"failed to run 'zpool iostat -v {pool_name}'")
        exit(1)

    vdev_lines = []
    for match in re.finditer(
        r"^(\S+).*$\n(?:^ +.*$\n)+|^(\S+).*$\n(?:^ +.*$\n)+",
        zpool_status_result,
        flags=re.MULTILINE
    ):
        vdev_lines.append(match.group(0))

    combined_vdevs = "\n".join(vdev_lines)

    return { pool_name: combined_vdevs }

def zpool_status_parse(zp_status_obj, key, pool_name):
    if key not in zp_status_obj.keys():
        return [], []

    zp_status_obj_path = zpool_status_flags(pool_name)

    zp_status_default = zp_status_obj[key].splitlines()
    zp_status_path = zp_status_obj_path[key].splitlines()

    vdevs = []
    disks = []
    counts = []
    disk_count = 0
    initial_disk = True
    for i in range(0,len(zp_status_default)):
        re_vdev_default = re.search(r"^\t  (\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*",zp_status_default[i])
        re_vdev_path = re.search(r"^\t  (/dev/\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*",zp_status_path[i])
        re_disk_path = re.search(r"^\t    (/dev/\S+)(?:-part[0-9])?\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*",zp_status_path[i])
        re_disk_default = re.search(r"^\t    (\S+)(?:-part[0-9])?\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*",zp_status_default[i])
        if re_vdev_default != None:
            # we have a vdev
            vdevs.append(
                {
                    "tag":key,
                    "name": re_vdev_default.group(1),
                    "state": re_vdev_default.group(2), 
                    "read_errors": re_vdev_default.group(3), 
                    "write_errors": re_vdev_default.group(4), 
                    "checksum_errors": re_vdev_default.group(5)
                }
            )
            if re_vdev_path != None:
                # vdev is also a disk. store it in the array of disks
                disks.append(
                    {
                        "tag":key,
                        "name":re_vdev_default.group(1),
                        "state": re_vdev_default.group(2), 
                        "read_errors": re_vdev_default.group(3), 
                        "write_errors": re_vdev_default.group(4), 
                        "checksum_errors": re_vdev_default.group(5)
                    }
                )
                if not initial_disk:
                    counts.append(disk_count)
                    disk_count = 1
                else:
                    disk_count = disk_count +1
                    initial_disk = False

            elif not initial_disk:
                counts.append(disk_count)
                disk_count = 0

        if re_disk_path != None and re_disk_default != None:
            # we've encountered a "regular" disk
            initial_disk = False
            disks.append(
                {
                    "tag":key,
                    "name":re_disk_default.group(1),
                    "state": re_disk_default.group(2), 
                    "read_errors": re_vdev_default.group(3) if re_vdev_default else re_disk_default.group(3),
                    "write_errors": re_vdev_default.group(4) if re_vdev_default else re_disk_default.group(4),
                    "checksum_errors": re_vdev_default.group(5) if re_vdev_default else re_disk_default.group(5)
                }
            )
            disk_count = disk_count + 1

    # fix for legacy "part-N" naming
    exception_match = r"^(\d+-\d+)(?:-part[0-9])"
    for disk in disks:
        match = re.match(exception_match, disk["name"])
        if match:
            disk["name"] = match.group(1)

    counts.append(disk_count)
    return vdevs, disks, counts


def verify_zfs_device_format(zp_status_obj, pool_name):
    alert = []
    default_pattern = r"^\t    (\d+-\d+)(?:-part[0-9])?\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*"
    unsupported_pattern = r"^\t    (\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*"
    disks = [
        {
            "tag": pool_name,
            "name": match.group(1),
            "state": match.group(2),
            "read_errors": match.group(3),
            "write_errors": match.group(4),
            "checksum_errors": match.group(5)
        }
        for match in re.finditer(default_pattern, zp_status_obj[pool_name], flags=re.MULTILINE)
    ]
    unsupported_disks = [
        {
            "tag": pool_name,
            "name": match.group(1),
            "state": match.group(2),
            "read_errors": match.group(3),
            "write_errors": match.group(4),
            "checksum_errors": match.group(5)
        }
        for match in re.finditer(unsupported_pattern, zp_status_obj[pool_name], flags=re.MULTILINE)
    ]
    if len(unsupported_disks) > len(disks):
        for disk in disks:
            if disk in unsupported_disks:
                unsupported_disks.remove(disk)
        exception_match = r"^(\d+-\d+)(?:-part[0-9])"
        for disk in unsupported_disks:
            match = re.match(exception_match,disk["name"])
            if match:
                unsupported_disks.remove(disk)
        alert.append(f"ZFS status displayed by this module for zpool '{pool_name}' may be incomplete.\n\n")
        alert.append("This module can only display zfs status information for devices that are created using a device alias.\n\n")
        alert.append("This can be done using the 45Drives cockpit-zfs-manager package:\nhttps://github.com/45Drives/cockpit-zfs-manager/releases/\n\n")
        if unsupported_disks:
            alert.append("The following zfs devices do not conform:\n")
            for disk in unsupported_disks:
                alert.append(f"\t  {disk['name']}\n")
        alert.append("\n")
    return alert


def zpool_iostat_parse(zp_iostat_obj, key, pool_name):
    if key not in zp_iostat_obj.keys():
        return [], []

    zp_iostat_obj_path = zpool_iostat_flags(pool_name)

    zp_iostat_default = zp_iostat_obj[key].splitlines()
    zp_iostat_path = zp_iostat_obj_path[key].splitlines()

    vdevs = []
    disks = []
    counts = []
    disk_count = 0
    initial_disk = True
    for i in range(0,len(zp_iostat_default)):
        re_vdev_default = re.search(r"^  (\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*", zp_iostat_default[i])
        re_vdev_path = re.search(r"^  (/dev/\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*", zp_iostat_path[i])
        re_disk_path = re.search(r"^    (/dev/\S+)(?:-part[0-9])?\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*", zp_iostat_path[i])
        re_disk_default = re.search(r"^    (\S+)(?:-part[0-9])?\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+).*", zp_iostat_default[i])

        if re_vdev_default:
            vdevs.append(
                {
                    "tag": key,
                    "raid_level": re_vdev_default.group(1),
                    "alloc": re_vdev_default.group(2),
                    "free": re_vdev_default.group(3),
                    "read_ops": re_vdev_default.group(4),
                    "write_ops": re_vdev_default.group(5),
                    "read_bw": re_vdev_default.group(6),
                    "write_bw": re_vdev_default.group(7)
                }
            )
            if re_vdev_path:
                # vdev is also a disk
                disks.append(
                    {
                        "tag": key,
                        "name": re_vdev_default.group(1),
                        "alloc": re_vdev_default.group(2),
                        "free": re_vdev_default.group(3),
                        "read_ops": re_vdev_default.group(4),
                        "write_ops": re_vdev_default.group(5),
                        "read_bw": re_vdev_default.group(6),
                        "write_bw": re_vdev_default.group(7)
                    }
                )
                # Mark the vdev as "Disk"
                vdevs[-1]["raid_level"] = "Disk"
                if not initial_disk:
                    counts.append(disk_count)
                    disk_count = 1
                else:
                    disk_count += 1
                    initial_disk = False
            elif not initial_disk:
                counts.append(disk_count)
                disk_count = 0

        if re_disk_path and re_disk_default:
            initial_disk = False
            disks.append(
                {
                    "tag": key,
                    "name": re_disk_default.group(1),
                    "alloc": re_disk_default.group(2),
                    "free": re_disk_default.group(3),
                    "read_ops": re_disk_default.group(4),
                    "write_ops": re_disk_default.group(5),
                    "read_bw": re_disk_default.group(6),
                    "write_bw": re_disk_default.group(7)
                }
            )
            disk_count += 1

    exception_match = r"^(\d+-\d+)(?:-part[0-9])"
    for disk in disks:
        match = re.match(exception_match,disk["name"])
        if match:
            disk["name"] = match.group(1)

    counts.append(disk_count)
    return vdevs, disks, counts


def get_zpool_status():
    json_zfs["warnings"] = []
    for pool in json_zfs["zpools"]:
        status_output = zpool_status(pool["name"])
        iostat_output = zpool_iostat(pool["name"])
        pool["state"] = status_output["state"]
        pool["vdevs"] = []

        alert = verify_zfs_device_format(status_output, pool["name"])
        if alert:
            json_zfs["warnings"] = json_zfs["warnings"] + alert

        for key in status_output.keys():
            # parse the output of both commands by top-level entry
            if key in iostat_output.keys():
                status_vdevs, status_disks, status_disk_counts = zpool_status_parse(
                    status_output, key, pool["name"]
                )
                iostat_vdevs, iostat_disks, iostat_disk_counts = zpool_iostat_parse(
                    iostat_output, key, pool["name"]
                )
                if not status_disks or not iostat_disks or not status_disk_counts or not iostat_disk_counts:
                    print("/usr/share/cockpit/45drives-disks/scripts/zfs_info failed to interpret zfs information:")
                    print(f"zpool status {pool['name']}:")
                    print(status_output[key])
                    print(f"zpool iostat -v {pool['name']}:")
                    print(iostat_output[key])
                    print("Other Information: ")
                    print("status_vdevs", json.dumps(status_vdevs, indent=2))
                    print("status_disks", json.dumps(status_disks, indent=2))
                    print("status_disk_counts", json.dumps(status_disk_counts, indent=2))
                    print("iostat_vdevs", json.dumps(iostat_vdevs, indent=2))
                    print("iostat_disks", json.dumps(iostat_disks, indent=2))
                    print("iostat_disk_counts", json.dumps(iostat_disk_counts, indent=2))
                    exit(1)

                disk_index = 0
                for i in range(len(status_vdevs)):
                    # Merge iostat + status info on each vdev
                    status_vdevs[i]["raid_level"] = iostat_vdevs[i]["raid_level"]
                    status_vdevs[i]["alloc"] = iostat_vdevs[i]["alloc"]
                    status_vdevs[i]["free"] = iostat_vdevs[i]["free"]
                    status_vdevs[i]["read_ops"] = iostat_vdevs[i]["read_ops"]
                    status_vdevs[i]["write_ops"] = iostat_vdevs[i]["write_ops"]
                    status_vdevs[i]["read_bw"] = iostat_vdevs[i]["read_bw"]
                    status_vdevs[i]["write_bw"] = iostat_vdevs[i]["write_bw"]
                    status_vdevs[i]["disks"] = []

                    for j in range(disk_index, disk_index + status_disk_counts[i]):
                        # Combine the disk info from both outputs
                        status_disks[j]["alloc"] = iostat_disks[j]["alloc"]
                        status_disks[j]["free"] = iostat_disks[j]["free"]
                        status_disks[j]["read_ops"] = iostat_disks[j]["read_ops"]
                        status_disks[j]["write_ops"] = iostat_disks[j]["write_ops"]
                        status_disks[j]["read_bw"] = iostat_disks[j]["read_bw"]
                        status_disks[j]["write_bw"] = iostat_disks[j]["write_bw"]
                        status_disks[j]["vdev_idx"] = len(pool["vdevs"])
                        status_vdevs[i]["disks"].append(status_disks[j])

                    pool["vdevs"].append(status_vdevs[i])
                    disk_index += status_disk_counts[i]


def check_zfs():
    try:
        command_result = subprocess.run(
            ["command -v zfs"],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            shell=True
        ).returncode
    except:
        return False
    return (command_result == 0)


def create_disk_entries():
    disk_entries = {}
    for pool_index, pool in enumerate(json_zfs["zpools"]):
        for vdev in pool["vdevs"]:
            for disk in vdev["disks"]:
                disk_entries[disk["name"]] = {}
                disk_entries[disk["name"]]["zpool_name"] = pool["name"]
                disk_entries[disk["name"]]["zpool_used"] = pool["used"]
                disk_entries[disk["name"]]["zpool_avail"] = pool["avail"]
                disk_entries[disk["name"]]["zpool_mountpoint"] = pool["mountpoint"]
                disk_entries[disk["name"]]["zpool_state"] = pool["state"]
                disk_entries[disk["name"]]["zpool_idx"] = pool_index
                disk_entries[disk["name"]]["vdev_raid_level"] = vdev["raid_level"]
                disk_entries[disk["name"]]["vdev_alloc"] = vdev["alloc"]
                disk_entries[disk["name"]]["vdev_free"] = vdev["free"]
                disk_entries[disk["name"]]["vdev_read_ops"] = vdev["read_ops"]
                disk_entries[disk["name"]]["vdev_write_ops"] = vdev["write_ops"]
                disk_entries[disk["name"]]["vdev_read_bw"] = vdev["read_bw"]
                disk_entries[disk["name"]]["vdev_write_bw"] = vdev["write_bw"]
                disk_entries[disk["name"]]["name"] = disk["name"]
                disk_entries[disk["name"]]["alloc"] = disk["alloc"]
                disk_entries[disk["name"]]["free"] = disk["free"]
                disk_entries[disk["name"]]["read_ops"] = disk["read_ops"]
                disk_entries[disk["name"]]["write_ops"] = disk["write_ops"]
                disk_entries[disk["name"]]["read_bw"] = disk["read_bw"]
                disk_entries[disk["name"]]["write_bw"] = disk["write_bw"]
                disk_entries[disk["name"]]["vdev_idx"] = disk["vdev_idx"]
                disk_entries[disk["name"]]["state"] = disk["state"]
                disk_entries[disk["name"]]["read_errors"] = disk["read_errors"]
                disk_entries[disk["name"]]["write_errors"] = disk["write_errors"]
                disk_entries[disk["name"]]["checksum_errors"] = disk["checksum_errors"]
                disk_entries[disk["name"]]["tag"] = disk["tag"]

    json_zfs["zfs_disks"] = disk_entries


def main():
    if check_zfs():
        json_zfs["zfs_installed"] = True
        json_zfs["zpools"] = get_zpool_list()
        get_zpool_status()
        create_disk_entries()

    print(json.dumps(json_zfs, indent=4))


if __name__ == "__main__":
    main()
