#!/usr/bin/python3

#
# Copyright (c) 2000-2019 Virtuozzo International GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import json
import logging
import os
import sys
import socket
import time
from datetime import datetime
from enum import Enum, unique
from subprocess import PIPE
from subprocess import Popen

try:
    from monotonic import monotonic as get_time
except ImportError:
    from time import time as get_time

from keystoneauth1.identity import v3
from keystoneauth1.session import Session
from novaclient.client import Client as NovaClient
from oslo_config import cfg
from cinderclient.v3.client import Client as CinderClient

NOVA_CONF_DIR = '/etc/kolla/nova-compute/'
LOG_PATH = '/var/log/shaman/fence.log'
logging.basicConfig(filename=LOG_PATH,
                    format='%(asctime)s | %(levelname)s | %(message)s',
                    level=logging.INFO)

ENABLE_TRACE  = True if logging.root.level == logging.DEBUG else False

# NOTE: We use credentials specified in nova.conf for placement api.
# That's a bit odd, but it works.
placement_group = cfg.OptGroup('placement')
placement_opts = [
    cfg.StrOpt('auth_url', required=True),
    cfg.StrOpt('username', required=True),
    cfg.StrOpt('project_name', required=True),
    cfg.StrOpt('user_domain_id', required=True),
    cfg.StrOpt('project_domain_id', required=True),
    cfg.StrOpt('password', required=True),
]
CONF = cfg.CONF
CONF.register_group(placement_group)
CONF.register_opts(placement_opts, placement_group)

API_MICROVERSION = '2.67'
MIN_RETRY_TIME = 30
WAIT_TIME_BEFORE_FAIL = 1800 # 30 minutes

@unique
class EndpointTypes(Enum):
    """The type of network interface of the OpenStack endpoints.
    """
    admin = 'admin'
    internal = 'internal'
    public = 'public'

def execute(cmd):
    logging.info('Execute: "{}"'.format(' '.join(cmd)))
    out = None
    try:
        p = Popen(cmd, stdout=PIPE, stderr=PIPE)
        out, err = p.communicate()
        if p.returncode != 0:
            logging.error('Retcode: {} stderr: {}'.format(p.returncode,
                                                          err.decode('ascii').strip()))
            return out, 1
    except Exception:
        logging.exception('Command execution failed')
        return out, 1
    logging.info('Command execution complete')
    return out, 0


class UnfenceException(Exception):
    pass


class ServiceNotFound(UnfenceException):
    pass


class ServiceError(UnfenceException):
    pass


class Service:
    binary = None

    def get_service_list(self):
        pass

    def enable(self, hostname):
        pass

    def disable(self, hostname):
        pass

    def is_enabled(self, hostname):
        pass


class NovaService(Service):
    binary = 'nova-compute'

    def __init__(self, session):
        self._client = NovaClient(API_MICROVERSION, session=session,
                                  endpoint_type=EndpointTypes.internal.value)

    def get_service_list(self, hostname=None, binary=None):
        return self._client.services.list(host=hostname, binary=binary)

    def _get_compute_service_by_host(self, hostname):
        services = self.get_service_list(hostname=hostname, binary=self.binary)
        if not services:
            raise ServiceNotFound(('%(binary)s service not found on the node '
                                   '%(host)s' %
                                   {'binary': self.binary, 'host': hostname}))
        return services[0]

    def get_service_by_host(self, hostname):
        return self._get_compute_service_by_host(hostname)

    def enable(self, hostname):
        service = self._get_compute_service_by_host(hostname)
        if service.forced_down:
            self._client.services.force_down(service.id, False)
        self._client.services.enable(service.id)

    def disable(self, hostname):
        service = self._get_compute_service_by_host(hostname)
        self._client.services.disable(service.id)
        self._client.services.force_down(service.id, True)

    def is_enabled(self,hostname):
        service = self._get_compute_service_by_host(hostname)
        logging.debug('service {%s} status {%s}', service.binary,
                     service.status)
        return service.status == 'enabled'

    def get_hypervisor_list(self):
        return self._client.hypervisors.list(detailed=True, servers=False)

    def cancel_evacuation(self, host):
        self._client.hci_hosts.cancel_evacuation(host.id)


