#!/usr/bin/python3

# Copyright 2024 Virtuozzo International GmbH.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import argparse
import contextlib
import datetime
import io
import json
import logging
import os
import re
import secrets
import shutil
import signal
import socket
import subprocess
import sys
import time
import traceback

import libvirt
import libvirt_qemu


RETURN_CODE_OK = 0
RETURN_CODE_CANCELED = 100
RETURN_CODE_ERROR = 255

QEMU_STORAGE_CMD = "/usr/bin/qemu-storage-daemon"
QEMU_IMG_CMD = "/usr/bin/qemu-img"
BACKUPS_DIR = "/var/lib/cinder/qcow2_backups/backups"
ARCHIVE_DIR = "/var/lib/cinder/qcow2_backups/archive"
BACKUP_CLIENT_BIN = "/usr/libexec/vz_backup_client"
NAME_REGEX = re.compile("^[a-zA-Z0-9_-]+$")

NODE_UPDATE_FLAG_FILE = '/run/.node_is_being_updated'

LOG = logging.getLogger()


def setup_logger(path=None):
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s:%(levelname)s:%(name)s:%(message)s")
    if path:
        file_handler = logging.FileHandler(path)
        formatter = logging.Formatter(
            "%(asctime)s:%(levelname)s:%(name)s:%(message)s")
        file_handler.setFormatter(formatter)
        logging.basicConfig(handler=file_handler)
        for handler in LOG.handlers:
            LOG.removeHandler(handler)
        LOG.addHandler(file_handler)
        LOG.info("Logger configured")


def redirect_streams():
    stdout = sys.stdout
    sys.stdout = FileLikeLoggerWrapper(logging.getLogger("stdout"))
    stdout.close()
    stderr = sys.stderr
    sys.stderr = FileLikeLoggerWrapper(logging.getLogger("stderr"))
    stderr.close()
    sys.stdin.close()


class FileLikeLoggerWrapper:
    def __init__(self, logger):
        self.logger = logger
        self.buffer = ""

    def write(self, text):
        lines = (self.buffer + text).splitlines()
        if not text.endswith("\n"):
            self.buffer = lines[-1].strip()
            lines = lines[:-1]
        for line in text.splitlines():
            line = line.strip()
            if line:
                self.logger.info(line.strip())

    def flush(self):
        if self.buffer:
            self.logger.info(self.buffer)
            self.buffer = ""

    def close(self):
        pass


def validate_name(name):
    if re.match(NAME_REGEX, name):
        return name
    raise argparse.ArgumentTypeError(f"invalid UUID: {name}")


