Skip to content

Commit

Permalink
Warn users to save their metadata file after auto-detecting/updating …
Browse files Browse the repository at this point in the history
…it (#1786)
  • Loading branch information
R-Palazzo authored Feb 13, 2024
1 parent 798ac51 commit 946d2c5
Show file tree
Hide file tree
Showing 10 changed files with 464 additions and 5 deletions.
22 changes: 22 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ class MultiTableMetadata:
def __init__(self):
self.tables = {}
self.relationships = []
self._multi_table_updated = False

def _check_updated_flag(self):
is_single_table_updated = any(table._updated for table in self.tables.values())
if is_single_table_updated or self._multi_table_updated:
return True

return False

def _reset_updated_flag(self):
for table in self.tables.values():
table._updated = False

self._multi_table_updated = False

def _validate_missing_relationship_keys(self, parent_table_name, parent_primary_key,
child_table_name, child_foreign_key):
Expand Down Expand Up @@ -264,6 +278,7 @@ def add_relationship(self, parent_table_name, child_table_name,
'parent_primary_key': deepcopy(parent_primary_key),
'child_foreign_key': deepcopy(child_foreign_key),
})
self._multi_table_updated = True

def remove_relationship(self, parent_table_name, child_table_name):
"""Remove the relationship between two tables.
Expand Down Expand Up @@ -291,6 +306,8 @@ def remove_relationship(self, parent_table_name, child_table_name):
for relation in relationships_to_remove:
self.relationships.remove(relation)

self._multi_table_updated = True

def remove_primary_key(self, table_name):
"""Remove the primary key from the given table.
Expand Down Expand Up @@ -320,6 +337,8 @@ def remove_primary_key(self, table_name):
LOGGER.info(info_msg)
self.relationships.remove(relationship)

self._multi_table_updated = True

def _validate_table_exists(self, table_name):
if table_name not in self.tables:
raise InvalidMetadataError(f"Unknown table name ('{table_name}').")
Expand Down Expand Up @@ -799,6 +818,7 @@ def add_table(self, table_name):
)

self.tables[table_name] = SingleTableMetadata()
self._multi_table_updated = True

def visualize(self, show_table_details='full', show_relationship_labels=True,
output_filepath=None):
Expand Down Expand Up @@ -947,6 +967,8 @@ def save_to_json(self, filepath):
with open(filepath, 'w', encoding='utf-8') as metadata_file:
json.dump(metadata, metadata_file, indent=4)

self._reset_updated_flag()

@classmethod
def load_from_json(cls, filepath):
"""Create a ``MultiTableMetadata`` instance from a ``json`` file.
Expand Down
14 changes: 14 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(self):
self.sequence_index = None
self.column_relationships = []
self._version = self.METADATA_SPEC_VERSION
self._updated = False

def _validate_unexpected_kwargs(self, column_name, sdtype, **kwargs):
expected_kwargs = self._SDTYPE_KWARGS.get(sdtype, ['pii'])
Expand Down Expand Up @@ -260,6 +261,7 @@ def add_column(self, column_name, **kwargs):
pii = column_kwargs.get('pii', True)
column_kwargs['pii'] = pii

self._updated = True
self.columns[column_name] = column_kwargs

def _validate_column_exists(self, column_name):
Expand Down Expand Up @@ -297,6 +299,7 @@ def update_column(self, column_name, **kwargs):

self._validate_column_args(column_name, sdtype, **kwargs)
self.columns[column_name] = _kwargs
self._updated = True

def to_dict(self):
"""Return a python ``dict`` representation of the ``SingleTableMetadata``."""
Expand Down Expand Up @@ -464,6 +467,8 @@ def _detect_columns(self, data):
if self.primary_key is None and first_pii_field:
self.primary_key = first_pii_field

self._updated = True

def detect_from_dataframe(self, data):
"""Detect the metadata from a ``pd.DataFrame`` object.
Expand Down Expand Up @@ -561,13 +566,15 @@ def set_primary_key(self, column_name):
' This key will be removed.'
)

