#!/usr/bin/python3
# group: quick
#
# Test vz-cprsave/vz-cprload QMP commands with various devices
#
# Copyright (c) 2025 Virtuozzo International GmbH.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import logging
import os
import re
from tempfile import NamedTemporaryFile

import iotests

disk_path = os.path.join(iotests.test_dir, 'disk')

sock_srv_re = re.compile('(path|fd)=.*,server=on')
def sub_fd_cpr(cmdarg: str) -> str:
    return cmdarg if not ',server=on' in cmdarg \
        else re.sub(sock_srv_re, r'\g<1>=vz-cpr-fd,server=on', cmdarg)


class TestVzCpr(iotests.QMPTestCase):
    def setUp(self) -> None:
        iotests.qemu_img_create('-f', iotests.imgfmt, disk_path, '64M')
        self.vm = iotests.VM()
        self.vm.add_args('-S', '-cpr-ram-rdonly')
        self.vm.add_drive(disk_path, opts='node-name=drive0-node', interface='none')
        self.cprfile = NamedTemporaryFile(dir=iotests.test_dir, delete=False).name

    def tearDown(self) -> None:
        if os.path.isfile(disk_path):
            os.remove(disk_path)
        if os.path.isfile(self.cprfile):
            os.remove(self.cprfile)

    def _recreate_cprfile(self) -> None:
        if self.cprfile and os.path.isfile(self.cprfile):
            os.remove(self.cprfile)
        self.cprfile = NamedTemporaryFile(dir=iotests.test_dir, delete=False).name

    def _test_cpr(self, pre_cpr_hook=None) -> None:
        self.vm.launch()
        cpr_cmdline = ' '.join(map(sub_fd_cpr, self.vm._qemu_full_args))

        self.vm.qmp_check_log('cont')

        if pre_cpr_hook is not None:
            pre_cpr_hook()

        assert os.path.isfile(self.cprfile)
        self.vm.qmp_check_log('vz-cprsave', file=self.cprfile,
                          mode='restart', cmdline=cpr_cmdline)
        self.vm.qmp_check_log('vz-cprload', file=self.cprfile)

        self._recreate_cprfile()
        assert os.path.isfile(self.cprfile)

        self.vm.qmp_check_log('vz-cprsave', file=self.cprfile,
                          mode='restart', cmdline=cpr_cmdline)
        self.vm.qmp_check_log('vz-cprload', file=self.cprfile)

        self.vm.shutdown()

    def test_virtio_blk_hd(self) -> None:
        self.vm.add_args('-device', 'virtio-blk-pci,drive=drive0')
        self._test_cpr()

    def test_virtio_scsi_hd(self) -> None:
        self.vm.add_args('-device', 'virtio-scsi-pci')
        self.vm.add_args('-device', 'scsi-hd,drive=drive0')
        self._test_cpr()

    def test_usb_hd(self) -> None:
        self.vm.add_args('-device', 'nec-usb-xhci,id=usb')
        self.vm.add_args('-device', 'usb-storage,drive=drive0,bus=usb.0')
        self._test_cpr()

    def test_virtio_net(self) -> None:
        self.vm.add_args('-device', 'virtio-net-pci')
        self._test_cpr()

    def test_two_network_ifs(self) -> None:
        self.vm.add_args('-device', 'virtio-net-pci')
        self.vm.add_args('-device', 'e1000')
        self._test_cpr()

    def test_usb_tablet(self) -> None:
        self.vm.add_args('-device', 'nec-usb-xhci,id=usb',
                         '-device', 'usb-tablet,bus=usb.0')
        self._test_cpr()

    def test_ide_cd(self) -> None:
        self.vm.add_args('-device', 'ide-cd,drive=drive0')
        self._test_cpr()

    def test_scsi_cd(self) -> None:
        self.vm.add_args('-device', 'virtio-scsi-pci')
        self.vm.add_args('-device', 'scsi-cd,drive=drive0')
        self._test_cpr()

    def test_floppy(self) -> None:
        self.vm.add_args('-device', 'floppy,drive=drive0')
        self._test_cpr()

    def test_null_block(self) -> None:
        self.vm.add_args('-blockdev', 'driver=null-co,read-zeroes=on,node-name=null')
        self.vm.add_args('-device', 'virtio-blk-pci,drive=null')
        self._test_cpr()

    def test_chardev_serial(self) -> None:
        self.vm.add_args('-chardev', 'null,id=char0',
                         '-device', 'isa-serial,chardev=char0')
        self._test_cpr()

    def test_chardev_srv_unix(self) -> None:
        chr_sock = iotests.file_path('chr_sock')
        self.vm.add_args('-chardev',
                         f'socket,id=char0,path={chr_sock},server=on,wait=off')
        self._test_cpr()

    def test_nbd_attached_hd(self) -> None:
        nbd_img = os.path.join(iotests.test_dir, 'nbd_disk')
        iotests.qemu_img_create('-f', 'qcow2', nbd_img, '64M')

        nbd_sock = iotests.file_path('nbd.sock', base_dir=iotests.sock_dir)
        with iotests.change_log_level('qemu.iotests.diff_io', level=logging.WARNING), \
             iotests.qemu_nbd_popen('--persistent', '--socket', nbd_sock,
                                    '--export-name=default', '-f', 'qcow2', nbd_img):
            self.vm.add_args('-blockdev', 'driver=nbd,server.type=unix,' +
                             f'server.path={nbd_sock},export=default,node-name=nbd0')
            self.vm.add_args('-device', 'virtio-blk-pci,drive=nbd0')
            self._test_cpr()

        if os.path.isfile(nbd_img):
            os.remove(nbd_img)
        if os.path.exists(str(nbd_sock)):
            os.remove(str(nbd_sock))

    def test_live_external_snapshot(self) -> None:
        snapshot_path = os.path.join(iotests.test_dir, 'snapshot.qcow2')
        def create_snapshot():
            self.vm.qmp_check_log('blockdev-snapshot-sync', node_name='drive0-node',
                              snapshot_file=snapshot_path, format='qcow2',
                              snapshot_node_name='drive0-snap-node')
        self._test_cpr(pre_cpr_hook=create_snapshot)
        if os.path.isfile(snapshot_path):
            os.remove(snapshot_path)

    def test_vnc_server(self) -> None:
        self.vm.add_args('-vnc', ':1')
        self._test_cpr()


if __name__ == '__main__':
    iotests.activate_logging()
    iotests.main(supported_fmts=["qcow2"],
                 supported_protocols=["file"])