class QMPClientMixin:
    def query_block(self):
        return self.call({"execute": "query-block"})

    def query_block_jobs(self):
        return self.call({"execute": "query-block-jobs"})

    def cancel_block_job(self, device):
        return self.call({"execute": "block-job-cancel",
                          "arguments": {"device": device}})

    def complete_block_job(self, device):
        return self.call({"execute": "block-job-complete",
                          "arguments": {"device": device}})

    def dismiss_block_job(self, device):
        return self.call({"execute": "block-job-dismiss",
                          "arguments": {"id": device}})

    def query_named_block_nodes(self, flat=True):
        return self.call({"execute": "query-named-block-nodes",
                          "arguments": {"flat": flat}})

    def start_backup(self, pit, from_pit, blocks, persistent=False):
        add_new_bitmap = [
            {"type": "block-dirty-bitmap-add", "data": {
                "name": pit, "persistent": persistent, "node": block["name"]}}
            for block in blocks]
        backups = [
            {"type": "drive-backup", "data": {
                "auto-dismiss": False, "compress": True,
                "mode": "existing", "sync": "full",
                "device": block["name"], "job-id": block["name"],
                "target": block["file"]}}
            for block in blocks]
        if from_pit:
            from_pit_clone = from_pit + '-clone'
            self.clone_bitmap(
                [block["name"] for block in blocks], from_pit, from_pit_clone)
            for backup in backups:
                backup["data"]["sync"] = "incremental"
                backup["data"]["bitmap"] = from_pit_clone
        cmd = {
            "execute": "transaction",
            "arguments": {
                "actions": add_new_bitmap + backups,
            }
        }
        return self.call(cmd)

    def get_backup_progress(self, devs):
        rv = {}
        total_len = 0
        total_progress = 0
        rv["jobs"] = self.query_block_jobs()
        jobs = {job.get("device"): job for job in rv["jobs"]
                if job.get("device") in devs}
        for dev in devs:
            job = jobs.get(dev)
            if not job:
                rv[dev] = None
                continue
            if (not job or "len" not in job or "offset" not in job
                    or "status" not in job or job.get("type") != "backup"):
                rv[dev] = None
                continue
            rv[dev] = int(job["offset"] / job["len"] * 100
                          if job["len"] else 100)
            if job["status"] != "concluded":
                # unless job is finished, progress is no more than 99
                rv[dev] = min(rv[dev], 99)
            total_len += job["len"]
            total_progress += job["offset"]
        if any(val is None for val in rv.values()):
            rv["total"] = None
            return rv
        rv["total"] = int(total_progress / total_len * 100
                          if total_len else 100)
        return rv

    def finish_backup(self, pit, from_pit, devs):
        jobs = self.query_block_jobs()
        jobs = {job.get("device"): job for job in jobs}
        for dev in devs:
            job = jobs.get(dev)
            if not jobs or job["type"] != "backup":
                raise Exception(f"Job not found for dev '{dev}'")
            self.dismiss_block_job(dev)

        jobs = self.query_block_jobs()
        LOG.debug('Jobs after finishing a backup: %s', json.dumps(jobs))

    def abort_backup_jobs(self):
        jobs = self.query_block_jobs()
        for job in jobs:
            if job["type"] == "backup":
                try:
                    self.cancel_block_job(job["device"])
                except Exception as e:
                    LOG.error("unable to cancel block job: %s", e)
                    self.dismiss_block_job(job["device"])
        jobs = self.query_block_jobs()
        LOG.debug('Jobs after aborting a backup: %s', json.dumps(jobs))

    def remove_bitmaps(self, blocks, keep=None, prefix=""):
        keep = set(keep or [])
        for block in blocks.values():
            for bitmap in block["bitmaps"]:
                name = bitmap.get("name")
                if name and name not in keep and name.startswith(prefix):
                    self.remove_bitmap(block["name"], name)

    def rename_bitmap(self, blocks, name, new_name):
        add_new_bitmap = [
            {"type": "block-dirty-bitmap-add", "data": {
                "name": new_name, "persistent": True, "node": block}}
            for block in blocks]
        copy_bitmap = [
            {"type": "block-dirty-bitmap-merge", "data": {
                "target": new_name, "node": block, "bitmaps": [name]}}
            for block in blocks]
        remove_old_bitmap = [
            {"type": "block-dirty-bitmap-remove", "data": {
                "name": name, "node": block}}
            for block in blocks]
        cmd = {
            "execute": "transaction",
            "arguments": {
                "actions": add_new_bitmap + copy_bitmap + remove_old_bitmap,
            }
        }
        return self.call(cmd)

    def clone_bitmap(self, blocks, name, new_name, persistent=False):
        add_new_bitmap = [
            {"type": "block-dirty-bitmap-add", "data": {
                "name": new_name, "persistent": persistent, "node": block}}
            for block in blocks]
        copy_bitmap = [
            {"type": "block-dirty-bitmap-merge", "data": {
                "target": new_name, "node": block, "bitmaps": [name]}}
            for block in blocks]
        cmd = {
            "execute": "transaction",
            "arguments": {
                "actions": add_new_bitmap + copy_bitmap,
            }
        }
        return self.call(cmd)

    def remove_bitmap(self, device, bitmap):
        return self.call({"execute": "block-dirty-bitmap-remove",
                          "arguments": {"node": device, "name": bitmap}})

    def add_bitmap(self, device, bitmap):
        return self.call({"execute": "block-dirty-bitmap-add",
                          "arguments": {"node": device, "name": bitmap}})

    def get_nodes(self, files):
        rv = {}
        nodes = self.query_named_block_nodes()
        for node in nodes:
            if node.get("file") not in files:
                continue
            if node.get("drv") != "qcow2":
                continue
            bitmaps = node.get("dirty-bitmaps", [])
            name = node["node-name"]
            rv[name] = {
                "name": name,
                "bitmaps": bitmaps,
                "file": files[node["file"]]["target"]
            }

        return rv


class DomainQMPClient(QMPClientMixin):
    """Libvirt API QMP client"""
    def __init__(self, domain):
        self.domain_id = domain
        conn = libvirt.open("qemu:///system")
        self.domain = conn.lookupByName(self.domain_id)

    def call(self, cmd):
        LOG.debug("Domain QMP call: %s", cmd)
        resp = json.loads(libvirt_qemu.qemuMonitorCommand(
            self.domain, json.dumps(cmd), 0))
        LOG.debug("Domain QMP response: %s", resp)
        if "error" in resp:
            raise Exception(resp)
        return resp["return"]


class QMPError(Exception):
    pass


