#!/usr/bin/python3
import re
import subprocess
import os
import sys
import json
import shlex
from optparse import OptionParser


ANSI_colors={
    "LGREEN":'\033[1;32m',
    "GREEN":'\033[0;32m',
    "YELLOW":'\033[0;33m',
    "MAGENTA":'\033[0;35m',
    "CYAN":'\033[0;36m',
    "WHITE":'\033[0;37m',
    "RED":'\033[0;31m',
    "GREY":'\033[1;30m',
    "END":'\033[0m'
}

def get_server_info():
    json_path = "/etc/45drives/server_info/server_info.json"
    if os.path.exists(json_path):
        f = open(json_path, "r")
        server_info = json.load(f)
        f.close()
        return server_info
    else:
        print("/etc/45drives/server_info/server_info.json not found.")
        print("loadtest is designed to run on 45Drives server hardware only.")
        print("If you have a 45drives server, run 'dmap', and try again.")
        sys.exit(1)

def get_disks_from_vdev_id():
    vpath = "/etc/vdev_id.conf"
    if not os.path.exists(vpath):
        print("/etc/vdev_id.conf not found. run 'dmap'.")
        sys.exit(1)

    with open(vpath, "r") as f:
        vdev_id_txt = f.read()

    disks = []
    for line in vdev_id_txt.split("\n"):
        if not line.startswith("alias"):
            continue

        disk_id = line.split()[1]
        raw_path = line.split()[2]

        # Vdev path may be a real symlink OR a ".." traversal path.
        # Treat it as occupied if it exists.
        exists = os.path.exists(raw_path)

        # realpath resolves symlinks AND normalizes ".." paths
        resolved = os.path.realpath(raw_path) if exists else ""
        dev = os.path.basename(resolved) if resolved.startswith("/dev/") else ""

        disks.append({
            "id": disk_id,
            "path": raw_path,
            "occupied": exists,
            "dev": dev,
            "model": get_disk_model(raw_path) if exists else ""
        })

    return disks

def get_disk_model(disk_path):
    if not os.path.exists(disk_path):
        return ""
    device_name = os.path.basename(os.path.realpath(disk_path))
    smartctl = subprocess.Popen(shlex.split("smartctl --all /dev/{pth} -j".format(pth=device_name)), stdout=subprocess.PIPE, universal_newlines=True)
    smartctl.wait()

    jq_model_command = "jq ."
    jq_model_process = subprocess.Popen(
            shlex.split(jq_model_command), stdin=smartctl.stdout, stdout=subprocess.PIPE, universal_newlines=True, stderr=subprocess.STDOUT)
    jq_model_process.wait()
    jqout_model,_ = jq_model_process.communicate()

    try:
        jq_model_json = json.loads(jqout_model)
    except ValueError:
        jq_model_json =  {}

    if jq_model_json != None and "model_name" in jq_model_json.keys():
        return jq_model_json["model_name"]
    elif jq_model_json != None and "product" in jq_model_json.keys():
        return jq_model_json["product"]
    else:
        return "Unknown"

def get_scsi_info():
    scsi_info = []
    command_str = "lsscsi -U"
    lsscsi_output = subprocess.Popen(
        shlex.split(command_str), stdout=subprocess.PIPE, universal_newlines=True).stdout
    for line in lsscsi_output:
        regex = re.search("\[(.*)\]\s+\S+\s+(\S+)\s+\/dev\/(\S+)",line)
        if regex != None:
            scsi_info.append({
                "scsi_id": regex.group(1),
                "uuid": regex.group(2),
                "dev": regex.group(3)
            })
    return scsi_info