class CinderService(Service):
    binary = 'cinder-volume'

    def __init__(self, session):
        self._client = CinderClient(API_MICROVERSION, session=session,
                                    endpoint_type=EndpointTypes.internal.value)

    def get_service_list(self, hostname=None, binary=None):
        return self._client.services.list(host=hostname, binary=binary)

    def _get_volume_service_by_host(self, hostname):
        services = self.get_service_list(hostname=hostname, binary=self.binary)
        if not services:
            raise ServiceNotFound(('%(binary)s service not found on the node '
                                   '%(host)s' %
                                   {'binary': self.binary, 'host': hostname}))
        return services[0]

    def enable(self, hostname):
        self._client.services.enable(host=hostname, binary=self.binary)

    def disable(self, hostname):
        self._client.services.disable(host=hostname, binary=self.binary)

    def is_enabled(self,hostname):
        service = self._get_volume_service_by_host(hostname)
        logging.debug('service {%s} status {%s}', service.binary,
                     service.status)
        return service.status == 'enabled'


class UnfenceNode:
    def __init__(self, host_id):
        CONF(default_config_dirs=[NOVA_CONF_DIR], default_config_files=[])
        self.host_id = host_id

        auth = v3.Password(auth_url=CONF.placement.auth_url,
                           username=CONF.placement.username,
                           password=CONF.placement.password,
                           project_name=CONF.placement.project_name,
                           user_domain_id=CONF.placement.user_domain_id,
                           project_domain_id=CONF.placement.project_domain_id)

        self.session = Session(auth, verify=False)
        self.nova_service = NovaService(self.session)
        self._services = [self.nova_service,
                          CinderService(self.session)]

    def validate_services_on_node(self, hv_hostname):
        # Validate all the services are running on the node after unfencing
        # the node.
        logging.info('Validate services are enabled after unfencing')
        for provider in self._services:
            services = provider.get_service_list(hostname=hv_hostname)
            for service in services:
                logging.debug('Validation Service %s status %s',
                             service.binary, service.status)
                if service.status != 'enabled':
                    raise ServiceError(
                        ('%(binary)s service is down after unfencing the '
                        'compute node %(host)s' % {'binary': service.binary,
                        'host': hv_hostname}))

    def get_host_hypervisor_details(self):
        hypervisors = self.nova_service.get_hypervisor_list()
        host_ip = socket.gethostbyname(
            'management.%s.nodes.svc.vstoragedomain.' % self.host_id)
        host_hv = next((i for i in hypervisors
                        if i.host_ip == host_ip), None)
        return host_hv

    def _is_node_evacuating(self):
        host_hv = self.get_host_hypervisor_details()
        if not host_hv:
            raise ServiceError('Unable to get host hypervisor information')
        return host_hv.is_evacuating

    def is_node_unfenced(self, hv_hostname):
        try:
            if all([srv.is_enabled(hv_hostname)
                    for srv in self._services]):
                return True
            return False
        except Exception as ex:
            logging.exception(ex, exc_info=ENABLE_TRACE)
            raise ServiceError('Failed to get node fence status')

    def is_node_fenced_manually(self, hypervisor_hostname):
        # Check the node had been fenced manually. In this case, do not
        # unfence the node automatically after crash.
        service = self.nova_service.get_service_by_host(
            hostname=hypervisor_hostname)
        if (service.status == 'disabled' and
                (not service.disabled_reason or
                'NODE_CRASH' not in service.disabled_reason)):
            logging.info('Node had been fenced manually with reason {%s}. '
                         'Node will not be unfenced automatically.',
                         service.disabled_reason)
            return 1
        return 0

    def unfence_node(self, host_hv):
        # If the node was manually unfenced, we need to stop unfencing the
        # node or retries.
        if self.is_node_unfenced(host_hv.hypervisor_hostname):
            logging.info('Node (%s) is already unfenced', self.host_id)
            return 0

        # Stop ongoing evacuation
        logging.info('Stopping ongoing VM evacuation on node (%s)',
                     host_hv.hypervisor_hostname)
        self.nova_service.cancel_evacuation(host_hv)

        # Wait for evacuation to be stopped or complete
        logging.info('Waiting for VM evacuation to be stopped or completed '
                      'on the node')
        start_time = get_time()
        while self._is_node_evacuating():
            if get_time() > (start_time + 300):   # 5minutes
                raise ServiceError('Timeout: Evacuation is not stopped.')
            time.sleep(10)

        logging.info('unfencing of compute node (%s) is scheduled',
                     host_hv.hypervisor_hostname)
        # enable cinder-volume and nova-compute service
        for srv in self._services:
            srv.enable(host_hv.hypervisor_hostname)

        # Validate all cinder and nova services are enabled
        self.validate_services_on_node(host_hv.hypervisor_hostname)

        logging.info('Node (%s) is unfenced successfully',
                      host_hv.hypervisor_hostname)
        return 0