class QMPConnectionError(QMPError):
    pass


class QMPNotFound(QMPError):
    pass


class EmptyQMPResponse(QMPError):
    pass


class IncorrectQMPResponse(QMPError):
    pass


class ErrorQMPResponse(QMPError):
    pass


class UnixSocketQMPClient(QMPClientMixin):
    """Unix socket QMP client"""

    def __init__(self, path):
        self._client_id = f"client-{secrets.token_hex(6)}-"
        self._call_id = 0
        self.path = path

    def connect(self, wait=0.1, attempts=1):
        for _attempt in range(attempts):
            try:
                self._connect()
                break
            except QMPError:
                if _attempt == attempts - 1:
                    raise
                time.sleep(wait)

    def _connect(self):
        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            self.sock.connect(self.path)
        except BrokenPipeError as e:
            raise QMPConnectionError(str(e))
        except FileNotFoundError as e:
            raise QMPNotFound(str(e))
        self.sockf = self.sock.makefile(mode="rw", buffering=2)
        # handshake
        json.loads(self.sockf.readline())
        self.call({"execute": "qmp_capabilities"})

    def call(self, cmd):
        call_id = self._call_id = self._call_id + 1
        cmd["id"] = self._client_id + str(call_id)
        LOG.debug("UnixSocket QMP call: %s", cmd)
        json.dump(cmd, self.sockf)
        while True:
            resp = self.sockf.readline()
            if not resp:
                raise EmptyQMPResponse("Empty answer")
            try:
                LOG.debug("UnixSocket QMP response: %s", resp)
                resp = json.loads(resp)
            except ValueError:
                raise IncorrectQMPResponse(f"Value is not json: {resp}")
            if resp.get("id") != cmd["id"]:
                continue
            if "error" in resp:
                raise ErrorQMPResponse(resp["error"])
            return resp["return"]


def get_process_cmdline(pid):
    """Get process cmdline for pid from /proc."""

    cmdline_path = os.path.join("/proc", str(pid), "cmdline")
    process_cmdline = None
    try:
        with open(cmdline_path, "r") as f:
            process_cmdline = f.read().split("\x00")
    except FileNotFoundError:
        pass
    return process_cmdline


def check_backup_process(backup_status):
    """Check backup process pid and cmdline.

    It makes sure that expected process is still running.
    Set error return code if expected process is not found.
    """

    pid = backup_status["pid"]
    status_cmdline = backup_status["cmdline"]
    process_cmdline = get_process_cmdline(pid)
    if process_cmdline != status_cmdline:
        backup_status["rc"] = RETURN_CODE_ERROR
        backup_status["error"] = "Backup process not found."


def get_backup_info(path, detailed=False):
    """Get backup information as a dict.

    Parse backup record directory content and return
    information about backup as a dict.
    """

    args_path = os.path.join(path, "args.json")
    if not os.path.exists(args_path):
        print(f"file '{args_path}' not found", file=sys.stderr)
        return

    backup_info = {}

    try:
        with open(args_path, "r") as f:
            backup_args = json.load(f)
            backup_info["id"] = backup_args["pit"]
            backup_info["domain"] = backup_args["domain"]
            volumes = backup_args.get("volumes")
            volumes = volumes.split(",") if volumes else []
            backup_info["volumes"] = volumes
            backup_info["targets"] = dict(
                target.split(":")
                for target in backup_args["targets"])
            if detailed:
                backup_info["args"] = backup_args
    except Exception as exc:
        print(f"unable to read '{args_path}': {exc}", file=sys.stderr)
        return

    result_path = os.path.join(path, "result.json")
    progress_path = os.path.join(path, "progress.json")
    if not os.path.exists(progress_path):
        print(f"file '{result_path}' not found", file=sys.stderr)
        return

    try:
        with open(progress_path, "r") as f:
            backup_status = json.load(f)
            if not os.path.exists(result_path):
                check_backup_process(backup_status)
    except Exception as exc:
        print(f"unable to read '{progress_path}': {exc}", file=sys.stderr)
        return

    try:
        if os.path.exists(result_path):
            # to avoid race between result appearing and process
            # dissapearing, we check result file twice
            with open(result_path, "r") as f:
                backup_status = json.load(f)

        if not detailed:
            backup_status.pop("extra", None)
            backup_status.pop("cmdline", None)
        backup_info["status"] = backup_status
    except Exception as exc:
        print(f"unable to read '{result_path}': {exc}", file=sys.stderr)
        return

    return backup_info


