#!/usr/bin/python2

import os
import sys
import re
import subprocess
import threading
import shutil
import tempfile
import datetime
import signal
import time

import parted
import _ped
import string
import argparse

class Logging:
    def __init__(self, loglevel=1, logfile="/var/log/prepare_vstorage_drive.log"):
        if logfile == "":
            self.logfile_fd = 0
        else:
            self.logfile_fd = open(logfile, "a+")
        self.loglevel = loglevel

    def __del__(self):
        if self.logfile_fd:
            self.logfile_fd.close()

    def log_it(self, message, loglevel):
        if self.logfile_fd:
            self.logfile_fd.write("%s" % datetime.datetime.today() + "\t" + message + "\n")
        if self.loglevel >= loglevel:
            print message

    def debug(self, message):
        self.log_it(message, 4)

    def info(self, message):
        self.log_it(message, 3)

    def warning(self, message):
        self.log_it(message, 2)

    def error(self, message):
        self.log_it(message, 1)


def isEfi():
    return os.path.exists("/sys/firmware/efi")


class thread(threading.Thread):
    def __init__(self, inputd, outputd, logmethod, command):
        threading.Thread.__init__(self)
        self.inputd = os.fdopen(inputd, "r")
        self.outputd = outputd
        self.logmethod = logmethod
        self.running = True
        self.command = command

    def run(self):
        while self.running:
            try:
                data = self.inputd.readline()
            except IOError:
                self.logmethod("Failed to read from pipe during a call to %s."
                               % self.command)
                break
            if data == "":
                self.running = False
                continue

            self.logmethod(data.rstrip('\n'))
            os.write(self.outputd, data)

    def stop(self):
        self.running = False
        return self


def execRedirect(command, argv, stdin=None, stdout=None, stderr=None):

    stdinclose = stdoutclose = stderrclose = lambda: None

    argv = list(argv)
    if isinstance(stdin, str):
        if os.access(stdin, os.R_OK):
            stdin = os.open(stdin, os.O_RDONLY)
            stdinclose = lambda: os.close(stdin)
        else:
            stdin = sys.stdin.fileno()
    elif isinstance(stdin, int):
        pass
    elif stdin is None or not isinstance(stdin, file):
        stdin = sys.stdin.fileno()

    orig_stdout = stdout
    if isinstance(stdout, str):
        stdout = os.open(stdout, os.O_RDWR | os.O_CREAT)
        stdoutclose = lambda: os.close(stdout)
    elif isinstance(stdout, int):
        pass
    elif stdout is None or not isinstance(stdout, file):
        stdout = sys.stdout.fileno()

    if isinstance(stderr, str) and isinstance(orig_stdout, str) and stderr == orig_stdout:
        stderr = stdout
    elif isinstance(stderr, str):
        stderr = os.open(stderr, os.O_RDWR | os.O_CREAT)
        stderrclose = lambda: os.close(stderr)
    elif isinstance(stderr, int):
        pass
    elif stderr is None or not isinstance(stderr, file):
        stderr = sys.stderr.fileno()

    program_log.info("Running command... %s" % ([command] + argv,))

    pstdout, pstdin = os.pipe()
    perrout, perrin = os.pipe()

    env = os.environ.copy()
    # Set C locale
    env.update({"LC_ALL": "C"})

    try:
        proc_std = thread(pstdout, stdout, program_log.info, command)
        proc_err = thread(perrout, stderr, program_log.error, command)

        proc_std.start()
        proc_err.start()

        proc = subprocess.Popen([command] + argv, stdin=stdin,
                                stdout=pstdin,
                                stderr=perrin,
                                cwd="/",
                                env=env)

        proc.wait()
        ret = proc.returncode

        os.close(pstdin)
        os.close(perrin)

        proc_std.join()
        del proc_std

        proc_err.join()
        del proc_err

        stdinclose()
        stdoutclose()
        stderrclose()
    except OSError as e:
        errstr = "Error running command %s: %s" % (command, e.strerror)
        log.error(errstr)
        program_log.error(errstr)
        os.close(pstdin)
        os.close(perrin)
        proc_std.join()
        proc_err.join()

        stdinclose()
        stdoutclose()
        stderrclose()
        raise RuntimeError(errstr)

    return ret