def consolidate_data(server_info, disks, scsi_info):
    # Build quick lookup for SCSI devices by /dev/<name> (e.g., sda, sdb)
    scsi_by_dev = {s["dev"]: s for s in scsi_info if s.get("dev")}
    # Track which disks we successfully matched to lsscsi
    matched = set()
    # HBA mapping behavior for disks we can match
    for hba in server_info.get("HBA", []):
        for disk in disks:
            if hba.get("Bus Address", "") and hba["Bus Address"] in disk.get("path", ""):
                scsi = scsi_by_dev.get(disk.get("dev", ""))
                if scsi:
                    matched.add(disk["id"])
                    if "disks" not in hba:
                        hba["disks"] = []
                    hba["disks"].append({
                        "id": disk["id"],
                        "path": disk["path"],
                        "occupied": disk["occupied"],
                        "dev": disk["dev"],
                        "scsi_id": scsi.get("scsi_id", ""),
                        "uuid": scsi.get("uuid", ""),
                        "model": disk["model"],
                    })
    # Build server_info["disks"]
    if "disks" not in server_info:
        server_info["disks"] = []

    # First, add all disks that matched lsscsi
    for disk in disks:
        scsi = scsi_by_dev.get(disk.get("dev", ""))
        if scsi:
            matched.add(disk["id"])
            server_info["disks"].append({
                "id": disk["id"],
                "path": disk["path"],
                "occupied": disk["occupied"],
                "dev": disk["dev"],
                "scsi_id": scsi.get("scsi_id", ""),
                "uuid": scsi.get("uuid", ""),
                "model": disk["model"],
            })

    # NVMe fallback: if not matched to lsscsi, still include it
    for disk in disks:
        if disk["id"] in matched:
            continue
        dev = disk.get("dev", "")
        if dev.startswith("nvme"):
            server_info["disks"].append({
                "id": disk["id"],
                "path": disk["path"],
                "occupied": disk["occupied"],
                "dev": dev,
                "scsi_id": "",     # not applicable / unknown for NVMe here
                "uuid": "",        # not applicable / unknown for NVMe here
                "model": disk["model"],
            })

def setup_fio_file(server_info,RO,RW):
    FIO_PARAMS_RW={
        "global": {
            "runtime":"30",
            "time_based":"1",
            "name":"loadtest",
            "filesize":"10G",
            "bs":"64k",
            "readwrite":"rw",
            "rwmixread":"20",
            "iodepth":"64",
            "numjobs":"1",
            "ioengine":"libaio",
            "direct":"1",
            "group_reporting":"1"
        }
    }
    FIO_PARAMS_RO={
        "global": {
            "runtime":"30",
            "time_based":"1",
            "name":"loadtest",
            "filesize":"10G",
            "bs":"64k",
            "readwrite":"randread",
            "iodepth":"64",
            "numjobs":"1",
            "ioengine":"libaio",
            "direct":"1",
            "group_reporting":"1"
        }
    }

    FIO_PARAMS = {}
    if RW:
        FIO_PARAMS = FIO_PARAMS_RW
    
    if RO:
        FIO_PARAMS = FIO_PARAMS_RO

    disks = server_info.get("disks",[])

    if os.path.isfile(os.path.expanduser("/opt/45drives/tools/loadtest.fio")):
        os.remove(os.path.expanduser("/opt/45drives/tools/loadtest.fio"))
    fio_job_file = open("/opt/45drives/tools/loadtest.fio","w")

    disk_string = ""
    num_jobs = 0

    for disk in disks:
        disk_string += "/dev/{ID}:".format(ID=disk["id"])

    disk_string = disk_string[:-1]
    
    for param_group in FIO_PARAMS:
        fio_job_file.write(f"[{param_group}]\n")
        for param in FIO_PARAMS[param_group]:
            if FIO_PARAMS[param_group][param] == "":
                fio_job_file.write(f"{param}\n")
            else:
                fio_job_file.write(f"{param}={FIO_PARAMS[param_group][param]}\n")
    
    fio_job_file.write("\n")
    
    disk_count = 0
    for disk in disks:
        fio_job_file.write(f"[group{disk_count}]\n")
        fio_job_file.write("filename=/dev/{did}\n\n".format(did=disk["id"]))
        disk_count = disk_count + 1

