#!/usr/bin/python3

import argparse
import json
from http.client import HTTPResponse
import io
import logging
import os
import pathlib
import re
import socket
import subprocess
import sys
import time
import uuid

import pyroute2
import requests
import urllib3

LOG = logging.getLogger(__name__)
log_filename = '/var/log/vstorage-ui-agent/register-node.log'
log_format = '%(asctime)s %(levelname)s %(message)s'

UWSGI_SOCK_AGENT = '/run/vstorage-ui/uwsgi-agent.sock'
UWSGI_SOCK_BACKEND = '/run/vstorage-ui/uwsgi-backend.sock'

HOSTNAME_JSON = '/usr/libexec/vstorage-ui-agent/var/hostname.json'
BACKEND_ROLES_INIT = '/usr/libexec/vstorage-ui-backend/etc/roles.init'


class RegisterError(Exception):
    pass


def parse_args():
    parser = argparse.ArgumentParser(
        description='Register node in a control plane'
    )
    verbose_group = parser.add_mutually_exclusive_group()
    verbose_group.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='Increase verbosity of output.'
    )
    verbose_group.add_argument(
        '-q', '--quiet',
        action='store_true',
        help='Suppress output except warnings and errors.'
    )
    parser.add_argument(
        '-t', '--token',
        metavar='<token>',
        help='Secret token. Required for remote node registration.'
    )
    parser.add_argument(
        '-x', '--public-interface',
        metavar='<iface>',
        help='Interface used for Admin panel. '
             'Only for local node registration.'
    )
    parser.add_argument(
        '--hostname',
        metavar='<hostname>',
        help='Hostname of the node. Defaults to the system hostname.'
    )
    parser.add_argument(
        '--ip-address',
        metavar='<ip-address>',
        help='IP address of the node. Use this if the node is behind NAT.'
    )
    parser.add_argument(
        '-m', '--management',
        dest='mn_host',
        metavar='<host>',
        required=True,
        help='Management node IP address or hostname.'
    )
    return parser.parse_args()


def is_installer_mode():
    path = pathlib.Path('/etc/systemd/system/default.target')
    try:
        default_target = path.resolve(strict=True)
        return 'anaconda' in os.path.basename(default_target)
    except FileNotFoundError:
        return False


def getoutput(cmd, **kwargs):
    shell = isinstance(cmd, str)
    LOG.debug('Running cmd (subprocess): %s', cmd)
    stime = time.time()
    res = subprocess.run(
        cmd, capture_output=True, text=True, shell=shell, **kwargs
    )
    LOG.debug(
        "CMD '%s' returned: %s in %0.3fs",
        cmd, res.returncode, time.time() - stime
    )
    if res.returncode != 0:
        LOG.error(
            'command: %(cmd)r\n'
            'exit code: %(code)r\n'
            'stdout: %(stdout)r\n'
            'stderr: %(stderr)r',
            {
                'cmd': cmd,
                'code': res.returncode,
                'stdout': res.stdout,
                'stderr': res.stderr
            }
        )
        raise RegisterError
    return res.stdout.strip()


def roles_cfg_exec(args):
    cmd = ['/usr/bin/python3', '/usr/bin/roles_cfg_cli']
    cmd.extend(args)
    if is_installer_mode():
        cmd.append('--offline')
    getoutput(cmd)


def encrypt(token, data):
    return getoutput(
        'openssl enc -base64 -e -aes-256-cbc -salt -pass env:MN_TOKEN',
        input=data, env={'MN_TOKEN': token}
    )


def decrypt(token, data):
    return getoutput(
        'openssl enc -base64 -d -aes-256-cbc -salt -pass env:MN_TOKEN -in -',
        input=data, env={'MN_TOKEN': token}
    )