def execCapture(command, argv, stdin=None, stderr=None, fatal=False):

    def closefds():
        stdinclose()
        stderrclose()

    stdinclose = stderrclose = lambda: None
    rc = ""
    argv = list(argv)

    if isinstance(stdin, str):
        if os.access(stdin, os.R_OK):
            stdin = os.open(stdin, os.O_RDONLY)
            stdinclose = lambda: os.close(stdin)
        else:
            stdin = sys.stdin.fileno()
    elif isinstance(stdin, int):
        pass
    elif stdin is None or not isinstance(stdin, file):
        stdin = sys.stdin.fileno()

    if isinstance(stderr, str):
        stderr = os.open(stderr, os.O_RDWR | os.O_CREAT)
        stderrclose = lambda: os.close(stderr)
    elif isinstance(stderr, int):
        pass
    elif stderr is None or not isinstance(stderr, file):
        stderr = sys.stderr.fileno()

    program_log.info("Running command... %s" % ([command] + argv,))

    env = os.environ.copy()
    # Set C locale
    env.update({"LC_ALL": "C"})

    try:
        proc = subprocess.Popen([command] + argv, stdin=stdin,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                cwd="/",
                                env=env)

        while True:
            (outStr, errStr) = proc.communicate()
            if outStr:
                map(program_log.info, outStr.splitlines())
                rc += outStr
            if errStr:
                map(program_log.error, errStr.splitlines())
                os.write(stderr, errStr)

            if proc.returncode is not None:
                break
        if proc.returncode and fatal:
            raise OSError(proc.returncode, errStr)
    except OSError as e:
        log.error("Error running command " + command + ": " + e.strerror)
        closefds()
        raise RuntimeError("Error running command " + command + ": " + e.strerror)

    closefds()
    return rc


def execWithTimeout(timeout, func, *args):
    while timeout > 0:
        try:
            if func(*args):
                return True
        except Exception as e:
            log.error("Can't execute function %s" % (func.__name__))

        time.sleep(1)
        timeout -= 1

    return False


def create_input(data):
    pipe = os.pipe()
    os.write(pipe[1], data + '\n')
    os.close(pipe[1])
    return pipe[0]


