diff --git a/pyproject.toml b/pyproject.toml index 026cebf80..6ca371918 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ 'copulas>=0.11.0', 'ctgan>=0.10.0', 'deepecho>=0.6.0', - 'rdt>=1.12.0', + 'rdt @ git+https://github.com/sdv-dev/RDT@main', 'sdmetrics>=0.14.0', 'platformdirs>=4.0', 'pyyaml>=6.0.1', @@ -75,7 +75,6 @@ dev = [ # docs 'docutils>=0.12,<1', - 'm2r2>=0.2.5,<1', 'nbsphinx>=0.5.0,<1', 'sphinx_toolbox>=2.5,<4', 'Sphinx>=3,<8', diff --git a/sdv/_utils.py b/sdv/_utils.py index 9138dbcce..3ec466537 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -10,10 +10,16 @@ import pandas as pd from pandas.core.tools.datetimes import _guess_datetime_format_for_array +from rdt.transformers.utils import _GENERATORS from sdv import version from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError +try: + from re import _parser as sre_parse +except ImportError: + import sre_parse + def _cast_to_iterable(value): """Return a ``list`` if the input object is not a ``list`` or ``tuple``.""" @@ -403,3 +409,33 @@ def generate_synthesizer_id(synthesizer): synth_version = version.public unique_id = ''.join(str(uuid.uuid4()).split('-')) return f'{class_name}_{synth_version}_{unique_id}' + + +def _get_chars_for_option(option, params): + if option not in _GENERATORS: + raise ValueError(f'REGEX operation: {option} is not supported by SDV.') + + if option == sre_parse.MAX_REPEAT: + new_option, new_params = params[2][0] # The value at the second index is the nested option + return _get_chars_for_option(new_option, new_params) + + return list(_GENERATORS[option](params, 1)[0]) + + +def get_possible_chars(regex, num_subpatterns=None): + """Get the list of possible characters a regex can create. + + Args: + regex (str): + The regex to parse. + num_subpatterns (int): + The number of sub-patterns from the regex to find characters for. + """ + parsed = sre_parse.parse(regex) + parsed = [p for p in parsed if p[0] != sre_parse.AT] + num_subpatterns = num_subpatterns or len(parsed) + possible_chars = [] + for option, params in parsed[:num_subpatterns]: + possible_chars += _get_chars_for_option(option, params) + + return possible_chars diff --git a/sdv/data_processing/numerical_formatter.py b/sdv/data_processing/numerical_formatter.py index fe9f72881..1d7100ba1 100644 --- a/sdv/data_processing/numerical_formatter.py +++ b/sdv/data_processing/numerical_formatter.py @@ -3,8 +3,8 @@ import logging import sys -import numpy as np import pandas as pd +from rdt.transformers.utils import learn_rounding_digits LOGGER = logging.getLogger(__name__) @@ -51,34 +51,6 @@ def __init__( self.enforce_min_max_values = enforce_min_max_values self.computer_representation = computer_representation - @staticmethod - def _learn_rounding_digits(data): - """Check if data has any decimals.""" - name = data.name - data = np.array(data) - roundable_data = data[~(np.isinf(data) | pd.isna(data))] - - # Doesn't contain numbers - if len(roundable_data) == 0: - return None - - # Doesn't contain decimal digits - if ((roundable_data % 1) == 0).all(): - return 0 - - # Try to round to fewer digits - if (roundable_data == roundable_data.round(MAX_DECIMALS)).all(): - for decimal in range(MAX_DECIMALS + 1): - if (roundable_data == roundable_data.round(decimal)).all(): - return decimal - - # Can't round, not equal after MAX_DECIMALS digits of precision - LOGGER.info( - f"No rounding scheme detected for column '{name}'." - ' Synthetic data will not be rounded.' - ) - return None - def learn_format(self, column): """Learn the format of a column. @@ -92,7 +64,7 @@ def learn_format(self, column): self._max_value = column.max() if self.enforce_rounding: - self._rounding_digits = self._learn_rounding_digits(column) + self._rounding_digits = learn_rounding_digits(column) def format_data(self, column): """Format a column according to the learned format. @@ -105,20 +77,17 @@ def format_data(self, column): numpy.ndarray: containing the formatted data. """ - column = column.copy().to_numpy() + column = column.copy() if self.enforce_min_max_values: column = column.clip(self._min_value, self._max_value) - elif self.computer_representation != 'Float': + elif not self.computer_representation.startswith('Float'): min_bound, max_bound = INTEGER_BOUNDS[self.computer_representation] column = column.clip(min_bound, max_bound) - is_integer = np.dtype(self._dtype).kind == 'i' + is_integer = pd.api.types.is_integer_dtype(self._dtype) if self.enforce_rounding and self._rounding_digits is not None: column = column.round(self._rounding_digits) elif is_integer: column = column.round(0) - if pd.isna(column).any() and is_integer: - return column - return column.astype(self._dtype) diff --git a/sdv/evaluation/multi_table.py b/sdv/evaluation/multi_table.py index e258b40db..302b971a6 100644 --- a/sdv/evaluation/multi_table.py +++ b/sdv/evaluation/multi_table.py @@ -77,8 +77,8 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name 1D marginal distribution plot (i.e. a histogram) of the columns. """ metadata = metadata.tables[table_name] - real_data = real_data[table_name] - synthetic_data = synthetic_data[table_name] + real_data = real_data[table_name] if real_data else None + synthetic_data = synthetic_data[table_name] if synthetic_data else None return single_table_visualization.get_column_plot( real_data, synthetic_data, @@ -118,8 +118,8 @@ def get_column_pair_plot( 2D bivariate distribution plot (i.e. a scatterplot) of the columns. """ metadata = metadata.tables[table_name] - real_data = real_data[table_name] - synthetic_data = synthetic_data[table_name] + real_data = real_data[table_name] if real_data else None + synthetic_data = synthetic_data[table_name] if synthetic_data else None return single_table_visualization.get_column_pair_plot( real_data, synthetic_data, metadata, column_names, sample_size, plot_type ) diff --git a/sdv/logging/logger.py b/sdv/logging/logger.py index 7ce51854c..0d073241c 100644 --- a/sdv/logging/logger.py +++ b/sdv/logging/logger.py @@ -33,7 +33,7 @@ def __init__(self, filename=None): def format(self, record): # noqa: A003 """Format the record and write to CSV.""" - row = record.msg + row = record.msg.copy() row['LEVEL'] = record.levelname self.writer.writerow(row) data = self.output.getvalue() diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index ac77f9ff9..799f7ab2b 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -55,6 +55,8 @@ class SingleTableMetadata: } _NUMERICAL_REPRESENTATIONS = frozenset([ + 'Float32', + 'Float64', 'Float', 'Int64', 'Int32', diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 0966c71f3..7de30efff 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -28,6 +28,7 @@ from sdv.logging import disable_single_table_logger, get_sdv_logger from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata +from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') @@ -372,9 +373,17 @@ def preprocess(self, data): processed_data = {} pbar_args = self._get_pbar_args(desc='Preprocess Tables') for table_name, table_data in tqdm(data.items(), **pbar_args): - synthesizer = self._table_synthesizers[table_name] - self._assign_table_transformers(synthesizer, table_name, table_data) - processed_data[table_name] = synthesizer._preprocess(table_data) + try: + synthesizer = self._table_synthesizers[table_name] + self._assign_table_transformers(synthesizer, table_name, table_data) + processed_data[table_name] = synthesizer._preprocess(table_data) + except SynthesizerInputError as e: + if INT_REGEX_ZERO_ERROR_MESSAGE in str(e): + raise SynthesizerInputError( + f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}' + ) + + raise e for table in list_of_changed_tables: data[table].columns = self._original_table_columns[table] diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 9614ed705..c807030c4 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -24,6 +24,7 @@ check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id, + get_possible_chars, ) from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor @@ -43,6 +44,10 @@ COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 +INT_REGEX_ZERO_ERROR_MESSAGE = ( + 'is stored as an int but the Regex allows it to start with "0". Please remove the Regex ' + 'or update it to correspond to valid ints.' +) DEPRECATION_MSG = ( "The 'SingleTableMetadata' is deprecated. Please use the new " @@ -177,6 +182,17 @@ def _validate(self, data): """ return [] + def _validate_primary_key(self, data): + primary_key = self.metadata.primary_key + is_int = primary_key and pd.api.types.is_integer_dtype(data[primary_key]) + regex = self.metadata.columns.get(primary_key, {}).get('regex_format') + if is_int and regex: + possible_characters = get_possible_chars(regex, 1) + if '0' in possible_characters: + raise SynthesizerInputError( + f'Primary key "{primary_key}" {INT_REGEX_ZERO_ERROR_MESSAGE}.' + ) + def validate(self, data): """Validate data. @@ -198,6 +214,7 @@ def validate(self, data): * values of a column don't satisfy their sdtype """ self._validate_metadata(data) + self._validate_primary_key(data) self._validate_constraints(data) # Retaining the logic of returning errors and raising them here to maintain consistency diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index f086ceca8..30174f43a 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -1,5 +1,7 @@ """Wrapper around CTGAN model.""" +import warnings + import numpy as np import pandas as pd import plotly.express as px @@ -285,7 +287,9 @@ def _fit(self, processed_data): transformers = self._data_processor._hyper_transformer.field_transformers discrete_columns = detect_discrete_columns(self.metadata, processed_data, transformers) self._model = CTGAN(**self._model_kwargs) - self._model.fit(processed_data, discrete_columns=discrete_columns) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='.*Attempting to run cuBLAS.*') + self._model.fit(processed_data, discrete_columns=discrete_columns) def _sample(self, num_rows, conditions=None): """Sample the indicated number of rows from the model. diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index a47ff8f4d..c46fde334 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -4,6 +4,7 @@ import math import re import warnings +from unittest.mock import Mock import faker import numpy as np @@ -2021,3 +2022,143 @@ def test_hma_synthesizer_with_fixed_combinations(): assert len(sampled['users']) > 1 assert len(sampled['records']) > 1 assert len(sampled['locations']) > 1 + + +REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w'] + + +@pytest.mark.parametrize('regex', REGEXES) +def test_fit_int_primary_key_regex_includes_zero(regex): + """Test that sdv errors if the primary key has a regex, is an int, and can start with 0.""" + # Setup + parent_data = pd.DataFrame({ + 'parent_id': [1, 2, 3, 4, 5, 6], + 'col': ['a', 'b', 'a', 'b', 'a', 'b'], + }) + child_data = pd.DataFrame({'id': [1, 2, 3, 4, 5, 6], 'parent_id': [1, 2, 3, 4, 5, 6]}) + data = { + 'parent_data': parent_data, + 'child_data': child_data, + } + metadata = MultiTableMetadata() + metadata.detect_from_dataframes(data) + metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex) + metadata.set_primary_key('parent_data', 'parent_id') + + # Run and Assert + instance = HMASynthesizer(metadata) + message = ( + 'Primary key for table "parent_data" is stored as an int but the Regex allows it to start ' + 'with "0". Please remove the Regex or update it to correspond to valid ints.' + ) + with pytest.raises(SynthesizerInputError, match=message): + instance.fit(data) + + +def test__estimate_num_columns_to_be_modeled_various_sdtypes(): + """Test the estimated number of columns is correct for various sdtypes. + + To check that the number columns is correct we Mock the ``_finalize`` method + and compare its output with the estimated number of columns. + + The dataset used follows the structure below: + R1 R2 + | / + GP + | + P + """ + # Setup + root1 = pd.DataFrame({'R1': [0, 1, 2]}) + root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]}) + grandparent = pd.DataFrame({'GP': [0, 1, 2], 'R1': [0, 1, 2], 'R2': [0, 1, 2]}) + parent = pd.DataFrame({ + 'P': [0, 1, 2], + 'GP': [0, 1, 2], + 'numerical': [0.1, 0.5, np.nan], + 'categorical': ['a', np.nan, 'c'], + 'datetime': [None, '2019-01-02', '2019-01-03'], + 'boolean': [float('nan'), False, True], + 'id': [0, 1, 2], + }) + data = { + 'root1': root1, + 'root2': root2, + 'grandparent': grandparent, + 'parent': parent, + } + metadata = MultiTableMetadata.load_from_dict({ + 'tables': { + 'root1': { + 'primary_key': 'R1', + 'columns': { + 'R1': {'sdtype': 'id'}, + }, + }, + 'root2': { + 'primary_key': 'R2', + 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, + }, + 'grandparent': { + 'primary_key': 'GP', + 'columns': { + 'GP': {'sdtype': 'id'}, + 'R1': {'sdtype': 'id'}, + 'R2': {'sdtype': 'id'}, + }, + }, + 'parent': { + 'primary_key': 'P', + 'columns': { + 'P': {'sdtype': 'id'}, + 'GP': {'sdtype': 'id'}, + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime': {'sdtype': 'datetime'}, + 'boolean': {'sdtype': 'boolean'}, + 'id': {'sdtype': 'id'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'root1', + 'parent_primary_key': 'R1', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R1', + }, + { + 'parent_table_name': 'root2', + 'parent_primary_key': 'R2', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R2', + }, + { + 'parent_table_name': 'grandparent', + 'parent_primary_key': 'GP', + 'child_table_name': 'parent', + 'child_foreign_key': 'GP', + }, + ], + }) + synthesizer = HMASynthesizer(metadata) + synthesizer._finalize = Mock(return_value=data) + + # Run estimation + estimated_num_columns = synthesizer._estimate_num_columns(metadata) + + # Run actual modeling + synthesizer.fit(data) + synthesizer.sample() + + # Assert estimated number of columns is correct + tables = synthesizer._finalize.call_args[0][0] + for table_name, table in tables.items(): + # Subract all the id columns present in the data, as those are not estimated + num_table_cols = len(table.columns) + if table_name in {'parent', 'grandparent'}: + num_table_cols -= 3 + if table_name in {'root1', 'root2'}: + num_table_cols -= 1 + + assert num_table_cols == estimated_num_columns[table_name] diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 2dfb8aa38..4832a2746 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -810,3 +810,31 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class): # Assert assert sample.columns.tolist() == data.columns.tolist() + + +REGEXES = ['[0-9]{3,4}', '0HQ-[a-z]', '0+', r'\d', r'\d{1,5}', r'\w'] + + +@pytest.mark.parametrize('regex', REGEXES) +@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) +def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex): + """Test that sdv errors if the primary key has a regex, is an int, and can start with 0.""" + # Setup + data = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [4, 5, 6], + 'c': ['a', 'b', 'c'], + }) + metadata = SingleTableMetadata() + metadata.detect_from_dataframe(data) + metadata.update_column('a', sdtype='id', regex_format=regex) + metadata.set_primary_key('a') + + # Run and Assert + instance = synthesizer_class(metadata) + message = ( + 'Primary key "a" is stored as an int but the Regex allows it to start with "0". Please ' + 'remove the Regex or update it to correspond to valid ints.' + ) + with pytest.raises(SynthesizerInputError, match=message): + instance.fit(data) diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 22fb568b4..84f864318 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -469,3 +469,37 @@ def test_datetime_values_inside_real_data_range(): assert check_in_synthetic.max() <= check_in_real.max() assert check_out_synthetic.min() >= check_out_real.min() assert check_out_synthetic.max() <= check_out_real.max() + + +def test_support_new_pandas_dtypes(): + """Test that the synthesizer supports the nullable numerical pandas dtypes.""" + # Setup + data = pd.DataFrame({ + 'Int8': pd.Series([1, 2, -3, pd.NA], dtype='Int8'), + 'Int16': pd.Series([1, 2, -3, pd.NA], dtype='Int16'), + 'Int32': pd.Series([1, 2, -3, pd.NA], dtype='Int32'), + 'Int64': pd.Series([1, 2, pd.NA, -3], dtype='Int64'), + 'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'), + 'Float64': pd.Series([1.113, 2.22, 3.3, pd.NA], dtype='Float64'), + }) + metadata = SingleTableMetadata().load_from_dict({ + 'columns': { + 'Int8': {'sdtype': 'numerical', 'computer_representation': 'Int8'}, + 'Int16': {'sdtype': 'numerical', 'computer_representation': 'Int16'}, + 'Int32': {'sdtype': 'numerical', 'computer_representation': 'Int32'}, + 'Int64': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'Float32': {'sdtype': 'numerical', 'computer_representation': 'Float32'}, + 'Float64': {'sdtype': 'numerical', 'computer_representation': 'Float64'}, + } + }) + + synthesizer = GaussianCopulaSynthesizer(metadata) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(10) + + # Assert + assert (synthetic_data.dtypes == data.dtypes).all() + assert (synthetic_data['Float32'] == synthetic_data['Float32'].round(1)).all(skipna=True) + assert (synthetic_data['Float64'] == synthetic_data['Float64'].round(3)).all(skipna=True) diff --git a/tests/unit/data_processing/test_numerical_formatter.py b/tests/unit/data_processing/test_numerical_formatter.py index 826599d9e..ec214db58 100644 --- a/tests/unit/data_processing/test_numerical_formatter.py +++ b/tests/unit/data_processing/test_numerical_formatter.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock import numpy as np import pandas as pd @@ -20,96 +20,6 @@ def test___init__(self): assert formatter.enforce_min_max_values is True assert formatter.computer_representation == 'Int8' - @patch('sdv.data_processing.numerical_formatter.LOGGER') - def test__learn_rounding_digits_more_than_15_decimals(self, log_mock): - """Test the ``_learn_rounding_digits`` method with more than 15 decimals. - - If the data has more than 15 decimals, return None and use ``LOGGER`` to inform the user. - """ - # Setup - data = pd.Series(np.random.random(size=10).round(20), name='col') - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - log_msg = ( - "No rounding scheme detected for column 'col'. Synthetic data will not be rounded." - ) - log_mock.info.assert_called_once_with(log_msg) - assert output is None - - def test__learn_rounding_digits_less_than_15_decimals(self): - """Test the ``_learn_rounding_digits`` method with less than 15 decimals. - - If the data has less than 15 decimals, the maximum number of decimals should be returned. - - Input: - - an array that contains floats with a maximum of 3 decimals and a NaN. - - Output: - - 3 - """ - # Setup - data = pd.Series(np.array([10, 0.0, 0.1, 0.12, 0.123, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output == 3 - - def test__learn_rounding_digits_negative_decimals_float(self): - """Test the ``_learn_rounding_digits`` method with floats multiples of powers of 10. - - If the data has all multiples of 10 the output should be None. - - Input: - - an array that contains floats that are multiples of 10, 100 and 1000 and a NaN. - """ - # Setup - data = pd.Series(np.array([1230.0, 12300.0, 123000.0, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output == 0 - - def test__learn_rounding_digits_negative_decimals_integer(self): - """Test the ``_learn_rounding_digits`` method with integers multiples of powers of 10. - - If the data has all multiples of 10 the output should be None. - - Input: - - an array that contains integers that are multiples of 10, 100 and 1000 and a NaN. - """ - # Setup - data = pd.Series(np.array([1230, 12300, 123000, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output == 0 - - def test__learn_rounding_digits_all_nans(self): - """Test the ``_learn_rounding_digits`` method with data that is all NaNs. - - If the data is all NaNs, expect that the output is 0. - - Input: - - an array of NaN. - """ - # Setup - data = pd.Series(np.array([np.nan, np.nan, np.nan, np.nan])) - - # Run - output = NumericalFormatter._learn_rounding_digits(data) - - # Assert - assert output is None - def test_learn_format(self): """Test that ``learn_format`` method. diff --git a/tests/unit/evaluation/test_multi_table.py b/tests/unit/evaluation/test_multi_table.py index 4cdea58a6..2995f9dd5 100644 --- a/tests/unit/evaluation/test_multi_table.py +++ b/tests/unit/evaluation/test_multi_table.py @@ -73,9 +73,31 @@ def test_get_column_plot(mock_plot): assert plot == 'plot' +@patch('sdv.evaluation.single_table.get_column_plot') +def test_get_column_plot_only_real_or_synthetic(mock_plot): + """Test that ``get_column_plot`` works when only real or synthetic data is provided.""" + # Setup + table1 = pd.DataFrame({'col': [1, 2, 3]}) + data1 = {'table': table1} + metadata = MultiTableMetadata() + metadata.detect_table_from_dataframe('table', table1) + mock_plot.return_value = 'plot' + + # Run + get_column_plot(data1, None, metadata, 'table', 'col') + get_column_plot(None, data1, metadata, 'table', 'col') + + # Assert + call_metadata = metadata.tables['table'] + mock_plot.assert_has_calls([ + ((table1, None, call_metadata, 'col', None), {}), + ((None, table1, call_metadata, 'col', None), {}), + ]) + + @patch('sdv.evaluation.single_table.get_column_pair_plot') def test_get_column_pair_plot(mock_plot): - """Test that ``get_column_pair`` plot is being called with the expected objects.""" + """Test that ``get_column_pair_plot`` is being called with the expected objects.""" # Setup table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]}) table2 = pd.DataFrame({'col1': [2, 1, 3], 'col2': [1, 2, 3]}) @@ -94,6 +116,28 @@ def test_get_column_pair_plot(mock_plot): assert plot == 'plot' +@patch('sdv.evaluation.single_table.get_column_pair_plot') +def test_get_column_pair_plot_only_real_or_synthetic(mock_plot): + """Test that ``get_column_pair_plot`` works when only real or synthetic data is provided.""" + # Setup + table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]}) + data1 = {'table': table1} + metadata = MultiTableMetadata() + metadata.detect_table_from_dataframe('table', table1) + mock_plot.return_value = 'plot' + + # Run + get_column_pair_plot(data1, None, metadata, 'table', ['col1', 'col2'], 2) + get_column_pair_plot(None, data1, metadata, 'table', ['col1', 'col2'], 2) + + # Assert + call_metadata = metadata.tables['table'] + mock_plot.assert_has_calls([ + ((table1, None, call_metadata, ['col1', 'col2'], None, 2), {}), + ((None, table1, call_metadata, ['col1', 'col2'], None, 2), {}), + ]) + + @patch('sdmetrics.visualization.get_cardinality_plot') def test_get_cardinality_plot(mock_plot): """Test it calls ``get_column_cardinality_plot`` in sdmetrics with the parent primary key.""" diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index c5b979d4f..337451f80 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -322,7 +322,8 @@ def test__validate_unexpected_kwargs_invalid(self, column_name, sdtype, kwargs, with pytest.raises(InvalidMetadataError, match=error_msg): instance._validate_unexpected_kwargs(column_name, sdtype, **kwargs) - def test__validate_column_invalid_sdtype(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test__validate_column_invalid_sdtype(self, mock_is_faker_function): """Test the method with an invalid sdtype. If the sdtype isn't one of the supported types, anonymized types or Faker functions, @@ -330,6 +331,7 @@ def test__validate_column_invalid_sdtype(self): """ # Setup instance = SingleTableMetadata() + mock_is_faker_function.return_value = False # Run and Assert error_msg = re.escape( @@ -340,11 +342,13 @@ def test__validate_column_invalid_sdtype(self): instance._validate_column_args('column', 'fake_type') error_msg = re.escape( - 'Invalid sdtype: None is not a string. Please use one of the ' 'supported SDV sdtypes.' + 'Invalid sdtype: None is not a string. Please use one of the supported SDV sdtypes.' ) with pytest.raises(InvalidMetadataError, match=error_msg): instance._validate_column_args('column', None) + mock_is_faker_function.assert_called_once_with('fake_type') + @patch('sdv.metadata.single_table.SingleTableMetadata._validate_unexpected_kwargs') @patch('sdv.metadata.single_table.SingleTableMetadata._validate_numerical') def test__validate_column_numerical(self, mock__validate_numerical, mock__validate_kwargs): @@ -599,7 +603,8 @@ def test_add_column_sdtype_not_in_kwargs(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.add_column('synthetic') - def test_add_column_invalid_sdtype(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_add_column_invalid_sdtype(self, mock_is_faker_function): """Test the method with an invalid sdtype. If the sdtype isn't one of the supported types, anonymized types or Faker functions, @@ -607,6 +612,7 @@ def test_add_column_invalid_sdtype(self): """ # Setup instance = SingleTableMetadata() + mock_is_faker_function.return_value = False # Run and Assert error_msg = re.escape( @@ -616,6 +622,8 @@ def test_add_column_invalid_sdtype(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.add_column('column', sdtype='fake_type') + mock_is_faker_function.assert_called_once_with('fake_type') + def test_add_column(self): """Test ``add_column`` method. @@ -794,13 +802,15 @@ def test_update_columns_sdtype_in_kwargs_error(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.update_columns(['col_1', 'col_2'], sdtype='numerical', pii=True) - def test_update_columns_multiple_errors(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_update_columns_multiple_errors(self, mock_is_faker_function): """Test the ``update_columns`` method. Test that ``update_columns`` with multiple errors. Should raise an ``InvalidMetadataError`` with a summary of all the errors. """ # Setup + mock_is_faker_function.return_value = True instance = SingleTableMetadata() instance.columns = { 'col_1': {'sdtype': 'country_code'}, @@ -817,6 +827,8 @@ def test_update_columns_multiple_errors(self): with pytest.raises(InvalidMetadataError, match=error_msg): instance.update_columns(['col_1', 'col_2', 'col_3'], pii=True) + mock_is_faker_function.assert_called_once_with('country_code') + def test_update_columns(self): """Test the ``update_columns`` method.""" # Setup @@ -839,9 +851,11 @@ def test_update_columns(self): 'salary': {'sdtype': 'categorical'}, } - def test_update_columns_kwargs_without_sdtype(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_update_columns_kwargs_without_sdtype(self, mock_is_faker_function): """Test the ``update_columns`` method when there is no ``sdtype`` in the kwargs.""" # Setup + mock_is_faker_function.return_value = True instance = SingleTableMetadata() instance.columns = { 'col_1': {'sdtype': 'country_code'}, @@ -859,6 +873,11 @@ def test_update_columns_kwargs_without_sdtype(self): 'col_3': {'sdtype': 'longitude', 'pii': True}, } assert instance._updated is True + mock_is_faker_function.assert_has_calls([ + call('country_code'), + call('latitude'), + call('longitude'), + ]) def test_update_columns_metadata(self): """Test the ``update_columns_metadata`` method.""" @@ -1620,7 +1639,8 @@ def test_set_primary_key_validation_columns(self): instance.set_primary_key('b') # NOTE: used to be ('a', 'b', 'd', 'c') - def test_set_primary_key_validation_categorical(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_set_primary_key_validation_categorical(self, mock_is_faker_function): """Test that ``set_primary_key`` crashes when its sdtype is categorical. Input: @@ -1630,6 +1650,7 @@ def test_set_primary_key_validation_categorical(self): - An ``InvalidMetadataError`` should be raised. """ # Setup + mock_is_faker_function.return_value = False instance = SingleTableMetadata() instance.add_column('column1', sdtype='categorical') instance.add_column('column2', sdtype='categorical') @@ -1640,6 +1661,8 @@ def test_set_primary_key_validation_categorical(self): with pytest.raises(InvalidMetadataError, match=err_msg): instance.set_primary_key('column1') + mock_is_faker_function.assert_called_once_with('categorical') + def test_set_primary_key(self): """Test that ``set_primary_key`` sets the ``_primary_key`` value.""" # Setup @@ -1776,7 +1799,8 @@ def test_set_sequence_key_validation_columns(self): instance.set_sequence_key('b') # NOTE: used to be ('a', 'b', 'd', 'c') - def test_set_sequence_key_validation_categorical(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_set_sequence_key_validation_categorical(self, mock_is_faker_function): """Test that ``set_sequence_key`` crashes when its sdtype is categorical. Input: @@ -1786,6 +1810,7 @@ def test_set_sequence_key_validation_categorical(self): - An ``InvalidMetadataError`` should be raised. """ # Setup + mock_is_faker_function.return_value = False instance = SingleTableMetadata() instance.add_column('column1', sdtype='categorical') instance.add_column('column2', sdtype='categorical') @@ -1796,6 +1821,8 @@ def test_set_sequence_key_validation_categorical(self): with pytest.raises(InvalidMetadataError, match=err_msg): instance.set_sequence_key('column1') + mock_is_faker_function.assert_called_once_with('categorical') + def test_set_sequence_key(self): """Test that ``set_sequence_key`` sets the ``_sequence_key`` value.""" # Setup @@ -1887,7 +1914,8 @@ def test_add_alternate_keys_validation_columns(self): instance.add_alternate_keys(['abc', '123']) # NOTE: used to be ['abc', ('123', '213', '312'), 'bca'] - def test_add_alternate_keys_validation_categorical(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test_add_alternate_keys_validation_categorical(self, mock_is_faker_function): """Test that ``add_alternate_keys`` crashes when its sdtype is categorical. Input: @@ -1897,6 +1925,7 @@ def test_add_alternate_keys_validation_categorical(self): - An ``InvalidMetadataError`` should be raised. """ # Setup + mock_is_faker_function.return_value = False instance = SingleTableMetadata() instance.add_column('column1', sdtype='categorical') instance.add_column('column2', sdtype='categorical') @@ -1909,6 +1938,8 @@ def test_add_alternate_keys_validation_categorical(self): with pytest.raises(InvalidMetadataError, match=err_msg): instance.add_alternate_keys(['column1', 'column2', 'column3']) + mock_is_faker_function.assert_has_calls([call('categorical'), call('categorical')]) + def test_add_alternate_keys_validation_primary_key(self): """Test that ``add_alternate_keys`` crashes when the key is a primary key. diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 5cfa9f132..0fbe4123d 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -24,6 +24,7 @@ from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer from sdv.multi_table.hma import HMASynthesizer +from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.ctgan import CTGANSynthesizer from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata @@ -159,9 +160,11 @@ def test___init___deprecated(self): with pytest.warns(FutureWarning, match=deprecation_msg): BaseMultiTableSynthesizer(metadata) - def test__init__column_relationship_warning(self): + @patch('sdv.metadata.single_table.is_faker_function') + def test__init__column_relationship_warning(self, mock_is_faker_function): """Test that a warning is raised only once when the metadata has column relationships.""" # Setup + mock_is_faker_function.return_value = True metadata = get_multi_table_metadata() metadata.add_column('nesreca', 'lat', sdtype='latitude') metadata.add_column('nesreca', 'lon', sdtype='longitude') @@ -184,6 +187,10 @@ def test__init__column_relationship_warning(self): warning for warning in caught_warnings if expected_warning in str(warning.message) ] assert len(column_relationship_warnings) == 1 + mock_is_faker_function.assert_has_calls([ + call('latitude'), + call('longitude'), + ]) def test___init___synthesizer_kwargs_deprecated(self): """Test that the ``synthesizer_kwargs`` method is deprecated.""" @@ -930,6 +937,84 @@ def test_preprocess_warning(self, mock_warnings): "please refit the model using 'fit' or 'fit_processed_data'." ) + def test_preprocess_single_table_preprocess_raises_error_0_int_regex(self): + """Test that if the single table synthesizer raises a specific error, it is reformatted. + + If a single table synthesizer raises an error about the primary key being an integer + with a regex that can start with zero, the error should be reformatted to include the + table name. + """ + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance.validate = Mock() + data = { + 'nesreca': pd.DataFrame({ + 'id_nesreca': np.arange(0, 20, 2), + 'upravna_enota': np.arange(10), + }), + 'oseba': pd.DataFrame({ + 'upravna_enota': np.arange(10), + 'id_nesreca': np.arange(10), + }), + 'upravna_enota': pd.DataFrame({ + 'id_upravna_enota': np.arange(10), + }), + } + + synth_nesreca = Mock() + synth_oseba = Mock() + synth_upravna_enota = Mock() + synth_nesreca._preprocess.side_effect = SynthesizerInputError(INT_REGEX_ZERO_ERROR_MESSAGE) + instance._table_synthesizers = { + 'nesreca': synth_nesreca, + 'oseba': synth_oseba, + 'upravna_enota': synth_upravna_enota, + } + + # Run + message = f'Primary key for table "nesreca" {INT_REGEX_ZERO_ERROR_MESSAGE}' + with pytest.raises(SynthesizerInputError, match=message): + instance.preprocess(data) + + def test_preprocess_single_table_preprocess_raises_error(self): + """Test that if the single table synthesizer raises any other error, it is raised. + + If a single table synthesizer raises an error besides the one concerning int primary keys + starting with 0 and having a regex, then the error should be raised as is. + """ + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance.validate = Mock() + data = { + 'nesreca': pd.DataFrame({ + 'id_nesreca': np.arange(0, 20, 2), + 'upravna_enota': np.arange(10), + }), + 'oseba': pd.DataFrame({ + 'upravna_enota': np.arange(10), + 'id_nesreca': np.arange(10), + }), + 'upravna_enota': pd.DataFrame({ + 'id_upravna_enota': np.arange(10), + }), + } + + synth_nesreca = Mock() + synth_oseba = Mock() + synth_upravna_enota = Mock() + synth_nesreca._preprocess.side_effect = SynthesizerInputError('blah') + instance._table_synthesizers = { + 'nesreca': synth_nesreca, + 'oseba': synth_oseba, + 'upravna_enota': synth_upravna_enota, + } + + # Run + with pytest.raises(SynthesizerInputError, match='blah'): + instance.preprocess(data) + @patch('sdv.multi_table.base.datetime') def test_fit_processed_data(self, mock_datetime, caplog): """Test that fit processed data calls ``_augment_tables`` and ``_model_tables``. diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 51d5aa1a4..0db06ea82 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -1119,111 +1119,3 @@ def test__estimate_num_columns_to_be_modeled(self): num_table_cols -= 1 assert num_table_cols == estimated_num_columns[table_name] - - def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): - """Test the estimated number of columns is correct for various sdtypes. - - To check that the number columns is correct we Mock the ``_finalize`` method - and compare its output with the estimated number of columns. - - The dataset used follows the structure below: - R1 R2 - | / - GP - | - P - """ - # Setup - root1 = pd.DataFrame({'R1': [0, 1, 2]}) - root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]}) - grandparent = pd.DataFrame({'GP': [0, 1, 2], 'R1': [0, 1, 2], 'R2': [0, 1, 2]}) - parent = pd.DataFrame({ - 'P': [0, 1, 2], - 'GP': [0, 1, 2], - 'numerical': [0.1, 0.5, np.nan], - 'categorical': ['a', np.nan, 'c'], - 'datetime': [None, '2019-01-02', '2019-01-03'], - 'boolean': [float('nan'), False, True], - 'id': [0, 1, 2], - }) - data = { - 'root1': root1, - 'root2': root2, - 'grandparent': grandparent, - 'parent': parent, - } - metadata = MultiTableMetadata.load_from_dict({ - 'tables': { - 'root1': { - 'primary_key': 'R1', - 'columns': { - 'R1': {'sdtype': 'id'}, - }, - }, - 'root2': { - 'primary_key': 'R2', - 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, - }, - 'grandparent': { - 'primary_key': 'GP', - 'columns': { - 'GP': {'sdtype': 'id'}, - 'R1': {'sdtype': 'id'}, - 'R2': {'sdtype': 'id'}, - }, - }, - 'parent': { - 'primary_key': 'P', - 'columns': { - 'P': {'sdtype': 'id'}, - 'GP': {'sdtype': 'id'}, - 'numerical': {'sdtype': 'numerical'}, - 'categorical': {'sdtype': 'categorical'}, - 'datetime': {'sdtype': 'datetime'}, - 'boolean': {'sdtype': 'boolean'}, - 'id': {'sdtype': 'id'}, - }, - }, - }, - 'relationships': [ - { - 'parent_table_name': 'root1', - 'parent_primary_key': 'R1', - 'child_table_name': 'grandparent', - 'child_foreign_key': 'R1', - }, - { - 'parent_table_name': 'root2', - 'parent_primary_key': 'R2', - 'child_table_name': 'grandparent', - 'child_foreign_key': 'R2', - }, - { - 'parent_table_name': 'grandparent', - 'parent_primary_key': 'GP', - 'child_table_name': 'parent', - 'child_foreign_key': 'GP', - }, - ], - }) - synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock(return_value=data) - - # Run estimation - estimated_num_columns = synthesizer._estimate_num_columns(metadata) - - # Run actual modeling - synthesizer.fit(data) - synthesizer.sample() - - # Assert estimated number of columns is correct - tables = synthesizer._finalize.call_args[0][0] - for table_name, table in tables.items(): - # Subract all the id columns present in the data, as those are not estimated - num_table_cols = len(table.columns) - if table_name in {'parent', 'grandparent'}: - num_table_cols -= 3 - if table_name in {'root1', 'root2'}: - num_table_cols -= 1 - - assert num_table_cols == estimated_num_columns[table_name] diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 1571f6a2e..5fd111960 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -658,6 +658,47 @@ def test_validate_raises_invalid_data_for_metadata(self): instance._validate_constraints.assert_called_once_with(data) instance._validate.assert_not_called() + def test_validate_int_primary_key_regex_starts_with_zero(self): + """Test that an error is raised if the primary key is an int that can start with 0. + + If the the primary key is stored as an int, but a regex is used with it, it is possible + that the first character can be a 0. If this happens, then we can get duplicate primary + key values since two different strings can be the same when converted ints + (eg. '00123' and '0123'). + """ + # Setup + data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']}) + metadata = Mock() + metadata.primary_key = 'key' + metadata.column_relationships = [] + metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[0-9]{3,4}'}} + instance = BaseSingleTableSynthesizer(metadata) + + # Run and Assert + message = ( + 'Primary key "key" is stored as an int but the Regex allows it to start with ' + '"0". Please remove the Regex or update it to correspond to valid ints.' + ) + with pytest.raises(SynthesizerInputError, match=message): + instance.validate(data) + + def test_validate_int_primary_key_regex_does_not_start_with_zero(self): + """Test that no error is raised if the primary key is an int that can't start with 0. + + If the the primary key is stored as an int, but a regex is used with it, it is possible + that the first character can be a 0. If it isn't possible, then no error should be raised. + """ + # Setup + data = pd.DataFrame({'key': [1, 2, 3], 'info': ['a', 'b', 'c']}) + metadata = Mock() + metadata.primary_key = 'key' + metadata.column_relationships = [] + metadata.columns = {'key': {'sdtype': 'id', 'regex_format': '[1-9]{3,4}'}} + instance = BaseSingleTableSynthesizer(metadata) + + # Run and Assert + instance.validate(data) + def test_update_transformers_invalid_keys(self): """Test error is raised if passed transformer doesn't match key column. diff --git a/tests/unit/test__utils.py b/tests/unit/test__utils.py index 87520c5df..1cbcf3416 100644 --- a/tests/unit/test__utils.py +++ b/tests/unit/test__utils.py @@ -1,5 +1,6 @@ import operator import re +import string from datetime import datetime from unittest.mock import Mock, patch @@ -12,6 +13,7 @@ _compare_versions, _convert_to_timedelta, _create_unique_name, + _get_chars_for_option, _get_datetime_format, _get_root_tables, _is_datetime_type, @@ -19,12 +21,18 @@ check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id, + get_possible_chars, ) from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.base import BaseSingleTableSynthesizer from tests.utils import SeriesMatcher +try: + from re import _parser as sre_parse +except ImportError: + import sre_parse + @patch('sdv._utils.pd.to_timedelta') def test__convert_to_timedelta(to_timedelta_mock): @@ -626,3 +634,82 @@ def test_generate_synthesizer_id(mock_version, mock_uuid): # Assert assert result == 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + + +@patch('sdv._utils._get_chars_for_option') +def test_get_possible_chars_excludes_at(mock_get_chars): + """Test that 'at' regex operations aren't included when getting chars.""" + # Setup + regex = '^[1-9]{1,2}$' + mock_get_chars.return_value = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + + # Run + possible_chars = get_possible_chars(regex) + + # Assert + mock_get_chars.assert_called_once() + mock_call = mock_get_chars.mock_calls[0] + assert mock_call[1][0] == sre_parse.MAX_REPEAT + assert mock_call[1][1][0] == 1 + assert mock_call[1][1][1] == 2 + assert mock_call[1][1][2].data == [(sre_parse.IN, [(sre_parse.RANGE, (49, 57))])] + assert possible_chars == [str(i) for i in range(10)] + + +def test_get_possible_chars_unsupported_regex(): + """Test that an error is raised if the regex contains unsupported options.""" + # Setup + regex = '(ab)*' + + # Run and assert + message = 'REGEX operation: SUBPATTERN is not supported by SDV.' + with pytest.raises(ValueError, match=message): + get_possible_chars(regex) + + +@patch('sdv._utils._get_chars_for_option') +def test_get_possible_chars_handles_max_repeat(mock_get_chars): + """Test that MAX_REPEATS are handled by recursively finding the first non MAX_REPEAT. + + One valid regex option is a MAX_REPEAT. Getting all possible values for this could be slow, + so we just look for the first nexted option that isn't a max_repeat to get the possible + characters instead. + """ + # Setup + regex = '[1-9]{1,2}' + mock_get_chars.side_effect = lambda x, y: _get_chars_for_option(x, y) + + # Run + possible_chars = get_possible_chars(regex) + + # Assert + assert len(mock_get_chars.mock_calls) == 2 + assert mock_get_chars.mock_calls[1][1] == mock_get_chars.mock_calls[0][1][1][2][0] + assert possible_chars == [str(i) for i in range(1, 10)] + + +def test_get_possible_chars_num_subpatterns(): + """Test that only characters for first x subpatterns are returned.""" + # Setup + regex = 'HID_[0-9]{3}_[a-z]{3}' + + # Run + possible_chars = get_possible_chars(regex, 3) + + # Assert + assert possible_chars == ['H', 'I', 'D'] + + +def test_get_possible_chars(): + """Test that all characters for regex are returned.""" + # Setup + regex = 'HID_[0-9]{3}_[a-z]{3}' + + # Run + possible_chars = get_possible_chars(regex) + + # Assert + prefix = ['H', 'I', 'D', '_'] + nums = [str(i) for i in range(10)] + lowercase_letters = list(string.ascii_lowercase) + assert possible_chars == prefix + nums + ['_'] + lowercase_letters