def run_fio(duration):
    with open('/opt/45drives/tools/loadtest.fio', 'r') as file :
        filedata = file.read()

    # set the proper duration in the fio job file.
    filedata = re.sub('^runtime=(.*)$', f'runtime={int(duration)}', filedata, flags = re.M)

    # Write the changes to the fio job file.
    with open('/opt/45drives/tools/loadtest.fio', 'w') as file:
        file.write(filedata)
    
    print("running fio job: /opt/45drives/tools/loadtest.fio  (duration: {d}s)".format(d=duration))

    # run fio
    fio = subprocess.Popen(
        shlex.split("fio /opt/45drives/tools/loadtest.fio"), stdout=subprocess.PIPE, universal_newlines=True)
    fio_output = fio.communicate()

def get_loadtest_info():
    # read/parse vdev_id.conf and return list of disk objects
    disks = get_disks_from_vdev_id()
    # read in server_info.json
    server_info = get_server_info()
    # read in ls -al /sys/bus/scsi/devices
    scsi_info = get_scsi_info()
    #add_disks_with_scsi_paths(disks,server_info)
    consolidate_data(server_info,disks,scsi_info)
    return server_info

def clear_log():
    rv=subprocess.run(["dmesg","--clear"],stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL).returncode
    if rv:
        print("Unable to clear dmesg log")
        exit(1)

def get_log():
    dmesg=subprocess.Popen(["dmesg"],stdout=subprocess.PIPE,stderr=subprocess.DEVNULL)
    try:
        outs, errs = dmesg.communicate(timeout=15)
    except subprocess.TimeoutExpired:
        dmesg.kill()
        outs, errs = dmesg.communicate()
    return outs.decode("UTF-8").split("\n")
        
def show_loadtest_disks(server_info):
    disks = server_info.get("disks",[])

    print("Performing a loadtest on the following disks: ")
    disk_str = ""
    disk_group = "1"
    for disk in disks:
        if disk["id"].split("-")[0] == disk_group:
            disk_str += "{d}, ".format(d=disk["id"])
        else:
            print(disk_str)
            disk_group = disk["id"].split("-")[0]
            disk_str = "{d}, ".format(d=disk["id"])
    print(disk_str)


