From b6b39744a7986727a54e92cf87a550e136eec481 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Thu, 1 Aug 2024 14:07:16 +0200 Subject: [PATCH] Testing conditional sampling on independent sampler --- sdv/sampling/independent_sampler.py | 83 ++++++++++++++++++++++++++--- sdv/single_table/base.py | 43 ++++----------- 2 files changed, 86 insertions(+), 40 deletions(-) diff --git a/sdv/sampling/independent_sampler.py b/sdv/sampling/independent_sampler.py index a86b0ddc0..0dcf4eab4 100644 --- a/sdv/sampling/independent_sampler.py +++ b/sdv/sampling/independent_sampler.py @@ -1,8 +1,15 @@ """Independent Samplers.""" +import functools import logging import warnings +import tqdm +import pandas as pd +from copulas.multivariate import GaussianMultivariate + +from sdv.single_table.utils import check_num_rows, handle_sampling_error + LOGGER = logging.getLogger(__name__) @@ -40,7 +47,8 @@ def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent """ raise NotImplementedError() - def _sample_table(self, synthesizer, table_name, num_rows, sampled_data): + def _sample_table(self, synthesizer, table_name, num_rows, sampled_data, conditions, + max_tries_per_batch=100, batch_size=None, output_file_path=None): """Sample a single table and all its children. Args: @@ -54,9 +62,49 @@ def _sample_table(self, synthesizer, table_name, num_rows, sampled_data): A dictionary mapping table names to sampled tables (pd.DataFrame). """ LOGGER.info(f'Sampling {num_rows} rows from table {table_name}') + if conditions: + num_rows = functools.reduce( + lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0 + ) + conditions = synthesizer._make_condition_dfs(conditions) + synthesizer._validate_conditions(conditions) + sampled = pd.DataFrame() + try: + with tqdm.tqdm(total=num_rows) as progress_bar: + progress_bar.set_description('Sampling conditions') + for condition_dataframe in conditions: + sampled_for_condition = synthesizer._sample_with_conditions( + condition_dataframe, + max_tries_per_batch, + batch_size, + progress_bar, + output_file_path, + keep_extra_columns=True + ) + sampled_with_conditions = pd.concat( + [sampled, sampled_for_condition], + ignore_index=True + ) - sampled_rows = synthesizer._sample_batch(num_rows, keep_extra_columns=True) - sampled_data[table_name] = sampled_rows + is_reject_sampling = bool( + hasattr(synthesizer, '_model') and\ + not isinstance(synthesizer._model, GaussianMultivariate) + ) + check_num_rows( + num_rows=len(sampled_with_conditions), + expected_num_rows=num_rows, + is_reject_sampling=is_reject_sampling, + max_tries_per_batch=max_tries_per_batch, + ) + + except (Exception, KeyboardInterrupt) as error: + handle_sampling_error(output_file_path, error) + + sampled_data[table_name] = sampled_with_conditions + + else: + sampled_rows = synthesizer._sample_batch(num_rows, keep_extra_columns=True) + sampled_data[table_name] = sampled_rows def _connect_tables(self, sampled_data): """Connect all related tables. @@ -123,7 +171,8 @@ def _finalize(self, sampled_data): return final_data - def _sample(self, scale=1.0): + def _sample(self, scale=1.0, conditions=None, max_tries_per_batch=100, + batch_size=None, output_file_path=None): """Sample the entire dataset. Returns a dictionary with all the tables of the dataset. The amount of rows sampled will @@ -146,17 +195,23 @@ def _sample(self, scale=1.0): """ sampled_data = {} send_min_sample_warning = False - for table in self.metadata.tables: - num_rows = int(self._table_sizes[table] * scale) + conditions = conditions if conditions else {} + for table_name in self.metadata.tables: + num_rows = int(self._table_sizes[table_name] * scale) if num_rows <= 0: send_min_sample_warning = True num_rows = 1 - synthesizer = self._table_synthesizers[table] + + synthesizer = self._table_synthesizers[table_name] self._sample_table( synthesizer=synthesizer, - table_name=table, + table_name=table_name, num_rows=num_rows, sampled_data=sampled_data, + conditions=conditions.get(table_name), + max_tries_per_batch=max_tries_per_batch, + batch_size=batch_size, + output_file_path=output_file_path ) if send_min_sample_warning: @@ -168,3 +223,15 @@ def _sample(self, scale=1.0): self._connect_tables(sampled_data) return self._finalize(sampled_data) + + def sample_from_conditions( + self, conditions, max_tries_per_batch=100, batch_size=None, output_file_path=None + ): + """Sample rows with the given conditions.""" + return self._sample( + conditions=conditions, + max_tries_per_batch=100, + batch_size=None, + output_file_path=None, + ) + diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 77ab31ce0..e4c2680b4 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -644,6 +644,7 @@ def _sample_rows( """ if self._model and not self._random_state_set: self._set_random_state(FIXED_RNG_SEED) + need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns if self._model and need_sample: if conditions is None: @@ -653,6 +654,7 @@ def _sample_rows( raw_sampled = self._sample(num_rows, transformed_conditions) except NotImplementedError: raw_sampled = self._sample(num_rows) + sampled = self._data_processor.reverse_transform(raw_sampled) if keep_extra_columns: input_columns = self._data_processor._hyper_transformer._input_columns @@ -663,6 +665,7 @@ def _sample_rows( if previous_rows is not None: sampled = pd.concat([previous_rows, sampled], ignore_index=True) + sampled = self._data_processor.filter_valid(sampled) if conditions is not None: @@ -817,6 +820,7 @@ def _sample_in_batches( float_rtol=0.01, progress_bar=None, output_file_path=None, + keep_extra_columns=False, ): sampled = [] batch_size = batch_size if num_rows > batch_size else num_rows @@ -829,6 +833,7 @@ def _sample_in_batches( float_rtol=float_rtol, progress_bar=progress_bar, output_file_path=output_file_path, + keep_extra_columns=keep_extra_columns, ) sampled.append(sampled_rows) @@ -846,6 +851,7 @@ def _conditionally_sample_rows( graceful_reject_sampling=True, progress_bar=None, output_file_path=None, + keep_extra_columns=False ): batch_size = batch_size or len(dataframe) sampled_rows = self._sample_in_batches( @@ -857,6 +863,7 @@ def _conditionally_sample_rows( float_rtol=float_rtol, progress_bar=progress_bar, output_file_path=output_file_path, + keep_extra_columns=keep_extra_columns ) if len(sampled_rows) > 0: @@ -969,9 +976,8 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file return sampled_data - def _sample_with_conditions( - self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None - ): + def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, + progress_bar=None, output_file_path=None, keep_extra_columns=False): """Sample rows with conditions. Args: @@ -1038,6 +1044,7 @@ def _sample_with_conditions( batch_size=batch_size, progress_bar=progress_bar, output_file_path=output_file_path, + keep_extra_columns=keep_extra_columns, ) all_sampled_rows.append(sampled_rows) else: @@ -1057,6 +1064,7 @@ def _sample_with_conditions( batch_size=batch_size, progress_bar=progress_bar, output_file_path=output_file_path, + keep_extra_columns=keep_extra_columns ) all_sampled_rows.append(sampled_rows) @@ -1096,41 +1104,12 @@ def _validate_conditions(self, conditions): def sample_from_conditions( self, conditions, max_tries_per_batch=100, batch_size=None, output_file_path=None ): - """Sample rows from this table with the given conditions. - - Args: - conditions (list[sdv.sampling.Condition]): - A list of sdv.sampling.Condition objects, which specify the column - values in a condition, along with the number of rows for that - condition. - max_tries_per_batch (int): - Number of times to retry sampling until the batch size is met. Defaults to 100. - batch_size (int): - The batch size to use per sampling call. - output_file_path (str or None): - The file to periodically write sampled rows to. Defaults to None. - - Returns: - pandas.DataFrame: - Sampled data. - - Raises: - ConstraintsNotMetError: - If the conditions are not valid for the given constraints. - ValueError: - If any of the following happens: - * any of the conditions' columns are not valid. - * no rows could be generated. - """ output_file_path = validate_file_path(output_file_path) - num_rows = functools.reduce( lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0 ) - conditions = self._make_condition_dfs(conditions) self._validate_conditions(conditions) - sampled = pd.DataFrame() try: with tqdm.tqdm(total=num_rows) as progress_bar: