From 04b12955e29740047af11389dd74c5aba609422c Mon Sep 17 00:00:00 2001 From: Freya Gustavsson Date: Tue, 25 Jun 2024 11:54:50 +0000 Subject: [PATCH] fix: prevent duplicate backups It was possible for multiple packages to depend on the same file. As we fetch the files for each package that needs to be backed up this caused issues when rolling back a system as we try to restore it twice. This fixes that issue by making sure we don't backup to the same path multiple times Signed-off-by: Freya Gustavsson --- .../actions/pre_ponr_changes/backup_system.py | 6 ++++- convert2rhel/backup/__init__.py | 12 +++++++++- convert2rhel/backup/files.py | 16 +++++++++---- convert2rhel/unit_tests/__init__.py | 11 +++++++++ convert2rhel/unit_tests/backup/backup_test.py | 19 ++++++++++++++- convert2rhel/unit_tests/backup/files_test.py | 24 ++++++++++++++++--- convert2rhel/unit_tests/subscription_test.py | 2 +- 7 files changed, 79 insertions(+), 11 deletions(-) diff --git a/convert2rhel/actions/pre_ponr_changes/backup_system.py b/convert2rhel/actions/pre_ponr_changes/backup_system.py index d5000c1298..097f3c731a 100644 --- a/convert2rhel/actions/pre_ponr_changes/backup_system.py +++ b/convert2rhel/actions/pre_ponr_changes/backup_system.py @@ -170,7 +170,10 @@ def run(self): def _get_changed_package_files(self): """Get the output from rpm -Va command from during resolving system info to get changes made to package files. - Return them as a list of dict, for example: + + The dict itself is unique and does not have duplicate entries + + :return dict: Return them as a list of dict, for example: [{"status":"S5T", "file_type":"c", "path":"/etc/yum.repos.d/CentOS-Linux-AppStream.repo"}] """ data = [] @@ -197,6 +200,7 @@ def _get_changed_package_files(self): for line in lines: parsed_line = self._parse_line(line.strip()) + # We first check that it has a path and status, otherwise nothing to backup if parsed_line["path"] and parsed_line["status"]: data.append(parsed_line) diff --git a/convert2rhel/backup/__init__.py b/convert2rhel/backup/__init__.py index 86db6b0a4e..73ea984cf9 100644 --- a/convert2rhel/backup/__init__.py +++ b/convert2rhel/backup/__init__.py @@ -58,7 +58,7 @@ class BackupController: """ def __init__(self): - self._restorables = [] + self._restorables = [] # type: list[RestorableChange] self._rollback_failures = [] def push(self, restorable): @@ -70,6 +70,13 @@ def push(self, restorable): if not isinstance(restorable, RestorableChange): raise TypeError("`%s` is not a RestorableChange object" % restorable) + # Check if the restorable is already backed up + # if it is, we skip it + for r in self._restorables: + if r == restorable: + loggerinst.debug("Skipping: {} has already been backed up".format(restorable.__class__.__name__)) + return + restorable.enable() self._restorables.append(restorable) @@ -150,6 +157,9 @@ def rollback_failures(self): """ return self._rollback_failures + def __len__(self): + return len(self._restorables) + @six.add_metaclass(abc.ABCMeta) class RestorableChange: diff --git a/convert2rhel/backup/files.py b/convert2rhel/backup/files.py index 2e8bf4eb19..8ef936cbb4 100644 --- a/convert2rhel/backup/files.py +++ b/convert2rhel/backup/files.py @@ -45,7 +45,7 @@ def __init__(self, filepath): raise TypeError("Path must be a file not a directory.") self.filepath = filepath - self._backup_path = None + self.backup_path = None def enable(self): """Save current version of a file""" @@ -57,8 +57,8 @@ def enable(self): if os.path.isfile(self.filepath): try: backup_path = self._hash_backup_path() + self.backup_path = backup_path shutil.copy2(self.filepath, backup_path) - self._backup_path = backup_path loggerinst.debug("Copied %s to %s." % (self.filepath, backup_path)) except (OSError, IOError) as err: # IOError for py2 and OSError for py3 @@ -129,10 +129,10 @@ def restore(self, rollback=True): return # Possible exceptions will be handled in the BackupController - shutil.copy2(self._backup_path, self.filepath) + shutil.copy2(self.backup_path, self.filepath) if rollback: # Remove the backed up file only when processing rollback - os.remove(self._backup_path) + os.remove(self.backup_path) if rollback: loggerinst.info("File %s restored." % self.filepath) @@ -150,6 +150,14 @@ def remove(self): except (OSError, IOError): loggerinst.debug("Couldn't remove restored file %s" % self.filepath) + def __eq__(self, value): + if hash(self) == hash(value): + return True + return False + + def __hash__(self): + return hash(self.filepath) if self.filepath else super(RestorableFile, self).__hash__() + class MissingFile(RestorableChange): """ diff --git a/convert2rhel/unit_tests/__init__.py b/convert2rhel/unit_tests/__init__.py index 095ef6bfaf..e1b7154d2a 100644 --- a/convert2rhel/unit_tests/__init__.py +++ b/convert2rhel/unit_tests/__init__.py @@ -869,6 +869,17 @@ def restore(self): super(MinimalRestorable, self).restore() +class FilePathRestorable(MinimalRestorable): + def __init__(self, filepath=None): + self.backup_path = filepath + super(FilePathRestorable, self).__init__() + + def __eq__(self, value): + if self.backup_path: + return self.backup_path == value.backup_path + return super(FilePathRestorable, self).__eq__(value) + + class ErrorOnRestoreRestorable(MinimalRestorable): def __init__(self, exception=None): self.exception = exception or Exception() diff --git a/convert2rhel/unit_tests/backup/backup_test.py b/convert2rhel/unit_tests/backup/backup_test.py index cf24d3aed8..d50c929189 100644 --- a/convert2rhel/unit_tests/backup/backup_test.py +++ b/convert2rhel/unit_tests/backup/backup_test.py @@ -3,7 +3,7 @@ import pytest from convert2rhel import backup -from convert2rhel.unit_tests import ErrorOnRestoreRestorable, MinimalRestorable +from convert2rhel.unit_tests import ErrorOnRestoreRestorable, FilePathRestorable, MinimalRestorable @pytest.fixture @@ -29,6 +29,23 @@ def test_pop(self, backup_controller, restorable): assert popped_restorable is restorable assert restorable.called["restore"] == 1 + def test_backup_same_paths(self, backup_controller): + restorable1 = FilePathRestorable("samepath") + restorable2 = FilePathRestorable("samepath") + + backup_controller.push(restorable1) + backup_controller.push(restorable2) + + # If we backup the same filepath we should ignore the next one + assert restorable1.backup_path == restorable2.backup_path + + # Same path on both restorables means only one gets backed up + assert len(backup_controller) == 1 + assert len(backup_controller.pop_all()) == 1 + + assert restorable1.called["restore"] == 1 + assert restorable2.called["restore"] == 0 + def test_pop_multiple(self, backup_controller): restorable1 = MinimalRestorable() restorable2 = MinimalRestorable() diff --git a/convert2rhel/unit_tests/backup/files_test.py b/convert2rhel/unit_tests/backup/files_test.py index b74679be24..f75de4dc9f 100644 --- a/convert2rhel/unit_tests/backup/files_test.py +++ b/convert2rhel/unit_tests/backup/files_test.py @@ -202,7 +202,7 @@ def test_restorable_file_restore(self, tmpdir, caplog, messages, enabled, rollba file_backup = RestorableFile(orig_file_path) file_backup.enabled = enabled - file_backup._backup_path = backedup_file_path + file_backup.backup_path = backedup_file_path file_backup.restore(rollback=rollback) # Check if the correct messages printed @@ -233,7 +233,7 @@ def test_restorable_file_missing_backup(self, tmpdir, pretend_os): file_backup = RestorableFile(orig_file_path) file_backup.enabled = True - file_backup._backup_path = backedup_file_path + file_backup.backup_path = backedup_file_path # Check if the exception is raised when the file is missing in the backup folder with pytest.raises(OSError): @@ -252,7 +252,7 @@ def test_restorable_file_missing_backup(self, tmpdir, pretend_os): file_backup = RestorableFile(orig_file_path) file_backup.enabled = True - file_backup._backup_path = backedup_file_path + file_backup.backup_path = backedup_file_path # Check if the exception is raised when the file is missing in the backup folder with pytest.raises(IOError): @@ -307,6 +307,24 @@ def test_hash_backup_path(self, filepath, tmpdir, monkeypatch): assert os.path.exists(os.path.dirname(result)) assert result == expected + @pytest.mark.parametrize( + ("filepath1", "filepath2"), + ( + ("/test.txt", "/test.txt"), + ("/another/directory/file.txt", "/test.txt"), + ), + ) + def test___eq___works(self, filepath1, filepath2, tmpdir, monkeypatch): + backup_dir = str(tmpdir) + monkeypatch.setattr(files, "BACKUP_DIR", backup_dir) + file1 = RestorableFile(filepath1) + file2 = RestorableFile(filepath2) + + if filepath1 == filepath2: + assert file1 == file2 + else: + assert file1 != file2 + class TestMissingFile: @pytest.mark.parametrize( diff --git a/convert2rhel/unit_tests/subscription_test.py b/convert2rhel/unit_tests/subscription_test.py index 8c58e6a7c4..961580a4cb 100644 --- a/convert2rhel/unit_tests/subscription_test.py +++ b/convert2rhel/unit_tests/subscription_test.py @@ -405,7 +405,7 @@ def test_register_system_os_release_fail(self, monkeypatch, tmpdir, caplog): os_release_file.enable() # Remove the file from the backup and orig path, so there will be failure during restoring the file - os.remove(os_release_file._backup_path) + os.remove(os_release_file.backup_path) os.remove(str(os_release_path)) ### Test the register system