#!/usr/bin/python3
"""
Collects SMART information from disks and outputs it
in a Prometheus-compatible format.
"""

from __future__ import print_function
import os
import sys
import time
import subprocess
import json
from past.builtins import long


class Output(object):
    # pylint: disable=useless-object-inheritance
    """Output accumulates script output per metric name."""

    def __init__(self, prefix):
        """Initializes Output class.

        All metrics added to this class will have :prefix: prepended to their
        name during final output.
        """

        self.metrics = {}
        self.order = []
        self.prefix = prefix

    def add(self, name, val, **kwargs):
        """Add a metric :name: with a value :val:.

        Arbitrary number of label/value pairs may be provided in :kwargs:.
        """
        if name not in self.metrics:
            self.order.append(name)
            self.metrics[name] = []
        self.metrics[name].append((val, kwargs))

    def flush(self):
        """Flushes accumulated metrics in a Prometheus-compatible format."""
        for name in self.order:
            print('# HELP {0}_{1} metric {1}'.format(self.prefix, name))
            print('# TYPE {0}_{1} gauge'.format(self.prefix, name))
            for val, lbls in self.metrics[name]:
                slbls = ','.join('{}="{}"'.format(key, value) for key, value
                                 in lbls.items())
                print('{0}_{1}{{{2}}} {3}'.format(
                    self.prefix, name, slbls, val))


def maybe_fraction(which):
    """Returns a function that parses a SMART metric which may be expressed
    as a fraction.

    The parser function will return :which: part of the fraction.
    """

    def process(val):
        parts = val.split('/')
        if which >= len(parts):
            return long(parts[0])
        return long(parts[which])
    return process


SMARTCTL = '/usr/sbin/smartctl'
STORCTL = '/opt/MegaRAID/storcli/storcli64'
SMARTATTR = {
    'airflow_temperature_cel': None,
    'available_reservd_space': None,
    'average_erase_count': None,
    'average_slc_erase_ct': None,
    'calibration_retry_count': None,
    'command_timeout': None,
    'crc_error_count': None,
    'current_pending_sector': None,
    'disk_shift': None,
    'ecc_uncorr_error_count': maybe_fraction(0),
    'end_to_end_error': None,
    'erase_fail_count': None,
    'erase_fail_count_total': None,
    'flash_writes_gib': None,
    'g_sense_error_rate': None,
    'hardware_ecc_recovered': None,
    'head_flying_hours': None,
    'high_fly_writes': None,
    'host_reads_mib': None,
    'host_reads_32mib': None,
    'host_writes_mib': None,
    'host_writes_32mib': None,
    'initial_bad_block_count': None,
    'life_curve_status': None,
    'lifetime_writes_gib': None,
    'lifetime_reads_gib': None,
    'load_cycle_count': None,
    'load_friction': None,
    'load_in_time': None,
    'load_retry_count': None,
    'loaded_hours': None,
    'max_erase_count': None,
    'max_slc_erase_ct': None,
    'min_erase_count': None,
    'min_slc_erase_ct': None,
    'maxavgerase_ct': None,
    'media_wearout_indicator': None,
    'multi_zone_error_rate': None,
    'nand_writes_1gib': None,
    'nand_writes_32mib': None,
    'offline_uncorrectable': None,
    'percent_lifetime_remain': None,
    'power_cycle_count': None,
    'power_loss_cap_test': None,
    'power_off_retract_count': None,
    'power_on_hours': None,
    'program_fail_count': None,
    'program_fail_cnt_total': None,
    'raid_recoverty_ct': None,
    'raw_read_error_rate': maybe_fraction(0),
    'read_soft_error_rate': None,
    'reallocated_event_count': None,
    'reallocated_sector_ct': None,
    'remaining_lifetime_perc': None,
    'reported_uncorrect': None,
    'retired_block_count': None,
    'runtime_bad_block': None,
    'sandforce_internal': None,
    'sata_downshift_count': None,
    'sata_phy_error_count': None,
    'seek_error_rate': None,
    'seek_time_performance': None,
    'slc_writes_32mib': None,
    'soft_ecc_correct_rate': maybe_fraction(0),
    'spin_retry_count': None,
    'spin_up_time': None,
    'ssd_life_left': None,
    'start_stop_count': None,
    'temperature_case': None,
    'temperature_celsius': None,
    'temperature_internal': None,
    'thermal_throttle': maybe_fraction(1),
    'throughput_performance': None,
    'tlc_writes_32mib': None,
    'total_erase_count': None,
    'total_lbas_read': None,
    'total_lbas_written': None,
    'total_slc_erase_ct': None,
    'udma_crc_error_count': None,
    'unc_soft_read_err_rate': maybe_fraction(0),
    'uncorrectable_error_cnt': None,
    'unexpect_power_loss_ct': None,
    'unsafe_shutdown_count': None,
    'unused_rsvd_blk_cnt_tot': None,
    'used_rsvd_blk_cnt_tot': None,
    'valid_spare_block_cnt': None,
    'wear_range_delta': None,
    'workld_host_reads_perc': None,
    'workld_media_wear_indic': None,
    'workload_minutes': None,
}
SMARTOUT = Output("smart")