def show_backup(args):
    if not os.path.exists(BACKUPS_DIR):
        print(f"directory '{BACKUPS_DIR}' not found", file=sys.stderr)
        return

    abs_path = os.path.join(BACKUPS_DIR, args.pit)
    if not os.path.exists(abs_path):
        print(f"directory '{abs_path}' not found", file=sys.stderr)
        return

    print(get_backup_info(abs_path, detailed=True))


def abort_backup(args):
    """Terminate backup client process if it is running."""
    abs_path = os.path.join(BACKUPS_DIR, args.pit)
    backup_info = get_backup_info(abs_path, detailed=True)
    status = backup_info["status"]
    if status["pid"] and status["rc"] is None:
        os.kill(status["pid"], 15)
        for i in range(args.wait * 10):
            check_backup_process(status)
            if status["rc"] is None:
                time.sleep(0.1)
        print(json.dumps(get_backup_info(abs_path)))
        return
    print(json.dumps(backup_info))


def abort_backup_for_domain(args):
    backups = list_records(args, BACKUPS_DIR)
    for backup in backups:
        if not backup:
            continue
        if backup["domain"] == args.domain:
            args.pit = backup["id"]
            abort_backup(args)


def abort_backup_for_volume(args):
    backups = list_records(args, BACKUPS_DIR)
    for backup in backups:
        if not backup:
            continue
        if args.volume in backup["volumes"]:
            args.pit = backup["id"]
            abort_backup(args)


def abort_all_backups(args):
    backups = list_records(args, BACKUPS_DIR)
    for backup in backups:
        if not backup:
            continue
        args.pit = backup["id"]
        abort_backup(args)


def list_records(args, dir_path):
    """Print backup records information stored in the directory."""
    try:
        files = os.listdir(dir_path)
    except FileNotFoundError:
        print(f"directory '{BACKUPS_DIR}' not found", file=sys.stderr)
        return []

    backups = []
    for path in files:
        abs_path = os.path.join(dir_path, path)
        if not os.path.isdir(abs_path):
            continue

        backup_info = get_backup_info(abs_path)
        backups.append(backup_info)
    return backups


def list_backups(args):
    backups = list_records(args, BACKUPS_DIR)
    print(json.dumps(backups))


def list_archive(args):
    backups = list_records(args, ARCHIVE_DIR)
    print(json.dumps(backups))


class BackupContext:
    def __init__(self):
        self.args = None
        self.progress = 0
        self.progress_extra = None
        self.terminate = False
        self.backup_process = None
        self.rc = None
        self.error = None
        self.pid = os.getpid()
        self.cmdline = get_process_cmdline("self")
        self.started_at = datetime.datetime.utcnow().isoformat()
        self.finished_at = None

    def get_status(self):
        return {
            "pid": self.pid,
            "cmdline": self.cmdline,
            "rc": self.rc,
            "error": self.error,
            "progress": self.progress,
            "extra": self.progress_extra,
            "canceled": self.terminate,
            "started_at": self.started_at,
            "finished_at": self.finished_at,
        }


@contextlib.contextmanager
def open_safe_write(path, mode):
    tmppath = path + "-tmp"
    with open(tmppath, mode) as f:
        yield f
        f.flush()
        os.fsync(f.fileno())
    os.rename(tmppath, path)


def write_status(path, status):
    with open_safe_write(path, "w") as f:
        json.dump(status, f)


def backup_images_using_qmp(args, context):
    progress_file = args.progress_file
    client = UnixSocketQMPClient(args.socket_file)
    # wait qmp socket to start up to 3 seconds
    client.connect(attempts=30, wait=0.1)

    targets = {}
    for target in args.targets:
        path, _sep, target = target.partition(":")
        targets[path] = {
            "target": target,
        }

    nodes = client.get_nodes(targets)
    new_pit = args.pit + "-new"
    client.start_backup(new_pit, args.from_pit, nodes.values())
    attempts = 4
    while True:
        try:
            progress = client.get_backup_progress(nodes)
            attempts = 4
        except QMPError:
            if not attempts:
                break
            client = UnixSocketQMPClient(args.socket_file)
            client.connect()
            attempts -= 1
            continue

        context.progress_extra = {"jobs": progress["jobs"]}
        for node in nodes.keys():
            if progress[node] is None:
                context.rc = RETURN_CODE_ERROR
                context.error = f"Block Job disappeared for '{node}'"
                break
        backup_progress = progress["total"]
        if progress_file and backup_progress != context.progress:
            context.progress = backup_progress
            write_status(progress_file, context.get_status())

        if progress["total"] == 100:
            context.rc = RETURN_CODE_OK
            context.error = ""
            break
        for i in range(10):
            if context.terminate:
                break
            time.sleep(0.1)

    client.finish_backup(new_pit, args.from_pit, nodes)

    if args.dry_run:
        # New temporary bitmap is not stored on disk.
        # No need to do anything.
        return

    client.rename_bitmap(nodes, new_pit, args.pit)
    keep_bitmaps = {args.pit}
    client.remove_bitmaps(
        client.get_nodes(targets), keep_bitmaps, args.cleanup_prefix)


