Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide user-friendly error messages when there are missing values in conditional sampling #1791

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,13 +794,6 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
show_progress_bar=show_progress_bar
)

def _validate_conditions(self, conditions):
"""Validate the user-passed conditions."""
for column in conditions.columns:
if column not in self._data_processor.get_sdtypes():
raise ValueError(f"Unexpected column name '{column}'. "
f'Use a column name that was present in the original data.')

def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
progress_bar=None, output_file_path=None):
"""Sample rows with conditions.
Expand Down Expand Up @@ -904,6 +897,27 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,

return all_sampled_rows

def _validate_conditions_unseen_columns(self, conditions):
"""Validate the user-passed conditions."""
for column in conditions.columns:
if column not in self._data_processor.get_sdtypes():
raise ValueError(f"Unexpected column name '{column}'. "
f'Use a column name that was present in the original data.')

@staticmethod
def _raise_condition_with_nans():
raise SynthesizerInputError(
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)

def _validate_conditions(self, conditions):
"""Validate the user-passed conditions."""
for condition_dataframe in conditions:
self._validate_conditions_unseen_columns(condition_dataframe)
if condition_dataframe.isna().any().any():
self._raise_condition_with_nans()

def sample_from_conditions(self, conditions, max_tries_per_batch=100,
batch_size=None, output_file_path=None):
"""Sample rows from this table with the given conditions.
Expand Down Expand Up @@ -939,8 +953,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100,
lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0)

conditions = self._make_condition_dfs(conditions)
for condition_dataframe in conditions:
self._validate_conditions(condition_dataframe)
self._validate_conditions(conditions)

sampled = pd.DataFrame()
try:
Expand Down Expand Up @@ -974,6 +987,17 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100,

return sampled

def _validate_known_columns(self, conditions):
"""Validate the user-passed conditions."""
self._validate_conditions_unseen_columns(conditions)
if conditions.dropna().empty:
self._raise_condition_with_nans()
elif conditions.isna().any().any():
warnings.warn(
'Missing values are not yet supported. '
'Rows with any missing values will not be created.'
)

