#!/usr/bin/python

import os
import sys
import parted
import subprocess
import time
import tempfile
import tarfile
import shutil
import glob
import re
from optparse import OptionParser
from netaddr import IPNetwork, IPAddress
import prlsdkapi
from signal import signal, SIGPIPE, SIG_DFL 

lv_vz_name = "vz"
part_name = "/%s" % lv_vz_name
lv_vz_devname = "/dev/%s/%s" % (lv_vz_name, lv_vz_name)
created_parts = []
started_services = []
vztmp = tempfile.mkdtemp(prefix="/tmp%s_mnt_" % part_name)
progname = os.path.basename(sys.argv[0])
vstorage_def_path = "/vstorage"
vstorage_bin = "/usr/bin/vstorage"
shaman_bin = "/usr/sbin/shaman"
vstorage_service = "/usr/libexec/vstorage/vstorage-service"
private_parts = ["private", "vmprivate"]
log_prefix = "/var/log/vstorage/storage-cfg_%s.log"
vstorage_auth_log = log_prefix % "auth"
vstorage_mds_log = log_prefix % "mds"
vstorage_cs_log = log_prefix % "cs"
vstorage_mount_log = log_prefix % "mount"
vstorage_replicas="1:1"
mds_port = 2510
minimal_cs_size = 100 # 100 GiB
snmp_mib_tree = ".1.3.6.1.4.1.26171"
ha_services = ("shaman", "pdrs")
vz_services = ("prl-disp", "libvirtd", "pfcached", "pfcached-mount", "vz", "vcmmd", "vstorage-fs")
systemctl = "/usr/bin/systemctl"
unitdir = "/usr/lib/systemd/system"
update_mds_list = "update_mds_list"
update_mds_list_interval = "1h"

def wait_for_dev(part, num_try, opts):
    if num_try > opts.timeout:
        error("Timeout waiting for %s in %i sec" % (part, opts.timeout))
    if not os.path.exists(part):
        time.sleep(1)
        wait_for_dev(part, num_try + 1, opts)