def smartctl_list_devices():
    """Lists all SMART-compatible devices on the node."""

    proc = subprocess.Popen([SMARTCTL, '--scan-open'],
                            stdout=subprocess.PIPE, text=True)
    out, _ = proc.communicate()

    devs = {}
    megaraid_devs = []
    for line in out.splitlines():
        line = line.strip()
        if not line or line.startswith('#'):
            continue
        parts = line.split(' ')
        if 'megaraid' in parts[2]:
            megaraid_devs.append({
                'device': parts[0],
                'type': parts[2],
            })
            continue
        devs[parts[0]] = [
            {
                'device': parts[0],
                'smartdev': parts[0],
                'type': parts[2],
            }
        ]
    return devs, megaraid_devs


def megaraid_process_vd(virt_disk, data, result, smart_devs):
    """Collects info about single MegaRAID virtual disk."""

    properties = 'VD{} Properties'.format(virt_disk)
    if properties not in data:
        return
    if 'SCSI NAA Id' not in data[properties]:
        return

    scsi_id = data[properties]['SCSI NAA Id']
    try:
        device = '/dev/' + os.path.basename(
            os.readlink('/dev/disk/by-id/wwn-0x{}'.format(scsi_id)))
    except:
        return

    phys_disks = 'PDs for VD {}'.format(virt_disk)
    if phys_disks not in data:
        return

    disks = []
    for phys_disk in data[phys_disks]:
        if 'DID' not in phys_disk:
            continue

        dev_id = phys_disk['DID']
        dev_type = 'megaraid,{}'.format(dev_id)
        for smart_dev in smart_devs:
            if not smart_dev['type'].endswith(dev_type):
                continue
            disks.append({
                'device': device,
                'smartdev': smart_dev['device'],
                'type': smart_dev['type'],
            })
            break

    if disks:
        result[device] = disks


def megaraid_list_devices(smart_devs):
    """Lists devices behind MegaRAID controllers."""

    proc = subprocess.Popen([STORCTL, '/call', '/vall', 'show', 'all', 'J'],
                            stdout=subprocess.PIPE, text=True)
    out, _ = proc.communicate()
    response = json.loads(out)

    result = {}
    for ctrl in response['Controllers']:
        if ctrl['Command Status']['Status'] != 'Success':
            continue
        if 'Response Data' not in ctrl:
            continue

        ctrl_id = ctrl['Command Status']['Controller']
        ctrl_cmd = "/c{}/v".format(ctrl_id)
        data = ctrl['Response Data']

        vds = []
        for key in data.keys():
            if key.startswith(ctrl_cmd):
                vds.append(int(key[len(ctrl_cmd):]))

        for virt_disk in vds:
            megaraid_process_vd(virt_disk, data, result, smart_devs)

    return result