def sample_remaining_columns(self, known_columns, max_tries_per_batch=100,
batch_size=None, output_file_path=None):
"""Sample remaining rows from already known columns.
Expand Down Expand Up @@ -1006,7 +1030,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100,
output_file_path = validate_file_path(output_file_path)

known_columns = known_columns.copy()
self._validate_conditions(known_columns)
self._validate_known_columns(known_columns)
sampled = pd.DataFrame()
try:
with tqdm.tqdm(total=len(known_columns)) as progress_bar:
Expand Down
78 changes: 78 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import warnings
from unittest.mock import patch

import numpy as np
import pandas as pd
import pkg_resources
import pytest
from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder

from sdv.datasets.demo import download_demo
from sdv.errors import SynthesizerInputError
from sdv.metadata import SingleTableMetadata
from sdv.sampling import Condition
from sdv.single_table import (
Expand Down Expand Up @@ -147,6 +150,81 @@ def test_sample_from_conditions_negative_float():
pd.testing.assert_series_equal(sampled_data['column1'], expected)


def test_sample_from_conditions_with_nans():
"""Test it crashes when condition has nans (GH#1758)."""
# Setup
data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
my_condition = Condition(
num_rows=250,
column_values={'room_type': None, 'has_rewards': False}
)

# Run
synthesizer.fit(data)

# Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer.sample_from_conditions(conditions=[my_condition])


def test_sample_remaining_columns_with_all_nans():
"""Test it crashes when every condition row has a nan (GH#1758)."""
# Setup
data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
known_columns = pd.DataFrame(data={
'has_rewards': [np.nan, False, True],
'amenities_fee': [5.00, np.nan, None]
})

# Run
synthesizer.fit(data)

# Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer.sample_remaining_columns(known_columns=known_columns)


def test_sample_remaining_columns_with_some_nans():
"""Test it warns when some of the condition rows contain nans (GH#1758)."""
# Setup
data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
known_columns = pd.DataFrame(data={
'has_rewards': [True, False, np.nan],
'amenities_fee': [5.00, np.nan, None]
})

# Run
synthesizer.fit(data)

# Assert
warn_msg = (
'Missing values are not yet supported. '
'Rows with any missing values will not be created.'
)
with pytest.warns(UserWarning, match=warn_msg):
synthesizer.sample_remaining_columns(known_columns=known_columns)


def test_multiple_fits():
"""Test the synthesizer refits correctly on new data.

Expand Down
58 changes: 52 additions & 6 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import date, datetime
from unittest.mock import ANY, MagicMock, Mock, call, mock_open, patch

import numpy as np
import pandas as pd
import pytest
from copulas.multivariate import GaussianMultivariate
Expand Down Expand Up @@ -1279,7 +1280,7 @@ def test_sample(self):
)
assert result == instance._sample_with_progress_bar.return_value

def test__validate_conditions(self):
def test__validate_conditions_unseen_columns(self):
"""Test that conditions are within the ``data_processor`` fields."""
# Setup
instance = Mock()
Expand All @@ -1290,12 +1291,12 @@ def test__validate_conditions(self):
conditions = pd.DataFrame({'name': ['Johanna'], 'surname': ['Doe']})

# Run
BaseSingleTableSynthesizer._validate_conditions(instance, conditions)
BaseSingleTableSynthesizer._validate_conditions_unseen_columns(instance, conditions)

# Assert
instance._data_processor.get_sdtypes.assert_called()

def test__validate_conditions_raises_error(self):
def test__validate_conditions_unseen_columns_raises_error(self):
"""Test that conditions are not in the ``data_processor`` fields."""
# Setup
instance = Mock()
Expand All @@ -1311,7 +1312,22 @@ def test__validate_conditions_raises_error(self):
'original data.'
)
with pytest.raises(ValueError, match=error_msg):
BaseSingleTableSynthesizer._validate_conditions(instance, conditions)
BaseSingleTableSynthesizer._validate_conditions_unseen_columns(instance, conditions)

def test__validate_conditions_nans(self):
"""Test that it raises an error when nans are in the data."""
# Setup
conditions = [pd.DataFrame({'names': [np.nan], 'surname': ['Doe']})]
synthesizer = BaseSingleTableSynthesizer(MagicMock())
synthesizer._validate_conditions_unseen_columns = Mock()

# Run and Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer._validate_conditions(conditions)

def test__sample_with_conditions_constraints_not_met(self):
"""Test when conditions are not met."""
Expand Down Expand Up @@ -1523,7 +1539,7 @@ def test_sample_remaining_columns(self, mock_validate_file_path, mock_tqdm,
instance = BaseSingleTableSynthesizer(metadata)
known_columns = pd.DataFrame({'name': ['Johanna Doe']})

instance._validate_conditions = Mock()
instance._validate_known_columns = Mock()
instance._sample_with_conditions = Mock()
instance._model = GaussianMultivariate()
instance._sample_with_conditions.return_value = pd.DataFrame({'name': ['John Doe']})
Expand Down Expand Up @@ -1563,7 +1579,7 @@ def test_sample_remaining_columns_handles_sampling_error(
instance = BaseSingleTableSynthesizer(metadata)
known_columns = pd.DataFrame({'name': ['Johanna Doe']})

instance._validate_conditions = Mock()
instance._validate_known_columns = Mock()
instance._sample_with_conditions = Mock()
instance._model = GaussianMultivariate()
keyboard_error = KeyboardInterrupt()
Expand All @@ -1585,6 +1601,36 @@ def test_sample_remaining_columns_handles_sampling_error(
pd.testing.assert_frame_equal(result, pd.DataFrame())
mock_handle_sampling_error.assert_called_once_with(False, 'temp_file', keyboard_error)

def test__validate_known_columns_nans(self):
"""Test that it crashes when condition has nans."""
# Setup
conditions = pd.DataFrame({'names': [np.nan], 'surname': ['Doe']})
synthesizer = BaseSingleTableSynthesizer(MagicMock())
synthesizer._validate_conditions_unseen_columns = Mock()

# Run and Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer._validate_known_columns(conditions)

def test__validate_known_columns_a_few_nans(self):
"""Test that it warns when condition has a few nans, but at least a valid row."""
# Setup
conditions = pd.DataFrame({'names': [np.nan, 'Dae'], 'surname': ['Doe', 'Due']})
synthesizer = BaseSingleTableSynthesizer(MagicMock())
synthesizer._validate_conditions_unseen_columns = Mock()

# Run and Assert
warn_msg = (
'Missing values are not yet supported. '
'Rows with any missing values will not be created.'
)
with pytest.warns(UserWarning, match=warn_msg):
synthesizer._validate_known_columns(conditions)

@patch('sdv.single_table.base.cloudpickle')
def test_save(self, cloudpickle_mock, tmp_path):
"""Test that the synthesizer is saved correctly."""
Expand Down
Loading