self._updated = True
self.primary_key = column_name

def remove_primary_key(self):
"""Remove the metadata primary key."""
if self.primary_key is None:
warnings.warn('No primary key exists to remove.')

self._updated = True
self.primary_key = None

def set_sequence_key(self, column_name):
Expand All @@ -584,6 +591,7 @@ def set_sequence_key(self, column_name):
' This key will be removed.'
)

self._updated = True
self.sequence_key = column_name

def _validate_alternate_keys(self, column_names):
Expand Down Expand Up @@ -626,6 +634,8 @@ def add_alternate_keys(self, column_names):
else:
self.alternate_keys.append(column)

self._updated = True

def _validate_sequence_index(self, column_name):
if not isinstance(column_name, str):
raise InvalidMetadataError("'sequence_index' must be a string.")
Expand All @@ -651,6 +661,7 @@ def set_sequence_index(self, column_name):
"""
self._validate_sequence_index(column_name)
self.sequence_index = column_name
self._updated = True

def _validate_sequence_index_not_in_sequence_key(self):
"""Check that ``_sequence_index`` and ``_sequence_key`` don't overlap."""
Expand Down Expand Up @@ -790,6 +801,7 @@ def add_column_relationship(self, relationship_type, column_names):
self._validate_all_column_relationships(to_check)

self.column_relationships.append(relationship)
self._updated = True

def validate(self):
"""Validate the metadata.
Expand Down Expand Up @@ -1028,6 +1040,8 @@ def save_to_json(self, filepath):
with open(filepath, 'w', encoding='utf-8') as metadata_file:
json.dump(metadata, metadata_file, indent=4)

self._updated = False

@classmethod
def load_from_dict(cls, metadata_dict):
"""Create a ``SingleTableMetadata`` instance from a python ``dict``.
Expand Down
10 changes: 10 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,21 @@ def _print(self, text='', **kwargs):
if self.verbose:
print(text, **kwargs) # noqa: T001

def _check_metadata_updated(self):
if self.metadata._check_updated_flag():
self.metadata._reset_updated_flag()
warnings.warn(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)

def __init__(self, metadata, locales=None, synthesizer_kwargs=None):
self.metadata = metadata
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=r'.*column relationship.*')
self.metadata.validate()

self._check_metadata_updated()
self.locales = locales
self.verbose = False
self.extended_columns = defaultdict(dict)
Expand Down Expand Up @@ -364,6 +373,7 @@ def fit(self, data):
Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format
(before any transformations).
"""
self._check_metadata_updated()
self._fitted = False
processed_data = self.preprocess(data)
self._print(text='\n', end='')
Expand Down
10 changes: 10 additions & 0 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,19 @@ def _update_default_transformers(self):
for sdtype, transformer in self._model_sdtype_transformers.items():
self._data_processor._update_transformers_by_sdtypes(sdtype, transformer)

def _check_metadata_updated(self):
if self.metadata._updated:
self.metadata._updated = False
warnings.warn(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, locales=None):
self._validate_inputs(enforce_min_max_values, enforce_rounding)
self.metadata = metadata
self.metadata.validate()
self._check_metadata_updated()
self.enforce_min_max_values = enforce_min_max_values
self.enforce_rounding = enforce_rounding
self.locales = locales
Expand Down Expand Up @@ -379,6 +388,7 @@ def fit(self, data):
data (pandas.DataFrame):
The raw data (before any transformations) to fit the model to.
"""
self._check_metadata_updated()
self._fitted = False
self._data_processor.reset_sampling()
self._random_state_set = False
Expand Down
156 changes: 156 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import re
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1246,3 +1247,158 @@ def test__extract_parameters(self):
'scale': 0.
}
assert result == expected_result