def smart_merge_with_megaraid(smart_devs, megaraid_devs):
    """Merges SMART capable devices with the devices behind MegaRAID
    controllers (if any), removing all the duplicates.
    """

    for dev_node, disks in megaraid_devs.items():
        if dev_node in smart_devs:
            smart_devs[dev_node] = disks

    devices = []
    for dev_node, disks in smart_devs.items():
        devices.extend(disks)

    return devices


def smart_list_devices():
    """Lists all SMART-capable devices, properly handling the ones behind
    MegaRAID controllers.
    """

    smart_devs, smart_megaraid = smartctl_list_devices()
    megaraid_devs = megaraid_list_devices(smart_megaraid)

    return smart_merge_with_megaraid(smart_devs, megaraid_devs)


def smart_print_device_info(dev):
    """Collects SMART device info for a particular device.

    Supports any device type expect for NVMe, which is processed separtely
    using 'nvme' tool.
    """

    # NVMe devices are processes using 'nvme' tool below so skip them for
    # smartctl
    if dev['type'] == 'nvme':
        return False

    encoding = {'PYTHONIOENCODING': 'utf-8'}
    proc = subprocess.Popen([SMARTCTL, '-i', '-H', '-d',
                             dev['type'], dev['smartdev']],
                            stdout=subprocess.PIPE, env=encoding, text=True)
    out, _ = proc.communicate()
    if proc.returncode & 0x7 != 0:
        return False

    infos = {}
    for line in out.splitlines():
        parts = line.split(':')
        if len(parts) != 2:
            continue
        infos[parts[0].strip().lower()] = parts[1].strip()

    model = infos.get('device model', '')
    serial = infos.get('serial number', "UNKNOWN")
    if model == '':
        model = infos.get('vendor', 'UNKNOWN') + ' ' + \
                infos.get('product', 'UNKNOWN')
    enabled = 'Enabled' in infos['smart support is']
    available = enabled or 'Disabled' in infos['smart support is']
    healthy = \
        infos.get('smart overall-health self-assessment test result', '') == \
        'PASSED' or infos.get('smart health status', '') == 'OK'
    capacity = long(infos['user capacity'].split(' ')[0].replace(',', '')) if \
        'user capacity' in infos else 0
    SMARTOUT.add("device_info", 1, disk=dev['device'], type=dev['type'],
                 device_model=model,
                 serial_number=serial)
    SMARTOUT.add('device_smart_available', int(available), disk=dev['device'],
                 type=dev['type'])
    SMARTOUT.add('device_smart_enabled', int(enabled), disk=dev['device'],
                 device_model=model, serial_number=serial, type=dev['type'])
    SMARTOUT.add('device_smart_healthy', int(healthy), disk=dev['device'],
                 device_model=model, serial_number=serial, type=dev['type'])
    SMARTOUT.add('device_capacity_bytes', capacity,
                 disk=dev['device'], type=dev['type'])

    return enabled


def scsi_parse_temperature(name):
    """Returns a function that parses SCSI SMART temperature metric."""

    def parse(line):
        parts = [x.strip() for x in line.split(':')]
        val = 0
        if parts[1] != '<not available>':
            val = int(parts[1][:-2])
        return [(name, val)]
    return parse


def scsi_parse_percent(name):
    """Returns a function that parses SCSI SMART percentage metric."""

    return lambda line: [(name, 100 - int(line.split(':')[1].strip()[:-1]))]


def scsi_parse_int(name):
    """Returns a function that parses a part of SCSI SMART output
    with an ':' separator.
    """

    return lambda line: [(name, int(line.split(':')[1].strip()))]


def scsi_parse_equal_sign(name, term, typ):
    """Returns a function parses a part of SCSI SMART output with
    an '=' separator.
    """

    return lambda line: [(name, typ(line.split('=')[term].strip()))]


