From e6e508ba197b104dea3fca96fea34d9bbf9a8cd7 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Tue, 23 Jan 2024 10:35:33 -0500 Subject: [PATCH] Provide a friendlier error if data is stored as dtype 'category' (CTGAN, TVAE) (#1746) --- sdv/errors.py | 4 +++ sdv/single_table/ctgan.py | 29 +++++++++++++++++- tests/integration/single_table/test_ctgan.py | 28 +++++++++++++++++- tests/unit/single_table/test_ctgan.py | 31 +++++++++++++++++--- 4 files changed, 86 insertions(+), 6 deletions(-) diff --git a/sdv/errors.py b/sdv/errors.py index 21797b809..bb7b8b039 100644 --- a/sdv/errors.py +++ b/sdv/errors.py @@ -56,5 +56,9 @@ def __str__(self): ) +class InvalidDataTypeError(Exception): + """Error to raise if data type is not valid.""" + + class VisualizationUnavailableError(Exception): """Exception to indicate that a visualization is unavailable.""" diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index f5653d7f8..363d97cf8 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -1,12 +1,35 @@ """Wrapper around CTGAN model.""" import numpy as np +import pandas as pd from ctgan import CTGAN, TVAE -from sdv.errors import NotFittedError +from sdv.errors import InvalidDataTypeError, NotFittedError from sdv.single_table.base import BaseSingleTableSynthesizer from sdv.single_table.utils import detect_discrete_columns +def _validate_no_category_dtype(data): + """Check that given data has no 'category' dtype columns. + + Args: + data (pd.DataFrame): + Data to check. + + Raises: + - ``InvalidDataTypeError`` if any columns in the data have 'category' dtype. + """ + category_cols = [ + col for col, dtype in data.dtypes.items() if pd.api.types.is_categorical_dtype(dtype) + ] + if category_cols: + categoricals = "', '".join(category_cols) + error_msg = ( + f"Columns ['{categoricals}'] are stored as a 'category' type, which is not " + "supported. Please cast these columns to an 'object' to continue." + ) + raise InvalidDataTypeError(error_msg) + + class LossValuesMixin: """Mixin for accessing loss values from synthesizers.""" @@ -200,6 +223,8 @@ def _fit(self, processed_data): processed_data (pandas.DataFrame): Data to be learned. """ + _validate_no_category_dtype(processed_data) + transformers = self._data_processor._hyper_transformer.field_transformers discrete_columns = detect_discrete_columns( self.get_metadata(), @@ -303,6 +328,8 @@ def _fit(self, processed_data): processed_data (pandas.DataFrame): Data to be learned. """ + _validate_no_category_dtype(processed_data) + transformers = self._data_processor._hyper_transformer.field_transformers discrete_columns = detect_discrete_columns( self.get_metadata(), diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index 6d70c72ba..bcdb18a25 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -1,11 +1,15 @@ +import re + import numpy as np import pandas as pd +import pytest from rdt.transformers import FloatFormatter, LabelEncoder from sdv.datasets.demo import download_demo +from sdv.errors import InvalidDataTypeError from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot from sdv.metadata import SingleTableMetadata -from sdv.single_table import CTGANSynthesizer +from sdv.single_table import CTGANSynthesizer, TVAESynthesizer def test__estimate_num_columns(): @@ -217,6 +221,28 @@ def test_categorical_metadata_with_int_data(): assert len(recycled_categories_for_c) == 50 +def test_category_dtype_errors(): + """Test CTGAN and TVAE error if data has 'category' dtype.""" + # Setup + data, metadata = download_demo('single_table', 'fake_hotel_guests') + data['room_type'] = data['room_type'].astype('category') + data['has_rewards'] = data['has_rewards'].astype('category') + + ctgan = CTGANSynthesizer(metadata) + tvae = TVAESynthesizer(metadata) + + # Run and Assert + expected_msg = re.escape( + "Columns ['has_rewards', 'room_type'] are stored as a 'category' type, which is not " + "supported. Please cast these columns to an 'object' to continue." + ) + with pytest.raises(InvalidDataTypeError, match=expected_msg): + ctgan.fit(data) + + with pytest.raises(InvalidDataTypeError, match=expected_msg): + tvae.fit(data) + + def test_ctgansynthesizer_with_constraints_generating_categorical_values(): """Test that ``CTGANSynthesizer`` does not crash when using constraints. diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 1626836e7..1dbafa9a8 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -1,12 +1,31 @@ +import re from unittest.mock import Mock, patch import numpy as np import pandas as pd import pytest -from sdv.errors import NotFittedError +from sdv.errors import InvalidDataTypeError, NotFittedError from sdv.metadata.single_table import SingleTableMetadata -from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer +from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer, _validate_no_category_dtype + + +def test__validate_no_category_dtype(): + """Test that 'category' dtype causes error.""" + # Setup + data = pd.DataFrame({ + 'category1': pd.Categorical(['a', 'a', 'b']), + 'value': [0, 1, 2], + 'category2': pd.Categorical([0, 1, 2]) + }) + + # Run and Assert + expected = re.escape( + "Columns ['category1', 'category2'] are stored as a 'category' type, which is not " + "supported. Please cast these columns to an 'object' to continue." + ) + with pytest.raises(InvalidDataTypeError, match=expected): + _validate_no_category_dtype(data) class TestCTGANSynthesizer: @@ -209,7 +228,8 @@ def test_preprocessing_few_categories(self, capfd): @patch('sdv.single_table.ctgan.CTGAN') @patch('sdv.single_table.ctgan.detect_discrete_columns') - def test__fit(self, mock_detect_discrete_columns, mock_ctgan): + @patch('sdv.single_table.ctgan._validate_no_category_dtype') + def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_ctgan): """Test the ``_fit`` from ``CTGANSynthesizer``. Test that when we call ``_fit`` a new instance of ``CTGAN`` is created as a model @@ -225,6 +245,7 @@ def test__fit(self, mock_detect_discrete_columns, mock_ctgan): instance._fit(processed_data) # Assert + mock_category_validate.assert_called_once_with(processed_data) mock_detect_discrete_columns.assert_called_once_with( metadata, processed_data, @@ -380,7 +401,8 @@ def test_get_parameters(self): @patch('sdv.single_table.ctgan.TVAE') @patch('sdv.single_table.ctgan.detect_discrete_columns') - def test__fit(self, mock_detect_discrete_columns, mock_tvae): + @patch('sdv.single_table.ctgan._validate_no_category_dtype') + def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_tvae): """Test the ``_fit`` from ``TVAESynthesizer``. Test that when we call ``_fit`` a new instance of ``TVAE`` is created as a model @@ -396,6 +418,7 @@ def test__fit(self, mock_detect_discrete_columns, mock_tvae): instance._fit(processed_data) # Assert + mock_category_validate.assert_called_once_with(processed_data) mock_detect_discrete_columns.assert_called_once_with( metadata, processed_data,