Skip to content

Commit

Permalink
move to mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Nov 20, 2023
1 parent 5eaf1f8 commit 5fdae3b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 38 deletions.
3 changes: 2 additions & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.single_table.utils import GANMixin
from sdv.utils import cast_to_iterable, groupby_list

LOGGER = logging.getLogger(__name__)


class PARSynthesizer(BaseSynthesizer):
class PARSynthesizer(GANMixin, BaseSynthesizer):
"""Synthesizer for sequential data.
This synthesizer uses the ``deepecho.models.par.PARModel`` class as the core model.
Expand Down
39 changes: 3 additions & 36 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import numpy as np
from ctgan import CTGAN, TVAE

from sdv.errors import NotFittedError
from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import detect_discrete_columns
from sdv.single_table.utils import GANMixin, detect_discrete_columns


class CTGANSynthesizer(BaseSingleTableSynthesizer):
class CTGANSynthesizer(GANMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``CTGAN`` model.
Args:
Expand Down Expand Up @@ -189,22 +188,6 @@ def _fit(self, processed_data):
self._model = CTGAN(**self._model_kwargs)
self._model.fit(processed_data, discrete_columns=discrete_columns)

def get_loss_values(self):
"""Get the loss values from the model.
Raises:
- ``NotFittedError`` if synthesizer has not been fitted.
Returns:
pd.DataFrame:
Dataframe containing the loss values per epoch.
"""
if not self._fitted:
err_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
raise NotFittedError(err_msg)

return self._model.loss_values.copy()

def _sample(self, num_rows, conditions=None):
"""Sample the indicated number of rows from the model.
Expand All @@ -226,7 +209,7 @@ def _sample(self, num_rows, conditions=None):
raise NotImplementedError("CTGANSynthesizer doesn't support conditional sampling.")


class TVAESynthesizer(BaseSingleTableSynthesizer):
class TVAESynthesizer(GANMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``TVAE`` model.
Args:
Expand Down Expand Up @@ -308,22 +291,6 @@ def _fit(self, processed_data):
self._model = TVAE(**self._model_kwargs)
self._model.fit(processed_data, discrete_columns=discrete_columns)

def get_loss_values(self):
"""Get the loss values from the model.
Raises:
- ``NotFittedError`` if synthesizer has not been fitted.
Returns:
pd.DataFrame:
Dataframe containing the loss values per epoch.
"""
if not self._fitted:
err_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
raise NotFittedError(err_msg)

return self._model.loss_values.copy()

def _sample(self, num_rows, conditions=None):
"""Sample the indicated number of rows from the model.
Expand Down
22 changes: 21 additions & 1 deletion sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from sdv.errors import SynthesizerInputError
from sdv.errors import NotFittedError, SynthesizerInputError

TMP_FILE_NAME = '.sample.csv.temp'
DISABLE_TMP_FILE = 'disable'
Expand Down Expand Up @@ -349,3 +349,23 @@ def log_numerical_distributions_error(numerical_distributions, processed_data_co
f"cannot be applied to column '{column}' because it no longer "
'exists after preprocessing.'
)


class GANMixin:
"""Mixin class for GAN-based synthesizers."""

def get_loss_values(self):
"""Get the loss values from the model.
Raises:
- ``NotFittedError`` if synthesizer has not been fitted.
Returns:
pd.DataFrame:
Dataframe containing the loss values per epoch.
"""
if not self._fitted:
err_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
raise NotFittedError(err_msg)

return self._model.loss_values.copy()

0 comments on commit 5fdae3b

Please sign in to comment.