def scsi_parse_errors_log(name):
    """Returns a functions that parses SCSI SMART errors log :name:."""

    def parse(line):
        metrics = []
        part = [x.strip() for x in line.split()]
        metrics.append((name + '_errors_corrected_ecc_fast', long(part[1])))
        metrics.append((name + '_errors_corrected_ecc_delayed', long(part[2])))
        metrics.append(
            (name + '_errors_corrected_reread_rewrite', long(part[3])))
        metrics.append((name + '_errors_corrections', long(part[5])))
        metrics.append((name + '_errors_data_processed', float(part[6]) *
                        (10**9)))
        metrics.append((name + '_errors_uncorrected', int(part[7])))

        return metrics
    return parse


SMART_SCSI_ATTR = {
    'Current Drive Temperature': scsi_parse_temperature('temperature_celsius'),
    'Drive Trip Temperature':
    scsi_parse_temperature('trip_temperature_celsius'),
    'Percentage used endurance indicator': scsi_parse_percent('life_left'),
    'Specified cycle count': scsi_parse_int('start_stop_spec'),
    'Accumulated start-stop cycles': scsi_parse_int('start_stop_cycles'),
    'Specified load-unload count': scsi_parse_int('load_cycle_spec'),
    'Accumulated load-unload cycles': scsi_parse_int('load_unload_cycles'),
    'Elements in grown defect list': scsi_parse_int('reallocated_sector_ct'),
    'Blocks sent to initiator':
    scsi_parse_equal_sign('total_blocks_read', 1, long),
    'Blocks received from initiator':
    scsi_parse_equal_sign('total_blocks_written', 1, long),
    'Blocks read from cache and sent to initiator':
    scsi_parse_equal_sign('total_blocks_cache_read', 1, long),
    'Number of read and write commands whose size <= segment size':
    scsi_parse_equal_sign('smaller_than_segment_reqs', 2, long),
    'Number of read and write commands whose size > segment size':
    scsi_parse_equal_sign('larger_that_segment_reqs', 1, long),
    'number of hours powered up':
    scsi_parse_equal_sign('power_on_hours', 1, float),
    'read:': scsi_parse_errors_log('read'),
    'write:': scsi_parse_errors_log('write'),
    'verify:': scsi_parse_errors_log('verify'),
    'Non-medium error count': scsi_parse_int('non_medium_errors'),
}


def smart_scsi_print_counters(dev, out):
    """Collects SMART counters from a SCSI device."""

    for line in out.splitlines():
        for known_counter, parser in SMART_SCSI_ATTR.items():
            if known_counter in line:
                metrics = parser(line)
                for metric in metrics:
                    SMARTOUT.add('scsi_' + metric[0], metric[1],
                                 disk=dev['device'], type=dev['type'])


def smart_print_counters(dev):
    """Collects SMART counters from a particular device."""

    proc = subprocess.Popen([SMARTCTL, '-a', '-d', dev['type'],
                             '-v', '9,raw24(raw8)',  # Power_On_Hours
                             '-v', '240,raw56:3210r54',  # Head_Flying_Hours
                             dev['smartdev']],
                            stdout=subprocess.PIPE, text=True)
    out, _ = proc.communicate()

    if dev['type'].startswith('scsi') or dev['type'].startswith('megaraid'):
        smart_scsi_print_counters(dev, out)
        return

    for line in out.splitlines():
        line = line.strip()
        if not line:
            continue

        cols = line.split()
        if len(cols) < 10:
            continue

        aname = cols[1].lower().replace("-", "_").replace("/", "_")
        if aname in SMARTATTR:
            SMARTOUT.add(aname, long(cols[3]), smart_id=cols[0],
                         value='normalized', disk=dev['device'],
                         type=dev['type'])
            val = long(cols[9], 0) if SMARTATTR[aname] is None else \
                SMARTATTR[aname](cols[9])
            SMARTOUT.add(aname, val, smart_id=cols[0], value='raw',
                         disk=dev['device'], type=dev['type'])


