Skip to content

Commit

Permalink
Support null foreign keys in get_random_subset (#2082)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored and rwedge committed Jul 12, 2024
1 parent e4bb850 commit c0c14eb
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 80 deletions.
69 changes: 38 additions & 31 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Utility functions for the MultiTable models."""

import math
import warnings
from collections import defaultdict
from copy import deepcopy

import numpy as np
import pandas as pd

from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv._utils import _get_root_tables
from sdv.errors import InvalidDataError, SamplingError
from sdv.multi_table import HMASynthesizer
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS

Expand Down Expand Up @@ -449,22 +448,22 @@ def _drop_rows(data, metadata, drop_missing_values):
])


def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep):
def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep, drop_missing_values):
"""Subsample the disconnected roots tables and their descendants."""
relationships = metadata.relationships
roots = _get_disconnected_roots_from_table(relationships, table)
for root in roots:
data[root] = data[root].sample(frac=ratio_to_keep)

_drop_rows(data, metadata, drop_missing_values=True)
_drop_rows(data, metadata, drop_missing_values)


def _subsample_table_and_descendants(data, metadata, table, num_rows):
def _subsample_table_and_descendants(data, metadata, table, num_rows, drop_missing_values):
"""Subsample the table and its descendants.
The logic is to first subsample all the NaN foreign keys of the table.
We raise an error if we cannot reach referential integrity while keeping the number of rows.
Then, we drop rows of the descendants to ensure referential integrity.
The logic is to first subsample all the NaN foreign keys of the table when ``drop_missing_values``
is True. We raise an error if we cannot reach referential integrity while keeping
the number of rows. Then, we drop rows of the descendants to ensure referential integrity.
Args:
data (dict):
Expand All @@ -474,19 +473,26 @@ def _subsample_table_and_descendants(data, metadata, table, num_rows):
Metadata of the datasets.
table (str):
Name of the table.
num_rows (int):
Number of rows to keep in the table.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to False.
"""
idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table)
num_rows_to_drop = len(data[table]) - num_rows
if len(idx_nan_fk) > num_rows_to_drop:
raise SamplingError(
f"Referential integrity cannot be reached for table '{table}' while keeping "
f'{num_rows} rows. Please try again with a bigger number of rows.'
)
else:
data[table] = data[table].drop(idx_nan_fk)
if drop_missing_values:
idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table)
num_rows_to_drop = len(data[table]) - num_rows
if len(idx_nan_fk) > num_rows_to_drop:
raise SamplingError(
f"Referential integrity cannot be reached for table '{table}' while keeping "
f'{num_rows} rows. Please try again with a bigger number of rows.'
)
else:
data[table] = data[table].drop(idx_nan_fk)

data[table] = data[table].sample(num_rows)
_drop_rows(data, metadata, drop_missing_values=True)
_drop_rows(data, metadata, drop_missing_values)


def _get_primary_keys_referenced(data, metadata):
Expand Down Expand Up @@ -593,7 +599,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
_subsample_ancestors(data, metadata, parent, primary_keys_referenced)


def _subsample_data(data, metadata, main_table_name, num_rows):
def _subsample_data(data, metadata, main_table_name, num_rows, drop_missing_values=False):
"""Subsample multi-table table based on a table and a number of rows.
The strategy is to:
Expand All @@ -613,6 +619,10 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
Name of the main table.
num_rows (int):
Number of rows to keep in the main table.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to False.
Returns:
dict:
Expand All @@ -621,20 +631,17 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
result = deepcopy(data)
primary_keys_referenced = _get_primary_keys_referenced(result, metadata)
ratio_to_keep = num_rows / len(result[main_table_name])
try:
_validate_foreign_keys_not_null(metadata, result)
except SynthesizerInputError:
warnings.warn(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils'
' to drop these rows before using ``get_random_subset``.'
)

try:
_subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep)
_subsample_table_and_descendants(result, metadata, main_table_name, num_rows)
_subsample_disconnected_roots(
result, metadata, main_table_name, ratio_to_keep, drop_missing_values
)
_subsample_table_and_descendants(
result, metadata, main_table_name, num_rows, drop_missing_values
)
_subsample_ancestors(result, metadata, main_table_name, primary_keys_referenced)
_drop_rows(result, metadata, drop_missing_values=True) # Drop remaining NaN foreign keys
_drop_rows(result, metadata, drop_missing_values)

