#!/usr/bin/python3

# 2025 Josh Boudreau <jboudreau@45drives.com>

import json
from typing import List, Tuple, Optional, Dict
import pyudev
import sys
import os
import re

context = pyudev.Context()

def set_env_var(key, value):
    print(f"{key}={value}")


def get_chassis_size_and_mobo() -> Optional[Tuple[str, str]]:
    try:
        with open("/etc/45drives/server_info/server_info.json", "r") as f:
            server_info = json.load(f)
        return server_info["Chassis Size"], server_info["Motherboard"]["Product Name"]
    except Exception as e:
        set_env_var("CUSTOM_ALIAS_ERR", str(e))
        return None


def get_alias_pcieport_parents(
    chassis_size: str, mobo: str
) -> Optional[List[Dict[str, str]]]:
    lut_chassissize_mobo_to_aliases = {
        "VM2": {
            "ME03-CE0-000": {
                "0000:40:03.1": "1-1",
                "0000:40:03.2": "1-2",
            }
        }
    }
    if chassis_size not in lut_chassissize_mobo_to_aliases:
        set_env_var("CUSTOM_ALIAS_ERR", f"unsupported chassis size: {chassis_size}")
        return None
    lut = lut_chassissize_mobo_to_aliases[chassis_size]
    if mobo not in lut:
        set_env_var(
            "CUSTOM_ALIAS_ERR", f"unsupported motherboard for {chassis_size}: {mobo}"
        )
        return None
    return lut[mobo]


def get_device_from_devpath(devpath: str) -> pyudev.Device:
    """
    Given a DEVPATH (e.g. '/devices/.../tty/ttyUSB0'),
    return the corresponding pyudev Device.
    """
    sys_path = os.path.join("/sys", devpath.lstrip("/"))
    return pyudev.Devices.from_sys_path(context, sys_path)


def get_pcieport_parent_sys_name(device: pyudev.Device) -> Optional[str]:
    try:
        pcieport_parent = next(
            d for d in device.ancestors if d.driver == "pcieport"
        )
        return pcieport_parent.sys_name
    except StopIteration:
        # no pcieport parent found
        return None


def main():
    result = get_chassis_size_and_mobo()
    if result is None:
        return
    chassis_size, mobo = result
    alias_parents = get_alias_pcieport_parents(chassis_size, mobo)
    if alias_parents is None:
        # not a custom alias server
        return

    if len(sys.argv) > 1:
        devpath = sys.argv[1]
    else:
        devpath = os.environ.get("DEVPATH")

    if not devpath:
        set_env_var("CUSTOM_ALIAS_ERR", "device path not provided!")
        return
    
    device = get_device_from_devpath(devpath)
    pcieport_parent_sys_name = get_pcieport_parent_sys_name(device)
    if pcieport_parent_sys_name is None:
        # special case for NVMe:
        # try stripping down to nvme subsystem device
        nvme_block_name: str = device.sys_name
        m = re.match(r"(nvme\d+)n\d+", nvme_block_name)
        if m:
            nvme_device = pyudev.Devices.from_name(context, "nvme", m.group(1))
            pcieport_parent_sys_name = get_pcieport_parent_sys_name(nvme_device)

    if pcieport_parent_sys_name not in alias_parents:
        # not mapped
        return

    alias = alias_parents[pcieport_parent_sys_name]

    set_env_var("ID_VDEV", alias)
    set_env_var("ID_VDEV_PATH", os.path.join("disk/by-vdev", alias))
    set_env_var("ID_VDEV_ALT_PATH", alias)


if __name__ == "__main__":
    main()
