#!/usr/bin/env python3
import subprocess
import logging
import sys
import traceback
import argparse
import json
import re

__author__ = "Andrii Konts"
__version__ = "0.1"
__date__ = "Feb 2023"

logger = logging.getLogger(__name__)
lspci = ''

def set_logs(debug=False):
    if debug:
        log_level = logging.DEBUG
    else:
        log_level = logging.INFO

    logger = logging.getLogger()
    logger.setLevel(log_level)
    # dirty hack for colors:)
    _colored = "\033[1;41m%s\033[1;0m" % logging.getLevelName(logging.ERROR)
    logging.addLevelName(logging.ERROR, _colored)

    # console
    ch = logging.StreamHandler()
    ch_format = logging.Formatter("%(levelname)-6s: %(message)s")
    ch.setFormatter(ch_format)
    logger.addHandler(ch)

def get_args(args_list):
    """parse arguments"""
    parser = argparse.ArgumentParser(description='net_interface_discovery.py script')
    parser.add_argument('-v', '--version', action='version', version='%(prog)s {version}'.format(version=__version__))
    parser.add_argument('--debug', dest='debug', action='store_true', help="Set log level to DEBUG")
    parser.add_argument('-o', '--onboard', dest='onboard', default="mgmt", help="Name for onboard interfaces")
    parser.add_argument('-e', '--external', dest='external', default="uplink", help="Name for external interfaces")
    group = parser.add_mutually_exclusive_group()
    group.add_argument('-j', '--json', dest='json', action='store_true', help="Show output in json format")
    group.add_argument('-u', '--udev', dest='udev', action='store_true', help="Show udev rules")
    return parser.parse_args(args_list)

def run_cmd(command, interactive=False):
    """run any command and watch progress"""
    logger.debug(f"Running command: {command}")

    if interactive:
        p = subprocess.Popen(command, shell=True)
    else:
        p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    out, err = p.communicate()

    logger.debug(f"Command output: \n{out}")
    if p.returncode:
        raise Exception('Return code: {code}\nOutput: {out}\nError: {err}'.format(code=p.returncode, out=out, err=err))
    out = out.decode() if out is not None else None
    return out

def parse_lspci_out(lspci_out):
    devices = []
    lspci_regex = re.compile("(.*) \\[(.*)\\]$")
    for entry in lspci_out.split("\n\n"):
        device = {}
        for line in entry.splitlines():
            line = line.split(":\t")
            match = lspci_regex.match(line[1])
            if match:
                device[line[0].strip().lower()] = match.group(1).strip()
                device[f"{line[0].strip().lower()}_id"] = match.group(2).strip()
            else:
                device[line[0].strip().lower()] = line[1].strip()
        if len(device.keys()) > 0:
            devices.append(device)
    return devices

def get_ethernet_controllers():
    # get ethernet controllers from lspci
    lspci_out = run_cmd(f"{lspci} -vmmnnD -d ::0200")
    ethernet_controllers = parse_lspci_out(lspci_out)

    # filter virtual functions
    ethernet_controllers = list(filter(lambda x: "Virtual Function" not in x["device"] and "BlueField" not in x["device"], ethernet_controllers))
    logger.debug(json.dumps(ethernet_controllers, indent=2))

    # get system vendor
    system_vendor = get_system_vendor()
    logger.debug(f"System vendor: {system_vendor}")

    # get unique pci addresses of network interfaces
    pci_addresses = set()
    ls_out = run_cmd("ls -l /sys/class/net|grep pci")
    for line in ls_out.splitlines():
        line = line.split(" -> ../..")
        pci_addresses.add('/sys'+'/'.join(line[1].strip().split('/')[0:-2]))

    # get device info for each network interface
    network_interfaces = {}
    for pci_address in pci_addresses:
        slot = pci_address.split("/")[-1]
        interface_name = run_cmd(f"ls -1 {pci_address}/net | head -n 1").strip()
        device_path = f"{pci_address}/net/{interface_name}"
        mac = run_cmd(f"cat {device_path}/address").strip()
        perm_mac = run_cmd(f"test -f {device_path}/bonding_slave/perm_hwaddr && cat {device_path}/bonding_slave/perm_hwaddr || echo 0").strip()
        speed = run_cmd(f"test -f {device_path}/speed && cat {device_path}/speed || echo 0").strip()
        network_interfaces[slot] = {
            "name": interface_name,
            "mac": mac if perm_mac == '0' else perm_mac,
            "speed": speed if speed != "-1" else None,
            "device_path": device_path
        }

    for i in range(len(ethernet_controllers)):
        ethernet_controllers[i].update(network_interfaces[ethernet_controllers[i]["slot"]])
        ethernet_controllers[i]["onboard"] = ethernet_controllers[i]["svendor"] == system_vendor

    return ethernet_controllers

def get_system_vendor():
    lspci_out = run_cmd(f"{lspci} -vmmnnD -d ::0300")
    vga_controllers = parse_lspci_out(lspci_out)
    if len(vga_controllers) == 0:
        raise Exception("failed to identify system vendor")
    return vga_controllers[0]["svendor"]

def gen_udev_rules(ethernet_controllers, onboard_interfaces_name="mgmt", external_interfaces_name="uplink"):
    onboard_interfaces_increment = 1
    external_interfaces_increment = 1
    udev_rules = []
    udev_line = "SUBSYSTEM==\"net\", ACTION==\"add\", DRIVERS==\"?*\", ATTR{{address}}==\"{mac}\", NAME=\"{name}\""

    for controller in ethernet_controllers:
        if controller['onboard']:
            required_name = f"{onboard_interfaces_name}{onboard_interfaces_increment}"
            onboard_interfaces_increment += 1
        else:
            required_name = f"{external_interfaces_name}{external_interfaces_increment}"
            external_interfaces_increment += 1

        udev_rules.append(
            udev_line.format(
                mac=controller["mac"],
                name=required_name,
            )
        )
    return "\n".join(udev_rules)

def gen_table(ethernet_controllers):
    table = []
    output_line = "| {slot:<13}| {vendor:<30}| {device:<40}| {speed:<8}| {mac:<18}| {name:<15}|"
    table.append(
        output_line.format(
            slot="Slot",
            vendor="Vendor",
            device="Model",
            speed="Speed",
            mac="MAC",
            name="Name"
        )
    )
    for controller in ethernet_controllers:
        table.append(
            output_line.format(
                slot=controller["slot"],
                vendor=controller["vendor"],
                device=controller["device"],
                speed=controller["speed"] if controller["speed"] else "---",
                mac=controller["mac"],
                name=controller["name"]
            )
        )
    return "\n".join(table)

def main(args_list):
    global lspci
    args = get_args(args_list)
    set_logs(debug=args.debug)
    logger.debug(args)
    lspci = run_cmd('which lspci').replace('\n', '')
    logger.debug(f"lspci location: {lspci}")
    ethernet_controllers = get_ethernet_controllers()

    if args.udev:
        print(
            gen_udev_rules(ethernet_controllers, args.onboard, args.external)
        )
    elif args.json:
        print(json.dumps(ethernet_controllers, indent=2))
    else:
        print(gen_table(ethernet_controllers))


if __name__ == "__main__":
    try:
        main(sys.argv[1:])
    except KeyboardInterrupt:
        logger.error('Killed by CTRL+C')
        exit(0)
    except SystemExit:
        exit(0)
    except Exception as e:
        logger.error(e)
        print(traceback.format_exc())
        exit(1)

