Skip to content

Commit

Permalink
Provide a friendlier error if data is stored as dtype 'category' (CTG…
Browse files Browse the repository at this point in the history
…AN, TVAE) (sdv-dev#1746)
  • Loading branch information
frances-h authored Jan 23, 2024
1 parent 39903e5 commit e6e508b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 6 deletions.
4 changes: 4 additions & 0 deletions sdv/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,9 @@ def __str__(self):
)


class InvalidDataTypeError(Exception):
"""Error to raise if data type is not valid."""


class VisualizationUnavailableError(Exception):
"""Exception to indicate that a visualization is unavailable."""
29 changes: 28 additions & 1 deletion sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
"""Wrapper around CTGAN model."""
import numpy as np
import pandas as pd
from ctgan import CTGAN, TVAE

from sdv.errors import NotFittedError
from sdv.errors import InvalidDataTypeError, NotFittedError
from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import detect_discrete_columns


def _validate_no_category_dtype(data):
"""Check that given data has no 'category' dtype columns.
Args:
data (pd.DataFrame):
Data to check.
Raises:
- ``InvalidDataTypeError`` if any columns in the data have 'category' dtype.
"""
category_cols = [
col for col, dtype in data.dtypes.items() if pd.api.types.is_categorical_dtype(dtype)
]
if category_cols:
categoricals = "', '".join(category_cols)
error_msg = (
f"Columns ['{categoricals}'] are stored as a 'category' type, which is not "
"supported. Please cast these columns to an 'object' to continue."
)
raise InvalidDataTypeError(error_msg)


class LossValuesMixin:
"""Mixin for accessing loss values from synthesizers."""

Expand Down Expand Up @@ -200,6 +223,8 @@ def _fit(self, processed_data):
processed_data (pandas.DataFrame):
Data to be learned.
"""
_validate_no_category_dtype(processed_data)

transformers = self._data_processor._hyper_transformer.field_transformers
discrete_columns = detect_discrete_columns(
self.get_metadata(),
Expand Down Expand Up @@ -303,6 +328,8 @@ def _fit(self, processed_data):
processed_data (pandas.DataFrame):
Data to be learned.
"""
_validate_no_category_dtype(processed_data)

transformers = self._data_processor._hyper_transformer.field_transformers
discrete_columns = detect_discrete_columns(
self.get_metadata(),
Expand Down
28 changes: 27 additions & 1 deletion tests/integration/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import re

import numpy as np
import pandas as pd
import pytest
from rdt.transformers import FloatFormatter, LabelEncoder

from sdv.datasets.demo import download_demo
from sdv.errors import InvalidDataTypeError
from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot
from sdv.metadata import SingleTableMetadata
from sdv.single_table import CTGANSynthesizer
from sdv.single_table import CTGANSynthesizer, TVAESynthesizer


def test__estimate_num_columns():
Expand Down Expand Up @@ -217,6 +221,28 @@ def test_categorical_metadata_with_int_data():
assert len(recycled_categories_for_c) == 50


def test_category_dtype_errors():
"""Test CTGAN and TVAE error if data has 'category' dtype."""
# Setup
data, metadata = download_demo('single_table', 'fake_hotel_guests')
data['room_type'] = data['room_type'].astype('category')
data['has_rewards'] = data['has_rewards'].astype('category')

ctgan = CTGANSynthesizer(metadata)
tvae = TVAESynthesizer(metadata)

# Run and Assert
expected_msg = re.escape(
"Columns ['has_rewards', 'room_type'] are stored as a 'category' type, which is not "
"supported. Please cast these columns to an 'object' to continue."
)
with pytest.raises(InvalidDataTypeError, match=expected_msg):
ctgan.fit(data)

with pytest.raises(InvalidDataTypeError, match=expected_msg):
tvae.fit(data)


def test_ctgansynthesizer_with_constraints_generating_categorical_values():
"""Test that ``CTGANSynthesizer`` does not crash when using constraints.
Expand Down
31 changes: 27 additions & 4 deletions tests/unit/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import re
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest

from sdv.errors import NotFittedError
from sdv.errors import InvalidDataTypeError, NotFittedError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer, _validate_no_category_dtype


def test__validate_no_category_dtype():
"""Test that 'category' dtype causes error."""
# Setup
data = pd.DataFrame({
'category1': pd.Categorical(['a', 'a', 'b']),
'value': [0, 1, 2],
'category2': pd.Categorical([0, 1, 2])
})

# Run and Assert
expected = re.escape(
"Columns ['category1', 'category2'] are stored as a 'category' type, which is not "
"supported. Please cast these columns to an 'object' to continue."
)
with pytest.raises(InvalidDataTypeError, match=expected):
_validate_no_category_dtype(data)


class TestCTGANSynthesizer:
Expand Down Expand Up @@ -209,7 +228,8 @@ def test_preprocessing_few_categories(self, capfd):

@patch('sdv.single_table.ctgan.CTGAN')
@patch('sdv.single_table.ctgan.detect_discrete_columns')
def test__fit(self, mock_detect_discrete_columns, mock_ctgan):
@patch('sdv.single_table.ctgan._validate_no_category_dtype')
def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_ctgan):
"""Test the ``_fit`` from ``CTGANSynthesizer``.
Test that when we call ``_fit`` a new instance of ``CTGAN`` is created as a model
Expand All @@ -225,6 +245,7 @@ def test__fit(self, mock_detect_discrete_columns, mock_ctgan):
instance._fit(processed_data)

# Assert
mock_category_validate.assert_called_once_with(processed_data)
mock_detect_discrete_columns.assert_called_once_with(
metadata,
processed_data,
Expand Down Expand Up @@ -380,7 +401,8 @@ def test_get_parameters(self):

@patch('sdv.single_table.ctgan.TVAE')
@patch('sdv.single_table.ctgan.detect_discrete_columns')
def test__fit(self, mock_detect_discrete_columns, mock_tvae):
@patch('sdv.single_table.ctgan._validate_no_category_dtype')
def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_tvae):
"""Test the ``_fit`` from ``TVAESynthesizer``.
Test that when we call ``_fit`` a new instance of ``TVAE`` is created as a model
Expand All @@ -396,6 +418,7 @@ def test__fit(self, mock_detect_discrete_columns, mock_tvae):
instance._fit(processed_data)

# Assert
mock_category_validate.assert_called_once_with(processed_data)
mock_detect_discrete_columns.assert_called_once_with(
metadata,
processed_data,
Expand Down

0 comments on commit e6e508b

Please sign in to comment.