diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 2d1a3b806..c78e4f818 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -86,6 +86,61 @@ def test___init__(self): # Assert assert instance.tables == {} assert instance.relationships == [] + assert instance._multi_table_updated is None + assert instance._multi_table_updated is None + + def test__check_metadata_updated_single_metadata_updated(self): + """Test ``_check_metadata_updated`` when a single table metadata has been updated.""" + # Setup + instance = MultiTableMetadata() + instance.tables['table_1'] = Mock() + instance.tables['table_2'] = Mock() + instance.tables['table_1']._updated = True + instance.tables['table_2']._updated = False + + # Run + instance._check_updated_flag() + + # Assert + assert instance._multi_table_updated is None + assert instance._updated is True + + def test__check_metadata_updated_multi_metadata_updated(self): + """Test ``_check_metadata_updated`` method when multi table metadata has been updated.""" + # Setup + instance = MultiTableMetadata() + instance.tables['table_1'] = Mock() + instance.tables['table_2'] = Mock() + instance.tables['table_1']._updated = False + instance.tables['table_2']._updated = False + instance._multi_table_updated = True + + # Run + instance._check_updated_flag() + + # Assert + assert instance._multi_table_updated is True + assert instance._updated is True + + def test__reset_updated_flag(self): + """Test the ``_reset_updated_flag`` method.""" + # Setup + instance = MultiTableMetadata() + instance.tables['table_1'] = Mock() + instance.tables['table_2'] = Mock() + instance.tables['table_1']._updated = False + instance.tables['table_2']._updated = True + instance._multi_table_updated = True + instance._updated = True + + # Run + instance._reset_updated_flag() + + # Assert + assert instance._multi_table_updated is False + assert instance._updated is False + assert instance.tables['table_1']._updated is False + assert instance.tables['table_2']._updated is False def test__validate_missing_relationship_keys_foreign_key(self): """Test the ``_validate_missing_relationship_keys`` method of ``MultiTableMetadata``. @@ -587,6 +642,7 @@ def test_add_relationship(self): instance._validate_relationship_does_not_exist.assert_called_once_with( 'users', 'id', 'sessions', 'user_id' ) + assert instance._multi_table_updated is True def test_add_relationship_child_key_is_primary_key(self): """Test that passing a primary key as ``child_foreign_key`` crashes.""" @@ -692,6 +748,7 @@ def test_remove_relationship(self): 'child_foreign_key': 'session_id', } ] + assert instance._multi_table_updated is True @patch('sdv.metadata.multi_table.warnings') def test_remove_relationship_relationship_not_found(self, warnings_mock): @@ -786,6 +843,7 @@ def test_remove_primary_key(self, logger_mock): "'table' was removed." ) logger_mock.info.assert_has_calls([call(msg1), call(msg2)]) + assert instance._multi_table_updated is True def test__validate_column_relationships_foreign_keys(self): """Test ``_validate_column_relationships_foriegn_keys.""" @@ -1363,6 +1421,7 @@ def test_add_table(self, table_metadata_mock): # Assert assert instance.tables == {'users': table_metadata_mock.return_value} + assert instance._multi_table_updated is True def test_add_table_empty_string(self): """Test that the method raises an error if the table name is an empty string.""" @@ -2686,6 +2745,7 @@ def test_save_to_json(self, tmp_path): """ # Setup instance = MultiTableMetadata() + instance._reset_updated_flag = Mock() # Run / Assert file_name = tmp_path / 'multitable.json' @@ -2695,6 +2755,8 @@ def test_save_to_json(self, tmp_path): saved_metadata = json.load(multi_table_file) assert saved_metadata == instance.to_dict() + instance._reset_updated_flag.assert_called_once() + def test__convert_relationships(self): """Test the ``_convert_relationships`` method. diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index eb9eb2af4..83e6f55c5 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -97,7 +97,8 @@ def test__print(self, mock_print): # Assert mock_print.assert_called_once_with('Fitting', end='') - def test___init__(self): + @patch('sdv.multi_table.base.BaseMultiTableSynthesizer._check_metadata_updated') + def test___init__(self, mock_check_metadata_updated): """Test that when creating a new instance this sets the defaults. Test that the metadata object is being stored and also being validated. Afterwards, this @@ -117,6 +118,7 @@ def test___init__(self): assert isinstance(instance._table_synthesizers['upravna_enota'], GaussianCopulaSynthesizer) assert instance._table_parameters == defaultdict(dict) instance.metadata.validate.assert_called_once_with() + mock_check_metadata_updated.assert_called_once() def test___init___synthesizer_kwargs_deprecated(self): """Test that the ``synthesizer_kwargs`` method is deprecated.""" @@ -132,6 +134,27 @@ def test___init___synthesizer_kwargs_deprecated(self): with pytest.warns(FutureWarning, match=warn_message): BaseMultiTableSynthesizer(metadata, synthesizer_kwargs={}) + def test__check_metadata_updated(self): + """Test the ``_check_metadata_updated`` method.""" + # Setup + instance = Mock() + instance.metadata = Mock() + instance.metadata._check_updated_flag = Mock() + instance.metadata._reset_updated_flag = Mock() + instance.metadata._updated = True + + # Run + expected_message = re.escape( + "We strongly recommend saving the metadata using 'save_to_json' for replicability" + ' in future SDV versions.' + ) + with pytest.warns(UserWarning, match=expected_message): + BaseMultiTableSynthesizer._check_metadata_updated(instance) + + # Assert + instance.metadata._check_updated_flag.assert_called_once() + instance.metadata._reset_updated_flag.assert_called_once() + def test_set_address_columns(self): """Test the ``set_address_columns`` method.""" # Setup @@ -836,6 +859,7 @@ def test_fit(self): # Assert instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) + instance._check_metadata_updated.assert_called_once() def test_reset_sampling(self): """Test that ``reset_sampling`` resets the numpy seed and the synthesizers."""