Skip to content

Commit

Permalink
integration tests multi table
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Feb 8, 2024
1 parent 5465932 commit 11f9b9b
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 4 deletions.
156 changes: 156 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import re
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1246,3 +1247,158 @@ def test__extract_parameters(self):
'scale': 0.
}
assert result == expected_result

def test_metadata_updated_no_warning(self, tmp_path):
"""Test scenario where no warning about metadata should be raised.
Run 1 - The medata is load from our demo datasets.
Run 2 - The metadata uses ``detect_from_dataframes`` but is saved to a file
before defining the syntheiszer.
Run 3 - The metadata is updated with a new column after the synthesizer
initialization, but is saved to a file before fitting.
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')

# Run 1
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter('always')
instance = HMASynthesizer(metadata)
instance.fit(data)

# Assert
assert len(captured_warnings) == 0

# Run 2
metadata_detect = MultiTableMetadata()
metadata_detect.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
for table_name, table_metadata in metadata.tables.items():
metadata_detect.tables[table_name].columns = table_metadata.columns
metadata_detect.tables[table_name].primary_key = table_metadata.primary_key

file_name = tmp_path / 'multitable_1.json'
metadata_detect.save_to_json(file_name)
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter('always')
instance = HMASynthesizer(metadata_detect)
instance.fit(data)

# Assert
assert len(captured_warnings) == 0

# Run 3
instance = HMASynthesizer(metadata_detect)
metadata_detect.update_column(
table_name='characters', column_name='age', sdtype='categorical')
file_name = tmp_path / 'multitable_2.json'
metadata_detect.save_to_json(file_name)
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter('always')
instance.fit(data)

# Assert
assert len(captured_warnings) == 0

def test_metadata_updated_warning_detect(self):
"""Test that using ``detect_from_dataframes`` without saving the metadata raise a warning.
The warning is expected to be raised only once during synthesizer initialization. It should
not be raised again when calling ``fit``.
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
metadata_detect = MultiTableMetadata()
metadata_detect.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
for table_name, table_metadata in metadata.tables.items():
metadata_detect.tables[table_name].columns = table_metadata.columns
metadata_detect.tables[table_name].primary_key = table_metadata.primary_key

expected_message = re.escape(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)

# Run
with pytest.warns(UserWarning, match=expected_message) as record:
instance = HMASynthesizer(metadata_detect)
instance.fit(data)

# Assert
assert len(record) == 1


parametrization = [
('update_column', {
'table_name': 'departure', 'column_name': 'city', 'sdtype': 'categorical'
}),
('set_primary_key', {'table_name': 'arrival', 'column_name': 'id_flight'}),
(
'add_column_relationship', {
'table_name': 'departure',
'relationship_type': 'address',
'column_names': ['city', 'country']
}
),
('add_alternate_keys', {'table_name': 'departure', 'column_names': ['city', 'country']}),
('set_sequence_key', {'table_name': 'departure', 'column_name': 'city'}),
('add_column', {
'table_name': 'departure', 'column_name': 'postal_code', 'sdtype': 'postal_code'
}),
]


@pytest.mark.parametrize(('method', 'kwargs'), parametrization)
def test_metadata_updated_warning(method, kwargs):
"""Test that modifying metadata without saving it raise a warning.
The warning should be raised during synthesizer initialization.
"""
metadata = MultiTableMetadata().load_from_dict({
'tables': {
'departure': {
'primary_key': 'id',
'columns': {
'id': {'sdtype': 'id'},
'date': {'sdtype': 'datetime'},
'city': {'sdtype': 'city'},
'country': {'sdtype': 'country'}
},
},
'arrival': {
'foreign_key': 'id',
'columns': {
'id': {'sdtype': 'id'},
'date': {'sdtype': 'datetime'},
'city': {'sdtype': 'city'},
'country': {'sdtype': 'country'},
'id_flight': {'sdtype': 'id'}
},
},
},
'relationships': [
{
'parent_table_name': 'departure',
'parent_primary_key': 'id',
'child_table_name': 'arrival',
'child_foreign_key': 'id'
}
]
})
expected_message = re.escape(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)

# Run
metadata.__getattribute__(method)(**kwargs)
with pytest.warns(UserWarning, match=expected_message):
HMASynthesizer(metadata)

# Assert
assert metadata._updated is False
for table_name, table_metadata in metadata.tables.items():
assert table_metadata._updated is False
15 changes: 11 additions & 4 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path):

@patch('sdv.single_table.base.BaseSingleTableSynthesizer._fit')
def test_metadata_updated_warning_detect(mock__fit):
"""Test that using ``detect_from_dataframes`` without saving the metadata raise a warning.
"""Test that using ``detect_from_dataframe`` without saving the metadata raise a warning.
The warning is expected to be raised only once during synthesizer initialization. It should
not be raised again when calling ``fit``.
Expand Down Expand Up @@ -561,10 +561,12 @@ def test_metadata_updated_warning_detect(mock__fit):
(
'add_column_relationship', {
'relationship_type': 'address',
'column_names': ['col 1', 'col 2']
'column_names': ['city', 'country']
}
),
('add_alternate_keys', {'column_names': ['col 1', 'col 2']}),
('set_sequence_key', {'column_name': 'col 1'}),
('add_column', {'column_name': 'col 6', 'sdtype': 'numerical'}),
]


Expand All @@ -579,14 +581,19 @@ def test_metadata_updated_warning(method, kwargs):
'col 1': {'sdtype': 'id'},
'col 2': {'sdtype': 'id'},
'col 3': {'sdtype': 'categorical'},
'col 4': {'sdtype': 'city'},
'col 5': {'sdtype': 'country'},
}
})
expected_message = re.escape(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)

# Run and Assert
# Run
metadata.__getattribute__(method)(**kwargs)
with pytest.warns(UserWarning, match=expected_message):
metadata.__getattribute__(method)(**kwargs)
BaseSingleTableSynthesizer(metadata)

# Assert
assert metadata._updated is False

0 comments on commit 11f9b9b

Please sign in to comment.