def smart_collect():
    """Collects SMART counters from SMART capable devices expect for NVMe."""

    devices = smart_list_devices()
    for dev in devices:
        SMARTOUT.add("smartctl_run", int(time.time()),
                     disk=dev['device'], type=dev['type'])
        if not smart_print_device_info(dev):
            continue
        smart_print_counters(dev)
    SMARTOUT.flush()


NVME = '/usr/sbin/nvme'
NVMEOUT = Output("smart_nvme")


def nvme_intel_print_counters(dev):
    """Collects Intel-specific SMART counters from Intel NVMe devices."""

    proc = subprocess.Popen([NVME, 'intel', 'smart-log-add', '-j',
                             dev['device']], stdout=subprocess.PIPE, text=True)
    out, _ = proc.communicate()
    if proc.returncode != 0:
        return
    nvme_smart_log_add = json.loads(out)

    for stat, statv in nvme_smart_log_add['Device stats'].items():
        for counter_type, value in statv.items():
            if isinstance(value, dict):
                for counter_subtype, subtype_value in value.items():
                    NVMEOUT.add('intel_'+stat, subtype_value,
                                type=counter_type, subtype=counter_subtype,
                                disk=dev['device'])
            else:
                NVMEOUT.add('intel_'+stat, value,
                            type=counter_type, disk=dev['device'])


NVME_VENDOR_SPECIFIC = {
    0x8086: nvme_intel_print_counters,
}


def nvme_list_devices():
    """Lists NVMe devices on the node along with their info."""

    devs = []

    proc = subprocess.Popen([NVME, 'list', '-o', 'json'],
                            stdout=subprocess.PIPE, text=True)
    out, _ = proc.communicate()
    if not out:
        return devs
    nvme_list = json.loads(out)

    if 'Devices' not in nvme_list:
        return devs

    for device in nvme_list['Devices']:
        proc = subprocess.Popen([NVME, 'id-ctrl', '-o', 'json',
                                 device['DevicePath']],
                                stdout=subprocess.PIPE, text=True)
        out, _ = proc.communicate()
        nvme_idctrl = json.loads(out)

        NVMEOUT.add("device_info", 1, disk=device['DevicePath'],
                    device_model=device['ModelNumber'],
                    serial_number=device['SerialNumber'])
        NVMEOUT.add("capacity_bytes", device['PhysicalSize'],
                    disk=device['DevicePath'])

        devs.append({'device': device['DevicePath'],
                     'vendor': nvme_idctrl['vid']})

    return devs


def nvme_print_counters(dev):
    """Collects SMART countiner for a particular NVMe device."""

    proc = subprocess.Popen([NVME, 'smart-log', '-o', 'json',
                             dev['device']], stdout=subprocess.PIPE, text=True)
    out, _ = proc.communicate()
    if proc.returncode != 0:
        return
    nvme_smart_log = json.loads(out)

    for counter, value in nvme_smart_log.items():
        NVMEOUT.add(counter, value, disk=dev['device'])

    if dev['vendor'] in NVME_VENDOR_SPECIFIC:
        NVME_VENDOR_SPECIFIC[dev['vendor']](dev)

    # make nvme output uniform like smartctl
    NVMEOUT.add('smart_enabled', 1, disk=dev['device'])
    NVMEOUT.add('smart_healthy', 1, disk=dev['device'])


def nvme_collect():
    """Collects SMART counters from NVMe devices."""

    devices = nvme_list_devices()
    for dev in devices:
        NVMEOUT.add("nvme_run", int(time.time()),
                    disk=dev['device'])
        nvme_print_counters(dev)
    NVMEOUT.flush()


if __name__ == '__main__':
    if os.geteuid() != 0:
        print('You need to have root priveleges to run this script.',
              file=sys.stderr)
        sys.exit(1)

    smart_collect()
    if os.path.isfile(NVME):
        nvme_collect()
    sys.exit(0)