def check_log(log, server_info):
    print("Checking log for errors...", end=" ")
    power_reset_errors = []
    failed_io_errors = []

    # SCSI patterns
    err_pwr_rst = re.compile(r"sd (\w+:\w+:\w+:\w+):.*(Power-on or device reset occurred)")
    err_failed_io = re.compile(r"sd (\w+:\w+:\w+:\w+): \[(sd\w+)\] (tag#\d+ FAILED .*$)")

    # NVMe / block-layer patterns to catch "Buffer I/O error on dev nvme0n1", "I/O error, dev nvme0n1", blk-mq variants.
    err_blk_io_1 = re.compile(r"\b(Buffer I/O error|I/O error)\b.*\bdev\s+(nvme\d+n\d+(?:p\d+)?)\b")
    err_blk_io_2 = re.compile(r"\bblk_update_request: I/O error, dev\s+(nvme\d+n\d+(?:p\d+)?)\b")
    err_blk_io_3 = re.compile(r"\bblk_mq_end_request: I/O error, dev\s+(nvme\d+n\d+(?:p\d+)?)\b")

    # NVMe controller-level patterns to catch timeouts/resets/controller failures
    err_nvme_ctrl_1 = re.compile(r"\bnvme\s+(nvme\d+):.*\b(timeout|timed out|I/O .* timeout)\b", re.IGNORECASE)
    err_nvme_ctrl_2 = re.compile(r"\bnvme\s+(nvme\d+):.*\b(reset|resetting|controller reset)\b", re.IGNORECASE)
    err_nvme_ctrl_3 = re.compile(r"\bnvme\s+(nvme\d+):.*\b(failed|failure|fatal|offline|dead)\b", re.IGNORECASE)
    # Some kernels log "nvme0: ..." (without leading "nvme ")
    err_nvme_ctrl_4 = re.compile(r"\b(nvme\d+):.*\b(timeout|timed out|reset|resetting|failed|fatal|offline|dead)\b", re.IGNORECASE)

    for line in log:
        # SCSI detection
        m = re.search(err_pwr_rst, line)
        if m is not None:
            power_reset_errors.append({"scsi_id": m.group(1), "error_msg": m.group(2)})

        m = re.search(err_failed_io, line)
        if m is not None:
            failed_io_errors.append({
                "scsi_id": m.group(1),
                "dev": m.group(2),
                "error_msg": m.group(3)
            })

        # NVMe block-device detection
        m = (re.search(err_blk_io_1, line) or
             re.search(err_blk_io_2, line) or
             re.search(err_blk_io_3, line))
        if m is not None:
            nvme_dev = m.group(m.lastindex)
            failed_io_errors.append({
                "scsi_id": "",
                "dev": nvme_dev,
                "error_msg": line.strip()
            })

        # NVMe controller-level detection
        m = (re.search(err_nvme_ctrl_1, line) or
             re.search(err_nvme_ctrl_2, line) or
             re.search(err_nvme_ctrl_3, line) or
             re.search(err_nvme_ctrl_4, line))
        if m is not None:
            ctrl = m.group(1)
            failed_io_errors.append({
                "scsi_id": "",
                "dev": ctrl,          # controller name (nvme0)
                "error_msg": line.strip()
            })

    disks = server_info.get("disks", [])
    bad_disks = set()

    # SCSI correlation
    for perr in power_reset_errors:
        for disk in disks:
            if perr.get("scsi_id", "0:0:0:0") == disk.get("scsi_id", ""):
                print("\033[0;31mERROR:\033[0;0m {d}: [{sid}] -> {emsg} uuid={u}".format(
                    d=disk["id"], sid=disk.get("scsi_id", ""), emsg=perr["error_msg"], u=disk.get("uuid", "")
                ))
                bad_disks.add(disk["id"])

    for ioerr in failed_io_errors:
        for disk in disks:
            # SCSI match by scsi_id
            if ioerr.get("scsi_id") and ioerr["scsi_id"] == disk.get("scsi_id", ""):
                print("\033[0;31mERROR:\033[0;0m {d}: [{sid}][{sd}] -> {emsg} uuid={u}".format(
                    d=disk["id"], sd=ioerr.get("dev", ""), sid=disk.get("scsi_id", ""),
                    emsg=ioerr["error_msg"], u=disk.get("uuid", "")
                ))
                bad_disks.add(disk["id"])

            # NVMe match by namespace name (nvme0n1) OR partition (nvme0n1p1)
            if (not ioerr.get("scsi_id")) and ioerr.get("dev"):
                evdev = ioerr["dev"]
                ddev = disk.get("dev", "")

                # Direct namespace/partition match:
                if evdev == ddev:
                    print("\033[0;31mERROR:\033[0;0m {d}: [dev={sd}] -> {emsg}".format(
                        d=disk["id"], sd=evdev, emsg=ioerr["error_msg"]
                    ))
                    bad_disks.add(disk["id"])

                # Controller event match: "nvme0" should match disk dev "nvme0n1" when the event dev is controller-shaped and disk dev is namespace-shaped.
                elif re.fullmatch(r"nvme\d+", evdev) and ddev.startswith(evdev):
                    print("\033[0;31mERROR:\033[0;0m {d}: [ctrl={sd}] -> {emsg}".format(
                        d=disk["id"], sd=evdev, emsg=ioerr["error_msg"]
                    ))
                    bad_disks.add(disk["id"])

    if bad_disks == set():
        print("No errors found.")
    else:
        print("Errors found")
    return bad_disks

