Skip to content

Commit

Permalink
Testing conditional sampling on independent sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Aug 1, 2024
1 parent a0e0a76 commit b6b3974
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 40 deletions.
83 changes: 75 additions & 8 deletions sdv/sampling/independent_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Independent Samplers."""

import functools
import logging
import warnings

import tqdm
import pandas as pd
from copulas.multivariate import GaussianMultivariate

from sdv.single_table.utils import check_num_rows, handle_sampling_error

LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -40,7 +47,8 @@ def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent
"""
raise NotImplementedError()

def _sample_table(self, synthesizer, table_name, num_rows, sampled_data):
def _sample_table(self, synthesizer, table_name, num_rows, sampled_data, conditions,
max_tries_per_batch=100, batch_size=None, output_file_path=None):
"""Sample a single table and all its children.
Args:
Expand All @@ -54,9 +62,49 @@ def _sample_table(self, synthesizer, table_name, num_rows, sampled_data):
A dictionary mapping table names to sampled tables (pd.DataFrame).
"""
LOGGER.info(f'Sampling {num_rows} rows from table {table_name}')
if conditions:
num_rows = functools.reduce(
lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0
)
conditions = synthesizer._make_condition_dfs(conditions)
synthesizer._validate_conditions(conditions)
sampled = pd.DataFrame()
try:
with tqdm.tqdm(total=num_rows) as progress_bar:
progress_bar.set_description('Sampling conditions')
for condition_dataframe in conditions:
sampled_for_condition = synthesizer._sample_with_conditions(
condition_dataframe,
max_tries_per_batch,
batch_size,
progress_bar,
output_file_path,
keep_extra_columns=True
)
sampled_with_conditions = pd.concat(
[sampled, sampled_for_condition],
ignore_index=True
)

sampled_rows = synthesizer._sample_batch(num_rows, keep_extra_columns=True)
sampled_data[table_name] = sampled_rows
is_reject_sampling = bool(
hasattr(synthesizer, '_model') and\
not isinstance(synthesizer._model, GaussianMultivariate)
)
check_num_rows(
num_rows=len(sampled_with_conditions),
expected_num_rows=num_rows,
is_reject_sampling=is_reject_sampling,
max_tries_per_batch=max_tries_per_batch,
)

except (Exception, KeyboardInterrupt) as error:
handle_sampling_error(output_file_path, error)

sampled_data[table_name] = sampled_with_conditions

else:
sampled_rows = synthesizer._sample_batch(num_rows, keep_extra_columns=True)
sampled_data[table_name] = sampled_rows

def _connect_tables(self, sampled_data):
"""Connect all related tables.
Expand Down Expand Up @@ -123,7 +171,8 @@ def _finalize(self, sampled_data):

return final_data

def _sample(self, scale=1.0):
def _sample(self, scale=1.0, conditions=None, max_tries_per_batch=100,
batch_size=None, output_file_path=None):
"""Sample the entire dataset.
Returns a dictionary with all the tables of the dataset. The amount of rows sampled will
Expand All @@ -146,17 +195,23 @@ def _sample(self, scale=1.0):
"""
sampled_data = {}
send_min_sample_warning = False
for table in self.metadata.tables:
num_rows = int(self._table_sizes[table] * scale)
conditions = conditions if conditions else {}
for table_name in self.metadata.tables:
num_rows = int(self._table_sizes[table_name] * scale)
if num_rows <= 0:
send_min_sample_warning = True
num_rows = 1
synthesizer = self._table_synthesizers[table]

synthesizer = self._table_synthesizers[table_name]
self._sample_table(
synthesizer=synthesizer,
table_name=table,
table_name=table_name,
num_rows=num_rows,
sampled_data=sampled_data,
conditions=conditions.get(table_name),
max_tries_per_batch=max_tries_per_batch,
batch_size=batch_size,
output_file_path=output_file_path
)

if send_min_sample_warning:
Expand All @@ -168,3 +223,15 @@ def _sample(self, scale=1.0):

self._connect_tables(sampled_data)
return self._finalize(sampled_data)

def sample_from_conditions(
self, conditions, max_tries_per_batch=100, batch_size=None, output_file_path=None
):
"""Sample rows with the given conditions."""
return self._sample(
conditions=conditions,
max_tries_per_batch=100,
batch_size=None,
output_file_path=None,
)

43 changes: 11 additions & 32 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ def _sample_rows(
"""
if self._model and not self._random_state_set:
self._set_random_state(FIXED_RNG_SEED)

