Skip to content

Commit

Permalink
Remove the need to pass in Metadata to DataProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Aug 8, 2024
1 parent cf5666c commit 69b9c7b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 13 deletions.
3 changes: 0 additions & 3 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from sdv.data_processing.numerical_formatter import NumericalFormatter
from sdv.data_processing.utils import load_module_from_path
from sdv.errors import SynthesizerInputError, log_exc_stacktrace
from sdv.metadata.metadata import Metadata
from sdv.metadata.single_table import SingleTableMetadata

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,8 +114,6 @@ def __init__(
locales=['en_US'],
):
self.metadata = metadata
if isinstance(metadata, Metadata):
self.metadata = metadata._convert_to_single_table()
self._enforce_rounding = enforce_rounding
self._enforce_min_max_values = enforce_min_max_values
self._model_kwargs = model_kwargs or {}
Expand Down
4 changes: 2 additions & 2 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import warnings
from collections import defaultdict
from pathlib import Path
from zipfile import ZipFile
Expand All @@ -15,8 +16,6 @@
from botocore.client import Config
from botocore.exceptions import ClientError

from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.metadata import Metadata

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,6 +113,7 @@ def _get_metadata(output_folder_name, in_memory_directory, dataset_name):
else:
metadata_path = 'metadata_v2.json'
if metadata_path not in in_memory_directory:
warnings.warn(f'Metadata for {dataset_name} is missing updated version v2.')
metadata_path = 'metadata_v1.json'
metadict = json.loads(in_memory_directory[metadata_path])
metadata = metadata.load_from_dict(metadict, dataset_name)
Expand Down
2 changes: 0 additions & 2 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Metadata."""

from pathlib import Path

from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_with_anonymized_columns(self):
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)

# Instance ``DataProcessor``
dp = DataProcessor(metadata)
dp = DataProcessor(metadata._convert_to_single_table())

# Fit
dp.fit(data)
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_with_anonymized_columns_and_primary_key(self):
data['id'] = np.arange(0, size).astype('O')

# Instance ``DataProcessor``
dp = DataProcessor(metadata)
dp = DataProcessor(metadata._convert_to_single_table())

# Fit
dp.fit(data)
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_prepare_for_fitting(self):
data, metadata = download_demo(
modality='single_table', dataset_name='student_placements_pii'
)
dp = DataProcessor(metadata)
dp = DataProcessor(metadata._convert_to_single_table())

# Run
dp.prepare_for_fitting(data)
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_reverse_transform_with_formatters(self):
"""End to end test using formatters."""
# Setup
data, metadata = download_demo(modality='single_table', dataset_name='student_placements')
dp = DataProcessor(metadata)
dp = DataProcessor(metadata._convert_to_single_table())

# Run
dp.fit(data)
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_refit_hypertransformer(self):
"""Test data processor re-fits _hyper_transformer."""
# Setup
data, metadata = download_demo(modality='single_table', dataset_name='student_placements')
dp = DataProcessor(metadata)
dp = DataProcessor(metadata._convert_to_single_table())

# Run
dp.fit(data)
Expand All @@ -348,7 +348,7 @@ def test_localized_anonymized_columns(self):
data, metadata = download_demo('single_table', 'adult')
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)

dp = DataProcessor(metadata, locales=['en_CA', 'fr_CA'])
dp = DataProcessor(metadata._convert_to_single_table(), locales=['en_CA', 'fr_CA'])

# Run
dp.fit(data)
Expand Down

0 comments on commit 69b9c7b

Please sign in to comment.