def write_logs(logs,log_path):
    if os.path.isfile(os.path.abspath(os.path.expanduser(log_path))):
        os.remove(os.path.abspath(os.path.expanduser(log_path)))
    
    print("Writing log file to: {lp}".format(lp=os.path.abspath(os.path.expanduser(log_path))))
    log_file = open(os.path.abspath(os.path.expanduser(log_path)),"w")
    trial_num = 1
    for log in logs:
        log_file.write("---------------------------------------------------------------------------------\n")
        log_file.write("Trial {tn} of {tot}\n".format(tn=trial_num,tot=len(logs)))
        log_file.write("---------------------------------------------------------------------------------\n")
        for line in log:
            log_file.write(line)
            log_file.write("\n")
        trial_num = trial_num + 1
    log_file.close()
    
def disable_dmesg_console():
    rv=subprocess.run(["dmesg","--console-off"],stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL).returncode
    if rv:
        print("Unable to disable dmesg console messages. (dmesg --console-off)")

def enable_dmesg_console():
    rv=subprocess.run(["dmesg","--console-on"],stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL).returncode
    if rv:
        print("Unable to enable dmesg console messages. (dmesg --console-off)")
        exit(1)

def show_welcome_and_confirmation(ro, rw):
    print("{s:{c}^{l}}".format(s="-",l=80,c="-"))
    print("{r} /$$   /$$ /$$$$$$$ {e}{w} /$$$$$$$            /$$                              {e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("{r}| $$  | $$| $$____/ {e}{w}| $$__  $$          |__/                              {e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("{r}| $$  | $$| $$      {e}{w}| $$  \ $$  /$$$$$$  /$$ /$$    /$$ /$$$$$$   /$$$$$$${e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("{r}| $$$$$$$$| $$$$$$$ {e}{w}| $$  | $$ /$$__  $$| $$|  $$  /$$//$$__  $$ /$$_____/{e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("{r}|_____  $$|_____  $${e}{w}| $$  | $$| $$  \__/| $$ \  $$/$$/| $$$$$$$$|  $$$$$$ {e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("{r}      | $$ /$$  \ $${e}{w}| $$  | $$| $$      | $$  \  $$$/ | $$_____/ \____  $${e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("{r}      | $$|  $$$$$$/{e}{w}| $$$$$$$/| $$      | $$   \  $/  |  $$$$$$$ /$$$$$$$/{e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("{r}      |__/ \______/ {e}{w}|_______/ |__/      |__/    \_/    \_______/|_______/ {e}".format(w=ANSI_colors["WHITE"],r=ANSI_colors["RED"],e=ANSI_colors["END"]))
    print("                                                                          ")
    print("                                                                          ")
    print("       {w} /$$                           /$$   /$$                           /$$    {e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("       {w}| $$                          | $$  | $$                          | $$    {e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("       {w}| $$  /$$$$$$   /$$$$$$   /$$$$$$$ /$$$$$$    /$$$$$$   /$$$$$$$ /$$$$$$  {e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("       {w}| $$ /$$__  $$ |____  $$ /$$__  $$|_  $$_/   /$$__  $$ /$$_____/|_  $$_/  {e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("       {w}| $$| $$  \ $$  /$$$$$$$| $$  | $$  | $$    | $$$$$$$$|  $$$$$$   | $$    {e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("       {w}| $$| $$  | $$ /$$__  $$| $$  | $$  | $$ /$$| $$_____/ \____  $$  | $$ /$${e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("       {w}| $$|  $$$$$$/|  $$$$$$$|  $$$$$$$  |  $$$$/|  $$$$$$$ /$$$$$$$/  |  $$$$/{e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("       {w}|__/ \______/  \_______/ \_______/   \___/   \_______/|_______/    \___/  {e}".format(w=ANSI_colors["WHITE"],e=ANSI_colors["END"]))
    print("                                                                        v{v}".format(v=get_version()))
    print("{s:{c}^{l}}".format(s="-",l=80,c="-"))

    if ro and rw:
        print("--readonly and --readwrite options are mutually exclusive. Use only one.")
        exit(1)
    
    if not ro and not rw:
        print("{r}!!! WARNING !!!{e} loadtest will perform random reads/writes on all storage drives installed in the system.".format(r=ANSI_colors["RED"],e=ANSI_colors["END"]))
        print("                Random writes performed on storage drives can lead to {r}IRRECOVERABLE CORRUPTION OF STORED DATA.{e}".format(r=ANSI_colors["RED"],e=ANSI_colors["END"]))
        confirm = input("Continue with loadtest (y/n): ")
        if confirm not in ["y","Y"]:
            print("Note: you can perform the loadtest using the --readonly option. Use --help for more information.")
            print("loadtest cancelled")
            exit(0)
        else:
            confirm2 = input("Are you sure? type 'yesiamsure': ")
            if confirm2 == "yesiamsure":
                return False, True
            else:
                print(f"Confirmation '{confirm2}' did not match 'yesiamsure'. Loadtest cancelled")
                exit(0)
    if ro:
        print("Performing loadtest in {g}readonly{e} mode.".format(g=ANSI_colors["GREEN"],e=ANSI_colors["END"]))
        return True, False
    
    if rw:
        print("Performing loadtest in {r}readwrite{e} mode.".format(r=ANSI_colors["RED"],e=ANSI_colors["END"]))
        return False, True
        

def get_version():
    version_file_path = "/etc/45drives/server_info/tools_version"
    if not os.path.exists(version_file_path):
        return ""
    else:
        f = open(version_file_path, "r")
        version = f.readline()
        f.close()
        return version.strip()

def write_to_kmsg(msg):
    cmd_str = "echo '{m}' > /dev/kmsg".format(m=msg)
    subprocess.run(cmd_str, shell=True)

def check_log_path(log_path):
    # create directories in user specified path if needed.
    if not os.path.isdir(os.path.dirname(os.path.abspath(os.path.expanduser(log_path)))):
        # directory doesn't exist for given log_path.
        print("Missing director(y/ies) in path provided. ({lp})".format(lp=log_path))
        response = input("Would you like to create the the missing director(y/ies) '{ld}' (y/n): ".format(ld=os.path.dirname(os.path.abspath(os.path.expanduser(log_path)))))
        if response in ["y","Y"]:
            print("Making directories required for log file: {ld}".format(ld=os.path.dirname(os.path.abspath(os.path.expanduser(log_path)))))
            try:
                os.makedirs(os.path.dirname(os.path.abspath(os.path.expanduser(log_path))))
            except:
                print("Unable to make directories: {ld}".format(ld=os.path.dirname(os.path.abspath(os.path.expanduser(log_path)))))
                print("Exiting loadtest..")
                exit(1)
        else:
            print("Run loadtest again using valid path argument. (option: -l <log_file>)")
            exit(1)

    if os.path.isfile(os.path.abspath(os.path.expanduser(log_path))):
        #prompt user for overwrite if logfile exists already.
        overwrite = input("Overwrite existing file: {lp} ? (y/n): ".format(lp=os.path.abspath(os.path.expanduser(log_path))))
        if overwrite in ["y","Y"]:
            os.remove(os.path.abspath(os.path.expanduser(log_path)))
            print("Making Log File: {lp}".format(lp=os.path.abspath(os.path.expanduser(log_path))))
            log_file = open(os.path.abspath(os.path.expanduser(log_path)),"w")
            log_file.write("loadtest Log File. {lp}".format(lp=os.path.abspath(os.path.expanduser(log_path))))
            log_file.close()
        else:
            print("Run loadtest again using different path. (option: -l <log_file>)")
            exit(1)
    else:
        #Create log file, and inform user of its location.
        print("Making Log File: {lp}".format(lp=os.path.abspath(os.path.expanduser(log_path))))
        log_file = open(os.path.abspath(os.path.expanduser(log_path)),"w")
        log_file.write("loadtest Log File. {lp}".format(lp=os.path.abspath(os.path.expanduser(log_path))))
        log_file.close()

def write_disk_models_to_kmsg(server_info):
    disks = server_info.get("disks",[])
    write_to_kmsg("Disks present during loadtest: ")
    for disk in disks:
        if "occupied" in disk.keys() and disk["occupied"]:
            disk_string = "{alias}: {dev} ({model})".format(alias=disk["id"],dev=disk["dev"],model=disk["model"])
            write_to_kmsg(disk_string)

def is_root():
    return os.geteuid() == 0

def main(): 
    parser = OptionParser() #use optparse to handle command line arguments
    parser.add_option("-w", "--readwrite", action="store_true", dest="readwrite",
            default=False, help="perform loadtest using read/write options without provoking warning message.")
    parser.add_option("-r", "--readonly", action="store_true", dest="readonly",
        default=False, help="perform loadtest in readonly mode.")    
    parser.add_option("-d", "--duration", action="store", dest="test_duration",
        default=600, type="int", help="overall test duration in seconds")
    parser.add_option("-t", "--trial_duration", action="store", dest="trial_duration",
        default=30, type="int", help="trial duration in seconds")  
    parser.add_option("-n", "--noflush", action="store_true", dest="noflush",
        default=False, help="don't flush dmesg log during loadtest")
    parser.add_option("-l", "--logfile", action="store", dest="log_file",
        default="/opt/45drives/tools/loadtest.log", type="str", help="specify path for loadtest log file.")
    (options, args) = parser.parse_args()

    if not is_root():
        print("Loadtest must be run with root privelages.")
        exit(1)

    RO, RW = show_welcome_and_confirmation(options.readonly, options.readwrite)

    if options.test_duration < options.trial_duration:
        print("Trial duration entered must be less than test duration.")
        exit(1)
    
    if not options.noflush: disable_dmesg_console()

    check_log_path(options.log_file)
    server_info = get_loadtest_info()

    logs = []
    bad_disks = set()
    total_trials = options.test_duration//options.trial_duration
    current_trial = 1
    
    write_to_kmsg("Starting 45Drives loadtest.")
    write_disk_models_to_kmsg(server_info)
    log = get_log()
    logs.append(log)
    while bad_disks == set() and current_trial <= total_trials:
        print("--------------------------------------------------------------------------------")
        print("Trial {x} of {y}".format(x=current_trial,y=total_trials))
        print("--------------------------------------------------------------------------------")
        if not options.noflush: clear_log()
        write_to_kmsg("Begin 45Drives loadtest trial {x} of {y}.".format(x=current_trial,y=total_trials))
        server_info = get_loadtest_info()
        show_loadtest_disks(server_info)
        setup_fio_file(server_info,RO,RW)
        run_fio(options.trial_duration)
        log = get_log()
        logs.append(log)
        bad_disks = check_log(log,server_info)
        if bad_disks == set():
            print("Trial {x} of {y} {g}Passed{e}".format(x=current_trial,y=total_trials,g=ANSI_colors["GREEN"],e=ANSI_colors["END"]))
            print("--------------------------------------------------------------------------------")
            write_to_kmsg("End 45Drives loadtest trial {x} of {y}.".format(x=current_trial,y=total_trials))
            logs[-1] = get_log()
            current_trial = current_trial + 1
            

    if bad_disks != set():
        print("Trial {x} of {y} {r}Failed{e}".format(x=current_trial,y=total_trials,r=ANSI_colors["RED"],e=ANSI_colors["END"]))
        print("--------------------------------------------------------------------------------")
        print("Loadtest Halted. Errors encountered with the following disks:")
        write_to_kmsg("End 45Drives loadtest trial {x} of {y}.".format(x=current_trial,y=total_trials))
        write_to_kmsg("45Drives Loadtest Halted.") 
        write_to_kmsg("45Drives Loadtest Failed") 
        logs[-1] = get_log()

        for sus_disk in bad_disks:
            print(f"\t{sus_disk}")
    else:
        print("\n{g}Loadtest Passed{e}".format(g=ANSI_colors["GREEN"],e=ANSI_colors["END"])) 
        write_to_kmsg("45Drives Loadtest Completed")
        write_to_kmsg("45Drives Loadtest Passed") 
        logs[-1] = get_log()

    write_logs(logs,options.log_file)
    if not options.noflush: enable_dmesg_console()


if __name__ == "__main__":
    main()