From 66876e719bbe7fe14233911d9bee4e191c60c28c Mon Sep 17 00:00:00 2001 From: John La Date: Wed, 14 Aug 2024 10:36:44 -0500 Subject: [PATCH] Update tests to use Metadata (#2178) --- sdv/io/local/local.py | 4 +- sdv/lite/single_table.py | 15 ++- sdv/metadata/metadata.py | 25 +++- sdv/sampling/independent_sampler.py | 2 +- sdv/single_table/base.py | 5 +- sdv/single_table/copulagan.py | 3 +- sdv/single_table/copulas.py | 3 +- sdv/single_table/ctgan.py | 6 +- .../data_processing/test_data_processor.py | 27 ++-- .../evaluation/test_multi_table.py | 3 +- .../evaluation/test_single_table.py | 6 +- tests/integration/io/local/test_local.py | 8 +- tests/integration/lite/test_single_table.py | 14 +- .../metadata/test_visualization.py | 8 +- tests/integration/multi_table/test_hma.py | 59 +++++---- tests/integration/sequential/test_par.py | 54 ++++---- tests/integration/single_table/test_base.py | 123 ++++++++++-------- .../single_table/test_constraints.py | 110 +++++++++------- .../integration/single_table/test_copulas.py | 24 ++-- tests/integration/single_table/test_ctgan.py | 23 ++-- tests/integration/utils/test_poc.py | 4 +- tests/integration/utils/test_utils.py | 4 +- tests/unit/evaluation/test_multi_table.py | 18 +-- tests/unit/evaluation/test_single_table.py | 105 +++++++++------ tests/unit/io/local/test_local.py | 6 +- tests/unit/lite/test_single_table.py | 2 +- tests/unit/metadata/test_metadata.py | 14 +- tests/unit/multi_table/test_base.py | 19 +-- tests/unit/multi_table/test_hma.py | 116 ++++++++++++++++- tests/unit/multi_table/test_utils.py | 26 ++-- tests/unit/sequential/test_par.py | 45 ++++--- tests/unit/single_table/test_base.py | 84 ++++++------ tests/unit/single_table/test_copulagan.py | 46 +++---- tests/unit/single_table/test_copulas.py | 38 +++--- tests/unit/single_table/test_ctgan.py | 67 +++++----- tests/unit/utils/test_poc.py | 12 +- tests/utils.py | 4 +- 37 files changed, 670 insertions(+), 462 deletions(-) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index ccba2f51d..ca827eddf 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -7,7 +7,7 @@ import pandas as pd -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata class BaseLocalHandler: @@ -29,7 +29,7 @@ def create_metadata(self, data): An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata properties from the data. """ - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) return metadata diff --git a/sdv/lite/single_table.py b/sdv/lite/single_table.py index 7ad82bd1e..619ae234d 100644 --- a/sdv/lite/single_table.py +++ b/sdv/lite/single_table.py @@ -7,6 +7,7 @@ import cloudpickle +from sdv.metadata.metadata import Metadata from sdv.single_table import GaussianCopulaSynthesizer LOGGER = logging.getLogger(__name__) @@ -20,13 +21,19 @@ "functionality, please use the 'GaussianCopulaSynthesizer'." ) +META_DEPRECATION_MSG = ( + "The 'SingleTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." +) + class SingleTablePreset: """Class for all single table synthesizer presets. Args: - metadata (sdv.metadata.SingleTableMetadata): - ``SingleTableMetadata`` instance. + metadata (sdv.metadata.Metadata): + ``Metadata`` instance. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. name (str): The preset to use. locales (list or str): @@ -49,6 +56,10 @@ def __init__(self, metadata, name, locales=['en_US']): raise ValueError(f"'name' must be one of {PRESETS}.") self.name = name + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + else: + warnings.warn(META_DEPRECATION_MSG, FutureWarning) if name == FAST_ML_PRESET: self._setup_fast_preset(metadata, self.locales) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index dcb1e0b4c..3e855a6eb 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -1,5 +1,7 @@ """Metadata.""" +import warnings + from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata @@ -10,6 +12,7 @@ class Metadata(MultiTableMetadata): """Metadata class that handles all metadata.""" METADATA_SPEC_VERSION = 'V1' + DEFAULT_SINGLE_TABLE_NAME = 'default_table_name' @classmethod def load_from_json(cls, filepath, single_table_name=None): @@ -66,9 +69,29 @@ def _set_metadata_dict(self, metadata, single_table_name=None): super()._set_metadata_dict(metadata) else: if single_table_name is None: - single_table_name = 'default_table_name' + single_table_name = self.DEFAULT_SINGLE_TABLE_NAME self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) + def _get_single_table_name(self): + """Get the table name if there is only one table. + + Checks to see if the metadata contains only a single table, if so + return the name. Otherwise warn the user and return None. + + Args: + metadata (dict): + Python dictionary representing a ``MultiTableMetadata`` or + ``SingleTableMetadata`` object. + """ + if len(self.tables) != 1: + warnings.warn( + 'This metadata does not contain only a single table. Could not determine ' + 'single table name and will return None.' + ) + return None + + return next(iter(self.tables), None) + def _convert_to_single_table(self): if len(self.tables) > 1: raise InvalidMetadataError( diff --git a/sdv/sampling/independent_sampler.py b/sdv/sampling/independent_sampler.py index a86b0ddc0..9fc4507c7 100644 --- a/sdv/sampling/independent_sampler.py +++ b/sdv/sampling/independent_sampler.py @@ -98,7 +98,7 @@ def _finalize(self, sampled_data): final_data = {} for table_name, table_rows in sampled_data.items(): synthesizer = self._table_synthesizers.get(table_name) - metadata = synthesizer.get_metadata() + metadata = synthesizer.get_metadata()._convert_to_single_table() dtypes = synthesizer._data_processor._dtypes dtypes_to_sdtype = synthesizer._data_processor._DTYPE_TO_SDTYPE diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index d618a4c48..e18a3931d 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -112,7 +112,9 @@ def __init__( ): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata + self._table_name = Metadata.DEFAULT_SINGLE_TABLE_NAME if isinstance(metadata, Metadata): + self._table_name = metadata._get_single_table_name() self.metadata = metadata._convert_to_single_table() elif isinstance(metadata, SingleTableMetadata): warnings.warn(DEPRECATION_MSG, FutureWarning) @@ -284,7 +286,8 @@ def get_parameters(self): def get_metadata(self): """Return the ``Metadata`` for this synthesizer.""" - return Metadata.load_from_dict(self.metadata.to_dict()) + table_name = getattr(self, '_table_name', None) + return Metadata.load_from_dict(self.metadata.to_dict(), table_name) def load_custom_constraint_classes(self, filepath, class_names): """Load a custom constraint class for the current synthesizer. diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index 19ca50b2e..1713cef45 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -59,8 +59,9 @@ class CopulaGANSynthesizer(CTGANSynthesizer): Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index 26167a946..c47b67424 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -29,8 +29,9 @@ class GaussianCopulaSynthesizer(BaseSingleTableSynthesizer): """Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula. Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index 30174f43a..1d03fe956 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -100,8 +100,9 @@ class CTGANSynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): """Model wrapping ``CTGAN`` model. Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen @@ -316,8 +317,9 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): """Model wrapping ``TVAE`` model. Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index a3217f712..dba25709b 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -23,6 +23,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import SynthesizerInputError from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata class TestDataProcessor: @@ -155,12 +156,12 @@ def test_with_primary_key_numerical(self): """ # Load metadata and data data, _ = download_demo('single_table', 'adult') - adult_metadata = SingleTableMetadata() - adult_metadata.detect_from_dataframe(data=data) + adult_metadata = Metadata() + adult_metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('id', sdtype='id') - adult_metadata.set_primary_key('id') + adult_metadata.add_column('adult', 'id', sdtype='id') + adult_metadata.set_primary_key('adult', 'id') # Add id size = len(data) @@ -169,7 +170,7 @@ def test_with_primary_key_numerical(self): data['id'] = ids # Instance ``DataProcessor`` - dp = DataProcessor(adult_metadata) + dp = DataProcessor(adult_metadata._convert_to_single_table()) # Fit dp.fit(data) @@ -195,17 +196,17 @@ def test_with_alternate_keys(self): # Load metadata and data data, _ = download_demo('single_table', 'adult') data['fnlwgt'] = data['fnlwgt'].astype(str) - adult_metadata = SingleTableMetadata() - adult_metadata.detect_from_dataframe(data=data) + adult_metadata = Metadata() + adult_metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('id', sdtype='id') - adult_metadata.set_primary_key('id') + adult_metadata.add_column('adult', 'id', sdtype='id') + adult_metadata.set_primary_key('adult', 'id') - adult_metadata.add_column('secondary_id', sdtype='id') - adult_metadata.update_column('fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]') + adult_metadata.add_column('adult', 'secondary_id', sdtype='id') + adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]') - adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt']) + adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt']) # Add id size = len(data) @@ -215,7 +216,7 @@ def test_with_alternate_keys(self): data['secondary_id'] = ids # Instance ``DataProcessor`` - dp = DataProcessor(adult_metadata) + dp = DataProcessor(adult_metadata._convert_to_single_table()) # Fit dp.fit(data) diff --git a/tests/integration/evaluation/test_multi_table.py b/tests/integration/evaluation/test_multi_table.py index 5f816fb87..7ca797435 100644 --- a/tests/integration/evaluation/test_multi_table.py +++ b/tests/integration/evaluation/test_multi_table.py @@ -2,7 +2,6 @@ from sdv.evaluation.multi_table import evaluate_quality, run_diagnostic from sdv.metadata.metadata import Metadata -from sdv.metadata.multi_table import MultiTableMetadata def test_evaluation(): @@ -18,7 +17,7 @@ def test_evaluation(): 'table1': table, 'table2': slightly_different_table, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'table1': { 'columns': { diff --git a/tests/integration/evaluation/test_single_table.py b/tests/integration/evaluation/test_single_table.py index 95e1b671e..8b3cf23fa 100644 --- a/tests/integration/evaluation/test_single_table.py +++ b/tests/integration/evaluation/test_single_table.py @@ -3,7 +3,6 @@ from sdv.datasets.demo import download_demo from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -11,8 +10,9 @@ def test_evaluation(): """Test ``evaluate_quality`` and ``run_diagnostic``.""" # Setup data = pd.DataFrame({'col': [1, 2, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') # Run and Assert diff --git a/tests/integration/io/local/test_local.py b/tests/integration/io/local/test_local.py index 2fa7e00c3..c52372f30 100644 --- a/tests/integration/io/local/test_local.py +++ b/tests/integration/io/local/test_local.py @@ -1,7 +1,7 @@ import pandas as pd from sdv.io.local import CSVHandler, ExcelHandler -from sdv.metadata import MultiTableMetadata +from sdv.metadata import Metadata class TestCSVHandler: @@ -27,7 +27,7 @@ def test_integration_write_and_read(self, tmpdir): assert len(data) == 2 assert 'table1' in data assert 'table2' in data - assert isinstance(metadata, MultiTableMetadata) is True + assert isinstance(metadata, Metadata) is True # Check if the dataframes match the original synthetic data pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) @@ -57,7 +57,7 @@ def test_integration_write_and_read(self, tmpdir): assert len(data) == 2 assert 'table1' in data assert 'table2' in data - assert isinstance(metadata, MultiTableMetadata) is True + assert isinstance(metadata, Metadata) is True # Check if the dataframes match the original synthetic data pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) @@ -91,7 +91,7 @@ def test_integration_write_and_read_append_mode(self, tmpdir): assert len(data) == 2 assert 'table1' in data assert 'table2' in data - assert isinstance(metadata, MultiTableMetadata) is True + assert isinstance(metadata, Metadata) is True # Check if the dataframes match the original synthetic data expected_table_one = pd.concat( diff --git a/tests/integration/lite/test_single_table.py b/tests/integration/lite/test_single_table.py index 90c66cb5d..ad6a42aea 100644 --- a/tests/integration/lite/test_single_table.py +++ b/tests/integration/lite/test_single_table.py @@ -3,7 +3,7 @@ import pytest from sdv.lite import SingleTablePreset -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata def test_sample(): @@ -12,8 +12,8 @@ def test_sample(): data = pd.DataFrame({'a': [1, 2, 3, np.nan]}) # Run - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'adult': data}) preset = SingleTablePreset(metadata, name='FAST_ML') preset.fit(data) samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5) @@ -29,8 +29,8 @@ def test_sample_with_constraints(): data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) preset = SingleTablePreset(metadata, name='FAST_ML') constraints = [ { @@ -57,8 +57,8 @@ def test_warnings_are_shown(): data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) with pytest.warns(FutureWarning, match=warn_message): preset = SingleTablePreset(metadata, name='FAST_ML') diff --git a/tests/integration/metadata/test_visualization.py b/tests/integration/metadata/test_visualization.py index 07cb870b6..eb8c6d657 100644 --- a/tests/integration/metadata/test_visualization.py +++ b/tests/integration/metadata/test_visualization.py @@ -1,6 +1,6 @@ import pandas as pd -from sdv.metadata import MultiTableMetadata, SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -9,8 +9,8 @@ def test_visualize_graph_for_single_table(): """Test it runs when a column name contains symbols.""" # Setup data = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) model = GaussianCopulaSynthesizer(metadata) # Run @@ -26,7 +26,7 @@ def test_visualize_graph_for_multi_table(): data1 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) tables = {'1': data1, '2': data2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(tables) metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index d161c6fa2..4393481e3 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -20,7 +20,6 @@ from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot from sdv.metadata.metadata import Metadata -from sdv.metadata.multi_table import MultiTableMetadata from sdv.multi_table import HMASynthesizer from tests.integration.single_table.custom_constraints import MyConstraint from tests.utils import catch_sdv_logs @@ -125,7 +124,7 @@ def test_get_info(self): # Setup data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} today = datetime.datetime.today().strftime('%Y-%m-%d') - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('tab') metadata.add_column('tab', 'col', sdtype='numerical') synthesizer = HMASynthesizer(metadata) @@ -220,7 +219,7 @@ def get_custom_constraint_data_and_metadata(self): 'numerical_col_2': [2, 4, 6], }) - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('parent', parent_data) metadata.update_column('parent', 'primary_key', sdtype='id') metadata.detect_table_from_dataframe('child', child_data) @@ -360,7 +359,7 @@ def test_hma_with_inequality_constraint(self): data = {'parent_table': parent_table, 'child_table': child_table} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe(table_name='parent_table', data=parent_table) metadata.update_column('parent_table', 'id', sdtype='id') metadata.detect_table_from_dataframe(table_name='child_table', data=child_table) @@ -449,7 +448,7 @@ def test_hma_primary_key_and_foreign_key_only(self): data = {'users': users, 'sessions': sessions, 'games': games} - metadata = MultiTableMetadata() + metadata = Metadata() for table_name, table in data.items(): metadata.detect_table_from_dataframe(table_name, table) @@ -585,7 +584,7 @@ def test_use_own_data_using_hma(self, tmp_path): assert datasets.keys() == {'guests', 'hotels'} # Metadata - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe(table_name='guests', data=datasets['guests']) metadata.detect_table_from_dataframe(table_name='hotels', data=datasets['hotels']) @@ -676,7 +675,7 @@ def test_use_own_data_using_hma(self, tmp_path): # Save and load metadata metadata_path = tmp_path / 'metadata.json' metadata.save_to_json(metadata_path) - loaded_metadata = MultiTableMetadata.load_from_json(metadata_path) + loaded_metadata = Metadata.load_from_json(metadata_path) # Assert loaded metadata matches saved assert metadata.to_dict() == loaded_metadata.to_dict() @@ -768,7 +767,7 @@ def test_hma_three_linear_nodes(self): } ) data = {'grandparent': grandparent, 'parent': parent, 'child': child} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'grandparent': { 'primary_key': 'grandparent_ID', @@ -847,7 +846,7 @@ def test_hma_one_parent_two_children(self): } ) data = {'parent': parent, 'child1': child1, 'child2': child2} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'parent_ID', @@ -920,7 +919,7 @@ def test_hma_two_parents_one_child(self): data={'parent_ID2': [0, 1, 2, 3, 4], 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No']} ) data = {'parent1': parent1, 'child': child, 'parent2': parent2} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent1': { 'primary_key': 'parent_ID1', @@ -1018,7 +1017,7 @@ def test_hma_two_lineages_one_grandchild(self): 'child2': child2, 'grandchild': grandchild, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'root1': { 'primary_key': 'id', @@ -1173,7 +1172,7 @@ def test__extract_parameters(self): '__sessions__user_id__loc': 0.5, '__sessions__user_id__scale': -0.25, }) - instance = HMASynthesizer(MultiTableMetadata()) + instance = HMASynthesizer(Metadata()) instance.extended_columns = { 'sessions': { '__sessions__user_id__num_rows': FloatFormatter(enforce_min_max_values=True), @@ -1206,7 +1205,7 @@ def test__recreate_child_synthesizer_with_default_parameters(self): f'{prefix}univariates__brand__loc': 0.5, f'{prefix}univariates__brand__scale': -0.25, }) - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'users': {'columns': {'user_id': {'sdtype': 'id'}}, 'primary_key': 'user_id'}, 'sessions': { @@ -1351,7 +1350,7 @@ def test_metadata_updated_warning_detect(self): def test_null_foreign_keys(self): """Test that the synthesizer does not crash when there are null foreign keys.""" # Setup - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('parent_table1') metadata.add_column('parent_table1', 'id', sdtype='id') metadata.set_primary_key('parent_table1', 'id') @@ -1456,7 +1455,7 @@ def test_sampling_with_unknown_sdtype_numerical_column(self): tables_dict = {'people': table1, 'company': table2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(tables_dict) # Run @@ -1508,7 +1507,7 @@ def test_hma_0_1_child(num_rows): ) data = {'parent': parent_table, 'child': pd.DataFrame(data=child_table_data)} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'id', @@ -1596,7 +1595,7 @@ def test_hma_0_1_grandparent(): }, ], } - metadata = MultiTableMetadata().load_from_dict(metadata_dict) + metadata = Metadata().load_from_dict(metadata_dict) metadata.validate() metadata.validate_data(data) synthesizer = HMASynthesizer(metadata=metadata, verbose=False) @@ -1635,7 +1634,7 @@ def test_metadata_updated_warning(method, kwargs): The warning should be raised during synthesizer initialization. """ - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'departure': { 'primary_key': 'id', @@ -1685,7 +1684,7 @@ def test_metadata_updated_warning(method, kwargs): def test_save_and_load_with_downgraded_version(tmp_path): """Test that synthesizers are raising errors if loaded on a downgraded version.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'departure': { 'primary_key': 'id', @@ -1736,7 +1735,7 @@ def test_save_and_load_with_downgraded_version(tmp_path): def test_fit_raises_version_error(): """Test that a ``VersionError`` is being raised if the current version is newer.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'departure': { 'primary_key': 'id', @@ -1828,7 +1827,7 @@ def test_fit_and_sample_numerical_col_names(): data['0'][1] = primary_key data['1'][1] = primary_key data['1'][2] = primary_key_2 - metadata = MultiTableMetadata() + metadata = Metadata() metadata_dict = {'tables': {}} for table_idx in range(num_tables): metadata_dict['tables'][str(table_idx)] = {'columns': {}} @@ -1844,7 +1843,7 @@ def test_fit_and_sample_numerical_col_names(): 'child_foreign_key': 2, } ] - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) metadata.set_primary_key('0', '1') # Run @@ -1875,7 +1874,7 @@ def test_detect_from_dataframe_numerical_col(): 'parent_data': parent_data, 'child_data': child_data, } - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('parent_data', parent_data) metadata.detect_table_from_dataframe('child_data', child_data) metadata.update_column('parent_data', '1', sdtype='id') @@ -1890,7 +1889,7 @@ def test_detect_from_dataframe_numerical_col(): child_table_name='child_data', ) - test_metadata = MultiTableMetadata() + test_metadata = Metadata() test_metadata.detect_from_dataframes(data) test_metadata.update_column('parent_data', '1', sdtype='id') test_metadata.update_column('child_data', '3', sdtype='id') @@ -1914,7 +1913,7 @@ def test_detect_from_dataframe_numerical_col(): assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist() assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist() - test_metadata = MultiTableMetadata() + test_metadata = Metadata() test_metadata.detect_from_dataframes(data) @@ -1930,7 +1929,7 @@ def test_table_name_logging(caplog): 'parent_data': parent_data, 'child_data': child_data, } - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) instance = HMASynthesizer(metadata) @@ -1952,7 +1951,7 @@ def test_disjointed_tables(): remove_some_dict = metadata.to_dict() half_list = remove_some_dict['relationships'][1::2] remove_some_dict['relationships'] = half_list - disjoined_metadata = MultiTableMetadata.load_from_dict(remove_some_dict) + disjoined_metadata = Metadata.load_from_dict(remove_some_dict) # Run disjoin_synthesizer = HMASynthesizer(disjoined_metadata) @@ -2009,7 +2008,7 @@ def test_hma_synthesizer_with_fixed_combinations(): } # Creating metadata for the dataset - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) metadata.update_column('users', 'user_id', sdtype='id') @@ -2059,7 +2058,7 @@ def test_fit_int_primary_key_regex_includes_zero(regex): 'parent_data': parent_data, 'child_data': child_data, } - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex) metadata.set_primary_key('parent_data', 'parent_id') @@ -2106,7 +2105,7 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(): 'grandparent': grandparent, 'parent': parent, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'root1': { 'primary_key': 'R1', diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 567586525..c9628a429 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -8,7 +8,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import SynthesizerInputError -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sequential import PARSynthesizer @@ -21,11 +21,11 @@ def _get_par_data_and_metadata(): 'entity': [1, 1, 2, 2], 'context': ['a', 'a', 'b', 'b'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('entity', sdtype='id') - metadata.set_sequence_key('entity') - metadata.set_sequence_index('date') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'entity', sdtype='id') + metadata.set_sequence_key('table', 'entity') + metadata.set_sequence_index('table', 'date') return data, metadata @@ -34,11 +34,11 @@ def test_par(): # Setup data = load_demo() data['date'] = pd.to_datetime(data['date']) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('store_id', sdtype='id') - metadata.set_sequence_key('store_id') - metadata.set_sequence_index('date') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'store_id', sdtype='id') + metadata.set_sequence_key('table', 'store_id') + metadata.set_sequence_index('table', 'date') model = PARSynthesizer( metadata=metadata, context_columns=['region'], @@ -68,11 +68,11 @@ def test_column_after_date_simple(): 'date': [date, date], 'col2': ['hello', 'world'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('col', sdtype='id') - metadata.set_sequence_key('col') - metadata.set_sequence_index('date') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'col', sdtype='id') + metadata.set_sequence_key('table', 'col') + metadata.set_sequence_index('table', 'date') # Run model = PARSynthesizer(metadata=metadata, epochs=1) @@ -114,7 +114,7 @@ def test_save_and_load(tmp_path): # Assert assert isinstance(loaded_instance, PARSynthesizer) - assert metadata == instance.metadata + assert metadata._convert_to_single_table().to_dict() == instance.metadata.to_dict() def test_synthesize_sequences(tmp_path): @@ -229,7 +229,7 @@ def test_par_subset_of_data_simplified(): 'date': ['2020-01-01', '2020-01-02', '2020-01-03'], }) data.index = [0, 1, 5] - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'sequence_index': 'date', 'sequence_key': 'id', 'columns': { @@ -261,7 +261,7 @@ def test_par_missing_sequence_index(): 'sequence_key': 'e_id', } - metadata = SingleTableMetadata().load_from_dict(metadata_dict) + metadata = Metadata().load_from_dict(metadata_dict) data = pd.DataFrame({'value': [10, 20, 30], 'e_id': [1, 2, 3]}) @@ -348,11 +348,11 @@ def test_par_unique_sequence_index_with_enforce_min_max(): test_df[['visits', 'pre_date']] = test_df[['visits', 'pre_date']].apply( pd.to_datetime, format='%Y-%m-%d', errors='coerce' ) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(test_df) - metadata.update_column(column_name='s_key', sdtype='id') - metadata.set_sequence_key('s_key') - metadata.set_sequence_index('visits') + metadata = Metadata() + metadata.detect_from_dataframes({'table': test_df}) + metadata.update_column(table_name='table', column_name='s_key', sdtype='id') + metadata.set_sequence_key('table', 's_key') + metadata.set_sequence_index('table', 'visits') synthesizer = PARSynthesizer( metadata, enforce_min_max_values=True, enforce_rounding=False, epochs=100, verbose=True ) @@ -378,7 +378,7 @@ def test_par_sequence_index_is_numerical(): 'sequence_key': 'engine_no', 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) data = pd.DataFrame({'engine_no': [0, 0, 1, 1], 'time_in_cycles': [1, 2, 3, 4]}) s1 = PARSynthesizer(metadata) @@ -396,7 +396,7 @@ def test_init_error_sequence_key_in_context(): }, 'sequence_key': 'A', } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) sequence_key_context_column_error_msg = re.escape( "The sequence key ['A'] cannot be a context column. " 'To proceed, please remove the sequence key from the context_columns parameter.' @@ -418,7 +418,7 @@ def test_par_with_datetime_context(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'user_id': {'sdtype': 'id', 'regex_format': 'ID_[0-9]{2}'}, 'birthdate': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 4832a2746..5f99abf93 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -23,7 +23,7 @@ ) from sdv.single_table.base import BaseSingleTableSynthesizer -METADATA = SingleTableMetadata.load_from_dict({ +METADATA = Metadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { 'column1': {'sdtype': 'numerical'}, @@ -91,10 +91,11 @@ def test_sample_from_conditions_with_batch_size(): 'column3': list(range(100)), }) - metadata = SingleTableMetadata() - metadata.add_column('column1', sdtype='numerical') - metadata.add_column('column2', sdtype='numerical') - metadata.add_column('column3', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'column1', sdtype='numerical') + metadata.add_column('table', 'column2', sdtype='numerical') + metadata.add_column('table', 'column3', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -117,10 +118,11 @@ def test_sample_from_conditions_negative_float(): 'column3': list(range(100)), }) - metadata = SingleTableMetadata() - metadata.add_column('column1', sdtype='numerical') - metadata.add_column('column2', sdtype='numerical') - metadata.add_column('column3', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'column1', sdtype='numerical') + metadata.add_column('table', 'column2', sdtype='numerical') + metadata.add_column('table', 'column3', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -230,10 +232,11 @@ def test_multiple_fits(): 'state': ['CA', 'CA', 'IL', 'CA', 'CA'], 'measurement': [27.1, 28.7, 26.9, 21.2, 30.9], }) - metadata = SingleTableMetadata() - metadata.add_column('city', sdtype='categorical') - metadata.add_column('state', sdtype='categorical') - metadata.add_column('measurement', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'city', sdtype='categorical') + metadata.add_column('table', 'state', sdtype='categorical') + metadata.add_column('table', 'measurement', sdtype='numerical') constraint = { 'constraint_class': 'FixedCombinations', 'constraint_parameters': {'column_names': ['city', 'state']}, @@ -273,7 +276,7 @@ def test_sampling(synthesizer): @pytest.mark.parametrize('synthesizer', SYNTHESIZERS) def test_sampling_reset_sampling(synthesizer): """Test ``sample`` method for each synthesizer using ``reset_sampling``.""" - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { 'column1': {'sdtype': 'numerical'}, @@ -310,11 +313,13 @@ def test_config_creation_doesnt_raise_error(): 'address_col': ['223 Williams Rd', '75 Waltham St', '77 Mass Ave'], 'numerical_col': [1, 2, 3], }) - test_metadata = SingleTableMetadata() + test_metadata = Metadata() # Run - test_metadata.detect_from_dataframe(test_data) - test_metadata.update_column(column_name='address_col', sdtype='address', pii=False) + test_metadata.detect_from_dataframes({'table': test_data}) + test_metadata.update_column( + table_name='table', column_name='address_col', sdtype='address', pii=False + ) synthesizer = GaussianCopulaSynthesizer(test_metadata) synthesizer.fit(test_data) @@ -330,11 +335,13 @@ def test_transformers_correctly_auto_assigned(): 'categorical_col': ['a', 'b', 'a'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}') - metadata.set_primary_key('primary_key') - metadata.update_column(column_name='pii_col', sdtype='address', pii=True) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column( + table_name='table', column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}' + ) + metadata.set_primary_key('table', 'primary_key') + metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False ) @@ -391,7 +398,7 @@ def test_modeling_with_complex_datetimes(): } # Run - metadata = SingleTableMetadata.load_from_dict(test_metadata) + metadata = Metadata.load_from_dict(test_metadata) metadata.validate() synth = GaussianCopulaSynthesizer(metadata) synth.validate(data) @@ -418,13 +425,13 @@ def test_auto_assign_transformers_and_update_with_pii(): } ) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) # Run - metadata.update_column(column_name='id', sdtype='first_name') - metadata.update_column(column_name='name', sdtype='name') - metadata.set_primary_key('id') + metadata.update_column(table_name='table', column_name='id', sdtype='first_name') + metadata.update_column(table_name='table', column_name='name', sdtype='name') + metadata.set_primary_key('table', 'id') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.auto_assign_transformers(data) @@ -451,11 +458,11 @@ def test_refitting_a_model(): } ) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='name', sdtype='name') - metadata.update_column('id', sdtype='id') - metadata.set_primary_key('id') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column(table_name='table', column_name='name', sdtype='name') + metadata.update_column('table', 'id', sdtype='id') + metadata.set_primary_key('table', 'id') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) @@ -478,8 +485,9 @@ def test_get_info(): # Setup data = pd.DataFrame({'col': [1, 2, 3]}) today = datetime.datetime.today().strftime('%Y-%m-%d') - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata) # Run @@ -512,7 +520,7 @@ def test_get_info(): def test_save_and_load(tmp_path): """Test that synthesizers can be saved and loaded properly.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) synthesizer_path = tmp_path / 'synthesizer.pkl' instance.save(synthesizer_path) @@ -534,7 +542,7 @@ def test_save_and_load(tmp_path): def test_save_and_load_no_id(tmp_path): """Test that synthesizers can be saved and loaded properly.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) synthesizer_path = tmp_path / 'synthesizer.pkl' delattr(instance, '_synthesizer_id') @@ -560,7 +568,7 @@ def test_save_and_load_no_id(tmp_path): def test_save_and_load_with_downgraded_version(tmp_path): """Test that synthesizers are raising errors if loaded on a downgraded version.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._fitted = True instance._fitted_sdv_version = '10.0.0' @@ -693,13 +701,17 @@ def test_metadata_updated_warning(method, kwargs): The warning should be raised during synthesizer initialization. """ # Setup - metadata = SingleTableMetadata().load_from_dict({ - 'columns': { - 'col 1': {'sdtype': 'id'}, - 'col 2': {'sdtype': 'id'}, - 'col 3': {'sdtype': 'categorical'}, - 'city': {'sdtype': 'city'}, - 'country': {'sdtype': 'country_code'}, + metadata = Metadata().load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'col 1': {'sdtype': 'id'}, + 'col 2': {'sdtype': 'id'}, + 'col 3': {'sdtype': 'categorical'}, + 'city': {'sdtype': 'city'}, + 'country': {'sdtype': 'country_code'}, + } + } } }) expected_message = re.escape( @@ -708,12 +720,13 @@ def test_metadata_updated_warning(method, kwargs): ) # Run - metadata.__getattribute__(method)(**kwargs) + single_metadata = metadata._convert_to_single_table() + single_metadata.__getattribute__(method)(**kwargs) with pytest.warns(UserWarning, match=expected_message): - BaseSingleTableSynthesizer(metadata) + BaseSingleTableSynthesizer(single_metadata) # Assert - assert metadata._updated is False + assert single_metadata._updated is False def test_fit_raises_version_error(): @@ -724,8 +737,8 @@ def test_fit_raises_version_error(): 'col 2': [4, 5, 6], 'col 3': ['a', 'b', 'c'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) instance = BaseSingleTableSynthesizer(metadata) instance._fitted_sdv_version = '1.0.0' @@ -755,11 +768,11 @@ def test_fit_and_sample_numerical_col_names(synthesizer_class): num_cols = 10 values = {i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)} data = pd.DataFrame(values) - metadata = SingleTableMetadata() + metadata = Metadata() metadata_dict = {'columns': {}} for i in range(num_cols): metadata_dict['columns'][i] = {'sdtype': 'numerical'} - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) # Run synth = synthesizer_class(metadata) @@ -779,7 +792,7 @@ def test_fit_and_sample_numerical_col_names(synthesizer_class): def test_sample_not_fitted(synthesizer): """Test that a synthesizer raises an error when trying to sample without fitting.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() synthesizer = synthesizer.__class__(metadata) expected_message = re.escape( 'This synthesizer has not been fitted. Please fit your synthesizer first before' @@ -800,10 +813,10 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class): 2: [4, 5, 6], 3: ['a', 'b', 'c'], }) - metadata = SingleTableMetadata() + metadata = Metadata() # Run - metadata.detect_from_dataframe(data) + metadata.detect_from_dataframes({'table': data}) instance = synthesizer_class(metadata) instance.fit(data) sample = instance.sample(5) diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 5a70ffea8..e42cdb227 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -11,7 +11,7 @@ from sdv.constraints import Constraint, create_custom_constraint_class from sdv.constraints.errors import AggregateConstraintsError from sdv.datasets.demo import download_demo -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sampling import Condition from sdv.single_table import GaussianCopulaSynthesizer from tests.integration.single_table.custom_constraints import MyConstraint @@ -72,10 +72,11 @@ def test_fit_with_unique_constraint_on_data_with_only_index_column(): ], }) - metadata = SingleTableMetadata() - metadata.add_column('key', sdtype='id') - metadata.add_column('index', sdtype='categorical') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'key', sdtype='id') + metadata.add_column('table', 'index', sdtype='categorical') + metadata.set_primary_key('table', 'key') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -136,11 +137,12 @@ def test_fit_with_unique_constraint_on_data_which_has_index_column(): ], }) - metadata = SingleTableMetadata() - metadata.add_column('key', sdtype='id') - metadata.add_column('index', sdtype='categorical') - metadata.add_column('test_column', sdtype='categorical') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'key', sdtype='id') + metadata.add_column('table', 'index', sdtype='categorical') + metadata.add_column('table', 'test_column', sdtype='categorical') + metadata.set_primary_key('table', 'key') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -194,10 +196,11 @@ def test_fit_with_unique_constraint_on_data_subset(): ], }) - metadata = SingleTableMetadata() - metadata.add_column('key', sdtype='id') - metadata.add_column('test_column', sdtype='categorical') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'key', sdtype='id') + metadata.add_column('table', 'test_column', sdtype='categorical') + metadata.set_primary_key('table', 'key') test_df = test_df.iloc[[1, 3, 4]] constraint = { @@ -227,7 +230,7 @@ def test_conditional_sampling_with_constraints(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'numerical'}, 'B': {'sdtype': 'numerical'}, @@ -290,10 +293,11 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc 'age': [27, 28, 26, 21, 30], }) - metadata = SingleTableMetadata() - metadata.add_column('city', sdtype='categorical') - metadata.add_column('state', sdtype='categorical') - metadata.add_column('age', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'city', sdtype='categorical') + metadata.add_column('table', 'state', sdtype='categorical') + metadata.add_column('table', 'age', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) @@ -335,9 +339,9 @@ def test_custom_constraints_from_file(tmpdir): 'categorical_col': ['a', 'b', 'a'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='pii_col', sdtype='address', pii=True) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False ) @@ -379,9 +383,9 @@ def test_custom_constraints_from_object(tmpdir): 'categorical_col': ['a', 'b', 'a'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='pii_col', sdtype='address', pii=True) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False ) @@ -444,7 +448,7 @@ def test_inequality_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -505,7 +509,7 @@ def test_scalar_inequality_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -556,7 +560,7 @@ def test_scalar_range_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -623,7 +627,7 @@ def test_range_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -697,7 +701,7 @@ def test_inequality_constraint_all_possible_nans_configurations(): # Setup data = pd.DataFrame(data={'A': [0, 1, np.nan, np.nan, 2], 'B': [2, np.nan, 3, np.nan, 3]}) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'numerical'}, 'B': {'sdtype': 'numerical'}, @@ -742,7 +746,7 @@ def test_range_constraint_all_possible_nans_configurations(): } } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) synthesizer = GaussianCopulaSynthesizer(metadata) my_constraint = { @@ -812,10 +816,10 @@ def reverse_transform(column_names, data): 'number': ['1', '2', '3'], 'other': [7, 8, 9], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('key', sdtype='id', regex_format=r'\w_\d') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'key', sdtype='id', regex_format=r'\w_\d') + metadata.set_primary_key('table', 'key') synth = GaussianCopulaSynthesizer(metadata) synth.add_custom_constraint_class(custom_constraint, 'custom') @@ -842,9 +846,10 @@ def test_timezone_aware_constraints(): data['col1'] = pd.to_datetime(data['col1']).dt.tz_localize('UTC') data['col2'] = pd.to_datetime(data['col2']).dt.tz_localize('UTC') - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='datetime') - metadata.add_column('col2', sdtype='datetime') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='datetime') + metadata.add_column('table', 'col2', sdtype='datetime') my_constraint = { 'constraint_class': 'Inequality', @@ -960,18 +965,25 @@ def _transform(self, table_data): def test_constraint_datetime_check(): """Test datetime columns are correctly identified in constraints. GH#1692""" # Setup - data = pd.DataFrame( - data={ - 'low_col': ['21 Sep, 15', '23 Aug, 14', '29 May, 12'], - 'high_col': ['02 Nov, 15', '12 Oct, 14', '08 Jul, 12'], - } - ) - metadata = SingleTableMetadata.load_from_dict({ - 'columns': { - 'low_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, - 'high_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, + data = { + 'table': pd.DataFrame( + data={ + 'low_col': ['21 Sep, 15', '23 Aug, 14', '29 May, 12'], + 'high_col': ['02 Nov, 15', '12 Oct, 14', '08 Jul, 12'], + } + ) + } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'low_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, + 'high_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, + } + } } }) + my_constraint = { 'constraint_class': 'Inequality', 'constraint_parameters': { @@ -987,7 +999,7 @@ def test_constraint_datetime_check(): synth = GaussianCopulaSynthesizer(metadata) synth.add_constraints([my_constraint]) - synth.fit(data) + synth.fit(data['table']) samples = synth.sample(3) # Assert diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 0a8477409..5e46a8417 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -15,7 +15,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import ConstraintsNotMetError from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sampling import Condition from sdv.single_table import GaussianCopulaSynthesizer @@ -277,10 +277,10 @@ def test_update_transformers_with_id_generator(): sample_num = 20 data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) - stm.update_column('user_id', sdtype='id') - stm.set_primary_key('user_id') + stm = Metadata() + stm.detect_from_dataframes({'table': data}) + stm.update_column('table', 'user_id', sdtype='id') + stm.set_primary_key('table', 'user_id') gc = GaussianCopulaSynthesizer(stm) custom_id = IDGenerator(starting_value=min_value_id) @@ -332,7 +332,7 @@ def test_numerical_columns_gets_pii(): data = pd.DataFrame( data={'id': [0, 1, 2, 3, 4], 'city': [0, 0, 0, 0, 0], 'numerical': [21, 22, 23, 24, 25]} ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'primary_key': 'id', 'columns': { 'id': {'sdtype': 'id'}, @@ -406,8 +406,8 @@ def test_categorical_column_with_numbers(): 'numerical_col': np.random.rand(20), }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) synthesizer = GaussianCopulaSynthesizer(metadata) @@ -435,9 +435,9 @@ def test_unknown_sdtype(): 'numerical_col': np.random.rand(3), }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('unknown', sdtype='unknown') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'unknown', sdtype='unknown') synthesizer = GaussianCopulaSynthesizer(metadata) @@ -482,7 +482,7 @@ def test_support_nullable_pandas_dtypes(): 'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'), 'Float64': pd.Series([1.113, 2.22, 3.3, pd.NA], dtype='Float64'), }) - metadata = SingleTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'columns': { 'Int8': {'sdtype': 'numerical', 'computer_representation': 'Int8'}, 'Int16': {'sdtype': 'numerical', 'computer_representation': 'Int16'}, diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index 26bc906b2..9f892f878 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -8,20 +8,21 @@ from sdv.datasets.demo import download_demo from sdv.errors import InvalidDataTypeError from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.single_table import CTGANSynthesizer, TVAESynthesizer def test__estimate_num_columns(): """Test the number of columns is estimated correctly.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('numerical', sdtype='numerical') - metadata.add_column('categorical', sdtype='categorical') - metadata.add_column('categorical2', sdtype='categorical') - metadata.add_column('categorical3', sdtype='categorical') - metadata.add_column('datetime', sdtype='datetime') - metadata.add_column('boolean', sdtype='boolean') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'numerical', sdtype='numerical') + metadata.add_column('table', 'categorical', sdtype='categorical') + metadata.add_column('table', 'categorical2', sdtype='categorical') + metadata.add_column('table', 'categorical3', sdtype='categorical') + metadata.add_column('table', 'datetime', sdtype='datetime') + metadata.add_column('table', 'boolean', sdtype='boolean') data = pd.DataFrame({ 'numerical': [0.1, 0.2, 0.3], 'datetime': ['2020-01-01', '2020-01-02', '2020-01-03'], @@ -134,7 +135,7 @@ def test_categoricals_are_not_preprocessed(): 'alcohol': ['medium', 'medium', 'low', 'high', 'low'], } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'age': {'sdtype': 'numerical'}, 'therapy': {'sdtype': 'boolean'}, @@ -180,7 +181,7 @@ def test_categorical_metadata_with_int_data(): }, } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) data = pd.DataFrame({ 'A': list(range(50)), 'B': list(range(50)), @@ -270,7 +271,7 @@ def test_ctgan_with_dropped_columns(): 'columns': {'user_id': {'sdtype': 'id'}, 'user_ssn': {'sdtype': 'ssn'}}, } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) # Run synth = CTGANSynthesizer(metadata) diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index b3dfcffdc..b7b3247e8 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -7,7 +7,7 @@ import pytest from sdv.datasets.demo import download_demo -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS, HMASynthesizer from sdv.multi_table.utils import _get_total_estimated_columns from sdv.utils.poc import get_random_subset, simplify_schema @@ -15,7 +15,7 @@ @pytest.fixture def metadata(): - return MultiTableMetadata.load_from_dict({ + return Metadata.load_from_dict({ 'tables': { 'parent': { 'columns': { diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index 8c3498311..ad381784a 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -7,13 +7,13 @@ from sdv.datasets.demo import download_demo from sdv.errors import InvalidDataError -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.utils import drop_unknown_references, get_random_sequence_subset @pytest.fixture def metadata(): - return MultiTableMetadata.load_from_dict({ + return Metadata.load_from_dict({ 'tables': { 'parent': { 'columns': { diff --git a/tests/unit/evaluation/test_multi_table.py b/tests/unit/evaluation/test_multi_table.py index 2995f9dd5..2ecdf52d1 100644 --- a/tests/unit/evaluation/test_multi_table.py +++ b/tests/unit/evaluation/test_multi_table.py @@ -11,7 +11,7 @@ get_column_plot, run_diagnostic, ) -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.metadata import Metadata def test_evaluate_quality(): @@ -20,7 +20,7 @@ def test_evaluate_quality(): table = pd.DataFrame({'col': [1, 2, 3]}) data1 = {'table': table} data2 = {'table': pd.DataFrame({'col': [2, 1, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table) QualityReport.generate = Mock() @@ -37,7 +37,7 @@ def test_run_diagnostic(): table = pd.DataFrame({'col': [1, 2, 3]}) data1 = {'table': table} data2 = {'table': pd.DataFrame({'col': [2, 1, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table) DiagnosticReport.generate = Mock() @@ -60,7 +60,7 @@ def test_get_column_plot(mock_plot): table2 = pd.DataFrame({'col': [2, 1, 3]}) data1 = {'table': table1} data2 = {'table': table2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -79,7 +79,7 @@ def test_get_column_plot_only_real_or_synthetic(mock_plot): # Setup table1 = pd.DataFrame({'col': [1, 2, 3]}) data1 = {'table': table1} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -103,7 +103,7 @@ def test_get_column_pair_plot(mock_plot): table2 = pd.DataFrame({'col1': [2, 1, 3], 'col2': [1, 2, 3]}) data1 = {'table': table1} data2 = {'table': table2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -122,7 +122,7 @@ def test_get_column_pair_plot_only_real_or_synthetic(mock_plot): # Setup table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]}) data1 = {'table': table1} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -170,7 +170,7 @@ def test_get_cardinality_plot(mock_plot): ], 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) mock_plot.return_value = 'plot' # Run @@ -213,7 +213,7 @@ def test_get_cardinality_plot_plot_type(mock_plot): ], 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) mock_plot.return_value = 'plot' # Run diff --git a/tests/unit/evaluation/test_single_table.py b/tests/unit/evaluation/test_single_table.py index 3a53e11e2..f669a6e40 100644 --- a/tests/unit/evaluation/test_single_table.py +++ b/tests/unit/evaluation/test_single_table.py @@ -14,7 +14,6 @@ run_diagnostic, ) from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata def test_evaluate_quality(): @@ -22,15 +21,18 @@ def test_evaluate_quality(): # Setup data1 = pd.DataFrame({'col': [1, 2, 3]}) data2 = pd.DataFrame({'col': [2, 1, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') QualityReport.generate = Mock() # Run evaluate_quality(data1, data2, metadata) # Assert - QualityReport.generate.assert_called_once_with(data1, data2, metadata.to_dict(), True) + QualityReport.generate.assert_called_once_with( + data1, data2, metadata._convert_to_single_table().to_dict(), True + ) def test_evaluate_quality_metadata(): @@ -55,15 +57,18 @@ def test_run_diagnostic(): # Setup data1 = pd.DataFrame({'col': [1, 2, 3]}) data2 = pd.DataFrame({'col': [2, 1, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') DiagnosticReport.generate = Mock(return_value=123) # Run run_diagnostic(data1, data2, metadata) # Assert - DiagnosticReport.generate.assert_called_once_with(data1, data2, metadata.to_dict(), True) + DiagnosticReport.generate.assert_called_once_with( + data1, data2, metadata._convert_to_single_table().to_dict(), True + ) def test_run_diagnostic_metadata(): @@ -93,8 +98,9 @@ def test_get_column_plot_continuous_data(mock_get_plot): # Setup data1 = pd.DataFrame({'col': [1, 2, 3]}) data2 = pd.DataFrame({'col': [2, 1, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -135,8 +141,9 @@ def test_get_column_plot_discrete_data(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -178,8 +185,9 @@ def test_get_column_plot_discrete_data_with_distplot(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='distplot') @@ -221,8 +229,9 @@ def test_get_column_plot_invalid_sdtype(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='id') # Run and Assert error_msg = re.escape( @@ -265,8 +274,9 @@ def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='id') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='bar') @@ -307,8 +317,9 @@ def test_get_column_plot_with_datetime_sdtype(mock_get_plot): # Setup real_data = pd.DataFrame({'datetime': ['2021-02-01', '2021-12-01']}) synthetic_data = pd.DataFrame({'datetime': ['2023-02-21', '2022-12-13']}) - metadata = SingleTableMetadata() - metadata.add_column('datetime', sdtype='datetime', datetime_format='%Y-%m-%d') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'datetime', sdtype='datetime', datetime_format='%Y-%m-%d') # Run plot = get_column_plot(real_data, synthetic_data, metadata, 'datetime') @@ -341,9 +352,10 @@ def test_get_column_pair_plot_with_continous_data(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'date': ['2021-01-01', '2022-01-01', '2023-01-01'], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('date', sdtype='datetime') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'date', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -375,9 +387,10 @@ def test_get_column_pair_plot_with_discrete_data(mock_get_plot): columns = ['name', 'subscriber'] real_data = pd.DataFrame({'name': ['John', 'Emily'], 'subscriber': [True, False]}) synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'subscriber': [False, False]}) - metadata = SingleTableMetadata() - metadata.add_column('name', sdtype='categorical') - metadata.add_column('subscriber', sdtype='boolean') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'subscriber', sdtype='boolean') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -401,9 +414,10 @@ def test_get_column_pair_plot_with_mixed_data(mock_get_plot): columns = ['name', 'counts'] real_data = pd.DataFrame({'name': ['John', 'Emily'], 'counts': [1, 2]}) synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'counts': [3, 1]}) - metadata = SingleTableMetadata() - metadata.add_column('name', sdtype='categorical') - metadata.add_column('counts', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'counts', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -433,9 +447,10 @@ def test_get_column_pair_plot_with_forced_plot_type(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'date': ['2021-01-01', '2022-01-01', '2023-01-01'], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('date', sdtype='datetime') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'date', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -474,9 +489,10 @@ def test_get_column_pair_plot_with_invalid_sdtype(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'id': [1, 2, 3], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('id', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'id', sdtype='id') # Run and Assert error_msg = re.escape( @@ -504,9 +520,10 @@ def test_get_column_pair_plot_with_invalid_sdtype_and_plot_type(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'id': [1, 2, 3], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('id', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'id', sdtype='id') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -532,9 +549,10 @@ def test_get_column_pair_plot_with_sample_size(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'price': [11.0, 22.0, 33.0], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('price', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'price', sdtype='numerical') # Run get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=2) @@ -594,9 +612,10 @@ def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'price': [11.0, 22.0, 33.0], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('price', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'price', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=10) diff --git a/tests/unit/io/local/test_local.py b/tests/unit/io/local/test_local.py index 48395f276..478af4c36 100644 --- a/tests/unit/io/local/test_local.py +++ b/tests/unit/io/local/test_local.py @@ -8,7 +8,7 @@ import pytest from sdv.io.local.local import BaseLocalHandler, CSVHandler, ExcelHandler -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata import Metadata class TestBaseLocalHandler: @@ -34,9 +34,9 @@ def test_create_metadata(self): metadata = instance.create_metadata(data) # Assert - assert isinstance(metadata, MultiTableMetadata) + assert isinstance(metadata, Metadata) assert metadata.to_dict() == { - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', + 'METADATA_SPEC_VERSION': 'V1', 'relationships': [], 'tables': { 'guests': { diff --git a/tests/unit/lite/test_single_table.py b/tests/unit/lite/test_single_table.py index c51d99b83..d9880cc16 100644 --- a/tests/unit/lite/test_single_table.py +++ b/tests/unit/lite/test_single_table.py @@ -21,7 +21,7 @@ def test___init__invalid_name(self): """ # Run and Assert with pytest.raises(ValueError, match=r"'name' must be one of *"): - SingleTablePreset(metadata=SingleTableMetadata(), name='invalid') + SingleTablePreset(metadata=Metadata(), name='invalid') @patch('sdv.lite.single_table.GaussianCopulaSynthesizer') def test__init__speed_passes_correct_parameters(self, gaussian_copula_mock): diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 0d282c347..da2046596 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -391,7 +391,7 @@ def test__set_metadata_multi_table(self, mock_singletablemetadata): Setup: - instance of ``Metadata``. - - A dict representing a ``MultiTableMetadata``. + - A dict representing a ``Metadata``. Mock: - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` @@ -457,16 +457,10 @@ def test__set_metadata_single_table(self): Setup: - instance of ``Metadata``. - - A dict representing a ``SingleTableMetadata``. - - Mock: - - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` - - Side Effects: - - ``SingleTableMetadata.load_from_dict`` has been called. + - A dict representing a single table``Metadata``. """ # Setup - multitable_metadata = { + single_table_metadata = { 'columns': {'my_column': 'value'}, 'primary_key': 'pk', 'alternate_keys': [], @@ -478,7 +472,7 @@ def test__set_metadata_single_table(self): instance = Metadata() # Run - instance._set_metadata_dict(multitable_metadata) + instance._set_metadata_dict(single_table_metadata) # Assert assert instance.tables['default_table_name'].columns == {'my_column': 'value'} diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 3ca501aaf..8e0c60629 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -151,7 +151,8 @@ def test___init___deprecated(self): """Test that init with old MultiTableMetadata gives a future warnging.""" # Setup metadata = get_multi_table_metadata() - metadata.validate = Mock() + multi_metadata = MultiTableMetadata.load_from_dict(metadata.to_dict()) + multi_metadata.validate = Mock() deprecation_msg = re.escape( "The 'MultiTableMetadata' is deprecated. Please use the new " @@ -160,7 +161,7 @@ def test___init___deprecated(self): # Run with pytest.warns(FutureWarning, match=deprecation_msg): - BaseMultiTableSynthesizer(metadata) + BaseMultiTableSynthesizer(multi_metadata) @patch('sdv.metadata.single_table.is_faker_function') def test__init__column_relationship_warning(self, mock_is_faker_function): @@ -231,7 +232,7 @@ def test__check_metadata_updated(self): def test_set_address_columns(self): """Test the ``set_address_columns`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'address_table': { 'columns': { @@ -274,7 +275,7 @@ def test_set_address_columns(self): def test_set_address_columns_error(self): """Test that ``set_address_columns`` raises an error for unknown table.""" # Setup - metadata = MultiTableMetadata() + metadata = Metadata() columns = ('country_column', 'city_column') metadata.validate = Mock() SingleTableMetadata.validate = Mock() @@ -857,7 +858,7 @@ def test_preprocess_int_columns(self): } ], } - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) instance = BaseMultiTableSynthesizer(metadata) instance.validate = Mock() instance._table_synthesizers = {'first_table': Mock(), 'second_table': Mock()} @@ -1425,7 +1426,7 @@ def test_add_constraints_missing_table_name(self): """Test error raised when ``table_name`` is missing.""" # Setup data = pd.DataFrame({'col': [1, 2, 3]}) - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', data) constraint = {'constraint_class': 'Inequality'} model = BaseMultiTableSynthesizer(metadata) @@ -1524,7 +1525,7 @@ def test_get_info(self, mock_version): """ # Setup data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('tab') metadata.add_column('tab', 'col', sdtype='numerical') mock_version.public = '1.0.0' @@ -1572,7 +1573,7 @@ def test_get_info_with_enterprise(self, mock_version): """ # Setup data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('tab') metadata.add_column('tab', 'col', sdtype='numerical') mock_version.public = '1.0.0' @@ -1636,7 +1637,7 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): def test_save_warning(self, tmp_path): """Test that the synthesizer produces a warning if saved without fitting.""" # Setup - synthesizer = BaseMultiTableSynthesizer(MultiTableMetadata()) + synthesizer = BaseMultiTableSynthesizer(Metadata()) # Run and Assert warn_msg = re.escape( diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index ef0085c19..72abc0e25 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -6,7 +6,7 @@ import pytest from sdv.errors import SynthesizerInputError -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer from tests.utils import get_multi_table_data, get_multi_table_metadata @@ -790,7 +790,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): 'col1': [0, 1, 2], }) data = {'parent': parent, 'child': child} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'id', @@ -879,7 +879,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): 'col': {'sdtype': 'numerical'}, }, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'id', @@ -1022,7 +1022,7 @@ def test__estimate_num_columns_to_be_modeled(self): 'parent': parent, 'child': child, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'root1': { 'primary_key': 'R1', @@ -1123,3 +1123,111 @@ def test__estimate_num_columns_to_be_modeled(self): num_table_cols -= 1 assert num_table_cols == estimated_num_columns[table_name] + + def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): + """Test the estimated number of columns is correct for various sdtypes. + + To check that the number columns is correct we Mock the ``_finalize`` method + and compare its output with the estimated number of columns. + + The dataset used follows the structure below: + R1 R2 + | / + GP + | + P + """ + # Setup + root1 = pd.DataFrame({'R1': [0, 1, 2]}) + root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]}) + grandparent = pd.DataFrame({'GP': [0, 1, 2], 'R1': [0, 1, 2], 'R2': [0, 1, 2]}) + parent = pd.DataFrame({ + 'P': [0, 1, 2], + 'GP': [0, 1, 2], + 'numerical': [0.1, 0.5, np.nan], + 'categorical': ['a', np.nan, 'c'], + 'datetime': [None, '2019-01-02', '2019-01-03'], + 'boolean': [float('nan'), False, True], + 'id': [0, 1, 2], + }) + data = { + 'root1': root1, + 'root2': root2, + 'grandparent': grandparent, + 'parent': parent, + } + metadata = Metadata.load_from_dict({ + 'tables': { + 'root1': { + 'primary_key': 'R1', + 'columns': { + 'R1': {'sdtype': 'id'}, + }, + }, + 'root2': { + 'primary_key': 'R2', + 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, + }, + 'grandparent': { + 'primary_key': 'GP', + 'columns': { + 'GP': {'sdtype': 'id'}, + 'R1': {'sdtype': 'id'}, + 'R2': {'sdtype': 'id'}, + }, + }, + 'parent': { + 'primary_key': 'P', + 'columns': { + 'P': {'sdtype': 'id'}, + 'GP': {'sdtype': 'id'}, + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime': {'sdtype': 'datetime'}, + 'boolean': {'sdtype': 'boolean'}, + 'id': {'sdtype': 'id'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'root1', + 'parent_primary_key': 'R1', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R1', + }, + { + 'parent_table_name': 'root2', + 'parent_primary_key': 'R2', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R2', + }, + { + 'parent_table_name': 'grandparent', + 'parent_primary_key': 'GP', + 'child_table_name': 'parent', + 'child_foreign_key': 'GP', + }, + ], + }) + synthesizer = HMASynthesizer(metadata) + synthesizer._finalize = Mock(return_value=data) + + # Run estimation + estimated_num_columns = synthesizer._estimate_num_columns(metadata) + + # Run actual modeling + synthesizer.fit(data) + synthesizer.sample() + + # Assert estimated number of columns is correct + tables = synthesizer._finalize.call_args[0][0] + for table_name, table in tables.items(): + # Subract all the id columns present in the data, as those are not estimated + num_table_cols = len(table.columns) + if table_name in {'parent', 'grandparent'}: + num_table_cols -= 3 + if table_name in {'root1', 'root2'}: + num_table_cols -= 1 + + assert num_table_cols == estimated_num_columns[table_name] diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index e01195f89..57b52fa08 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -8,7 +8,7 @@ import pytest from sdv.errors import InvalidDataError, SamplingError -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.utils import ( _drop_rows, _get_all_descendant_per_root_at_order_n, @@ -605,7 +605,7 @@ def test__get_disconnected_roots_from_table(table_name, expected_result): def test__simplify_relationships_and_tables(): """Test the ``_simplify_relationships`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': {'columns': {'col_1': {'sdtype': 'numerical'}}}, 'parent': {'columns': {'col_2': {'sdtype': 'numerical'}}}, @@ -646,7 +646,7 @@ def test__simplify_relationships_and_tables(): def test__simplify_grandchildren(): """Test the ``_simplify_grandchildren`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': {'columns': {'col_1': {'sdtype': 'numerical'}}}, 'parent': {'columns': {'col_2': {'sdtype': 'numerical'}}}, @@ -697,7 +697,7 @@ def test__get_num_column_to_drop(): datetime_columns = {f'col_{i}': {'sdtype': 'datetime'} for i in range(600, 900)} id_columns = {f'col_{i}': {'sdtype': 'id'} for i in range(900, 910)} email_columns = {f'col_{i}': {'sdtype': 'email'} for i in range(910, 920)} - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'child': { 'columns': { @@ -881,11 +881,11 @@ def test__simplify_children(mock_get_columns_to_drop_child, mock_hma): child_1_before_simplify['columns']['col_4'] = {'sdtype': 'categorical'} child_2_before_simplify = deepcopy(child_2) child_2_before_simplify['columns']['col_8'] = {'sdtype': 'categorical'} - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'relationships': relatioships, 'tables': {'child_1': child_1_before_simplify, 'child_2': child_2_before_simplify}, }) - metadata_after_simplify_2 = MultiTableMetadata().load_from_dict({ + metadata_after_simplify_2 = Metadata().load_from_dict({ 'relationships': relatioships, 'tables': {'child_1': child_1, 'child_2': child_2}, }) @@ -951,7 +951,7 @@ def test__simplify_metadata_no_child_simplification(mock_hma): 'other_table': {'columns': {'col_8': {'sdtype': 'numerical'}}}, 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'relationships': relationships, 'tables': tables, }) @@ -1047,7 +1047,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): }, 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'relationships': relationships, 'tables': tables, }) @@ -1122,7 +1122,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): def test__simplify_data(): """Test the ``_simplify_data`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'parent': {'columns': {'col_1': {'sdtype': 'id'}}}, 'child': {'columns': {'col_2': {'sdtype': 'id'}}}, @@ -1249,7 +1249,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo 'col_12': [6, 7, 8, 9, 10], }), } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'disconnected_root': { 'columns': { @@ -1378,7 +1378,7 @@ def test__get_primary_keys_referenced(): }), } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': { 'columns': { @@ -1598,7 +1598,7 @@ def test__subsample_ancestors(): 'child': {21, 22, 23, 24, 25}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': { 'columns': { @@ -1785,7 +1785,7 @@ def test__subsample_ancestors_schema_diamond_shape(): 'parent_2': {31, 32, 33, 34, 35}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': { 'columns': { diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 21a40922c..966939869 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -19,16 +19,17 @@ class TestPARSynthesizer: def get_metadata(self, add_sequence_key=True, add_sequence_index=False): - metadata = SingleTableMetadata() - metadata.add_column('time', sdtype='datetime') - metadata.add_column('gender', sdtype='categorical') - metadata.add_column('name', sdtype='id') - metadata.add_column('measurement', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'time', sdtype='datetime') + metadata.add_column('table', 'gender', sdtype='categorical') + metadata.add_column('table', 'name', sdtype='id') + metadata.add_column('table', 'measurement', sdtype='numerical') if add_sequence_key: - metadata.set_sequence_key('name') + metadata.set_sequence_key('table', 'name') if add_sequence_index: - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') return metadata @@ -76,7 +77,10 @@ def test___init__(self): 'verbose': False, } assert isinstance(synthesizer._data_processor, DataProcessor) - assert synthesizer._data_processor.metadata == metadata + assert ( + synthesizer._data_processor.metadata.to_dict() + == metadata._convert_to_single_table().to_dict() + ) assert isinstance(synthesizer._context_synthesizer, GaussianCopulaSynthesizer) assert synthesizer._context_synthesizer.metadata.columns == { 'gender': {'sdtype': 'categorical'}, @@ -238,14 +242,14 @@ def test_get_metadata(self): result = instance.get_metadata() # Assert - assert result._convert_to_single_table().to_dict() == metadata.to_dict() + assert result.to_dict() == metadata.to_dict() assert isinstance(result, Metadata) def test_validate_context_columns_unique_per_sequence_key(self): """Test error is raised if context column values vary for each tuple of sequence keys. Setup: - A ``SingleTableMetadata`` instance where the context columns vary for different + A ``Metadata`` instance where the context columns vary for different combinations of values of the sequence keys. """ # Setup @@ -255,12 +259,13 @@ def test_validate_context_columns_unique_per_sequence_key(self): 'ct_col1': [1, 2, 2, 3, 2], 'ct_col2': [3, 3, 4, 3, 2], }) - metadata = SingleTableMetadata() - metadata.add_column('sk_col1', sdtype='id') - metadata.add_column('sk_col2', sdtype='id') - metadata.add_column('ct_col1', sdtype='numerical') - metadata.add_column('ct_col2', sdtype='numerical') - metadata.set_sequence_key('sk_col1') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'sk_col1', sdtype='id') + metadata.add_column('table', 'sk_col2', sdtype='id') + metadata.add_column('table', 'ct_col1', sdtype='numerical') + metadata.add_column('table', 'ct_col2', sdtype='numerical') + metadata.set_sequence_key('table', 'sk_col1') instance = PARSynthesizer(metadata=metadata, context_columns=['ct_col1', 'ct_col2']) # Run and Assert @@ -524,7 +529,7 @@ def test_auto_assign_transformers_without_enforce_min_max(self, mock_get_transfo 'measurement': [55, 60, 65], }) metadata = self.get_metadata() - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') mock_get_transfomers.return_value = {'time': FloatFormatter} # Run @@ -606,7 +611,7 @@ def test__fit_sequence_columns_with_categorical_float( data = self.get_data() data['measurement'] = data['measurement'].astype(float) metadata = self.get_metadata() - metadata.update_column('measurement', sdtype='categorical') + metadata.update_column('table', 'measurement', sdtype='categorical') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['M'], dtype=object), 'data': [['2020-01-03'], [65.0]]}, @@ -644,7 +649,7 @@ def test__fit_sequence_columns_with_sequence_index(self, assemble_sequences_mock 'measurement': [55, 60, 65, 65, 70], }) metadata = self.get_metadata() - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['F'], dtype=object), 'data': [[1, 1], [55, 60], [1, 1]]}, @@ -835,7 +840,7 @@ def test__sample_from_par_with_sequence_index(self, tqdm_mock): """ # Setup metadata = self.get_metadata() - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) model_mock = Mock() par._model = model_mock diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 5fd111960..158f71325 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -206,7 +206,7 @@ def test___init__invalid_enforce_min_max_values(self): ' Please provide True or False.' ) with pytest.raises(SynthesizerInputError, match=err_msg): - BaseSingleTableSynthesizer(SingleTableMetadata(), enforce_min_max_values='invalid') + BaseSingleTableSynthesizer(Metadata(), enforce_min_max_values='invalid') def test___init__invalid_enforce_rounding(self): """Test it crashes when ``enforce_rounding`` is not a boolean.""" @@ -216,12 +216,12 @@ def test___init__invalid_enforce_rounding(self): ' Please provide True or False.' ) with pytest.raises(SynthesizerInputError, match=err_msg): - BaseSingleTableSynthesizer(SingleTableMetadata(), enforce_rounding='invalid') + BaseSingleTableSynthesizer(Metadata(), enforce_rounding='invalid') def test_set_address_columns_warning(self): """Test ``set_address_columns`` method when the synthesizer has been fitted.""" # Setup - synthesizer = BaseSingleTableSynthesizer(SingleTableMetadata()) + synthesizer = BaseSingleTableSynthesizer(Metadata()) # Run and Assert expected_message = re.escape( @@ -286,7 +286,7 @@ def test_auto_assign_transformers(self): def test_auto_assign_transformers_with_invalid_data(self): """Test that auto_assign_transformer throws useful error about invalid data""" # Setup - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'a': {'sdtype': 'categorical'}, } @@ -590,7 +590,7 @@ def test_validate(self): """ # Setup data = pd.DataFrame() - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._validate_metadata = Mock() instance._validate_constraints = Mock() @@ -612,7 +612,7 @@ def test_validate_raises_constraints_error(self): """ # Setup data = pd.DataFrame() - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._validate_metadata = Mock(return_value=[]) instance._validate_constraints = Mock() @@ -639,7 +639,7 @@ def test_validate_raises_invalid_data_for_metadata(self): """ # Setup data = pd.DataFrame() - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._validate_metadata = Mock(return_value=[]) instance._validate_constraints = Mock() @@ -707,11 +707,12 @@ def test_update_transformers_invalid_keys(self): """ # Setup column_name_to_transformer = {'col2': RegexGenerator(), 'col3': FloatFormatter()} - metadata = SingleTableMetadata() - metadata.add_column('col2', sdtype='id') - metadata.add_column('col3', sdtype='id') - metadata.set_sequence_key(('col2')) - metadata.add_alternate_keys(['col3']) + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col2', sdtype='id') + metadata.add_column('table', 'col3', sdtype='id') + metadata.set_sequence_key('table', 'col2') + metadata.add_alternate_keys('table', ['col3']) instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -728,9 +729,10 @@ def test_update_transformers_already_fitted(self): fitted_transformer = FloatFormatter() fitted_transformer.fit(pd.DataFrame({'col': [1]}), 'col') column_name_to_transformer = {'col1': BinaryEncoder(), 'col2': fitted_transformer} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='boolean') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='boolean') + metadata.add_column('table', 'col2', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -742,9 +744,10 @@ def test_update_transformers_warns_gaussian_copula(self): """Test warning is raised when ohe is used for categorical column in the GaussianCopula.""" # Setup column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='categorical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='categorical') + metadata.add_column('table', 'col2', sdtype='numerical') instance = GaussianCopulaSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -769,9 +772,10 @@ def test_update_transformers_warns_models(self): """ # Setup column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='categorical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='categorical') + metadata.add_column('table', 'col2', sdtype='numerical') # NOTE: when PARSynthesizer is implemented, add it here as well for model in [CTGANSynthesizer, CopulaGANSynthesizer, TVAESynthesizer]: @@ -794,9 +798,10 @@ def test_update_transformers_warns_fitted(self): """ # Setup column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='numerical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='numerical') + metadata.add_column('table', 'col2', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) instance._fitted = True @@ -812,9 +817,10 @@ def test_update_transformers(self): """Test method correctly updates the transformers in the HyperTransformer.""" # Setup column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='numerical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='numerical') + metadata.add_column('table', 'col2', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -1912,7 +1918,7 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): def test_save_warning(self, tmp_path): """Test that the synthesizer produces a warning if saved without fitting.""" # Setup - synthesizer = BaseSynthesizer(SingleTableMetadata()) + synthesizer = BaseSynthesizer(Metadata()) # Run and Assert warn_msg = re.escape( @@ -2033,7 +2039,7 @@ def test_add_custom_constraint_class(self): def test_add_constraint_warning(self): """Test a warning is raised when the synthesizer had already been fitted.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._fitted = True @@ -2045,8 +2051,9 @@ def test_add_constraint_warning(self): def test_add_constraints(self): """Test a list of constraints can be added to the synthesizer.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2075,8 +2082,9 @@ def test_add_constraints(self): def test_get_constraints(self): """Test a list of constraints is returned by the method.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2109,8 +2117,9 @@ def test_get_info_no_enterprise(self, mock_sdv_version): data = pd.DataFrame({'col': [1, 2, 3]}) mock_sdv_version.public = '1.0.0' mock_sdv_version.enterprise = None - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) @@ -2156,8 +2165,9 @@ def test_get_info_with_enterprise(self, mock_sdv_version): data = pd.DataFrame({'col': [1, 2, 3]}) mock_sdv_version.public = '1.0.0' mock_sdv_version.enterprise = '1.2.0' - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index e41fbdc83..e7c531a26 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -9,7 +9,6 @@ from sdv.errors import SynthesizerInputError from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulagan import CopulaGANSynthesizer @@ -17,7 +16,7 @@ class TestCopulaGANSynthesizer: def test___init__(self): """Test creating an instance of ``CopulaGANSynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -89,8 +88,9 @@ def test___init__with_unified_metadata(self): def test___init__custom(self): """Test creating an instance of ``CopulaGANSynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('field', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'field', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -158,7 +158,7 @@ def test___init__custom(self): def test___init__incorrect_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` receives a non-dictionary.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = 'invalid' # Run @@ -169,7 +169,7 @@ def test___init__incorrect_numerical_distributions(self): def test___init__invalid_column_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` includes invalid columns.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = {'totally_fake_column_name': 'beta'} # Run @@ -184,7 +184,7 @@ def test___init__invalid_column_numerical_distributions(self): def test_get_params(self): """Test that inherited method ``get_params`` returns all the specific init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CopulaGANSynthesizer(metadata) # Run @@ -224,18 +224,11 @@ def test__create_gaussian_normalizer_config(self, mock_rdt): """ # Setup numerical_distributions = {'age': 'gamma'} - metadata = SingleTableMetadata() - metadata.columns = { - 'name': { - 'sdtype': 'categorical', - }, - 'age': { - 'sdtype': 'numerical', - }, - 'account': { - 'sdtype': 'numerical', - }, - } + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'age', sdtype='numerical') + metadata.add_column('table', 'account', sdtype='numerical') instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame({ @@ -280,8 +273,9 @@ def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger): were renamed/dropped during preprocessing. """ # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame() @@ -305,7 +299,7 @@ def test__fit(self, mock_rdt, mock_ctgansynthesizer__fit): one of the ``copulas`` distributions. """ # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CopulaGANSynthesizer(metadata) instance._create_gaussian_normalizer_config = Mock() processed_data = pd.DataFrame() @@ -333,8 +327,8 @@ def test_get_learned_distributions(self): """ # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) cgs = CopulaGANSynthesizer(stm) zero_transformer_mock = Mock(spec_set=GaussianNormalizer) zero_transformer_mock._univariate.to_dict.return_value = { @@ -378,8 +372,8 @@ def test_get_learned_distributions_raises_an_error(self): """Test that ``get_learned_distributions`` raises an error.""" # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) cgs = CopulaGANSynthesizer(stm) # Run and Assert diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 130fb068d..ce4300a12 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -9,7 +9,6 @@ from sdv.errors import SynthesizerInputError from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -37,7 +36,7 @@ def test_get_distribution_class_not_in_distributions(self): def test___init__(self): """Test creating an instance of ``GaussianCopulaSynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True numerical_distributions = None @@ -91,8 +90,9 @@ def test___init__with_unified_metadata(self): def test___init__custom(self): """Test creating an instance of ``GaussianCopulaSynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('field', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'field', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False numerical_distributions = {'field': 'gamma'} @@ -118,7 +118,7 @@ def test___init__custom(self): def test___init__incorrect_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` receives a non-dictionary.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = 'invalid' # Run @@ -129,7 +129,7 @@ def test___init__incorrect_numerical_distributions(self): def test___init__incorrect_column_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` includes invalid columns.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = {'totally_fake_column_name': 'beta'} # Run @@ -144,7 +144,7 @@ def test___init__incorrect_column_numerical_distributions(self): def test_get_parameters(self): """Test that inherited method ``get_parameters`` returns the specified init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = GaussianCopulaSynthesizer(metadata) # Run @@ -167,8 +167,9 @@ def test__fit_logging(self, mock_logger): were renamed/dropped during preprocessing. """ # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = GaussianCopulaSynthesizer( metadata, numerical_distributions=numerical_distributions @@ -194,9 +195,10 @@ def test__fit(self, mock_multivariate, mock_warnings): the ``numerical_distributions``. """ # Setup - metadata = SingleTableMetadata() - metadata.add_column('name', sdtype='numerical') - metadata.add_column('user.id', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='numerical') + metadata.add_column('table', 'user.id', sdtype='numerical') numerical_distributions = {'name': 'uniform', 'user.id': 'gamma'} processed_data = pd.DataFrame({ @@ -339,7 +341,7 @@ def test__rebuild_gaussian_copula(self): - numpy array, Square correlation matrix """ # Setup - metadata = SingleTableMetadata() + metadata = Metadata() gaussian_copula = GaussianCopulaSynthesizer(metadata) model_parameters = { 'univariates': { @@ -373,7 +375,7 @@ def test__rebuild_gaussian_copula(self): def test__rebuild_gaussian_copula_with_defaults(self, logger_mock): """Test the method with invalid parameters and default fallbacks.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() gaussian_copula = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') distribution_mock = Mock() delattr(distribution_mock.MODEL_CLASS, '_argcheck') @@ -472,8 +474,8 @@ def test_get_learned_distributions(self): """ # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) gcs = GaussianCopulaSynthesizer(stm, numerical_distributions={'one': 'uniform'}) gcs.fit(data) @@ -497,8 +499,8 @@ def test_get_learned_distributions_raises_an_error(self): """ # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) gcs = GaussianCopulaSynthesizer(stm) # Run and Assert diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 0f81558e6..5cb3e6000 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -7,7 +7,7 @@ from sdmetrics import visualization from sdv.errors import InvalidDataTypeError, NotFittedError -from sdv.metadata.single_table import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer, _validate_no_category_dtype @@ -33,7 +33,7 @@ class TestCTGANSynthesizer: def test___init__(self): """Test creating an instance of ``CTGANSynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -65,7 +65,7 @@ def test___init__(self): def test___init__with_unified_metadata(self): """Test creating an instance of ``CTGANSynthesizer`` with Metadata.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -97,7 +97,7 @@ def test___init__with_unified_metadata(self): def test___init__custom(self): """Test creating an instance of ``CTGANSynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -157,7 +157,7 @@ def test___init__custom(self): def test_get_parameters(self): """Test that inherited method ``get_parameters`` returns the specific init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) # Run @@ -187,10 +187,11 @@ def test_get_parameters(self): def test__estimate_num_columns(self): """Test that ``_estimate_num_columns`` returns without crashing the number of columns.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('id', sdtype='numerical') - metadata.add_column('name', sdtype='categorical') - metadata.add_column('surname', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'id', sdtype='numerical') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'surname', sdtype='categorical') data = pd.DataFrame({ 'id': np.random.rand(1_001), 'name': [f'cat_{i}' for i in range(1_001)], @@ -212,9 +213,10 @@ def test__estimate_num_columns(self): def test_preprocessing_many_categories(self, capfd): """Test a message is printed during preprocess when a column has many categories.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('categorical', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') + metadata.add_column('table', 'categorical', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(1_001), 'categorical': [f'cat_{i}' for i in range(1_001)], @@ -242,9 +244,10 @@ def test_preprocessing_many_categories(self, capfd): def test_preprocessing_few_categories(self, capfd): """Test a message is not printed during preprocess when a column has few categories.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('categorical', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') + metadata.add_column('table', 'categorical', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(10), 'categorical': [f'cat_{i}' for i in range(10)], @@ -269,8 +272,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_c that have been detected by the utility function. """ # Setup - metadata = SingleTableMetadata() - instance = CTGANSynthesizer(metadata) + metadata = Metadata() + single_metadata = metadata._convert_to_single_table() + instance = CTGANSynthesizer(single_metadata) processed_data = Mock() # Run @@ -279,7 +283,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_c # Assert mock_category_validate.assert_called_once_with(processed_data) mock_detect_discrete_columns.assert_called_once_with( - metadata, processed_data, instance._data_processor._hyper_transformer.field_transformers + single_metadata, + processed_data, + instance._data_processor._hyper_transformer.field_transformers, ) mock_ctgan.assert_called_once_with( batch_size=500, @@ -309,7 +315,7 @@ def test_get_loss_values(self): mock_model = Mock() loss_values = pd.DataFrame({'Epoch': [0, 1, 2], 'Loss': [0.8, 0.6, 0.5]}) mock_model.loss_values = loss_values - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) instance._model = mock_model instance._fitted = True @@ -323,7 +329,7 @@ def test_get_loss_values(self): def test_get_loss_values_error(self): """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) # Run / Assert @@ -335,7 +341,7 @@ def test_get_loss_values_error(self): def test_get_loss_values_plot(self, mock_line_plot): """Test the ``get_loss_values_plot`` method from ``CTGANSynthesizer.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) mock_loss_value = Mock() mock_loss_value.item.return_value = 0.1 @@ -364,7 +370,7 @@ class TestTVAESynthesizer: def test___init__(self): """Test creating an instance of ``TVAESynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -391,7 +397,7 @@ def test___init__(self): def test___init__custom(self): """Test creating an instance of ``TVAESynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -436,7 +442,7 @@ def test___init__custom(self): def test_get_parameters(self): """Test that inherited method ``get_parameters`` returns the specific init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = TVAESynthesizer(metadata) # Run @@ -468,8 +474,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_t that have been detected by the utility function. """ # Setup - metadata = SingleTableMetadata() - instance = TVAESynthesizer(metadata) + metadata = Metadata() + single_metadata = metadata._convert_to_single_table() + instance = TVAESynthesizer(single_metadata) processed_data = Mock() # Run @@ -478,7 +485,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_t # Assert mock_category_validate.assert_called_once_with(processed_data) mock_detect_discrete_columns.assert_called_once_with( - metadata, processed_data, instance._data_processor._hyper_transformer.field_transformers + single_metadata, + processed_data, + instance._data_processor._hyper_transformer.field_transformers, ) mock_tvae.assert_called_once_with( batch_size=500, @@ -503,7 +512,7 @@ def test_get_loss_values(self): mock_model = Mock() loss_values = pd.DataFrame({'Epoch': [0, 1, 2], 'Loss': [0.8, 0.6, 0.5]}) mock_model.loss_values = loss_values - metadata = SingleTableMetadata() + metadata = Metadata() instance = TVAESynthesizer(metadata) instance._model = mock_model instance._fitted = True @@ -517,7 +526,7 @@ def test_get_loss_values(self): def test_get_loss_values_error(self): """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = TVAESynthesizer(metadata) # Run / Assert diff --git a/tests/unit/utils/test_poc.py b/tests/unit/utils/test_poc.py index 6cd82c183..1473ed539 100644 --- a/tests/unit/utils/test_poc.py +++ b/tests/unit/utils/test_poc.py @@ -6,8 +6,8 @@ import pytest from sdv.errors import InvalidDataError -from sdv.metadata import MultiTableMetadata from sdv.metadata.errors import InvalidMetadataError +from sdv.metadata.metadata import Metadata from sdv.utils.poc import ( drop_unknown_references, get_random_subset, @@ -68,7 +68,7 @@ def test_simplify_schema( # Setup data = Mock() metadata = Mock() - simplified_metatadata = MultiTableMetadata() + simplified_metatadata = Metadata() mock_get_total_estimated_columns.return_value = 2000 mock_simplify_metadata.return_value = simplified_metatadata mock_simplify_data.return_value = { @@ -91,7 +91,7 @@ def test_simplify_schema( def test_simplify_schema_invalid_metadata(): """Test ``simplify_schema`` when the metadata is not invalid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': {'table1': {'columns': {'column1': {'sdtype': 'categorical'}}}}, 'relationships': [ { @@ -119,7 +119,7 @@ def test_simplify_schema_invalid_metadata(): def test_simplify_schema_invalid_data(): """Test ``simplify_schema`` when the data is not valid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'table1': {'columns': {'column1': {'sdtype': 'id'}}, 'primary_key': 'column1'}, 'table2': { @@ -152,7 +152,7 @@ def test_simplify_schema_invalid_data(): def test_get_random_subset_invalid_metadata(): """Test ``get_random_subset`` when the metadata is invalid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': {'table1': {'columns': {'column1': {'sdtype': 'categorical'}}}}, 'relationships': [ { @@ -180,7 +180,7 @@ def test_get_random_subset_invalid_metadata(): def test_get_random_subset_invalid_data(): """Test ``get_random_subset`` when the data is not valid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'table1': {'columns': {'column1': {'sdtype': 'id'}}, 'primary_key': 'column1'}, 'table2': { diff --git a/tests/utils.py b/tests/utils.py index a2fce6d1f..eded42cb1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ import pandas as pd from sdv.logging import get_sdv_logger -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.metadata import Metadata class DataFrameMatcher: @@ -80,7 +80,7 @@ def get_multi_table_metadata(): 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } - return MultiTableMetadata.load_from_dict(dict_metadata) + return Metadata.load_from_dict(dict_metadata) def get_multi_table_data():