Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow me to access loss values for GAN-based synthesizers #1681

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,34 @@ def get_learned_distributions(self, table_name):
f"table because it uses the '{synthesizer.__class__.__name__}'."
)

def get_loss_values(self, table_name):
"""Get the loss values from a model for a table.

Return a pandas dataframe mapping of the loss values per epoch of GAN
based synthesizers

Args:
table_name (str):
Table name for which the parameters should be retrieved.

Returns:
pd.DataFrame:
Dataframe of loss values per epoch
"""
if table_name not in self._table_synthesizers:
raise ValueError(
f"Table '{table_name}' is not present in the metadata."
)

synthesizer = self._table_synthesizers[table_name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we error if the table name is bad? I think there is an error message we use in a lot of places for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we usually use _validate_table_name but the full error message doesn't really make sense in this context, so I used an edited version.

if hasattr(synthesizer, 'get_loss_values'):
return synthesizer.get_loss_values()

raise SynthesizerInputError(
f"Loss values are not available for table '{table_name}' "
'because the table does not use a GAN-based model.'
)

def _validate_constraints_to_be_added(self, constraints):
for constraint_dict in constraints:
if 'table_name' not in constraint_dict.keys():
Expand Down
3 changes: 2 additions & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.single_table.ctgan import LossValuesMixin
from sdv.utils import cast_to_iterable, groupby_list

LOGGER = logging.getLogger(__name__)


class PARSynthesizer(BaseSynthesizer):
class PARSynthesizer(LossValuesMixin, BaseSynthesizer):
"""Synthesizer for sequential data.

This synthesizer uses the ``deepecho.models.par.PARModel`` class as the core model.
Expand Down
25 changes: 23 additions & 2 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,32 @@
import numpy as np
from ctgan import CTGAN, TVAE

from sdv.errors import NotFittedError
from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import detect_discrete_columns


class CTGANSynthesizer(BaseSingleTableSynthesizer):
class LossValuesMixin:
"""Mixin for accessing loss values from synthesizers."""

def get_loss_values(self):
"""Get the loss values from the model.

Raises:
- ``NotFittedError`` if synthesizer has not been fitted.

Returns:
pd.DataFrame:
Dataframe containing the loss values per epoch.
"""
if not self._fitted:
err_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
raise NotFittedError(err_msg)

return self._model.loss_values.copy()


class CTGANSynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``CTGAN`` model.

Args:
Expand Down Expand Up @@ -209,7 +230,7 @@ def _sample(self, num_rows, conditions=None):
raise NotImplementedError("CTGANSynthesizer doesn't support conditional sampling.")


class TVAESynthesizer(BaseSingleTableSynthesizer):
class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``TVAE`` model.

Args:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"pandas>=1.5.0;python_version>='3.11'",
'tqdm>=4.15,<5',
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.4,<0.8',
'ctgan>=0.8,<0.9',
'deepecho>=0.5,<0.6',
'rdt>=1.9.0,<2',
'sdmetrics>=0.12.1,<0.13',
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_par():
assert sampled.shape == data.shape
assert (sampled.dtypes == data.dtypes).all()
assert (sampled.notna().sum(axis=1) != 0).all()
loss_values = model.get_loss_values()
assert len(loss_values) == 1
assert list(loss_values.columns) == ['Epoch', 'Loss']


def test_column_after_date_simple():
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def test_synthesize_table_ctgan(tmp_path):
assert len(synthetic_data) == 500
for column in sensitive_columns:
assert synthetic_data[column].isin(real_data[column]).sum() == 0
loss_values = synthesizer.get_loss_values()
assert list(loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
assert len(loss_values) == 300
custom_loss_values = custom_synthesizer.get_loss_values()
assert list(custom_loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
assert len(custom_loss_values) == 100

# Assert - evaluate
assert quality_report.get_score() > 0
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import pytest

from sdv.errors import InvalidDataError, SynthesizerInputError
from sdv.errors import InvalidDataError, NotFittedError, SynthesizerInputError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.base import BaseMultiTableSynthesizer
Expand Down Expand Up @@ -850,6 +850,44 @@ def test_get_learned_distributions_raises_non_parametric_error(self):
with pytest.raises(SynthesizerInputError, match=msg):
instance.get_learned_distributions('nesreca')

def test_get_loss_values_bad_table_name(self):
"""Test the ``get_loss_values`` errors if bad ``table_name`` provided."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)

# Run and Assert
error_msg = "Table 'bad_table' is not present in the metadata."
with pytest.raises(ValueError, match=error_msg):
instance.get_loss_values('bad_table')

def test_get_loss_values_unfitted_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._table_synthesizers['nesreca'] = CTGANSynthesizer(metadata.tables['nesreca'])

# Run and Assert
error_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=error_msg):
instance.get_loss_values('nesreca')

def test_get_loss_values_unsupported_synthesizer_error(self):
"""Test the ``get_loss_values`` errors if synthesizer not GAN-based."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._table_synthesizers['nesreca']._fitted = True

# Run and Assert
msg = re.escape(
"Loss values are not available for table 'nesreca' "
'because the table does not use a GAN-based model.'
)
with pytest.raises(SynthesizerInputError, match=msg):
instance.get_loss_values('nesreca')

def test_add_constraint_warning(self):
"""Test a warning is raised when the synthesizer had already been fitted."""
# Setup
Expand Down
33 changes: 32 additions & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rdt.transformers import FloatFormatter, UnixTimestampEncoder

from sdv.data_processing.data_processor import DataProcessor
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv.errors import InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sequential.par import PARSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
Expand Down Expand Up @@ -505,6 +505,37 @@ def test__fit_without_sequence_key(self):
par._fit_context_model.assert_not_called()
par._fit_sequence_columns.assert_called_once_with(data)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``PARSynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = PARSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = PARSynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()

@patch('sdv.sequential.par.tqdm')
def test__sample_from_par(self, tqdm_mock):
"""Test that the method properly samples from the underlying ``PAR`` model.
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import numpy as np
import pandas as pd
import pytest

from sdv.errors import NotFittedError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer

Expand Down Expand Up @@ -226,6 +228,37 @@ def test__fit(self, mock_detect_discrete_columns, mock_ctgan):
discrete_columns=mock_detect_discrete_columns.return_value
)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``CTGANSynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = CTGANSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = CTGANSynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()


class TestTVAESynthesizer:

Expand Down Expand Up @@ -359,3 +392,34 @@ def test__fit(self, mock_detect_discrete_columns, mock_tvae):
processed_data,
discrete_columns=mock_detect_discrete_columns.return_value
)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``TVAESynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = TVAESynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = TVAESynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()
Loading