def start_qemu_process(args):
    """Start qemu storage process with a socket and requested images"""

    sock_path = args.socket_file
    qemu_cmdline = [
        QEMU_STORAGE_CMD, "--chardev",
        f"socket,path={sock_path},server=on,wait=off,id=char1",
        "--monitor", "chardev=char1"]
    for num, target in enumerate(args.targets):
        qcow2_path, _sep, target = target.partition(":")
        qemu_cmdline.extend([
            "--blockdev",
            f"driver=file,filename={qcow2_path},"
            f"node-name=target-storage-{num}",
            "--blockdev",
            f"driver=qcow2,file=target-storage-{num},"
            f"node-name=target-{num}"])
    return subprocess.Popen(
        qemu_cmdline, stdout=subprocess.PIPE,
        stderr=subprocess.PIPE, text=True)


def backup_images(args):
    """Backup images that are not attached to any running domain"""
    try:
        context = BackupContext()
        setup_logger(args.log_file)
        redirect_streams()
        progress_file = args.progress_file
        if progress_file:
            write_status(progress_file, context.get_status())

        if os.path.exists(NODE_UPDATE_FLAG_FILE):
            context.rc = RETURN_CODE_CANCELED
            context.error = "Upgrade is in progress."
            write_status(args.result_file, context.get_status())
            return

        def signal_handler(_signo, _frame):
            if context.terminate:
                # already terminating
                return

            context.terminate = True
            if context.backup_process:
                context.backup_process.terminate()

        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

        context.args = args
        backup_process = context.backup_process = start_qemu_process(args)
        if context.terminate:
            # avoid race between signal and proc start
            backup_process.terminate()

        try:
            backup_images_using_qmp(args, context)
        except Exception:
            LOG.exception('Backup failed')
            if context.terminate:
                context.rc = RETURN_CODE_CANCELED
            else:
                context.rc = RETURN_CODE_ERROR

            context.error = traceback.format_exc()
            if context.terminate:
                context.error += "backup canceled"

        backup_process.terminate()
        proccess_error = ""
        try:
            _out, proccess_error = backup_process.communicate(timeout=15)
        except subprocess.TimeoutExpired:
            # qemu process took our volumes, we want it dead
            backup_process.kill()
            try:
                _out, proccess_error = backup_process.communicate(timeout=15)
            except Exception:
                proccess_error = traceback.format_exc()

        context.error = "\n".join(
            err for err in (context.error, proccess_error) if err)

        if progress_file:
            # for example context.terminate was updated
            write_status(progress_file, context.get_status())

        if args.result_file:
            context.finished_at = datetime.datetime.utcnow().isoformat()
            write_status(
                args.result_file, context.get_status())
    except Exception:
        context.error = ''.join(context.error, traceback.format_exc())
        write_status(args.result_file, context.get_status())


def backup_domain(args):
    """Backup images attached to a running domain"""
    # backup domain using backup client utility
    context = BackupContext()
    setup_logger(args.log_file)
    redirect_streams()
    progress_file = args.progress_file
    if progress_file:
        write_status(progress_file, context.get_status())

    if os.path.exists(NODE_UPDATE_FLAG_FILE):
        context.rc = RETURN_CODE_CANCELED
        context.error = "Upgrade is in progress."
        write_status(args.result_file, context.get_status())
        return

    def signal_handler(_signo, _frame):
        if context.terminate:
            # already terminating
            return

        context.terminate = True
        if context.backup_process:
            context.backup_process.terminate()

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

    context.args = args
    cmd = [BACKUP_CLIENT_BIN] + context.args.extra
    backup_process = context.backup_process = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    if context.terminate:
        # avoid race between signal and proc start
        backup_process.terminate()

    stdout = backup_process.stdout
    new_progress = context.progress
    line = io.StringIO()
    while c := stdout.read(1):
        if c not in "\n\r":
            line.write(c)
            continue

        # example of progress line:  "Written 10%"
        line = line.getvalue()
        if line and line[-1] == "%":
            new_progress = line.split()[-1]
            try:
                new_progress = int(new_progress[:-1])
            except ValueError:
                new_progress = None

        if new_progress and progress_file and context.progress != new_progress:
            context.progress = new_progress
            write_status(progress_file, context.get_status())

        line = io.StringIO()

    # for example context.terminate was updated
    write_status(progress_file, context.get_status())

    context.error = backup_process.communicate()[1]
    context.rc = backup_process.returncode
    if context.terminate:
        context.rc = RETURN_CODE_CANCELED

    if args.result_file:
        context.finished_at = datetime.datetime.utcnow().isoformat()
        write_status(
            args.result_file, context.get_status())


def wait_for_file(path, wait=0.1, attempts=50):
    for i in range(attempts):
        if os.path.exists(path):
            break
        if i == attempts - 1:
            raise FileNotFoundError(path)
        time.sleep(wait)


def _get_backup_commandline(pit, command):
    """Returns cmdline of same utility wrapped in systemd-run"""
    return [
        "systemd-run", "--slice", "vstorage-compute-storage",
        "--unit", f"cinder-backup-{pit}",
        sys.executable, os.path.abspath(__file__), command]


def start_backup_using_client(args):
    pit = f"{args.prefix}:{args.pit}" if args.prefix else args.pit
    from_pit = f"{args.prefix}:{args.from_pit}" if args.prefix else args.from_pit

    cmdargs = _get_backup_commandline(args.pit, "backup-domain")
    cmdargs += [
        "--log-file", args.log_file,
        "--result-file", args.result_file,
        "--progress-file", args.progress_file,
        "--", "append" if args.from_pit else "create",
        "--name", args.domain_name,
        "--cached",
        "--pit", from_pit if args.from_pit else pit,
    ]
    if args.from_pit:
        cmdargs += ["--rename-pit", pit]
    cmdargs += ["--cleanup-prefix",
                args.prefix or args.from_pit or ""]
    if args.dry_run:
        cmdargs += ["--dry-run"]
    for target in args.targets:
        disk_path, _sep, target_path = target.partition(":")
        cmdargs += [
            "--image", f"{disk_path}::{target_path}"
        ]

    # start demonized process
    p = subprocess.Popen(cmdargs, start_new_session=True)
    if args.progress_file:
        try:
            wait_for_file(args.progress_file)
        except FileNotFoundError:
            p.terminate()
            raise


def start_backup_using_qemu(args):
    pit = f"{args.prefix}:{args.pit}" if args.prefix else args.pit
    from_pit = f"{args.prefix}:{args.from_pit}" if args.prefix else args.from_pit

    cmdargs = _get_backup_commandline(args.pit, "backup-images")
    cmdargs += [
        "--log-file", args.log_file,
        "--result-file", args.result_file,
        "--progress-file", args.progress_file,
        "--socket-file", args.socket_file,
        pit,
    ]
    if args.domain:
        cmdargs += ["--domain", args.domain]
    if args.from_pit:
        cmdargs += ["--from-pit", from_pit]
    cmdargs += ["--cleanup-prefix", args.prefix or args.from_pit or ""]
    if args.dry_run:
        cmdargs += ["--dry-run"]

    for target in args.targets:
        cmdargs += args.targets

    # start demonized process
    p = subprocess.Popen(cmdargs, start_new_session=True)
    if args.progress_file:
        try:
            wait_for_file(args.progress_file)
        except FileNotFoundError:
            p.terminate()
            raise


def start_backup(args):
    prepare_backup_info(args)
    if args.domain:
        conn = libvirt.open("qemu:///system")
        try:
            domain = conn.lookupByUUIDString(args.domain)
            args.domain_name = domain.name()
        except libvirt.libvirtError as ex:
            error_code = ex.get_error_code()
            if error_code != libvirt.VIR_ERR_NO_DOMAIN:
                raise
        else:
            if domain.isActive():
                return start_backup_using_client(args)

    return start_backup_using_qemu(args)


def list_bitmaps_using_client(args):
    cmd = [
        BACKUP_CLIENT_BIN, "bitmaps",
        "--name", args.domain_name,
        "--image", args.target]

    output = subprocess.check_output(cmd, text=True)
    data = json.loads(output)
    bitmaps = data["image"]["bitmaps"] or []
    print(json.dumps(bitmaps))


def list_bitmaps_using_qemu_img(args):
    cmd = [
        QEMU_IMG_CMD, "info", args.target,
        "--output", "json"]

    output = subprocess.check_output(cmd, text=True)
    data = json.loads(output)
    bitmaps = data.get("format-specific", {}).get(
        "data", {}).get("bitmaps", [])
    bitmaps = [bitmap["name"] for bitmap in bitmaps
               if "in-use" not in bitmap.get("flags", [])
               and bitmap.get("name")]
    print(json.dumps(bitmaps))


def list_bitmaps(args):
    if args.domain:
        conn = libvirt.open("qemu:///system")
        try:
            domain = conn.lookupByUUIDString(args.domain)
            args.domain_name = domain.name()
        except libvirt.libvirtError as ex:
            error_code = ex.get_error_code()
            if error_code != libvirt.VIR_ERR_NO_DOMAIN:
                raise
        else:
            if domain.isActive():
                return list_bitmaps_using_client(args)

    return list_bitmaps_using_qemu_img(args)


def remove_bitmaps_using_client(args):
    for pit in args.pits:
        cmd = [BACKUP_CLIENT_BIN, "remove_bitmap",
               "--name", args.domain_name,
               "--image", args.target,
               "--pit", pit]
        rv = subprocess.run(cmd)
        if rv.returncode:
            sys.exit(rv.returncode)
    sys.exit(0)


def remove_bitmaps_using_qemu_img(args):
    for pit in args.pits:
        cmd = [QEMU_IMG_CMD, "bitmap", args.target, "--remove", pit]
        rv = subprocess.run(cmd)
        if rv.returncode:
            sys.exit(rv.returncode)
    sys.exit(0)


def remove_bitmaps(args):
    if args.domain:
        conn = libvirt.open("qemu:///system")
        try:
            domain = conn.lookupByUUIDString(args.domain)
            args.domain_name = domain.name()
        except libvirt.libvirtError as ex:
            error_code = ex.get_error_code()
            if error_code != libvirt.VIR_ERR_NO_DOMAIN:
                raise
        else:
            if domain.isActive():
                return remove_bitmaps_using_client(args)

    return remove_bitmaps_using_qemu_img(args)


def prepare_backup_info(args):
    params = vars(args)
    params.pop("run_command")

    backup_dir = os.path.join(BACKUPS_DIR, args.pit)
    os.mkdir(backup_dir)

    args_file = os.path.join(backup_dir, "args.json")
    write_status(args_file, params)

    args.progress_file = os.path.join(backup_dir, "progress.json")
    args.result_file = os.path.join(backup_dir, "result.json")
    args.log_file = os.path.join(backup_dir, "log")
    args.socket_file = os.path.join(backup_dir, "socket")


def remove_backup_record(args):
    abs_path = os.path.join(BACKUPS_DIR, args.pit)
    if not os.path.exists(abs_path):
        print(f"directory '{abs_path}' not found", file=sys.stderr)
        return

    backup_info = get_backup_info(abs_path)
    if backup_info:
        if backup_info["status"]["rc"] is None:
            print("backup is not finished", file=sys.stderr)
            sys.exit(1)
    shutil.rmtree(abs_path)


def archive_backup_record(args):
    abs_path = os.path.join(BACKUPS_DIR, args.pit)
    if not os.path.exists(abs_path):
        print(f"directory '{abs_path}' not found", file=sys.stderr)
        return

    backup_info = get_backup_info(abs_path)
    if not backup_info:
        print(f"backup record '{abs_path}' corrupted", file=sys.stderr)
        return

    postfix = format(int(time.time()), "x")
    archive_path = os.path.join(ARCHIVE_DIR, f"{args.pit}-{postfix}")
    os.rename(abs_path, archive_path)


def remove_archive(args):
    current_time = time.time()
    try:
        files = os.listdir(ARCHIVE_DIR)
    except FileNotFoundError:
        print(f"directory '{BACKUPS_DIR}' not found", file=sys.stderr)
        return []

    for path in files:
        abs_path = os.path.join(ARCHIVE_DIR, path)
        if not os.path.isdir(abs_path):
            continue

        if os.path.islink(abs_path):
            os.unlink(abs_path)
            continue

        mtime = os.path.getmtime(abs_path)
        if current_time - mtime <= args.before:
            continue

        sub_files = os.listdir(abs_path)
        for file_path in sub_files:
            file_abs_path = os.path.join(abs_path, file_path)
            os.unlink(file_abs_path)
        os.rmdir(abs_path)