except InvalidDataError as error:
if 'All references in table' not in str(error.args[0]):
raise error
Expand Down
21 changes: 9 additions & 12 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def data():
)

child = pd.DataFrame(
data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maye', 'No', 'No']}
data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maybe', 'No', 'No']}
)

return {'parent': parent, 'child': child}
Expand Down Expand Up @@ -229,20 +229,17 @@ def test_get_random_subset_disconnected_schema():


def test_get_random_subset_with_missing_values(metadata, data):
"""Test ``get_random_subset`` when there is missing values in the foreign keys."""
"""Test ``get_random_subset`` when there is missing values in the foreign keys.
Here there should be at least one missing values in the random subset.
"""
# Setup
data = deepcopy(data)
data['child'].loc[4, 'parent_id'] = np.nan
expected_warning = re.escape(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils'
' to drop these rows before using ``get_random_subset``.'
)
data['child'].loc[[2, 3, 4], 'parent_id'] = np.nan

# Run
with pytest.warns(UserWarning, match=expected_warning):
cleaned_data = get_random_subset(data, metadata, 'child', 3)
result = get_random_subset(data, metadata, 'child', 3)

# Assert
assert len(cleaned_data['child']) == 3
assert not pd.isna(cleaned_data['child']['parent_id']).any()
assert len(result['child']) == 3
assert result['child']['parent_id'].isnull().sum() > 0
100 changes: 63 additions & 37 deletions tests/unit/multi_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,13 +1302,15 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo
expected_result = deepcopy(data)

# Run
_subsample_disconnected_roots(data, metadata, 'disconnected_root', ratio_to_keep)
_subsample_disconnected_roots(
data, metadata, 'disconnected_root', ratio_to_keep, drop_missing_values=False
)

# Assert
mock_get_disconnected_roots_from_table.assert_called_once_with(
metadata.relationships, 'disconnected_root'
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
mock_drop_rows.assert_called_once_with(data, metadata, False)
for table_name in metadata.tables:
if table_name not in {'grandparent', 'other_root'}:
pd.testing.assert_frame_equal(data[table_name], expected_result[table_name])
Expand All @@ -1317,8 +1319,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo


@patch('sdv.multi_table.utils._drop_rows')
@patch('sdv.multi_table.utils._get_nan_fk_indices_table')
def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, mock_drop_rows):
def test__subsample_table_and_descendants(mock_drop_rows):
"""Test the ``_subsample_table_and_descendants`` method."""
# Setup
data = {
Expand All @@ -1339,40 +1340,17 @@ def test__subsample_table_and_descendants(mock_get_nan_fk_indices_table, mock_dr
'col_8': [6, 7, 8, 9, 10],
}),
}
mock_get_nan_fk_indices_table.return_value = {0}
metadata = Mock()
metadata.relationships = Mock()

# Run
_subsample_table_and_descendants(data, metadata, 'parent', 3)
_subsample_table_and_descendants(data, metadata, 'parent', 3, drop_missing_values=False)

# Assert
mock_get_nan_fk_indices_table.assert_called_once_with(data, metadata.relationships, 'parent')
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
mock_drop_rows.assert_called_once_with(data, metadata, False)
assert len(data['parent']) == 3


@patch('sdv.multi_table.utils._get_nan_fk_indices_table')
def test__subsample_table_and_descendants_nan_fk(mock_get_nan_fk_indices_table):
"""Test the ``_subsample_table_and_descendants`` when there are too many NaN foreign keys."""
# Setup
data = {'parent': [1, 2, 3, 4, 5, 6]}
mock_get_nan_fk_indices_table.return_value = {0, 1, 2, 3, 4}
metadata = Mock()
metadata.relationships = Mock()
expected_message = re.escape(
"Referential integrity cannot be reached for table 'parent' while keeping "
'3 rows. Please try again with a bigger number of rows.'
)

