Skip to content

Commit

Permalink
Update tests to use Metadata (#2178)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored and R-Palazzo committed Sep 26, 2024
1 parent 609e545 commit 66876e7
Show file tree
Hide file tree
Showing 37 changed files with 670 additions and 462 deletions.
4 changes: 2 additions & 2 deletions sdv/io/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from sdv.metadata import MultiTableMetadata
from sdv.metadata.metadata import Metadata


class BaseLocalHandler:
Expand All @@ -29,7 +29,7 @@ def create_metadata(self, data):
An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata
properties from the data.
"""
metadata = MultiTableMetadata()
metadata = Metadata()
metadata.detect_from_dataframes(data)
return metadata

Expand Down
15 changes: 13 additions & 2 deletions sdv/lite/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import cloudpickle

from sdv.metadata.metadata import Metadata
from sdv.single_table import GaussianCopulaSynthesizer

LOGGER = logging.getLogger(__name__)
Expand All @@ -20,13 +21,19 @@
"functionality, please use the 'GaussianCopulaSynthesizer'."
)

META_DEPRECATION_MSG = (
"The 'SingleTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)


class SingleTablePreset:
"""Class for all single table synthesizer presets.
Args:
metadata (sdv.metadata.SingleTableMetadata):
``SingleTableMetadata`` instance.
metadata (sdv.metadata.Metadata):
``Metadata`` instance.
* sdv.metadata.SingleTableMetadata can be used but will be deprecated.
name (str):
The preset to use.
locales (list or str):
Expand All @@ -49,6 +56,10 @@ def __init__(self, metadata, name, locales=['en_US']):
raise ValueError(f"'name' must be one of {PRESETS}.")

self.name = name
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()
else:
warnings.warn(META_DEPRECATION_MSG, FutureWarning)
if name == FAST_ML_PRESET:
self._setup_fast_preset(metadata, self.locales)

Expand Down
25 changes: 24 additions & 1 deletion sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Metadata."""

import warnings

from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
Expand All @@ -10,6 +12,7 @@ class Metadata(MultiTableMetadata):
"""Metadata class that handles all metadata."""

METADATA_SPEC_VERSION = 'V1'
DEFAULT_SINGLE_TABLE_NAME = 'default_table_name'

@classmethod
def load_from_json(cls, filepath, single_table_name=None):
Expand Down Expand Up @@ -66,9 +69,29 @@ def _set_metadata_dict(self, metadata, single_table_name=None):
super()._set_metadata_dict(metadata)
else:
if single_table_name is None:
single_table_name = 'default_table_name'
single_table_name = self.DEFAULT_SINGLE_TABLE_NAME
self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata)

def _get_single_table_name(self):
"""Get the table name if there is only one table.
Checks to see if the metadata contains only a single table, if so
return the name. Otherwise warn the user and return None.
Args:
metadata (dict):
Python dictionary representing a ``MultiTableMetadata`` or
``SingleTableMetadata`` object.
"""
if len(self.tables) != 1:
warnings.warn(
'This metadata does not contain only a single table. Could not determine '
'single table name and will return None.'
)
return None

return next(iter(self.tables), None)

def _convert_to_single_table(self):
if len(self.tables) > 1:
raise InvalidMetadataError(
Expand Down
2 changes: 1 addition & 1 deletion sdv/sampling/independent_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _finalize(self, sampled_data):
final_data = {}
for table_name, table_rows in sampled_data.items():
synthesizer = self._table_synthesizers.get(table_name)
metadata = synthesizer.get_metadata()
metadata = synthesizer.get_metadata()._convert_to_single_table()
dtypes = synthesizer._data_processor._dtypes
dtypes_to_sdtype = synthesizer._data_processor._DTYPE_TO_SDTYPE

Expand Down
5 changes: 4 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def __init__(
):
self._validate_inputs(enforce_min_max_values, enforce_rounding)
self.metadata = metadata
self._table_name = Metadata.DEFAULT_SINGLE_TABLE_NAME
if isinstance(metadata, Metadata):
self._table_name = metadata._get_single_table_name()
self.metadata = metadata._convert_to_single_table()
elif isinstance(metadata, SingleTableMetadata):
warnings.warn(DEPRECATION_MSG, FutureWarning)
Expand Down Expand Up @@ -284,7 +286,8 @@ def get_parameters(self):

def get_metadata(self):
"""Return the ``Metadata`` for this synthesizer."""
return Metadata.load_from_dict(self.metadata.to_dict())
table_name = getattr(self, '_table_name', None)
return Metadata.load_from_dict(self.metadata.to_dict(), table_name)

def load_custom_constraint_classes(self, filepath, class_names):
"""Load a custom constraint class for the current synthesizer.
Expand Down
3 changes: 2 additions & 1 deletion sdv/single_table/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ class CopulaGANSynthesizer(CTGANSynthesizer):
Args:
metadata (sdv.metadata.SingleTableMetadata):
metadata (sdv.metadata.Metadata):
Single table metadata representing the data that this synthesizer will be used for.
* sdv.metadata.SingleTableMetadata can be used but will be deprecated.
enforce_min_max_values (bool):
Specify whether or not to clip the data returned by ``reverse_transform`` of
the numerical transformer, ``FloatFormatter``, to the min and max values seen
Expand Down
3 changes: 2 additions & 1 deletion sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ class GaussianCopulaSynthesizer(BaseSingleTableSynthesizer):
"""Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula.
Args:
metadata (sdv.metadata.SingleTableMetadata):
metadata (sdv.metadata.Metadata):
Single table metadata representing the data that this synthesizer will be used for.
* sdv.metadata.SingleTableMetadata can be used but will be deprecated.
enforce_min_max_values (bool):
Specify whether or not to clip the data returned by ``reverse_transform`` of
the numerical transformer, ``FloatFormatter``, to the min and max values seen
Expand Down
6 changes: 4 additions & 2 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ class CTGANSynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``CTGAN`` model.
Args:
metadata (sdv.metadata.SingleTableMetadata):
metadata (sdv.metadata.Metadata):
Single table metadata representing the data that this synthesizer will be used for.
* sdv.metadata.SingleTableMetadata can be used but will be deprecated.
enforce_min_max_values (bool):
Specify whether or not to clip the data returned by ``reverse_transform`` of
the numerical transformer, ``FloatFormatter``, to the min and max values seen
Expand Down Expand Up @@ -316,8 +317,9 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``TVAE`` model.
Args:
metadata (sdv.metadata.SingleTableMetadata):
metadata (sdv.metadata.Metadata):
Single table metadata representing the data that this synthesizer will be used for.
* sdv.metadata.SingleTableMetadata can be used but will be deprecated.
enforce_min_max_values (bool):
Specify whether or not to clip the data returned by ``reverse_transform`` of
the numerical transformer, ``FloatFormatter``, to the min and max values seen
Expand Down
27 changes: 14 additions & 13 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sdv.datasets.demo import download_demo
from sdv.errors import SynthesizerInputError
from sdv.metadata import SingleTableMetadata
from sdv.metadata.metadata import Metadata


class TestDataProcessor:
Expand Down Expand Up @@ -155,12 +156,12 @@ def test_with_primary_key_numerical(self):
"""
# Load metadata and data
data, _ = download_demo('single_table', 'adult')
adult_metadata = SingleTableMetadata()
adult_metadata.detect_from_dataframe(data=data)
adult_metadata = Metadata()
adult_metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('id', sdtype='id')
adult_metadata.set_primary_key('id')
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')

# Add id
size = len(data)
Expand All @@ -169,7 +170,7 @@ def test_with_primary_key_numerical(self):
data['id'] = ids

# Instance ``DataProcessor``
dp = DataProcessor(adult_metadata)
dp = DataProcessor(adult_metadata._convert_to_single_table())

# Fit
dp.fit(data)
Expand All @@ -195,17 +196,17 @@ def test_with_alternate_keys(self):
# Load metadata and data
data, _ = download_demo('single_table', 'adult')
data['fnlwgt'] = data['fnlwgt'].astype(str)
adult_metadata = SingleTableMetadata()
adult_metadata.detect_from_dataframe(data=data)
adult_metadata = Metadata()
adult_metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('id', sdtype='id')
adult_metadata.set_primary_key('id')
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')

adult_metadata.add_column('secondary_id', sdtype='id')
adult_metadata.update_column('fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]')
adult_metadata.add_column('adult', 'secondary_id', sdtype='id')
adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]')

adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt'])
adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt'])

# Add id
size = len(data)
Expand All @@ -215,7 +216,7 @@ def test_with_alternate_keys(self):
data['secondary_id'] = ids

# Instance ``DataProcessor``
dp = DataProcessor(adult_metadata)
dp = DataProcessor(adult_metadata._convert_to_single_table())

# Fit
dp.fit(data)
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/evaluation/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sdv.evaluation.multi_table import evaluate_quality, run_diagnostic
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata


def test_evaluation():
Expand All @@ -18,7 +17,7 @@ def test_evaluation():
'table1': table,
'table2': slightly_different_table,
}
metadata = MultiTableMetadata().load_from_dict({
metadata = Metadata().load_from_dict({
'tables': {
'table1': {
'columns': {
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from sdv.datasets.demo import download_demo
from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic
from sdv.metadata.metadata import Metadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer


def test_evaluation():
"""Test ``evaluate_quality`` and ``run_diagnostic``."""
# Setup
data = pd.DataFrame({'col': [1, 2, 3]})
metadata = SingleTableMetadata()
metadata.add_column('col', sdtype='numerical')
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('table', 'col', sdtype='numerical')
synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm')

# Run and Assert
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/io/local/test_local.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd

from sdv.io.local import CSVHandler, ExcelHandler
from sdv.metadata import MultiTableMetadata
from sdv.metadata import Metadata


class TestCSVHandler:
Expand All @@ -27,7 +27,7 @@ def test_integration_write_and_read(self, tmpdir):
assert len(data) == 2
assert 'table1' in data
assert 'table2' in data
assert isinstance(metadata, MultiTableMetadata) is True
assert isinstance(metadata, Metadata) is True

# Check if the dataframes match the original synthetic data
pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1'])
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_integration_write_and_read(self, tmpdir):
assert len(data) == 2
assert 'table1' in data
assert 'table2' in data
assert isinstance(metadata, MultiTableMetadata) is True
assert isinstance(metadata, Metadata) is True

# Check if the dataframes match the original synthetic data
pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1'])
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_integration_write_and_read_append_mode(self, tmpdir):
assert len(data) == 2
assert 'table1' in data
assert 'table2' in data
assert isinstance(metadata, MultiTableMetadata) is True
assert isinstance(metadata, Metadata) is True

# Check if the dataframes match the original synthetic data
expected_table_one = pd.concat(
Expand Down
14 changes: 7 additions & 7 deletions tests/integration/lite/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from sdv.lite import SingleTablePreset
from sdv.metadata import SingleTableMetadata
from sdv.metadata.metadata import Metadata


def test_sample():
Expand All @@ -12,8 +12,8 @@ def test_sample():
data = pd.DataFrame({'a': [1, 2, 3, np.nan]})

# Run
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata = Metadata()
metadata.detect_from_dataframes({'adult': data})
preset = SingleTablePreset(metadata, name='FAST_ML')
preset.fit(data)
samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5)
Expand All @@ -29,8 +29,8 @@ def test_sample_with_constraints():
data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})

# Run
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
preset = SingleTablePreset(metadata, name='FAST_ML')
constraints = [
{
Expand All @@ -57,8 +57,8 @@ def test_warnings_are_shown():
data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})

# Run
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})

with pytest.warns(FutureWarning, match=warn_message):
preset = SingleTablePreset(metadata, name='FAST_ML')
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/metadata/test_visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from sdv.metadata import MultiTableMetadata, SingleTableMetadata
from sdv.metadata.metadata import Metadata
from sdv.multi_table.hma import HMASynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer

Expand All @@ -9,8 +9,8 @@ def test_visualize_graph_for_single_table():
"""Test it runs when a column name contains symbols."""
# Setup
data = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata = Metadata()
metadata.detect_from_dataframes({'table': data})
model = GaussianCopulaSynthesizer(metadata)

# Run
Expand All @@ -26,7 +26,7 @@ def test_visualize_graph_for_multi_table():
data1 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
tables = {'1': data1, '2': data2}
metadata = MultiTableMetadata()
metadata = Metadata()
metadata.detect_from_dataframes(tables)
metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
Expand Down
Loading

0 comments on commit 66876e7

Please sign in to comment.