Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow me to access loss values for GAN-based synthesizers #1681

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,29 @@ def get_learned_distributions(self, table_name):
f"table because it uses the '{synthesizer.__class__.__name__}'."
)

def get_loss_values(self, table_name):
"""Get the loss values from a model for a table.

Return a pandas dataframe mapping of the loss values per epoch of GAN
based synthesizers

Args:
table_name (str):
Table name for which the parameters should be retrieved.

Returns:
pd.DataFrame:
Dataframe of loss values per epoch
"""
synthesizer = self._table_synthesizers[table_name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we error if the table name is bad? I think there is an error message we use in a lot of places for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we usually use _validate_table_name but the full error message doesn't really make sense in this context, so I used an edited version.

if hasattr(synthesizer, 'get_loss_values'):
return synthesizer.get_loss_values()

raise SynthesizerInputError(
f"Loss values are not available for table '{table_name}' "
'because the table does not use a GAN-based model.'
)

def _validate_constraints_to_be_added(self, constraints):
for constraint_dict in constraints:
if 'table_name' not in constraint_dict.keys():
Expand Down
21 changes: 19 additions & 2 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
from deepecho import PARModel
from deepecho.sequences import assemble_sequences

from sdv.errors import SamplingError, SynthesizerInputError
from sdv.errors import NotFittedError, SamplingError, SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.single_table.utils import GANMixin
from sdv.utils import cast_to_iterable, groupby_list

LOGGER = logging.getLogger(__name__)


class PARSynthesizer(BaseSynthesizer):
class PARSynthesizer(GANMixin, BaseSynthesizer):
"""Synthesizer for sequential data.

This synthesizer uses the ``deepecho.models.par.PARModel`` class as the core model.
Expand Down Expand Up @@ -266,6 +267,22 @@ def _fit(self, processed_data):
LOGGER.debug(f'Fitting {self.__class__.__name__} model to table')
self._fit_sequence_columns(processed_data)

def get_loss_values(self):
"""Get the loss values from the model.

Raises:
- ``NotFittedError`` if synthesizer has not been fitted.

Returns:
pd.DataFrame:
Dataframe containing the loss values per epoch.
"""
if not self._fitted:
err_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
raise NotFittedError(err_msg)

return self._model.loss_values.copy()
frances-h marked this conversation as resolved.
Show resolved Hide resolved

def _sample_from_par(self, context, sequence_length=None):
"""Sample new sequences.

Expand Down
6 changes: 3 additions & 3 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from ctgan import CTGAN, TVAE

from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import detect_discrete_columns
from sdv.single_table.utils import GANMixin, detect_discrete_columns


class CTGANSynthesizer(BaseSingleTableSynthesizer):
class CTGANSynthesizer(GANMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``CTGAN`` model.

Args:
Expand Down Expand Up @@ -209,7 +209,7 @@ def _sample(self, num_rows, conditions=None):
raise NotImplementedError("CTGANSynthesizer doesn't support conditional sampling.")


class TVAESynthesizer(BaseSingleTableSynthesizer):
class TVAESynthesizer(GANMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``TVAE`` model.

Args:
Expand Down
22 changes: 21 additions & 1 deletion sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from sdv.errors import SynthesizerInputError
from sdv.errors import NotFittedError, SynthesizerInputError

TMP_FILE_NAME = '.sample.csv.temp'
DISABLE_TMP_FILE = 'disable'
Expand Down Expand Up @@ -349,3 +349,23 @@ def log_numerical_distributions_error(numerical_distributions, processed_data_co
f"cannot be applied to column '{column}' because it no longer "
'exists after preprocessing.'
)


class GANMixin:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you feel about putting this class in the ctgan file since all the gans are already there? Also I think VAEs aren't technically GANs. Maybe we should change the name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to LossValuesMixin instead

"""Mixin class for GAN-based synthesizers."""

def get_loss_values(self):
"""Get the loss values from the model.

Raises:
- ``NotFittedError`` if synthesizer has not been fitted.

Returns:
pd.DataFrame:
Dataframe containing the loss values per epoch.
"""
if not self._fitted:
err_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
raise NotFittedError(err_msg)

return self._model.loss_values.copy()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"pandas>=1.5.0;python_version>='3.11'",
'tqdm>=4.15,<5',
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.4,<0.8',
'ctgan>=0.8,<0.9',
'deepecho>=0.5,<0.6',
'rdt>=1.9.0,<2',
'sdmetrics>=0.12.1,<0.13',
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_par():
assert sampled.shape == data.shape
assert (sampled.dtypes == data.dtypes).all()
assert (sampled.notna().sum(axis=1) != 0).all()
loss_values = model.get_loss_values()
assert len(loss_values) == 1
assert list(loss_values.columns) == ['Epoch', 'Loss']


def test_column_after_date_simple():
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def test_synthesize_table_ctgan(tmp_path):
assert len(synthetic_data) == 500
for column in sensitive_columns:
assert synthetic_data[column].isin(real_data[column]).sum() == 0
loss_values = synthesizer.get_loss_values()
assert list(loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
assert len(loss_values) == 300
custom_loss_values = custom_synthesizer.get_loss_values()
assert list(custom_loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
assert len(custom_loss_values) == 100

# Assert - evaluate
assert quality_report.get_score() > 0
Expand Down
29 changes: 28 additions & 1 deletion tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import pytest

from sdv.errors import InvalidDataError, SynthesizerInputError
from sdv.errors import InvalidDataError, NotFittedError, SynthesizerInputError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.base import BaseMultiTableSynthesizer
Expand Down Expand Up @@ -850,6 +850,33 @@ def test_get_learned_distributions_raises_non_parametric_error(self):
with pytest.raises(SynthesizerInputError, match=msg):
instance.get_learned_distributions('nesreca')

def test_get_loss_values_unfitted_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._table_synthesizers['nesreca'] = CTGANSynthesizer(metadata.tables['nesreca'])

# Run and Assert
error_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=error_msg):
instance.get_loss_values('nesreca')

def test_get_loss_values_unsupported_synthesizer_error(self):
"""Test the ``get_loss_values`` errors if synthesizer not GAN-based."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._table_synthesizers['nesreca']._fitted = True

# Run and Assert
msg = re.escape(
"Loss values are not available for table 'nesreca' "
'because the table does not use a GAN-based model.'
)
with pytest.raises(SynthesizerInputError, match=msg):
instance.get_loss_values('nesreca')

def test_add_constraint_warning(self):
"""Test a warning is raised when the synthesizer had already been fitted."""
# Setup
Expand Down
33 changes: 32 additions & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rdt.transformers import FloatFormatter, UnixTimestampEncoder

from sdv.data_processing.data_processor import DataProcessor
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv.errors import InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sequential.par import PARSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
Expand Down Expand Up @@ -505,6 +505,37 @@ def test__fit_without_sequence_key(self):
par._fit_context_model.assert_not_called()
par._fit_sequence_columns.assert_called_once_with(data)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``PARSynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = PARSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = PARSynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()

@patch('sdv.sequential.par.tqdm')
def test__sample_from_par(self, tqdm_mock):
"""Test that the method properly samples from the underlying ``PAR`` model.
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import numpy as np
import pandas as pd
import pytest

from sdv.errors import NotFittedError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer

Expand Down Expand Up @@ -226,6 +228,37 @@ def test__fit(self, mock_detect_discrete_columns, mock_ctgan):
discrete_columns=mock_detect_discrete_columns.return_value
)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``CTGANSynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = CTGANSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = CTGANSynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()


class TestTVAESynthesizer:

Expand Down Expand Up @@ -359,3 +392,34 @@ def test__fit(self, mock_detect_discrete_columns, mock_tvae):
processed_data,
discrete_columns=mock_detect_discrete_columns.return_value
)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``TVAESynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = TVAESynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = TVAESynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()
Loading