# Run
with pytest.raises(SamplingError, match=expected_message):
_subsample_table_and_descendants(data, metadata, 'parent', 3)

# Assert
mock_get_nan_fk_indices_table.assert_called_once_with(data, metadata.relationships, 'parent')


def test__get_primary_keys_referenced():
"""Test the ``_get_primary_keys_referenced`` method."""
data = {
Expand Down Expand Up @@ -1930,9 +1908,7 @@ def test__subsample_ancestors_schema_diamond_shape():
@patch('sdv.multi_table.utils._subsample_ancestors')
@patch('sdv.multi_table.utils._get_primary_keys_referenced')
@patch('sdv.multi_table.utils._drop_rows')
@patch('sdv.multi_table.utils._validate_foreign_keys_not_null')
def test__subsample_data(
mock_validate_foreign_keys_not_null,
mock_drop_rows,
mock_get_primary_keys_referenced,
mock_subsample_ancestors,
Expand All @@ -1954,24 +1930,74 @@ def test__subsample_data(
result = _subsample_data(data, metadata, main_table, num_rows)

# Assert
mock_validate_foreign_keys_not_null.assert_called_once_with(metadata, data)
mock_drop_rows.assert_called_once_with(data, metadata, False)
mock_get_primary_keys_referenced.assert_called_once_with(data, metadata)
mock_subsample_disconnected_roots.assert_called_once_with(data, metadata, main_table, 0.5)
mock_subsample_disconnected_roots.assert_called_once_with(
data, metadata, main_table, 0.5, False
)
mock_subsample_table_and_descendants.assert_called_once_with(
data, metadata, main_table, num_rows
data, metadata, main_table, num_rows, False
)
mock_subsample_ancestors.assert_called_once_with(
data, metadata, main_table, primary_key_reference
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
assert result == data


def test__subsample_data_with_null_foreing_keys():
"""Test the ``_subsample_data`` method when there are null foreign keys."""
# Setup
metadata = MultiTableMetadata.load_from_dict({
'tables': {
'parent': {
'columns': {
'id': {'sdtype': 'id'},
'A': {'sdtype': 'categorical'},
'B': {'sdtype': 'numerical'},
},
'primary_key': 'id',
},
'child': {'columns': {'parent_id': {'sdtype': 'id'}, 'C': {'sdtype': 'categorical'}}},
},
'relationships': [
{
'parent_table_name': 'parent',
'child_table_name': 'child',
'parent_primary_key': 'id',
'child_foreign_key': 'parent_id',
}
],
})

parent = pd.DataFrame(
data={
'id': [0, 1, 2, 3, 4],
'A': [True, True, False, True, False],
'B': [0.434, 0.312, 0.212, 0.339, 0.491],
}
)

child = pd.DataFrame(
data={'parent_id': [0, 1, 2, 2, 5], 'C': ['Yes', 'No', 'Maybe', 'No', 'No']}
)

data = {'parent': parent, 'child': child}
data['child'].loc[[2, 3, 4], 'parent_id'] = np.nan

# Run
result_with_nan = _subsample_data(data, metadata, 'child', 4, drop_missing_values=False)
result_without_nan = _subsample_data(data, metadata, 'child', 2, drop_missing_values=True)

# Assert
assert len(result_with_nan['child']) == 4
assert result_with_nan['child']['parent_id'].isnull().sum() > 0
assert len(result_without_nan['child']) == 2
assert set(result_without_nan['child'].index) == {0, 1}


@patch('sdv.multi_table.utils._subsample_disconnected_roots')
@patch('sdv.multi_table.utils._get_primary_keys_referenced')
@patch('sdv.multi_table.utils._validate_foreign_keys_not_null')
def test__subsample_data_empty_dataset(
mock_validate_foreign_keys_not_null,
mock_get_primary_keys_referenced,
mock_subsample_disconnected_roots,
):
Expand Down

0 comments on commit c0c14eb

Please sign in to comment.