def main():
    parser = argparse.ArgumentParser()
    command_parser = parser.add_subparsers(dest="command")
    list_backups_parser = command_parser.add_parser("list")
    list_backups_parser.set_defaults(run_command=list_backups)

    list_archive_parser = command_parser.add_parser("list-archive")
    list_archive_parser.set_defaults(run_command=list_archive)

    remove_archive_parser = command_parser.add_parser("remove-archive")
    remove_archive_parser.set_defaults(run_command=remove_archive)
    remove_archive_parser.add_argument(
        "--before", metavar="<seconds>", type=int, default=60 * 60 * 24 * 7)

    backup_images_parser = command_parser.add_parser("backup-images")
    backup_images_parser.set_defaults(run_command=backup_images)
    backup_images_parser.add_argument("--domain", dest="domain")
    backup_images_parser.add_argument(
        "--progress-file", metavar="<PATH>")
    backup_images_parser.add_argument(
        "--result-file", metavar="<PATH>")
    backup_images_parser.add_argument(
        "--log-file", metavar="<PATH>")
    backup_images_parser.add_argument(
        "--socket-file", metavar="<PATH>")
    backup_images_parser.add_argument("--from-pit")
    backup_images_parser.add_argument("--cleanup-prefix", default="")
    backup_images_parser.add_argument(
        "--dry-run", action="store_true")
    backup_images_parser.add_argument("pit")
    backup_images_parser.add_argument("targets", nargs="+")

    backup_domain_parser = command_parser.add_parser("backup-domain")
    backup_domain_parser.set_defaults(run_command=backup_domain)
    backup_domain_parser.add_argument(
        "--progress-file", metavar="<PATH>")
    backup_domain_parser.add_argument(
        "--result-file", metavar="<PATH>")
    backup_domain_parser.add_argument(
        "--log-file", metavar="<PATH>")
    backup_domain_parser.add_argument(
        "extra", nargs="*", metavar="<BACKUP ARGS>")

    start_backup_parser = command_parser.add_parser("start-backup")
    start_backup_parser.set_defaults(run_command=start_backup)
    start_backup_parser.add_argument("--domain")
    start_backup_parser.add_argument("--volumes")
    start_backup_parser.add_argument("--from-pit", type=validate_name)
    start_backup_parser.add_argument("--prefix", default="")
    start_backup_parser.add_argument("--dry-run", action="store_true")
    start_backup_parser.add_argument("pit", type=validate_name)
    start_backup_parser.add_argument("targets", nargs="+")

    list_bitmaps_parser = command_parser.add_parser("list-bitmaps")
    list_bitmaps_parser.set_defaults(run_command=list_bitmaps)
    list_bitmaps_parser.add_argument("--domain", dest="domain")
    list_bitmaps_parser.add_argument("target")

    remove_bitmaps_parser = command_parser.add_parser("remove-bitmaps")
    remove_bitmaps_parser.set_defaults(run_command=remove_bitmaps)
    remove_bitmaps_parser.add_argument("--domain")
    remove_bitmaps_parser.add_argument("target")
    remove_bitmaps_parser.add_argument("pits", nargs="+")

    show_backup_parser = command_parser.add_parser("show-backup")
    show_backup_parser.set_defaults(run_command=show_backup)
    show_backup_parser.add_argument("pit")

    abort_backup_parser = command_parser.add_parser("abort-backup")
    abort_backup_parser.set_defaults(run_command=abort_backup)
    abort_backup_parser.add_argument(
        "--wait", dest="wait", type=int, default=10)
    abort_backup_parser.add_argument("pit", type=validate_name)

    abort_backup_domain_parser = command_parser.add_parser(
        "abort-backup-for-domain")
    abort_backup_domain_parser.set_defaults(
        run_command=abort_backup_for_domain)
    abort_backup_domain_parser.add_argument(
        "--wait", dest="wait", type=int, default=10)
    abort_backup_domain_parser.add_argument(
        "domain", type=validate_name)

    abort_backup_volume_parser = command_parser.add_parser(
        "abort-backup-for-volume")
    abort_backup_volume_parser.set_defaults(
        run_command=abort_backup_for_volume)
    abort_backup_volume_parser.add_argument(
        "--wait", dest="wait", type=int, default=10)
    abort_backup_volume_parser.add_argument(
        "volume", type=validate_name)

    abort_all_backups_parser = command_parser.add_parser(
        "abort-all-backups")
    abort_all_backups_parser.set_defaults(
        run_command=abort_all_backups)
    abort_all_backups_parser.add_argument(
        "--wait", dest="wait", type=int, default=60)

    remove_backup_parser = command_parser.add_parser("remove-backup-record")
    remove_backup_parser.set_defaults(run_command=remove_backup_record)
    remove_backup_parser.add_argument("pit", type=validate_name)

    archive_backup_parser = command_parser.add_parser("archive-backup-record")
    archive_backup_parser.set_defaults(run_command=archive_backup_record)
    archive_backup_parser.add_argument("pit", type=validate_name)

    args = parser.parse_args()
    args.run_command(args)


if __name__ == "__main__":
    main()
    sys.exit(0)