class BaseClient:

    def request(self, method, uri, **kwargs):
        raise NotImplementedError

    def get(self, uri, **kwargs):
        return self.request('GET', uri, **kwargs)

    def post(self, uri, **kwargs):
        return self.request('POST', uri, **kwargs)

    def put(self, uri, **kwargs):
        return self.request('PUT', uri, **kwargs)

    @staticmethod
    def _get_response_data(resp):
        try:
            resp_data = resp.json()
        except Exception:
            resp_data = resp.text

        try:
            resp.raise_for_status()
        except requests.exceptions.HTTPError as err:
            if isinstance(resp_data, dict):
                error = resp_data.get('error', {})
                message = resp_data.get('message')
                message = error.get('fields', error.get('message', message))
                if message:
                    err.args = (f'{err.args[0]}\n{message}',)
            raise
        return resp_data

    @staticmethod
    def _sanitize_data(data):
        data = dict(data)
        for key, value in data.items():
            if key in ('certificates', 'token', 'key', 'cert'):
                value = f'<hidden_{len(value)}>'
            data[key] = value
        return data

    def _log_request(self, method, url, **kwargs):
        if LOG.isEnabledFor(logging.DEBUG):
            parts = self._log_request_parts(method, url, **kwargs)
            LOG.debug('REQ: %s', ' '.join(parts))

    def _log_request_parts(self, method, url, **kwargs):
        return [method, url]

    def _log_response(self, resp, stime):
        if LOG.isEnabledFor(logging.DEBUG):
            parts = [f'RESP: [{resp.status_code}]']
            try:
                data = self._sanitize_data(resp.json())
                parts.append(json.dumps(data))
            except Exception:
                parts.append(resp.text.strip())
            elapsed = time.time() - stime
            parts.append(f"{elapsed:.3f}s")
            LOG.debug(' '.join(parts))


class HttpsClient(BaseClient):
    def __init__(self, addr, port):
        self.addr = addr
        self.port = port

    def request(self, method, uri, data=None, token=None):
        headers = {
            "X-Requested-With": "XMLHttpRequest",
            "Content-type": "application/json"
        }
        params = dict(timeout=(10, 30), verify=False, headers=headers)
        if data:
            if token:
                params['json'] = {'data': encrypt(token, json.dumps(data))}
            else:
                params['json'] = {'data': data}
        url = f"https://{self.addr}:{self.port}/{uri.lstrip('/')}"
        self._log_request(method, url, headers=headers,
                          data=data, encrypted=bool(token))

        stime = time.time()
        resp = requests.request(method, url, **params)
        self._log_response(resp, stime)
        return self._get_response_data(resp)

    def _log_request_parts(
        self, method, url, headers=None, data=None, encrypted=False
    ):
        headers = headers or {}
        parts = ['curl --insecure -X', method, url]
        for header_name, header_value in headers.items():
            parts.append(f'-H "{header_name}: {header_value}"')
        if data:
            dat_str = json.dumps(self._sanitize_data(data))
            dat_str = f'encrypted@{data}' if encrypted else dat_str
            parts.append(f'-d {dat_str}')
        return parts


class UwsgiClient(BaseClient):
    def __init__(self, sock_path):
        self.sock_path = sock_path

    def request(self, method, uri, data=None, env_vars=None):
        env_vars = env_vars or {}
        env_vars['REQUEST_METHOD'] = method
        env_vars['PATH_INFO'] = uri
        data_bytes = (json.dumps(data) if data else '').encode('utf8')
        if data_bytes:
            env_vars['CONTENT_TYPE'] = 'application/json'
            env_vars['CONTENT_LENGTH'] = str(len(data_bytes))

        url = f'unix://{self.sock_path} {uri}'
        self._log_request(method, url, data=data, env_vars=env_vars)

        stime = time.time()
        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
            sock.connect(self.sock_path)
            sock.send(self._pack_uwsgi_data(env_vars) + data_bytes)
            http_resp = HTTPResponse(sock)
            http_resp.begin()

            resp = requests.Response()
            resp.url = url
            resp.status_code = http_resp.status
            resp.reason = http_resp.reason
            resp.headers = dict(http_resp.getheaders())
            resp.raw = io.BytesIO(http_resp.read())

        self._log_response(resp, stime)
        resp_data = self._get_response_data(resp)
        return resp_data

    def _log_request_parts(self, method, url, data=None, env_vars=None):
        parts = [url]
        for dat in filter(None, (env_vars, data)):
            parts.append(json.dumps(self._sanitize_data(dat)))
        return parts

    def _pack_uwsgi_data(self, data):
        pk = b''
        for k, v in data.items():
            pk += (self._uwsgi_sz(k) + k.encode('utf8') +
                   self._uwsgi_sz(v) + v.encode('utf8'))
        return b'\x00' + self._uwsgi_sz(pk) + b'\x00' + pk

    @staticmethod
    def _uwsgi_sz(x):
        s = hex(len(x))[2:].rjust(4, '0')
        s = bytes.fromhex(s)
        return s[::-1]