def re_read(drive, num_try, opts):
    if num_try > opts.timeout:
        error("Failed to re-read partition table for %s in %i sec" % (drive, opts.timeout))
    p = subprocess.Popen(["/usr/sbin/blockdev", "--flushbufs", "--rereadpt", drive], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = p.communicate()
    if p.returncode != 0:
        print "Utility blockdev failed: exited with code %d" % p.returncode
        print "blockdev stdout: %s" % stdout
        print "blockdev stderr: %s" % stderr
        # Drive was busy. Will wait and try again
        if "BLKRRPART: Device or resource busy" in stderr:
            time.sleep(1)
            re_read(drive, num_try + 1, opts)

def get_output(cmd, ignoreerr = False, stdout = subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE):
    proc = subprocess.Popen(cmd,
                            shell=True,
                            stderr=stderr,
                            stdout=stdout,
                            stdin=stdin)
    output, err = proc.communicate()
    ret = proc.returncode
    if ret:
        errormsg = '`%s\'returned %i:\n %s\n' % (cmd, ret, err)
        if not ignoreerr:
            print "Output:\n%s" % output
            print "Errors:\n%s" % err
            raise ValueError(errormsg)
    return (output, ret)

def stop_vz_services(services, ignoreerr = False):
    for srv in services:
        print "Stopping %s..." % srv
        get_output("%s stop %s.service" % (systemctl, srv), ignoreerr)

def start_vz_services(services, ignoreerr = False):
    for srv in reversed(services):
        print "Starting %s..." % srv
        get_output("%s start %s.service" % (systemctl, srv), ignoreerr)

def cleanup():
    try:
        os.unlink("/tmp/prepare_datastore")
    except:
        pass
    get_output("/bin/umount %s" % vztmp, ignoreerr = True)
    shutil.rmtree(vztmp)

def error(msg):
    print msg
    cleanup()
    sys.exit(1)

def stop_service(service):
    print "Stopping %s service..." % service
    rc = get_output("/usr/bin/killall %s" % service, ignoreerr=True)[1]
    if rc:
        print "Failed to stop %s service: %i . Looks like already stopped" % (service, rc)
        return False
    return True

def start_service(service, opts, started_services, args):
    print "Starting %s service..." % service
    rc = get_output("%s start %s %s %s" % (vstorage_service, opts.cluster, service, ' '.join(args)))[1]
    if rc:
        error("Failed to start %s service: %i" % (service, rc))
    started_services.append(service)
    return True

def create_vstorage_input(pw):
    pipe = os.pipe()
    os.write(pipe[1], pw + '\n')
    os.close(pipe[1])
    return pipe[0]

def vstorage_auth_node(opts):
    while (opts.timeout):
        input_data = create_vstorage_input(opts.password)
        auth_start_time = time.time()
        stderr = open(vstorage_auth_log, "w")
        rc = get_output("%s -vvvvvvvv -t 30 -c %s auth-node %s -P" % (vstorage_bin, opts.cluster, opts.join_mds_ip_p), stdin=input_data, ignoreerr=True, stderr=stderr)[1]
        overal_auth_time = time.time() - auth_start_time
        os.close(input_data)
        stderr.write("\nvstorage auth-node exit code is %i, working time: %.2f sec.\n" % (rc, overal_auth_time))
        stderr.close()

        if rc == 0 or rc == 253 or opts.timeout == 1:    #code 253 - wrong password.
            break

        if overal_auth_time > 60:
            print "Time exceeded execution of command vstorage auth-node: %.2f sec." % overal_auth_time
        else:
            time.sleep(60 - overal_auth_time)
        opts.timeout -= 60

    if rc == 0:
        print "The server has been successfully authenticated."
    elif rc == 253:
        error("Failed to authenticate the server in the cluster: wrong password")
    else:
        error("Authentication failed: Could not connect to the cluster! " \
            "see vstorage error(s) in program log")

    return rc

def get_host_ip(subnet):
    # Check that subnet is OK
    if subnet:
        try:
            IPNetwork(subnet)
        except Exception, e:
            error("Failed to parse given cluster_network: %s" % e)
    ipl = get_nonlocal_ip_list()
    if not ipl:
        return ""
    if subnet:
        try:
            return [ ip for ip in ipl if IPAddress(ip) in IPNetwork(subnet)][0]
        except:
            error("No any IP (%s) detected in subnet %s" % (" ".join(ipl), subnet))
    for line in get_output("/sbin/ip -4 a s br0")[0].splitlines():
        m = re.match("^.*inet ([0-9\.]+)/.*$", line, re.I)
        if m:
            return m.group(1)
    return ""

def parseArgs(args):
    usage = """
    %s --storage_mode=(local|vstorage) [--mds=(create|join)] [--cluster_network=IP_NETWORK] [--join_mds_ip=IP] [--cs] [--client] [--cluster=CLUSTER_NAME] [--ha] [--timeout=TIMEOUT] [--password=PASSWORD] [--onboot] [--update_mds_ips]""" % progname

    parser = OptionParser(version = "%s 0.9.5.2" % progname, usage=usage)

    # query options
    parser.add_option("-s", "--storage_mode", default="", action="store",
                      help="Storage creation mode (local|vstorage).")
    parser.add_option("-m", "--mds", default="", action="store",
                      help="MDS creation mode (create|join).")
    parser.add_option("-c", "--cs", default=False, action="store_true",
                      help="CS creation (on all available HDDs).")
    parser.add_option("-l", "--client", default=False, action="store_true",
                      help="Client creation.")
    parser.add_option("-r", "--cluster", default="", action="store",
                      help="Cluster name.")
    parser.add_option("-a", "--ha", default=False, action="store_true",
                      help="Enable High Availability.")
    parser.add_option("-t", "--timeout", default="1", action="store",
                      help="Timeout in minutes. Default is 1 minute.")
    parser.add_option("-p", "--password", default="", action="store",
                      help="Cluster password.")
    parser.add_option("-j", "--join_mds_ip", default="", action="store",
                      help="Master MDS ip.")
    parser.add_option("-n", "--cluster_network", default="", action="store",
                      help="Network for Cluster MDS IP to be selected.")
    parser.add_option("-o", "--onboot", default=False, action="store_true",
                      help="As a service during boot mode.")
    parser.add_option("-u", "--update_mds_ips", default=False, action="store_true",
                      help="Update MDS ips using service timer.")
    (opts, argsleft) = parser.parse_args()

    opts.timeout = int(opts.timeout)*60
    if opts.timeout < 60:
        print "Timeout should not be set for less than a minute value."
        parser.print_usage()
        sys.exit(1)

    if not opts.storage_mode or opts.storage_mode not in ("local", "vstorage"):
        print "You should select storage mode: local or vstorage"
        parser.print_usage()
        sys.exit(1)

    if opts.storage_mode == "local":
        if opts.mds or opts.cs or opts.client or opts.ha or opts.cluster:
            print "Unsupported for local mode."
            parser.print_usage()
            sys.exit(1)
        return opts

    if not opts.cluster:
        print "Cluster name not specified!"
        parser.print_usage()
        sys.exit(1)

    if not opts.client and not opts.mds and not opts.cs:
        print "You should create MDS, CS or Client for Vstorage.."
        parser.print_usage()
        sys.exit(1)

    if opts.mds and not opts.mds in ("create", "join"):
        print "MDS can be only joined or created."
        parser.print_usage()
        sys.exit(1)

    if opts.mds and not opts.password:
        print "Password required for MDS."
        parser.print_usage()
        sys.exit(1)

    if not opts.client and opts.ha:
        print "High availability can't be used without client mode."
        parser.print_usage()
        sys.exit(1)

    if opts.ha and not opts.cluster_network:
        print "HA availability configuration requires cluster network."
        parser.print_usage()
        sys.exit(1)

    opts.join_mds_ip_p = ""
    if opts.join_mds_ip:
        # Ignore join IP option for creation case
        if opts.mds == "create":
            opts.join_mds_ip = ""
        else:
            opts.join_mds_ip_p = "-b %s" % opts.join_mds_ip

    return opts

def add_fstab(line):
    # Write to fstab
    with open('/etc/fstab', 'a+') as f:
        f.seek(-1, os.SEEK_END)
        if f.read() != '\n':
            f.write('\n')
        f.write(line)
        f.close()

def getIPs():
    ips = []
    stderr = open("/dev/null", "w")
    out = get_output("/sbin/ip a l")[0]
    stderr.close()
    for line in out.splitlines():
        m = re.match("^.*inet ([0-9\.]+)/.*$", line, re.I)
        m6 = re.match("^.*inet6 ([a-f0-9:]+)/.*$", line, re.I)
        if m and m.group(1):
            ips.append(m.group(1))
        if m6 and m6.group(1):
            ips.append(m6.group(1))
    return ips

def get_nonlocal_ip_list(ipv4_only = True):
    ip_list = []

    for ip in getIPs():
        if ip == "127.0.0.1" or ip == "::1" or ip.startswith("fe80"):
            continue
        if ipv4_only and ip.find(':') > 0 and ip.find('.') == -1:
            continue
        ip_list.append(ip)

    if len(ip_list) == 0:
        print "No network interface card on this server has " \
                "a valid IP address assigned.\nMake sure the " \
                "network for this server is configured properly."

    return ip_list


def get_vstorage_client_ip(opts):
    """Retrieve ip address of client from 'vstorage stat' output.
       Here we assume that this address can be used to contact cluster.
       Return ip address in string format or None"""
    stderr = open("/dev/null", "w")
    stat = get_output("%s -c %s stat -X" % (vstorage_bin, opts.cluster), stderr=stderr)[0]
    stderr.close()

    cluster_ips = set(re.findall('<host>([^\<host>]*)', stat))
    node_ips = set(get_nonlocal_ip_list())

    client_ip = node_ips & cluster_ips
    if not client_ip:
        return None

    return client_ip.pop()

def configure_snmpd(opts):
    try:
        f = open("/etc/snmp/snmpd.local.conf", "a")

        try:
            sample = "rwcommunity" + "\tparallels" + "\t%s" + "\t%s"

            # grant read/write access for localhost
            print >> f, sample % ("127.0.0.1", snmp_mib_tree)
            # grant read/write access for cluster subnet
            print >> f, sample % (opts.cluster_network, snmp_mib_tree)

        finally:
            f.close()

    except IOError:
        print "Failed to edit configuration file of the SNMP daemon"

def is_vmct_running():
    timeout = 60 * 1000

    try:
        # Connect to dispatcher
        prlsdkapi.init_server_sdk()
        server = prlsdkapi.Server()
        server.login_local().wait(msecs = timeout)

        # Get VMs list
        vms = server.get_vm_list_ex(prlsdkapi.consts.PVTF_VM + prlsdkapi.consts.PVTF_CT)

        # Get running
        running = [ r for r in vms.wait(msecs = timeout) if r.get_vm_info().get_state() == prlsdkapi.consts.VMS_RUNNING ]

        try:
            # Disconnect
            server.logoff().wait(msecs = timeout)
            prlsdkapi.deinit_sdk()
        except:
            pass
    except:
        # SMTH goes wrong, assume that no VMCTs runned
        return False

    return len(running) > 0

def mkdir_p(path):
    if os.path.isdir(path):
        return
    spath = path.split('/')
    for i in range(len(spath) + 1):
        dpath = '/'.join(spath[:i])
        if not dpath or os.path.isdir(dpath):
            continue
        os.mkdir(dpath)

# Get params
opts = parseArgs(sys.argv[1:])

# Check for network
mds_ip = get_host_ip(opts.cluster_network)
if not mds_ip:
    error("Could not determine host IP!")

# Check that any Vm/CT running
if is_vmct_running():
    error("Some VM(s)/Container(s) are running. Please stop them.")

if opts.storage_mode == "vstorage":
    if os.path.lexists(vstorage_def_path):
        error("%s default directory already exists" % vstorage_def_path)
    for p in private_parts:
        chk_dir = "%s/%s" % (part_name, p)
        try:
            os.rmdir(chk_dir)
        except:
            pass
        if os.path.lexists(chk_dir):
            error("%s dir already exists" % chk_dir)
    if not os.path.exists(vstorage_bin):
        error("%s doesn't exists" % vstorage_bin)
    if not os.path.exists(vstorage_service):
        error("%s doesn't exists" % vstorage_service)
    os.mkdir(vstorage_def_path)
else:
    # Check that /vz/already created
    if os.stat('/').st_dev != os.stat(part_name).st_dev:
        error("%s already partitioned" % part_name)

if not opts.onboot:
    stop_vz_services(ha_services, ignoreerr = True)
    stop_vz_services(vz_services)

if opts.storage_mode == "local" or opts.cs:
    csnum = 1
    for dev in parted.getAllDevices():
        if dev.sectorSize == 2048:
            print "Looks like %s is a CDROM, skipping..." % dev.path
            continue
        if dev.busy:
            continue
        if dev.getLength(unit='GiB') < minimal_cs_size:
            print "Skipped %s - too small..." % dev.path
            continue
        print "Going to attach %s to datastore..." % dev.path
        # create GPT label
        disk = parted.freshDisk(dev, "gpt")
        constraint = dev.optimalAlignedConstraint
        # create main partition
        geometry = parted.Geometry(device=dev, start=64, end=(constraint.maxSize - 1))
        # Make it as lvm
        filesystem = parted.FileSystem(type="ext4", geometry=geometry)
        partition = parted.Partition(disk=disk, fs=filesystem, type=parted.PARTITION_NORMAL, geometry=geometry)
        if opts.storage_mode == "local":
            partition.setFlag(parted.PARTITION_LVM)
        disk.addPartition(partition=partition, constraint=constraint)
        # Apply
        disk.commit()
        part = dev.path + "1"
        # Wait a bit
        re_read(dev.path, 0, opts)
        # Wait for part device
        wait_for_dev(part, 0, opts)
        # Create PV for local
        if opts.storage_mode == "local":
            get_output("/usr/bin/dd if=/dev/zero of=%s bs=16K count=1" % part)
            get_output("/usr/sbin/pvcreate %s" % part)
            created_parts.append((part_name, part))
        else:
            created_parts.append(("%s/%s-cs%i" % (vstorage_def_path, opts.cluster, csnum), part))
            csnum += 1
#        break

    if not created_parts:
        print "No empty block devices found!.."
        sys.exit(1)

    if opts.storage_mode == "local":
        print "Creating volume group %s ..." % lv_vz_name
        # Create vg
        get_output("/usr/sbin/vgcreate %s %s" % (lv_vz_name, ' '.join([p[1] for p in created_parts])))

        # Create lv
        get_output("/usr/sbin/lvcreate -L %ib --name %s %s" % \
            (int(get_output("/usr/sbin/vgs --noheadings -ovg_free --nosuffix --units b %s" % \
            lv_vz_name)[0]), lv_vz_name, lv_vz_name))

    if opts.storage_mode == "local":
        created_parts = [(part_name, lv_vz_devname)]

    for part, dev in created_parts:
        print "Formatting %s to ext4 FS..." % dev
        # Format
        get_output("/usr/sbin/mkfs.ext4 %s" % dev)

        print "Attaching %s to the system..." % part

        # Get UUID
        uuid = get_output("/usr/sbin/blkid -o value -s UUID %s" % dev)[0].rstrip('\n')

        # Write to fstab
        add_fstab("UUID=%s %s                     ext4    defaults,noatime,lazytime 1 2\n" % (uuid, part))

        # Create part dir
        if not os.path.exists(part):
            os.mkdir(part)

        # Mount
        get_output("/bin/mount %s" % part)
else:
    print "Skipped datastore creation: vstorage without CS"

# Re-copy VZ
if opts.storage_mode == "local":
    # Tmp mount
    get_output("/bin/mount -o bind / %s" % vztmp)

    print "Copying old %s data..." % part_name

    # Use default sighandler here
    signal(SIGPIPE, SIG_DFL)

    # Create pipe
    rfd, wfd = os.pipe()
    r = os.fdopen(rfd, 'r')
    w = os.fdopen(wfd, 'w')

    pid = os.fork()
    if pid == 0:
        w.close()
        rtar = tarfile.open(mode='r|', fileobj = r)
        rtar.extractall(path=part_name)
        rtar.close()
        r.close()
        os._exit(0)
    elif pid < 0:
        # Error
        error("Fork failed")

    r.close()
    wtar = tarfile.open(mode='w|', fileobj = w)
    curdir = os.getcwd()
    os.chdir("%s%s/" % (vztmp, part_name))
    for f in glob.glob("*"):
        wtar.add(name=f)
    os.chdir(curdir)
    wtar.close()
    w.close()
    os.waitpid(pid, 0)
else:
    # Create bs.list
    bs_dir = "/etc/vstorage/clusters/%s" % opts.cluster
    bs_list = "%s/bs.list" % bs_dir
    if not os.path.exists(bs_list) and opts.join_mds_ip:
        mkdir_p(bs_dir)
        f = open(bs_list, "w")
        f.write("%s:%i\n" % (opts.join_mds_ip, mds_port))
        f.close()
    # MDS
    if not opts.mds == "create":
        print "Authenticating on %s cluster..." % opts.cluster
        vstorage_auth_node(opts)
    if opts.mds:
        init_arg = ""
        pin_arg = ""
        if opts.mds == "create":
            stdin=create_vstorage_input(opts.password)
            init_arg = "--init-cluster"
            pin_arg = "-P"
            print "Creating MDS %s..." % opts.cluster
        else:
            print "Adding MDS for %s cluster..." % opts.cluster
            stdin=subprocess.PIPE
        mds_path = vstorage_def_path + "/%s-mds" % opts.cluster
        os.mkdir(mds_path)
        mds_path += "/data"
        stderr = open(vstorage_mds_log, "w")
        rc = get_output("%s -vvvvvvvv -t 90 -c %s make-mds %s %s -a %s:%i -r %s" % (vstorage_bin, opts.cluster, init_arg, pin_arg, mds_ip, mds_port, mds_path), stdin=stdin, stderr=stderr)[1]
        stderr.close()
        if rc:
            error("Failed to make-mds: %i" % rc)
        if opts.mds == "create":
            print "Creating private links on %s cluster..." % opts.cluster
            start_service("mdsd", opts, started_services, [mds_path])
            tmpdir = "%s/vstorage_m" % vztmp
            os.mkdir(tmpdir)
            stderr = open(vstorage_mount_log, "w")
            get_output("/bin/mount -t fuse.vstorage vstorage://%s %s" % (opts.cluster, tmpdir), stderr=stderr)
            stderr.close()
            for priv in private_parts:
                os.mkdir(tmpdir + "/" + priv)
            stderr = open(vstorage_mount_log, "w")
            get_output("/bin/umount %s" % tmpdir, stderr=stderr)
            stderr.close()
            # Setting replicas to default
            get_output("%s -c %s set-attr -p / -R replicas=%s" % (vstorage_bin, opts.cluster, vstorage_replicas))

    if opts.update_mds_ips:
        with open("%s/%s.service" % (unitdir, update_mds_list), 'w') as f:
            f.write("[Unit]\n" \
"Description=Update MDS IP list service\n" \
"\n" \
"[Service]\n" \
"ExecStart=/usr/sbin/%s %s\n" \
"Type=simple\n" \
"\n" \
"[Install]\n" \
"WantedBy=multi-user.target\n" % (update_mds_list, opts.cluster))
            f.close()
        with open("%s/%s.timer" % (unitdir, update_mds_list), 'w') as f:
            f.write("[Unit]\n" \
"Description=Update MDS IP list timer\n" \
"\n" \
"[Timer]\n" \
"OnActiveSec=%s\n" \
"OnUnitActiveSec=%s\n" \
"\n" \
"[Install]\n" \
"WantedBy=multi-user.target\n" % (update_mds_list_interval, update_mds_list_interval))
            f.close()

    if opts.cs:
        for part in [ p[0] for p in created_parts]:
            cspath = part + "/data"
            while (opts.timeout):
                print "Creating %s CS..." % cspath
                start_time = time.time()
                stderr = open(vstorage_cs_log, "w")
                rc = get_output("%s -vvvvvvvv -t 90 -c %s make-cs -r %s %s" % (vstorage_bin, opts.cluster, cspath, opts.join_mds_ip_p), stderr=stderr)[1]
                stderr.close()
                overal_time = time.time() - start_time
                if rc:
                    print "CS join failed: %i, will try again" % rc
                    reg_timeout -= 60
                    if overal_time > 60:
                        print "Time exceeded execution of command vstorage make-cs: %.2f sec." % overal_time
                        continue
                    time.sleep(60 - overal_time)
                    continue
                start_service("csd", opts, started_services, [cspath])
                break
            if rc:
                print "CS join failed: %i" % rc

    if opts.client:
        print "Configuring %s client..." % opts.cluster
        # TODO For now we assume that these dirs are empty
        client_dir = "%s/%s" % (vstorage_def_path, opts.cluster)
        os.mkdir(client_dir)
        for p in private_parts:
            t_dir = "%s/%s" % (client_dir, p)
            os.symlink(t_dir, "%s/%s" % (part_name, p))
        add_fstab("vstorage://%s %s fuse.vstorage _netdev 0 0\n" % (opts.cluster, client_dir))

    for srv in sorted(set(started_services)):
        stop_service(srv)

    get_output("%s daemon-reload" % systemctl, ignoreerr=True)

    # Start vstorage created services
    for srv in glob.glob("%s/vstorage*%s*.service" % (unitdir, opts.cluster)):
        get_output("%s start %s" % (systemctl, os.path.basename(srv)), ignoreerr=True)

    if opts.update_mds_ips:
        for act in ("enable", "start"):
            for srvtype in ("service", "timer"):
                get_output("%s %s %s.%s" % (systemctl, act, update_mds_list, srvtype), ignoreerr=True)

    if opts.ha:
        get_output("/bin/mount %s" % client_dir)
        get_output("%s -c %s join -r" % (shaman_bin, opts.cluster))
        get_output("/bin/umount %s" % client_dir)
        configure_snmpd(opts)

get_output("%s daemon-reload" % systemctl, ignoreerr=True)

if not opts.onboot:
    start_vz_services(vz_services)
    start_vz_services(ha_services, ignoreerr = True)

print "Cleanup..."

cleanup()

print "All done!"
