Skip to content

Commit

Permalink
Merge branch 'feature/metadata' into issue_2131_metadata_for_demos
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Aug 12, 2024
2 parents 81ca2b2 + 94d7ff1 commit bef38a1
Show file tree
Hide file tree
Showing 19 changed files with 585 additions and 256 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
'copulas>=0.11.0',
'ctgan>=0.10.0',
'deepecho>=0.6.0',
'rdt>=1.12.0',
'rdt @ git+https://github.com/sdv-dev/RDT@main',
'sdmetrics>=0.14.0',
'platformdirs>=4.0',
'pyyaml>=6.0.1',
Expand Down Expand Up @@ -75,7 +75,6 @@ dev = [

# docs
'docutils>=0.12,<1',
'm2r2>=0.2.5,<1',
'nbsphinx>=0.5.0,<1',
'sphinx_toolbox>=2.5,<4',
'Sphinx>=3,<8',
Expand Down
36 changes: 36 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@

import pandas as pd
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
from rdt.transformers.utils import _GENERATORS

from sdv import version
from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError

try:
from re import _parser as sre_parse
except ImportError:
import sre_parse


def _cast_to_iterable(value):
"""Return a ``list`` if the input object is not a ``list`` or ``tuple``."""
Expand Down Expand Up @@ -403,3 +409,33 @@ def generate_synthesizer_id(synthesizer):
synth_version = version.public
unique_id = ''.join(str(uuid.uuid4()).split('-'))
return f'{class_name}_{synth_version}_{unique_id}'


def _get_chars_for_option(option, params):
if option not in _GENERATORS:
raise ValueError(f'REGEX operation: {option} is not supported by SDV.')

if option == sre_parse.MAX_REPEAT:
new_option, new_params = params[2][0] # The value at the second index is the nested option
return _get_chars_for_option(new_option, new_params)

return list(_GENERATORS[option](params, 1)[0])


def get_possible_chars(regex, num_subpatterns=None):
"""Get the list of possible characters a regex can create.
Args:
regex (str):
The regex to parse.
num_subpatterns (int):
The number of sub-patterns from the regex to find characters for.
"""
parsed = sre_parse.parse(regex)
parsed = [p for p in parsed if p[0] != sre_parse.AT]
num_subpatterns = num_subpatterns or len(parsed)
possible_chars = []
for option, params in parsed[:num_subpatterns]:
possible_chars += _get_chars_for_option(option, params)

return possible_chars
41 changes: 5 additions & 36 deletions sdv/data_processing/numerical_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
import sys

import numpy as np
import pandas as pd
from rdt.transformers.utils import learn_rounding_digits

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,34 +51,6 @@ def __init__(
self.enforce_min_max_values = enforce_min_max_values
self.computer_representation = computer_representation

@staticmethod
def _learn_rounding_digits(data):
"""Check if data has any decimals."""
name = data.name
data = np.array(data)
roundable_data = data[~(np.isinf(data) | pd.isna(data))]

# Doesn't contain numbers
if len(roundable_data) == 0:
return None

# Doesn't contain decimal digits
if ((roundable_data % 1) == 0).all():
return 0

# Try to round to fewer digits
if (roundable_data == roundable_data.round(MAX_DECIMALS)).all():
for decimal in range(MAX_DECIMALS + 1):
if (roundable_data == roundable_data.round(decimal)).all():
return decimal

# Can't round, not equal after MAX_DECIMALS digits of precision
LOGGER.info(
f"No rounding scheme detected for column '{name}'."
' Synthetic data will not be rounded.'
)
return None

def learn_format(self, column):
"""Learn the format of a column.
Expand All @@ -92,7 +64,7 @@ def learn_format(self, column):
self._max_value = column.max()

if self.enforce_rounding:
self._rounding_digits = self._learn_rounding_digits(column)
self._rounding_digits = learn_rounding_digits(column)

def format_data(self, column):
"""Format a column according to the learned format.
Expand All @@ -105,20 +77,17 @@ def format_data(self, column):
numpy.ndarray:
containing the formatted data.
"""
column = column.copy().to_numpy()
column = column.copy()
if self.enforce_min_max_values:
column = column.clip(self._min_value, self._max_value)
elif self.computer_representation != 'Float':
elif not self.computer_representation.startswith('Float'):
min_bound, max_bound = INTEGER_BOUNDS[self.computer_representation]
column = column.clip(min_bound, max_bound)

is_integer = np.dtype(self._dtype).kind == 'i'
is_integer = pd.api.types.is_integer_dtype(self._dtype)
if self.enforce_rounding and self._rounding_digits is not None:
column = column.round(self._rounding_digits)
elif is_integer:
column = column.round(0)

if pd.isna(column).any() and is_integer:
return column

return column.astype(self._dtype)
8 changes: 4 additions & 4 deletions sdv/evaluation/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name
1D marginal distribution plot (i.e. a histogram) of the columns.
"""
metadata = metadata.tables[table_name]
real_data = real_data[table_name]
synthetic_data = synthetic_data[table_name]
real_data = real_data[table_name] if real_data else None
synthetic_data = synthetic_data[table_name] if synthetic_data else None
return single_table_visualization.get_column_plot(
real_data,
synthetic_data,
Expand Down Expand Up @@ -118,8 +118,8 @@ def get_column_pair_plot(
2D bivariate distribution plot (i.e. a scatterplot) of the columns.
"""
metadata = metadata.tables[table_name]
real_data = real_data[table_name]
synthetic_data = synthetic_data[table_name]
real_data = real_data[table_name] if real_data else None
synthetic_data = synthetic_data[table_name] if synthetic_data else None
return single_table_visualization.get_column_pair_plot(
real_data, synthetic_data, metadata, column_names, sample_size, plot_type
)
Expand Down
2 changes: 1 addition & 1 deletion sdv/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, filename=None):

def format(self, record): # noqa: A003
"""Format the record and write to CSV."""
row = record.msg
row = record.msg.copy()
row['LEVEL'] = record.levelname
self.writer.writerow(row)
data = self.output.getvalue()
Expand Down
2 changes: 2 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class SingleTableMetadata:
}

_NUMERICAL_REPRESENTATIONS = frozenset([
'Float32',
'Float64',
'Float',
'Int64',
'Int32',
Expand Down
15 changes: 12 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand Down Expand Up @@ -372,9 +373,17 @@ def preprocess(self, data):
processed_data = {}
pbar_args = self._get_pbar_args(desc='Preprocess Tables')
for table_name, table_data in tqdm(data.items(), **pbar_args):
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
try:
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
except SynthesizerInputError as e:
if INT_REGEX_ZERO_ERROR_MESSAGE in str(e):
raise SynthesizerInputError(
f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}'
)

raise e

for table in list_of_changed_tables:
data[table].columns = self._original_table_columns[table]
Expand Down
17 changes: 17 additions & 0 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
check_sdv_versions_and_warn,
check_synthesizer_version,
generate_synthesizer_id,
get_possible_chars,
)
from sdv.constraints.errors import AggregateConstraintsError
from sdv.data_processing.data_processor import DataProcessor
Expand All @@ -43,6 +44,10 @@

COND_IDX = str(uuid.uuid4())
FIXED_RNG_SEED = 73251
INT_REGEX_ZERO_ERROR_MESSAGE = (
'is stored as an int but the Regex allows it to start with "0". Please remove the Regex '
'or update it to correspond to valid ints.'
)

DEPRECATION_MSG = (
"The 'SingleTableMetadata' is deprecated. Please use the new "
Expand Down Expand Up @@ -177,6 +182,17 @@ def _validate(self, data):
"""
return []

def _validate_primary_key(self, data):
primary_key = self.metadata.primary_key
is_int = primary_key and pd.api.types.is_integer_dtype(data[primary_key])
regex = self.metadata.columns.get(primary_key, {}).get('regex_format')
if is_int and regex:
possible_characters = get_possible_chars(regex, 1)
if '0' in possible_characters:
raise SynthesizerInputError(
f'Primary key "{primary_key}" {INT_REGEX_ZERO_ERROR_MESSAGE}.'
)

def validate(self, data):
"""Validate data.
Expand All @@ -198,6 +214,7 @@ def validate(self, data):
* values of a column don't satisfy their sdtype
"""
self._validate_metadata(data)
self._validate_primary_key(data)
self._validate_constraints(data)

# Retaining the logic of returning errors and raising them here to maintain consistency
Expand Down
6 changes: 5 additions & 1 deletion sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Wrapper around CTGAN model."""

import warnings

import numpy as np
import pandas as pd
import plotly.express as px
Expand Down Expand Up @@ -285,7 +287,9 @@ def _fit(self, processed_data):
transformers = self._data_processor._hyper_transformer.field_transformers
discrete_columns = detect_discrete_columns(self.metadata, processed_data, transformers)
self._model = CTGAN(**self._model_kwargs)
self._model.fit(processed_data, discrete_columns=discrete_columns)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='.*Attempting to run cuBLAS.*')
self._model.fit(processed_data, discrete_columns=discrete_columns)

def _sample(self, num_rows, conditions=None):
"""Sample the indicated number of rows from the model.
Expand Down
Loading

0 comments on commit bef38a1

Please sign in to comment.