class PrepareHDD:
    def __init__(self, device, out="/dev/null", err="/dev/null", loglevel=3, ssd=False, boot=True, fs_uuid=None):
        self.device = device
        self.product_name = "Virtuozzo Server"
        self.startpart = 64
        self.boot = boot
        if self.boot:
            self.bootsize = 512 * 1024 * 1024 / 512  # 512MB in sectors
        else:
            self.bootsize = self.startpart
        # FAT label is only 11 symbols long...
        self.bootlabel = "PCSBoot"
        self.bootuuid = "Undefined"
        self.bootlabel_efi = "PCSBootEFI"
        self.bootuuid_efi = "UndefinedUUID"
        self.bootlabel_installed = False
        # TODO: generate unique label for disk. This label will used by hotplug in future
        self.vstorage_cs_label = "vstorage-hotplug"
        self.stdout = out
        self.stderr = err
        self.tempdir = tempfile.mkdtemp(prefix="PrepareHDD_", dir="/tmp")
        self.mpath = self.tempdir + "/boot_cs"
        self.device_map = self.tempdir + "/device.map"
        self.mounted = False
        self.log = Logging(loglevel, logfile="")
        self.ssd = ssd
        self.fs_uuid=fs_uuid

    def check(self):
        if not os.path.exists(self.device):
            self.log.error("Given device %s does not exist" % self.device)
            return False

        return True

    def check_partition(self):
        if not os.path.exists(self.device):
            self.log.error("Given partition %s does not exist" % self.device)
            return False

        return True

    def __del__(self):
        try:
            self.umount()
            os.rmdir(self.mpath)
            os.unlink(self.device_map)
            os.rmdir(self.tempdir)
        except OSError:
            return

    def umount(self):
        if not self.mounted:
            return True

        # Umount
        rc = execRedirect("/bin/umount", [self.mpath],
                          stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to mount %s GRUB disk: %i..." % (self.device, rc))
            return False

        self.mounted = False
        return True

    def zero_device(self):
        # Zero mbr + first sectors + end
        self.log.info("Zeroing out beginning and end of %s..." % self.device)
        fd = None

        try:
            fd = os.open(self.device, os.O_RDWR)
            buf = '\0' * 1024 * 1024
            os.write(fd, buf)
            os.lseek(fd, -1024 * 1024, 2)
            os.write(fd, buf)
            os.close(fd)
        except Exception as e:
            if getattr(e, "errno", None) != 28:  # No space left in device
                self.log.error("error zeroing out %s: %s" % (self.device, e))
            if fd:
                os.close(fd)

            return False

        return True

    def wait_for_device(self, device):
        self.log.info("Waiting for kernel...")
	self.log.info("Waiting for kernel device {}".format(device))

        # Wait for device
        if execWithTimeout(60, os.path.exists, device):
            return True

        self.log.error("Failed to wait for %s disk" % device)
        return False

    def udev_settle(self):
        self.log.info("Settle udev...")

        rc = execRedirect("/usr/sbin/udevadm", ["settle"],
                          None, self.stdout, self.stderr)
        if rc:
            self.log.error("Failed settle udev")
            return False

        return True

    def exec_disk_operation(self, command, argv, stdin=None, stdout=None, stderr=None):
        # Wait until the device is released by other processes and then execute disk operation
        if self.udev_settle():
            rc = execRedirect(command, argv, stdin, stdout, stderr)
            if not rc:
                return True

        self.log.error("Failed to execute disk operation %s for %s." % (command, self.device))
        return False

    def commit_disk_partitions(self, disk):
        try:
            if self.udev_settle():
                disk.commit()
                return True
        except Exception, e:
            self.log.error("Failed to commit modifications to disk %s." % self.device)

        return False

    def prepare_vstorage_partition(self):
        if not self.zero_device():
            return False

        if self.is_mounted():
            self.log.error("Failed to prepare partition: %s are mounted..." % (self.device))
            return False

        return self.prepare_data(self.device)

    def prepare_vstorage_data_disk(self):
        if not self.zero_device():
            return False

        try:
            self.log.info("Partitioning %s..." % self.device)
            device = parted.Device(self.device)
            # create GPT label
            disk = parted.freshDisk(device, "gpt")
            constraint = device.optimalAlignedConstraint
            if self.boot:
                # create 1 boot partition
                geometry = parted.Geometry(device=device,
                                           start=self.startpart,
                                           end=(self.bootsize - 1))
                # Make it fat
                partition_ped = _ped.Partition(disk.getPedDisk(),
                                               parted.PARTITION_NORMAL,
                                               geometry.start, geometry.end,
                                               _ped.file_system_type_get("fat32"))
                partition_ped.set_name("EFI System Partition")
                partition = parted.Partition(PedPartition=partition_ped)
                disk.addPartition(partition=partition, constraint=constraint)

            # create main partition
            geometry = parted.Geometry(device=device,
                                       start=self.bootsize,
                                       end=(constraint.maxSize - 1))
            # Make it as ext4
            filesystem = parted.FileSystem(type="ext4", geometry=geometry)
            partition = parted.Partition(disk=disk, fs=filesystem,
                                         type=parted.PARTITION_NORMAL, geometry=geometry)
            disk.addPartition(partition=partition, constraint=constraint)

            # Apply mofications to disk with several attempts
            committed = execWithTimeout(10, self.commit_disk_partitions, disk)
            if not committed:
                self.log.error("Failed to commit modifications to disk %s" % (self.device))
                return False

        except Exception, e:
            self.log.error("Failed to repartition disk %s:\n%s" % (self.device, e))
            return False
        self.wait_for_device("%sp1" % self.device if self.device[-1] in string.digits else "%s1" % self.device)

        if self.boot:
            os.unlink("%s1" % self.device)
            # Set active flag for GPT due to buggy BIOSes
            execRedirect("/sbin/fdisk", [self.device],
                         stdin=create_input("a\n1\nw"),
                         stdout=self.stdout,
                         stderr=self.stderr)
            # Ignore exit code, due to kernel lag, doesn't matter will be partition re-readed

            self.wait_for_device("%s1" % self.device)

            # Set boot flag on EFI GPT partition
            # Due to parted bug it can't be properly set on partition creation...
            device = parted.Device(self.device)
            disk = parted.Disk(device)
            partition = disk.getPartitionByPath("%s1" % self.device)
            partition.setFlag(parted.PARTITION_BOOT)
            disk.commitToDevice()

            # Prepare boot
            if not self.prepare_boot("%s1" % self.device):
                return False
            return self.prepare_data("%s2" % self.device)

        return self.prepare_data("%sp1" % self.device if self.device[-1] in string.digits else "%s1" % self.device)

    def get_device_for_mpoint(self, mpoint):
        bootdev = ""

        f = open("/proc/mounts", 'r')
        for line in f.read().splitlines():
            if re.search(r"^/dev/[a-zA-Z0-9]+ %s " % mpoint, line):
                bootdev = line.split()[0]
                break

        f.close()

        return bootdev

    def is_mounted(self):
        with open("/proc/mounts", 'r') as f:
            for line in f.readlines():
                if re.search(r"^%s\s" % self.device, line):
                    return True
        return False

    def set_bootlabel(self, device):
        if self.bootlabel == execCapture("/sbin/e2label", [device],
                                         stderr=os.path.join(self.stderr)).rstrip("\n"):
            return True

        self.log.info("Creating %s label..." % device)
        rc = execRedirect("/sbin/e2label", [device, self.bootlabel],
                          stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to set boot label for %s: %i..." % (device, rc))
            return False

        return True

    def set_bootlabel_efi(self, device):
        if self.bootlabel_efi == execCapture("/sbin/dosfslabel", [device],
                                             stderr=os.path.join(self.stderr)).rstrip("\n"):
            return True

        self.log.info("Creating %s label..." % device)
        rc = execRedirect("/sbin/dosfslabel", [device, self.bootlabel_efi],
                          stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to set EFI boot label for %s: %i..." % (device, rc))
            return False

        return True

    def get_uuid(self, bootdev):
        blkid = execCapture("/sbin/blkid", ["-o", "export", bootdev],
                            stderr=self.stderr)
        uuid = ""
        for uuid in blkid.split("\n"):
            if uuid.startswith("UUID="):
                uuid = re.sub("^UUID=", "", uuid)
                break
        if uuid == "":
            self.log.error("Failed to get UUID for %s disk" % bootdev)
            return ""

        return uuid.lower()

    def boot_label(self):

        self.log.info("Detecting labels...")
        bootdev = self.get_device_for_mpoint("/boot")

        if bootdev == "":
            bootdev = self.get_device_for_mpoint("/")
            if bootdev == "":
                self.log.error("Failed to detect boot disk")
                return False

        if not self.set_bootlabel(bootdev):
            return False

        # Get UUID
        self.bootuuid = self.get_uuid(bootdev)
        if self.bootuuid == "":
            return False

        if not isEfi():
            return True

        # EFI part
        bootdev = self.get_device_for_mpoint("/boot/efi")

        if bootdev == "":
            self.log.error("Failed to detect EFI boot disk")
            return False

        if not self.set_bootlabel_efi(bootdev):
            return False

        # Get EFI UUID
        self.bootuuid_efi = self.get_uuid(bootdev)
        if self.bootuuid_efi == "":
            return False

        return True

    def prepare_boot(self, device):
        self.log.info("Formatting %s partition..." % device)
        # Format additional /boot
        rc = execRedirect("/sbin/mkdosfs", [device],
                          stdout=os.path.join(self.stdout),
                          stderr=os.path.join(self.stderr))
        if rc:
            self.log.error("Failed to format %s disk: %i..." % (device, rc))
            return False

        # Create boot labels on main /boot
        if not self.boot_label():
            return False

        # Mount
        os.makedirs(self.mpath)
        rc = execRedirect("/bin/mount", [ "-t", "vfat", device, self.mpath ],
                          stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to mount %s disk: %i..." % (device, rc))
            return False

        self.mounted = True

        # Create bootloader data
        try:
            self.log.info("Installing bootloader to %s..." % device)
            os.mkdir(self.mpath + "/grub")
            if os.path.exists("/usr/share/grub/x86_64-unknown"):
                grub_dir = "/usr/share/grub/x86_64-unknown"
            elif os.path.exists("/usr/share/grub/x86_64-redhat"):
                grub_dir = "/usr/share/grub/x86_64-redhat"
            else:
                raise OSError("grub not found")
            shutil.copy("%s/stage1" % grub_dir, self.mpath + "/grub")
            shutil.copy("%s/stage2" % grub_dir, self.mpath + "/grub")
            if os.path.exists("/boot/grub/splash.xpm.gz"):
                shutil.copy("/boot/grub/splash.xpm.gz", self.mpath + "/grub")
            f = open(self.mpath + "/grub/grub.conf", "a+")
            f.write(
                "background FFFFFF\n"
                "foreground 000000\n"
                "default=0\n"
                "timeout=5\n"
                "splashimage=(hd0,0)/grub/splash.xpm.gz\n"
                "hiddenmenu\n"
                "title Redirect to %s main boot loader\n"
                "        uuid_label_load_mbr %s %s\n" %
                (self.product_name, self.bootuuid, self.bootlabel))
            f.close()
            f = open(self.device_map, "a+")
            f.write("(hd1) %s\n" % re.sub("[0-9]+", "", device))
            f.close()
            # Place EFI bootloader, so user will be able to chainload original EFI loader
            os.makedirs(self.mpath + "/EFI/BOOT")
            shutil.copy("/boot/efi/EFI/BOOT/BOOTX64.efi", self.mpath + "/EFI/BOOT")
            f = open(self.mpath + "/EFI/BOOT/BOOTX64.conf", "a+")
            f.write(
                "background FFFFFF\n"
                "foreground 000000\n"
                "default=0\n"
                "timeout=5\n"
                "splashimage=(hd0,0)/grub/splash.xpm.gz\n"
                "hiddenmenu\n"
                "title Redirect to %s main boot loader\n"
                "        uuid_label %s %s\n"
                "        chainloader /EFI/BOOT/BOOTX64.efi\n\n" %
                (self.product_name, self.bootuuid_efi, self.bootlabel_efi))
            f.close()
            self.umount()
        except Exception, e:
            self.log.error("Failed to prepare GRUB data for drive %s:\n%s" % (device, e))
            return False
	finally:
            self.umount()
            os.rmdir(self.mpath)

        # Install bootloader
        rc = execRedirect("/sbin/grub", ["--batch", "--device-map=%s" % self.device_map],
                          stdin=create_input("root (hd1,0)\nsetup (hd1)\nquit"),
                          stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to install GRUB on %s disk: %i..." % (device, rc))
            return False

        return True

    def prepare_data(self, device):
        self.log.info("Formatting %s partition..." % device)
        opts = [device, "-q", "-E", "lazy_itable_init=1", "-O", "uninit_bg", "-m", "0"]
        if self.fs_uuid:
            opts += ['-U', self.fs_uuid]

        # Format disk with several attempts
        formatted = execWithTimeout(10, self.exec_disk_operation,
                                    "/sbin/mkfs.ext4", opts, None,
                                    self.stdout, self.stderr)
        if not formatted:
            self.log.error("Failed to format %s disk" % device)
            return False

        # Set error behavior to remount-ro
        self.log.info("Set error behavior for %s partition..." % device)
        rc = execRedirect("/usr/sbin/tune2fs", ["-e", "remount-ro", device],
                          stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.error("Failed to set error behavior on %s disk: %i..." % (device, rc))
            return False

        # Set SCT error recovery control for disk
        self.log.info("Set error behavior for %s partition..." % device)
        rc = execRedirect("/usr/sbin/smartctl", ["-q", "errorsonly", "-l", "scterc,70,70", device],
                          stdout=self.stdout, stderr=self.stderr)
        if rc:
            self.log.warning("Failed to set SCT error recovery control on %s disk: %i..." % (device, rc))

        return True


def signal_handler(signal, frame):
    print
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)



# Set log levels to exec functions
log = Logging(1)
program_log = Logging(0)

check_option = lambda x: x in sys.argv

ssd = check_option("--ssd")
boot = not check_option("--noboot")
ask = not check_option("-y")
part = check_option("-p")

parser = argparse.ArgumentParser(description='Prepare device for vstorage services')
parser.add_argument('device', help='Drive (or partition. see -p option) path that you want to prepare for vstorage')
parser.add_argument('-y', action='store_true', help='Do not ask any questions')
parser.add_argument('--ssd', action='store_true', help='Drive is SSD')
parser.add_argument('--noboot', action='store_true', help='Do not install GRUB bootloader')
parser.add_argument('-p', action='store_true', help='Prepare only given partition')
parser.add_argument('-U', help='UUID of the newly created filesystem')

args = parser.parse_args()

ssd = args.ssd
boot = not args.noboot
ask = not args.y
part = args.p

hdd = PrepareHDD(sys.argv[1], ssd=ssd, boot=boot, fs_uuid=args.U)
if (not part and not hdd.check()) or (part and not hdd.check_partition()):
    sys.exit(1)

if ask:
    print "ALL data on %s will be completely destroyed. Are you sure to continue? [y]" % sys.argv[1]
    if raw_input("") != "y":
        sys.exit(0)

if not part:
    if not hdd.prepare_vstorage_data_disk():
        print "Prepare vstorage data disk failed"
        sys.exit(1)
else:
    if not hdd.prepare_vstorage_partition():
        print "Prepare vstorage partition failed"
        sys.exit(1)

print "Done!"

sys.exit(0)