def is_node_suspended(node_id):
    cmd = ['/usr/sbin/shaman', 'stat', '-j']
    stat, ret = execute(cmd)
    if ret or stat is None:
        return None
    val = json.loads(stat.decode('utf-8'))
    for key in ['status', 'nodes', node_id, 'status']:
        val = val.get(key)
        if val is None:
            return None
    if val == 'Suspended':
        return True
    return False

def _check_node_can_be_unfenced(node_id, is_retry=False):
    if is_retry:
        return True
    # get current cluster configuration
    get_config_cmd = ['/usr/sbin/shaman', 'get-config', '-j']
    cur_config, ret = execute(get_config_cmd)
    if ret or cur_config is None:
        logging.error("Failed to get shaman global configuration")
        return None
    val = json.loads(cur_config.decode('utf-8'))
    crash_threshold = val.get('node_crash_per_hour_threshold')

    if crash_threshold == 0:
        # Node suspension is disabled. Node will be unfenced everytime.
        return True

    if crash_threshold == 1:
        # Node will be suspended on first failure. Node will not be unfenced
        logging.info("Node will be suspended on first failure. Not unfencing the node.")
        return False

    # Get node crash history
    cmd = ['/usr/sbin/shaman', 'kvtool', 'get', '/ha/shaman/master/state',
           '-Vz']
    state, ret = execute(cmd)
    if ret or state is None:
        logging.error("Failed to get cluster stat from shaman kv")
        return None
    val = json.loads(state.decode('utf-8'))
    for key in ['nodes', node_id, 'crash_history']:
        val = val.get(key)
        if val is None:
            logging.info("No crash history found for the node.")
            return None

    logging.debug("crash_history %s", val)
    # Convert the crash_history time string to timestamp and get crash_history
    # for last 1 hour. Crashed node can recover and start after 1 hour.
    curr_time = datetime.timestamp(datetime.now())
    crash_ts_last_hr = []
    for ts in val:
        # Remove the nanoseconds from crash history timestamp
        split_ts = ts.split(".")
        ts_val = datetime.timestamp(datetime.strptime(split_ts[0],
                                                      '%Y-%m-%dT%H:%M:%S'))
        if curr_time - ts_val <= 3600:
            crash_ts_last_hr.append(ts)

    logging.debug("crash timestamps in last hr %s", crash_ts_last_hr)
    # Don't unfence the node, if the node will be suspended on next crash.
    # i.e. In case node_crash_threshold=3, the node would unfenced 1 time
    # (on first crash), on next crash, node will not be unfenced.
    if len(crash_ts_last_hr) < (crash_threshold - 1):
        return True

    # Node shouldn't be unfenced
    logging.info("Node will reach the crash threshold and will be suspended "
                 "on next crash occurrence. Not unfencing the node.")
    return False

def unfence_compute_node(host_id):
    """Unfence compute node after recovery from crash. It will stop the on-going
    evacuation and unfence the node.

    host_id : ID of the crashed host received from shaman event
    """
    start_time = get_time()
    retry_wait_time = start_time + WAIT_TIME_BEFORE_FAIL
    compute_node = UnfenceNode(host_id)

    def _wait_for_node_to_be_fenced(hypervisor_hostname):
        # In some host, it would take few seconds for shaman to detect
        # the crash. Wait for few seconds and retry.
        wait_time = get_time() + 60 # 60 seconds
        logging.info('Wait for node(%s) to be fenced', hypervisor_hostname)
        while wait_time > get_time():
            try:
                if not compute_node.is_node_unfenced(hypervisor_hostname):
                    logging.debug('node (%s) is fenced.', host_id)
                    return 1
            except Exception as exc:
                logging.exception(exc, exc_info=ENABLE_TRACE)
            time.sleep(30)
        return 0

    is_retry = False
    while True:
        try:
            host_hv = compute_node.get_host_hypervisor_details()
            if not host_hv:
                logging.error('Hypervisor details for the node %s is not '
                              'found', compute_node.host_id)
                return 8
        except Exception as ex:
            logging.exception(ex, exc_info=ENABLE_TRACE)
            time.sleep(30)
            continue

        # Check the node has been fenced manually. Do not UNFENCE manually
        # fenced node.
        if compute_node.is_node_fenced_manually(host_hv.hypervisor_hostname):
            return 0

        try:
            if not _check_node_can_be_unfenced(host_id, is_retry):
                return 0
        except Exception as ex:
            logging.exception(ex, exc_info=False)
            return 0

        # Check the node is fenced before trying to unfence.
        # Sometimes when the backend node is crashed, it would take few
        # seconds to failover. Node would be fenced only after the failover.
        if not _wait_for_node_to_be_fenced(host_hv.hypervisor_hostname):
            logging.info('Node (%s) is not fenced after the crash', host_id)
            return 0

        # Unfence the compute node
        logging.debug('unfence compute node %s', host_hv.hypervisor_hostname)
        t1 = get_time()
        try:
            return compute_node.unfence_node(host_hv)
        except ServiceNotFound as ex:
            logging.error(ex, exc_info=ENABLE_TRACE)
            return 1
        except ServiceError as ex:
            # When a service is not enabled, retry unfencing for some time
            # before failing.
            logging.exception(ex, exc_info=ENABLE_TRACE)
            if get_time() > retry_wait_time:
                logging.error('Exceeded maximum retry time of 30minutes. '
                              'Stopping unfencing of node %s', host_id)
                return 1
        except Exception as ex:
            logging.error('Failed to schedule UNFENCE node')
            logging.exception(ex, exc_info=False)

        # Sleep
        is_retry = True
        t2 = get_time()
        time.sleep(max(0, MIN_RETRY_TIME - int(t2 - t1)))

def main():
    try:
        event_type = os.getenv('EVENT')
        if not os.path.exists(NOVA_CONF_DIR):
            logging.info('Ignore event %s: it is not a compute node', event_type)
            return 0
        if event_type == 'FENCE':
            logging.info('Fence vms')
            # the last argument is the rule, so it must be passed as one parameter
            cmd = ['/usr/bin/ovs-ofctl', 'add-flow', 'br-int', 'table=0,priority=65535,actions=drop']
        elif event_type == 'UNFENCE':
            logging.info('Unfence vms')
            # the last argument is the rule, so it must be passed as one parameter
            cmd = ['/usr/bin/ovs-ofctl', '--strict', 'del-flows', 'br-int', 'table=0,priority=65535']
        else:
            logging.info('Event %s ignored', event_type)
            return 0

        out, ret = execute(cmd)
        if ret:
            # Command execution failed
            return ret
    except Exception:
        logging.exception('Failed to fence/unfence vms')
        return 1

    # unfence the compute node
    if event_type == 'UNFENCE':
        host_id = os.getenv('NODE_ID')
        if host_id is None:
            return 8

        # After the node reaches the failure threshold configured in
        # NODE_CRASH_PER_HOUR_THRESHOLD, it is suspended.
        ret = is_node_suspended(host_id)
        if ret is None:
            logging.info('Failed to find node (%s) status ', host_id)
            return 8
        if ret:
            logging.info('Node (%s) is suspended and will not be unfenced',
                         host_id)
            return 8

        # Retry in a separate process so that shaman is not blocked
        if os.fork():
            return 0

        # Unfence Compute node
        logging.debug("unfencing node %s", host_id)
        try:
            unfence_compute_node(host_id)
        except Exception as ex:
            logging.exception(ex)

if __name__ == '__main__':
    sys.exit(main())