def test_metadata_updated_no_warning(self, tmp_path):
"""Test scenario where no warning about metadata should be raised.
Run 1 - The medata is load from our demo datasets.
Run 2 - The metadata uses ``detect_from_dataframes`` but is saved to a file
before defining the syntheiszer.
Run 3 - The metadata is updated with a new column after the synthesizer
initialization, but is saved to a file before fitting.
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')

# Run 1
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter('always')
instance = HMASynthesizer(metadata)
instance.fit(data)

# Assert
assert len(captured_warnings) == 0

# Run 2
metadata_detect = MultiTableMetadata()
metadata_detect.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
for table_name, table_metadata in metadata.tables.items():
metadata_detect.tables[table_name].columns = table_metadata.columns
metadata_detect.tables[table_name].primary_key = table_metadata.primary_key

file_name = tmp_path / 'multitable_1.json'
metadata_detect.save_to_json(file_name)
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter('always')
instance = HMASynthesizer(metadata_detect)
instance.fit(data)

# Assert
assert len(captured_warnings) == 0

# Run 3
instance = HMASynthesizer(metadata_detect)
metadata_detect.update_column(
table_name='characters', column_name='age', sdtype='categorical')
file_name = tmp_path / 'multitable_2.json'
metadata_detect.save_to_json(file_name)
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter('always')
instance.fit(data)

# Assert
assert len(captured_warnings) == 0

def test_metadata_updated_warning_detect(self):
"""Test that using ``detect_from_dataframes`` without saving the metadata raise a warning.
The warning is expected to be raised only once during synthesizer initialization. It should
not be raised again when calling ``fit``.
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
metadata_detect = MultiTableMetadata()
metadata_detect.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
for table_name, table_metadata in metadata.tables.items():
metadata_detect.tables[table_name].columns = table_metadata.columns
metadata_detect.tables[table_name].primary_key = table_metadata.primary_key

expected_message = re.escape(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)

# Run
with pytest.warns(UserWarning, match=expected_message) as record:
instance = HMASynthesizer(metadata_detect)
instance.fit(data)

# Assert
assert len(record) == 1


parametrization = [
('update_column', {
'table_name': 'departure', 'column_name': 'city', 'sdtype': 'categorical'
}),
('set_primary_key', {'table_name': 'arrival', 'column_name': 'id_flight'}),
(
'add_column_relationship', {
'table_name': 'departure',
'relationship_type': 'address',
'column_names': ['city', 'country']
}
),
('add_alternate_keys', {'table_name': 'departure', 'column_names': ['city', 'country']}),
('set_sequence_key', {'table_name': 'departure', 'column_name': 'city'}),
('add_column', {
'table_name': 'departure', 'column_name': 'postal_code', 'sdtype': 'postal_code'
}),
]


@pytest.mark.parametrize(('method', 'kwargs'), parametrization)
def test_metadata_updated_warning(method, kwargs):
"""Test that modifying metadata without saving it raise a warning.
The warning should be raised during synthesizer initialization.
"""
metadata = MultiTableMetadata().load_from_dict({
'tables': {
'departure': {
'primary_key': 'id',
'columns': {
'id': {'sdtype': 'id'},
'date': {'sdtype': 'datetime'},
'city': {'sdtype': 'city'},
'country': {'sdtype': 'country'}
},
},
'arrival': {
'foreign_key': 'id',
'columns': {
'id': {'sdtype': 'id'},
'date': {'sdtype': 'datetime'},
'city': {'sdtype': 'city'},
'country': {'sdtype': 'country'},
'id_flight': {'sdtype': 'id'}
},
},
},
'relationships': [
{
'parent_table_name': 'departure',
'parent_primary_key': 'id',
'child_table_name': 'arrival',
'child_foreign_key': 'id'
}
]
})
expected_message = re.escape(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)

# Run
metadata.__getattribute__(method)(**kwargs)
with pytest.warns(UserWarning, match=expected_message):
HMASynthesizer(metadata)

# Assert
assert metadata._multi_table_updated is False
for table_name, table_metadata in metadata.tables.items():
assert table_metadata._updated is False
Loading

0 comments on commit 946d2c5

Please sign in to comment.