#!/usr/bin/python3
# group: quick
#
# Test vz-cprsave/vz-cprload QMP commands with various devices
#
# Copyright (c) 2025 Virtuozzo International GmbH.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import logging
import os
import re
from tempfile import NamedTemporaryFile

import iotests

disk_path = os.path.join(iotests.test_dir, 'disk')

sock_srv_re = re.compile('(path|fd)=.*,server=on')
def sub_fd_cpr(cmdarg: str) -> str:
    return cmdarg if not ',server=on' in cmdarg \
        else re.sub(sock_srv_re, r'\g<1>=vz-cpr-fd,server=on', cmdarg)


class TestVzCpr(iotests.QMPTestCase):
    def setUp(self) -> None:
        iotests.qemu_img_create('-f', iotests.imgfmt, disk_path, '64M')
        self.vm = iotests.VM()
        self.vm.add_args('-S', '-cpr-ram-rdonly')
        self.vm.add_drive(disk_path, opts='node-name=drive0-node', interface='none')
        self.cprfile = NamedTemporaryFile(dir=iotests.test_dir, delete=False).name

    def tearDown(self) -> None:
        if os.path.isfile(disk_path):
            os.remove(disk_path)
        if os.path.isfile(self.cprfile):
            os.remove(self.cprfile)

    def _recreate_cprfile(self) -> None:
        if self.cprfile and os.path.isfile(self.cprfile):
            os.remove(self.cprfile)
        self.cprfile = NamedTemporaryFile(dir=iotests.test_dir, delete=False).name

    def _test_cpr(self, pre_cpr_hook=None) -> None:
        self.vm.launch()

        self.vm.qmp_check_log('cont')

        if pre_cpr_hook is not None:
            pre_cpr_hook()

        cpr_cmdline = ' '.join(map(sub_fd_cpr, self.vm._qemu_full_args))

        assert os.path.isfile(self.cprfile)
        self.vm.qmp_check_log('vz-cprsave', file=self.cprfile,
                          mode='restart', cmdline=cpr_cmdline)
        self.vm.qmp_check_log('vz-cprload', file=self.cprfile)

        self._recreate_cprfile()
        assert os.path.isfile(self.cprfile)

        self.vm.qmp_check_log('vz-cprsave', file=self.cprfile,
                          mode='restart', cmdline=cpr_cmdline)
        self.vm.qmp_check_log('vz-cprload', file=self.cprfile)

        self.vm.shutdown()

    def test_virtio_blk_hd(self) -> None:
        self.vm.add_args('-device', 'virtio-blk-pci,drive=drive0')
        self._test_cpr()

    def test_virtio_scsi_hd(self) -> None:
        self.vm.add_args('-device', 'virtio-scsi-pci')
        self.vm.add_args('-device', 'scsi-hd,drive=drive0')
        self._test_cpr()

    def test_usb_hd(self) -> None:
        self.vm.add_args('-device', 'nec-usb-xhci,id=usb')
        self.vm.add_args('-device', 'usb-storage,drive=drive0,bus=usb.0')
        self._test_cpr()

    def test_virtio_net(self) -> None:
        self.vm.add_args('-device', 'virtio-net-pci')
        self._test_cpr()

    def test_two_network_ifs(self) -> None:
        self.vm.add_args('-device', 'virtio-net-pci')
        self.vm.add_args('-device', 'e1000')
        self._test_cpr()

    def test_usb_tablet(self) -> None:
        self.vm.add_args('-device', 'nec-usb-xhci,id=usb',
                         '-device', 'usb-tablet,bus=usb.0')
        self._test_cpr()

    def test_ide_cd(self) -> None:
        self.vm.add_args('-device', 'ide-cd,drive=drive0')
        self._test_cpr()

    def test_scsi_cd(self) -> None:
        self.vm.add_args('-device', 'virtio-scsi-pci')
        self.vm.add_args('-device', 'scsi-cd,drive=drive0')
        self._test_cpr()

    def test_floppy(self) -> None:
        self.vm.add_args('-device', 'floppy,drive=drive0')
        self._test_cpr()

    def test_null_block(self) -> None:
        self.vm.add_args('-blockdev', 'driver=null-co,read-zeroes=on,node-name=null')
        self.vm.add_args('-device', 'virtio-blk-pci,drive=null')
        self._test_cpr()

    def test_chardev_serial(self) -> None:
        self.vm.add_args('-chardev', 'null,id=char0',
                         '-device', 'isa-serial,chardev=char0')
        self._test_cpr()

    def test_chardev_srv_unix(self) -> None:
        chr_sock = iotests.file_path('chr_sock')
        self.vm.add_args('-chardev',
                         f'socket,id=char0,path={chr_sock},server=on,wait=off')
        self._test_cpr()

    def test_nbd_attached_hd(self) -> None:
        nbd_img = os.path.join(iotests.test_dir, 'nbd_disk')
        iotests.qemu_img_create('-f', 'qcow2', nbd_img, '64M')

        nbd_sock = iotests.file_path('nbd.sock', base_dir=iotests.sock_dir)
        with iotests.change_log_level('qemu.iotests.diff_io', level=logging.WARNING), \
             iotests.qemu_nbd_popen('--persistent', '--socket', nbd_sock,
                                    '--export-name=default', '-f', 'qcow2', nbd_img):
            self.vm.add_args('-blockdev', 'driver=nbd,server.type=unix,' +
                             f'server.path={nbd_sock},export=default,node-name=nbd0')
            self.vm.add_args('-device', 'virtio-blk-pci,drive=nbd0')
            self._test_cpr()

        if os.path.isfile(nbd_img):
            os.remove(nbd_img)
        if os.path.exists(str(nbd_sock)):
            os.remove(str(nbd_sock))

    def test_live_external_snapshot(self) -> None:
        snapshot_path = os.path.join(iotests.test_dir, 'snapshot.qcow2')
        def create_snapshot():
            self.vm.qmp_check_log('blockdev-snapshot-sync', node_name='drive0-node',
                              snapshot_file=snapshot_path, format='qcow2',
                              snapshot_node_name='drive0-snap-node')
        self._test_cpr(pre_cpr_hook=create_snapshot)
        if os.path.isfile(snapshot_path):
            os.remove(snapshot_path)

    def test_vnc_server(self) -> None:
        self.vm.add_args('-vnc', ':1')
        self._test_cpr()

    def test_vcpu_hotplug(self) -> None:
        def hotplug_vcpu():
            cpus = self.vm.qmp('query-hotpluggable-cpus')['return']
            assert len(cpus) == 2
            cpu = next(c for c in cpus if 'qom-path' not in c)
            self.vm.qmp_check_log('device_add', driver=cpu['type'],
                            id='cpu1', **cpu['props'])

            # Change '-smp 1,maxcpus=2' to '-smp 2' for new QEMU on cpr-exec
            args = list(self.vm._qemu_full_args)
            args[args.index('-smp') + 1] = '2'
            self.vm._qemu_full_args = tuple(args)

        self.vm.add_args('-smp', '1,maxcpus=2')
        self._test_cpr(pre_cpr_hook=hotplug_vcpu)

    def test_memory_hotplug(self) -> None:
        def hotplug_dimm() -> None:
            self.vm.qmp_check_log('object-add', qom_type='memory-backend-ram',
                                  id='mem1', size=1024 * 1024 * 1024)
            self.vm.qmp_check_log('device_add', driver='pc-dimm',
                                  id='dimm1', memdev='mem1')

            # Add the DIMM to _qemu_full_args so the new QEMU on cpr-exec
            # has the device present when loading migration state
            args = list(self.vm._qemu_full_args)
            args += ['-object', 'memory-backend-ram,id=mem1,size=1G',
                     '-device', 'pc-dimm,id=dimm1,memdev=mem1']
            self.vm._qemu_full_args = tuple(args)

        self.vm.add_args('-m', '1G,slots=2,maxmem=2G',
                         '-object', 'memory-backend-ram,id=ram0,size=1G',
                         '-numa', 'node,nodeid=0,memdev=ram0')
        self._test_cpr(pre_cpr_hook=hotplug_dimm)


if __name__ == '__main__':
    iotests.activate_logging()
    iotests.main(supported_fmts=["qcow2"],
                 supported_protocols=["file"])
