Skip to content

Commit

Permalink
unit tests multi table
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Feb 8, 2024
1 parent 4eb291b commit 5465932
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
62 changes: 62 additions & 0 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,61 @@ def test___init__(self):
# Assert
assert instance.tables == {}
assert instance.relationships == []
assert instance._multi_table_updated is None
assert instance._multi_table_updated is None

def test__check_metadata_updated_single_metadata_updated(self):
"""Test ``_check_metadata_updated`` when a single table metadata has been updated."""
# Setup
instance = MultiTableMetadata()
instance.tables['table_1'] = Mock()
instance.tables['table_2'] = Mock()
instance.tables['table_1']._updated = True
instance.tables['table_2']._updated = False

# Run
instance._check_updated_flag()

# Assert
assert instance._multi_table_updated is None
assert instance._updated is True

def test__check_metadata_updated_multi_metadata_updated(self):
"""Test ``_check_metadata_updated`` method when multi table metadata has been updated."""
# Setup
instance = MultiTableMetadata()
instance.tables['table_1'] = Mock()
instance.tables['table_2'] = Mock()
instance.tables['table_1']._updated = False
instance.tables['table_2']._updated = False
instance._multi_table_updated = True

# Run
instance._check_updated_flag()

# Assert
assert instance._multi_table_updated is True
assert instance._updated is True

def test__reset_updated_flag(self):
"""Test the ``_reset_updated_flag`` method."""
# Setup
instance = MultiTableMetadata()
instance.tables['table_1'] = Mock()
instance.tables['table_2'] = Mock()
instance.tables['table_1']._updated = False
instance.tables['table_2']._updated = True
instance._multi_table_updated = True
instance._updated = True

# Run
instance._reset_updated_flag()

# Assert
assert instance._multi_table_updated is False
assert instance._updated is False
assert instance.tables['table_1']._updated is False
assert instance.tables['table_2']._updated is False

def test__validate_missing_relationship_keys_foreign_key(self):
"""Test the ``_validate_missing_relationship_keys`` method of ``MultiTableMetadata``.
Expand Down Expand Up @@ -587,6 +642,7 @@ def test_add_relationship(self):
instance._validate_relationship_does_not_exist.assert_called_once_with(
'users', 'id', 'sessions', 'user_id'
)
assert instance._multi_table_updated is True

def test_add_relationship_child_key_is_primary_key(self):
"""Test that passing a primary key as ``child_foreign_key`` crashes."""
Expand Down Expand Up @@ -692,6 +748,7 @@ def test_remove_relationship(self):
'child_foreign_key': 'session_id',
}
]
assert instance._multi_table_updated is True

@patch('sdv.metadata.multi_table.warnings')
def test_remove_relationship_relationship_not_found(self, warnings_mock):
Expand Down Expand Up @@ -786,6 +843,7 @@ def test_remove_primary_key(self, logger_mock):
"'table' was removed."
)
logger_mock.info.assert_has_calls([call(msg1), call(msg2)])
assert instance._multi_table_updated is True

def test__validate_column_relationships_foreign_keys(self):
"""Test ``_validate_column_relationships_foriegn_keys."""
Expand Down Expand Up @@ -1363,6 +1421,7 @@ def test_add_table(self, table_metadata_mock):

# Assert
assert instance.tables == {'users': table_metadata_mock.return_value}
assert instance._multi_table_updated is True

def test_add_table_empty_string(self):
"""Test that the method raises an error if the table name is an empty string."""
Expand Down Expand Up @@ -2686,6 +2745,7 @@ def test_save_to_json(self, tmp_path):
"""
# Setup
instance = MultiTableMetadata()
instance._reset_updated_flag = Mock()

# Run / Assert
file_name = tmp_path / 'multitable.json'
Expand All @@ -2695,6 +2755,8 @@ def test_save_to_json(self, tmp_path):
saved_metadata = json.load(multi_table_file)
assert saved_metadata == instance.to_dict()

instance._reset_updated_flag.assert_called_once()

def test__convert_relationships(self):
"""Test the ``_convert_relationships`` method.
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def test__print(self, mock_print):
# Assert
mock_print.assert_called_once_with('Fitting', end='')

def test___init__(self):
@patch('sdv.multi_table.base.BaseMultiTableSynthesizer._check_metadata_updated')
def test___init__(self, mock_check_metadata_updated):
"""Test that when creating a new instance this sets the defaults.
Test that the metadata object is being stored and also being validated. Afterwards, this
Expand All @@ -117,6 +118,7 @@ def test___init__(self):
assert isinstance(instance._table_synthesizers['upravna_enota'], GaussianCopulaSynthesizer)
assert instance._table_parameters == defaultdict(dict)
instance.metadata.validate.assert_called_once_with()
mock_check_metadata_updated.assert_called_once()

def test___init___synthesizer_kwargs_deprecated(self):
"""Test that the ``synthesizer_kwargs`` method is deprecated."""
Expand All @@ -132,6 +134,27 @@ def test___init___synthesizer_kwargs_deprecated(self):
with pytest.warns(FutureWarning, match=warn_message):
BaseMultiTableSynthesizer(metadata, synthesizer_kwargs={})

def test__check_metadata_updated(self):
"""Test the ``_check_metadata_updated`` method."""
# Setup
instance = Mock()
instance.metadata = Mock()
instance.metadata._check_updated_flag = Mock()
instance.metadata._reset_updated_flag = Mock()
instance.metadata._updated = True

# Run
expected_message = re.escape(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)
with pytest.warns(UserWarning, match=expected_message):
BaseMultiTableSynthesizer._check_metadata_updated(instance)

# Assert
instance.metadata._check_updated_flag.assert_called_once()
instance.metadata._reset_updated_flag.assert_called_once()

def test_set_address_columns(self):
"""Test the ``set_address_columns`` method."""
# Setup
Expand Down Expand Up @@ -836,6 +859,7 @@ def test_fit(self):
# Assert
instance.preprocess.assert_called_once_with(data)
instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value)
instance._check_metadata_updated.assert_called_once()

def test_reset_sampling(self):
"""Test that ``reset_sampling`` resets the numpy seed and the synthesizers."""
Expand Down

0 comments on commit 5465932

Please sign in to comment.