need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns
if self._model and need_sample:
if conditions is None:
Expand All @@ -653,6 +654,7 @@ def _sample_rows(
raw_sampled = self._sample(num_rows, transformed_conditions)
except NotImplementedError:
raw_sampled = self._sample(num_rows)

sampled = self._data_processor.reverse_transform(raw_sampled)
if keep_extra_columns:
input_columns = self._data_processor._hyper_transformer._input_columns
Expand All @@ -663,6 +665,7 @@ def _sample_rows(

if previous_rows is not None:
sampled = pd.concat([previous_rows, sampled], ignore_index=True)

sampled = self._data_processor.filter_valid(sampled)

if conditions is not None:
Expand Down Expand Up @@ -817,6 +820,7 @@ def _sample_in_batches(
float_rtol=0.01,
progress_bar=None,
output_file_path=None,
keep_extra_columns=False,
):
sampled = []
batch_size = batch_size if num_rows > batch_size else num_rows
Expand All @@ -829,6 +833,7 @@ def _sample_in_batches(
float_rtol=float_rtol,
progress_bar=progress_bar,
output_file_path=output_file_path,
keep_extra_columns=keep_extra_columns,
)
sampled.append(sampled_rows)

Expand All @@ -846,6 +851,7 @@ def _conditionally_sample_rows(
graceful_reject_sampling=True,
progress_bar=None,
output_file_path=None,
keep_extra_columns=False
):
batch_size = batch_size or len(dataframe)
sampled_rows = self._sample_in_batches(
Expand All @@ -857,6 +863,7 @@ def _conditionally_sample_rows(
float_rtol=float_rtol,
progress_bar=progress_bar,
output_file_path=output_file_path,
keep_extra_columns=keep_extra_columns
)

if len(sampled_rows) > 0:
Expand Down Expand Up @@ -969,9 +976,8 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file

return sampled_data

def _sample_with_conditions(
self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None
):
def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
progress_bar=None, output_file_path=None, keep_extra_columns=False):
"""Sample rows with conditions.
Args:
Expand Down Expand Up @@ -1038,6 +1044,7 @@ def _sample_with_conditions(
batch_size=batch_size,
progress_bar=progress_bar,
output_file_path=output_file_path,
keep_extra_columns=keep_extra_columns,
)
all_sampled_rows.append(sampled_rows)
else:
Expand All @@ -1057,6 +1064,7 @@ def _sample_with_conditions(
batch_size=batch_size,
progress_bar=progress_bar,
output_file_path=output_file_path,
keep_extra_columns=keep_extra_columns
)
all_sampled_rows.append(sampled_rows)

Expand Down Expand Up @@ -1096,41 +1104,12 @@ def _validate_conditions(self, conditions):
def sample_from_conditions(
self, conditions, max_tries_per_batch=100, batch_size=None, output_file_path=None
):
"""Sample rows from this table with the given conditions.
Args:
conditions (list[sdv.sampling.Condition]):
A list of sdv.sampling.Condition objects, which specify the column
values in a condition, along with the number of rows for that
condition.
max_tries_per_batch (int):
Number of times to retry sampling until the batch size is met. Defaults to 100.
batch_size (int):
The batch size to use per sampling call.
output_file_path (str or None):
The file to periodically write sampled rows to. Defaults to None.
Returns:
pandas.DataFrame:
Sampled data.
Raises:
ConstraintsNotMetError:
If the conditions are not valid for the given constraints.
ValueError:
If any of the following happens:
* any of the conditions' columns are not valid.
* no rows could be generated.
"""
output_file_path = validate_file_path(output_file_path)

num_rows = functools.reduce(
lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0
)

conditions = self._make_condition_dfs(conditions)
self._validate_conditions(conditions)

sampled = pd.DataFrame()
try:
with tqdm.tqdm(total=num_rows) as progress_bar:
Expand Down

0 comments on commit b6b3974

Please sign in to comment.