diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index a7a21abf4..3a05f281e 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -15,12 +15,13 @@ from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table import GaussianCopulaSynthesizer from sdv.single_table.base import BaseSynthesizer +from sdv.single_table.utils import GANMixin from sdv.utils import cast_to_iterable, groupby_list LOGGER = logging.getLogger(__name__) -class PARSynthesizer(BaseSynthesizer): +class PARSynthesizer(GANMixin, BaseSynthesizer): """Synthesizer for sequential data. This synthesizer uses the ``deepecho.models.par.PARModel`` class as the core model. diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index 04b188daa..668457c0f 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -2,12 +2,11 @@ import numpy as np from ctgan import CTGAN, TVAE -from sdv.errors import NotFittedError from sdv.single_table.base import BaseSingleTableSynthesizer -from sdv.single_table.utils import detect_discrete_columns +from sdv.single_table.utils import GANMixin, detect_discrete_columns -class CTGANSynthesizer(BaseSingleTableSynthesizer): +class CTGANSynthesizer(GANMixin, BaseSingleTableSynthesizer): """Model wrapping ``CTGAN`` model. Args: @@ -189,22 +188,6 @@ def _fit(self, processed_data): self._model = CTGAN(**self._model_kwargs) self._model.fit(processed_data, discrete_columns=discrete_columns) - def get_loss_values(self): - """Get the loss values from the model. - - Raises: - - ``NotFittedError`` if synthesizer has not been fitted. - - Returns: - pd.DataFrame: - Dataframe containing the loss values per epoch. - """ - if not self._fitted: - err_msg = 'Loss values are not available yet. Please fit your synthesizer first.' - raise NotFittedError(err_msg) - - return self._model.loss_values.copy() - def _sample(self, num_rows, conditions=None): """Sample the indicated number of rows from the model. @@ -226,7 +209,7 @@ def _sample(self, num_rows, conditions=None): raise NotImplementedError("CTGANSynthesizer doesn't support conditional sampling.") -class TVAESynthesizer(BaseSingleTableSynthesizer): +class TVAESynthesizer(GANMixin, BaseSingleTableSynthesizer): """Model wrapping ``TVAE`` model. Args: @@ -308,22 +291,6 @@ def _fit(self, processed_data): self._model = TVAE(**self._model_kwargs) self._model.fit(processed_data, discrete_columns=discrete_columns) - def get_loss_values(self): - """Get the loss values from the model. - - Raises: - - ``NotFittedError`` if synthesizer has not been fitted. - - Returns: - pd.DataFrame: - Dataframe containing the loss values per epoch. - """ - if not self._fitted: - err_msg = 'Loss values are not available yet. Please fit your synthesizer first.' - raise NotFittedError(err_msg) - - return self._model.loss_values.copy() - def _sample(self, num_rows, conditions=None): """Sample the indicated number of rows from the model. diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 11043ca2f..25d4516be 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -5,7 +5,7 @@ import numpy as np -from sdv.errors import SynthesizerInputError +from sdv.errors import NotFittedError, SynthesizerInputError TMP_FILE_NAME = '.sample.csv.temp' DISABLE_TMP_FILE = 'disable' @@ -349,3 +349,23 @@ def log_numerical_distributions_error(numerical_distributions, processed_data_co f"cannot be applied to column '{column}' because it no longer " 'exists after preprocessing.' ) + + +class GANMixin: + """Mixin class for GAN-based synthesizers.""" + + def get_loss_values(self): + """Get the loss values from the model. + + Raises: + - ``NotFittedError`` if synthesizer has not been fitted. + + Returns: + pd.DataFrame: + Dataframe containing the loss values per epoch. + """ + if not self._fitted: + err_msg = 'Loss values are not available yet. Please fit your synthesizer first.' + raise NotFittedError(err_msg) + + return self._model.loss_values.copy()