class NodeRegistrator:
    def __init__(self, mn_addr, mn_port, args):
        self.mn_addr = mn_addr
        self.mn_port = mn_port
        self.args = args

        self.uwsgi_agent = UwsgiClient(UWSGI_SOCK_AGENT)
        self.uwsgi_backend = UwsgiClient(UWSGI_SOCK_BACKEND)
        self.http_backend = HttpsClient(self.mn_addr, self.mn_port)

    def register_node(self):
        args = self.args

        route_info, is_local = self.get_route_info(self.mn_addr)
        bind_dev = route_info['dev']
        ip_address = args.ip_address if args.ip_address else route_info['src']

        if is_local:
            if is_installer_mode():
                # The backend node is installing, the registration is done
                # on the first boot
                LOG.info('Local registering skipped in anaconda mode')
                return

            if not args.token:
                env_vars = {
                    'HTTP_X_REQUESTED_WITH': 'XMLHttpRequest',
                    'SKIP_AUTH': 'True'
                }
                args.token = self.uwsgi_backend.post(
                    '/api/v2/nodes/registration/token/',
                    env_vars=env_vars
                )['token']
        if not args.token:
            raise RegisterError('Registration token must be provided.')

        host_id = self.uwsgi_agent.get('/api/v1/host_id/')['host_id']
        node_id = self.uwsgi_agent.get('/api/v1/node_id/')['node_id']
        node_id = str(uuid.UUID(node_id))
        hostname = args.hostname if args.hostname else self.get_hostname()
        release_version = self.get_release_version()
        public_interface = self.get_public_interface()

        register_data = {
            'host_id': host_id,
            'node_id': node_id,
            'ip_addr': ip_address,
            'hostname': hostname,
            'release_version': release_version,
        }
        if args.hostname:
            register_data['force_hostname'] = True

        if not is_local:
            LOG.info('Opening agent port in the firewall')
            roles_cfg_exec(['append', '-i', bind_dev, '-r', 'agent:17514',
                            '--outbound-allow-list', '0.0.0.0:any:0'])

        LOG.info(
            'Registering node %r with IP address %r and hostname %r',
            node_id, ip_address, hostname
        )
        resp_data = self.http_backend.post(
            '/api/v3/nodes/registration/register/',
            data=register_data, token=args.token
        )

        # FIXME: Hostname assignment is VHI behavior.
        # We should reconsider it for VHP registration:
        LOG.info('Assigning hostname %r', resp_data['hostname'])
        self.set_hostname(hostname, resp_data['hostname'])

        LOG.info('Saving certificates')
        certificates = decrypt(args.token, resp_data['certificates'])
        self.save_certificates(certificates)

        LOG.info('Configuring Nginx')
        self.configure_nginx()

        update_data = {
            'node_id': node_id,
            'management_iface': bind_dev,
        }
        if public_interface:
            update_data['webcp_iface'] = public_interface

        LOG.info('Requesting node %r configuration update', node_id)
        resp_data = self.http_backend.put(
            '/api/v3/nodes/registration/register/',
            data=update_data, token=args.token
        )
        LOG.info('Node %r has been registered', node_id)

    def save_certificates(self, certificates):
        certs = {}
        mapping = {
            'CA.CRT': 'ca.crt',
            'AGENT.CRT': 'server.crt',
            'AGENT.KEY': 'server.key',
            'IPSEC.CRT': 'ipsec.crt',
            'IPSEC.KEY': 'ipsec.key'
        }
        for marker, cert in mapping.items():
            parts = certificates.split(marker)
            if len(parts) < 3:
                msg = f'Cannot find {cert} certificate block in the response.'
                raise RegisterError(msg)
            certs[cert] = f'{parts[1].strip()}\n'

        os.makedirs('/etc/nginx/ssl', exist_ok=True)
        for cert in ('ca.crt', 'server.crt', 'server.key'):
            self.write_file(os.path.join('/etc/nginx/ssl', cert), certs[cert])

        cacert_data = {'cert': certs['ca.crt']}
        cert_data = {'key': certs['ipsec.key'],
                     'cert': certs['ipsec.crt']}
        self.uwsgi_agent.put('/api/v1/ipsec/cert/', data=cert_data)
        self.uwsgi_agent.put('/api/v1/ipsec/cacert/', data=cacert_data)

    def configure_nginx(self):
        path = '/etc/nginx/conf.d/vstorage-ui-agent.conf'
        data = self.read_file(path)

        new_lines = []
        for line in data.splitlines():
            line = re.sub(r'([ \t]*)listen[ \t]+(\d+).*;',
                          r'\1listen \2 ssl;', line)
            line = re.sub(r'([ \t]*)#(ssl.*;)',
                          r'\1\2', line)
            new_lines.append(line.rstrip())

        self.write_file(path, '\n'.join(new_lines))
        getoutput('systemctl restart nginx.service')

    def get_hostname(self):
        current = socket.gethostname()

        data = {}
        if os.path.exists(HOSTNAME_JSON):
            data = json.loads(self.read_file(HOSTNAME_JSON))
        original = data.get("original")
        assigned = data.get("assigned")

        if assigned and current != assigned:
            LOG.info('Hostname change was detected')
            return current
        return original if original else current

    def set_hostname(self, orig_hostname, new_hostname):
        data = {
            "original": orig_hostname,
            "assigned": new_hostname
        }
        data = json.dumps(data, indent=2)
        self.write_file(HOSTNAME_JSON, data + '\n')
        getoutput(
            f"hostnamectl set-hostname {new_hostname} --static --transient"
        )

    def get_release_version(self):
        data = self.read_file('/etc/hci-release')
        m = re.search(r'([0-9]\.[0-9]\.[0-9]) \(([0-9]+)\)', data)
        if not m:
            raise RegisterError('Cannot determine release version.')
        return '-'.join(m.groups())

    def get_public_interface(self):
        if self.args.public_interface:
            with pyroute2.IPRoute() as ipr:
                if not ipr.link_lookup(ifname=self.args.public_interface):
                    raise RegisterError(
                        f'Interface "{self.args.public_interface}" not found.'
                    )
            return self.args.public_interface

        # The backend configures the public interface itself.
        # Node registration should respect this setting.
        if os.path.exists(BACKEND_ROLES_INIT):
            data = self.read_file(BACKEND_ROLES_INIT)
            for line in data.splitlines():
                if not line.startswith('WEB_CP_IF='):
                    continue
                ifname = line.split('=')[-1].strip()
                with pyroute2.IPRoute() as ipr:
                    if ipr.link_lookup(ifname=ifname):
                        LOG.info(
                            'Using %r from %s as public interface',
                            line, BACKEND_ROLES_INIT
                        )
                        return ifname
        return None

    def get_route_info(self, dst_ip):
        route = {}
        is_local = False
        parts = getoutput(f'ip route get {dst_ip}').split()
        for idx, part in enumerate(parts):
            if part == 'dev':
                dev = parts[idx+1]
                if dev == 'lo':
                    is_local = True
                    dev = self.get_iface_by_addr(dst_ip)
                route['dev'] = dev
            if part == 'src':
                route['src'] = parts[idx+1]
        return route, is_local

    @staticmethod
    def get_iface_by_addr(ip_addr):
        with pyroute2.IPRoute() as ipr:
            for addr in ipr.get_addr():
                attrs = dict(addr['attrs'])
                if attrs.get('IFA_ADDRESS') == ip_addr:
                    idx = addr['index']
                    iface = ipr.get_links(idx)[0]['attrs'][0][1]  # IFLA_IFNAME
                    return iface
        raise RegisterError(
            f'Interface with IP address "{ip_addr}" not found.'
        )

    @staticmethod
    def read_file(filename):
        with open(filename) as fd:
            return fd.read()

    @staticmethod
    def write_file(filename, data):
        with open(filename, 'w') as fd:
            fd.write(data)

    @staticmethod
    def validate_ifname(ifname):
        with pyroute2.IPRoute() as ipr:
            if not ipr.link_lookup(ifname=ifname):
                raise RegisterError(f'Interface "{ifname}" not found.')


def main():
    args = parse_args()

    level = (logging.DEBUG if args.verbose else
             logging.ERROR if args.quiet else
             logging.INFO)

    logging.basicConfig(level=level, filename=log_filename, format=log_format)
    console_handler = logging.StreamHandler(sys.stderr)
    console_handler.setLevel(level)
    console_handler.setFormatter(logging.Formatter('%(message)s'))
    LOG.addHandler(console_handler)

    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
    LOG.info('Running command: %s', ' '.join(sys.argv))

    try:
        mn_addr = socket.gethostbyname(args.mn_host)
    except socket.error as err:
        LOG.error('Cannot use %r as management host: %s', args.mn_host, err)
        sys.exit(1)

    mn_port = os.getenv("MN_PORT", 8888)
    try:
        mn_port = int(mn_port)
    except ValueError:
        LOG.error('Cannot use %r as MN_PORT', mn_port)
        sys.exit(1)

    registrator = NodeRegistrator(mn_addr, mn_port, args)
    try:
        registrator.register_node()
    except Exception as err:
        LOG.error("Cannot register node: %s", err)
        if args.verbose:
            raise
        sys.exit(1)


if __name__ == '__main__':
    main()
