Skip to content

Commit

Permalink
Allow me to access loss values for GAN-based synthesizers (#1681)
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h authored Nov 27, 2023
1 parent 8d94c1e commit ae7321f
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 6 deletions.
28 changes: 28 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,34 @@ def get_learned_distributions(self, table_name):
f"table because it uses the '{synthesizer.__class__.__name__}'."
)

def get_loss_values(self, table_name):
"""Get the loss values from a model for a table.
Return a pandas dataframe mapping of the loss values per epoch of GAN
based synthesizers
Args:
table_name (str):
Table name for which the parameters should be retrieved.
Returns:
pd.DataFrame:
Dataframe of loss values per epoch
"""
if table_name not in self._table_synthesizers:
raise ValueError(
f"Table '{table_name}' is not present in the metadata."
)

synthesizer = self._table_synthesizers[table_name]
if hasattr(synthesizer, 'get_loss_values'):
return synthesizer.get_loss_values()

raise SynthesizerInputError(
f"Loss values are not available for table '{table_name}' "
'because the table does not use a GAN-based model.'
)

def _validate_constraints_to_be_added(self, constraints):
for constraint_dict in constraints:
if 'table_name' not in constraint_dict.keys():
Expand Down
3 changes: 2 additions & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.single_table.ctgan import LossValuesMixin
from sdv.utils import cast_to_iterable, groupby_list

LOGGER = logging.getLogger(__name__)


class PARSynthesizer(BaseSynthesizer):
class PARSynthesizer(LossValuesMixin, BaseSynthesizer):
"""Synthesizer for sequential data.
This synthesizer uses the ``deepecho.models.par.PARModel`` class as the core model.
Expand Down
25 changes: 23 additions & 2 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,32 @@
import numpy as np
from ctgan import CTGAN, TVAE

from sdv.errors import NotFittedError
from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import detect_discrete_columns


class CTGANSynthesizer(BaseSingleTableSynthesizer):
class LossValuesMixin:
"""Mixin for accessing loss values from synthesizers."""

def get_loss_values(self):
"""Get the loss values from the model.
Raises:
- ``NotFittedError`` if synthesizer has not been fitted.
Returns:
pd.DataFrame:
Dataframe containing the loss values per epoch.
"""
if not self._fitted:
err_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
raise NotFittedError(err_msg)

return self._model.loss_values.copy()


class CTGANSynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``CTGAN`` model.
Args:
Expand Down Expand Up @@ -209,7 +230,7 @@ def _sample(self, num_rows, conditions=None):
raise NotImplementedError("CTGANSynthesizer doesn't support conditional sampling.")


class TVAESynthesizer(BaseSingleTableSynthesizer):
class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):
"""Model wrapping ``TVAE`` model.
Args:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"pandas>=1.5.0;python_version>='3.11'",
'tqdm>=4.15,<5',
'copulas>=0.9.0,<0.10',
'ctgan>=0.7.4,<0.8',
'ctgan>=0.8,<0.9',
'deepecho>=0.5,<0.6',
'rdt>=1.9.0,<2',
'sdmetrics>=0.12.1,<0.13',
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_par():
assert sampled.shape == data.shape
assert (sampled.dtypes == data.dtypes).all()
assert (sampled.notna().sum(axis=1) != 0).all()
loss_values = model.get_loss_values()
assert len(loss_values) == 1
assert list(loss_values.columns) == ['Epoch', 'Loss']


def test_column_after_date_simple():
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def test_synthesize_table_ctgan(tmp_path):
assert len(synthetic_data) == 500
for column in sensitive_columns:
assert synthetic_data[column].isin(real_data[column]).sum() == 0
loss_values = synthesizer.get_loss_values()
assert list(loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
assert len(loss_values) == 300
custom_loss_values = custom_synthesizer.get_loss_values()
assert list(custom_loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
assert len(custom_loss_values) == 100

# Assert - evaluate
assert quality_report.get_score() > 0
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import pytest

from sdv.errors import InvalidDataError, SynthesizerInputError
from sdv.errors import InvalidDataError, NotFittedError, SynthesizerInputError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.base import BaseMultiTableSynthesizer
Expand Down Expand Up @@ -938,6 +938,44 @@ def test_get_learned_distributions_raises_non_parametric_error(self):
with pytest.raises(SynthesizerInputError, match=msg):
instance.get_learned_distributions('nesreca')

def test_get_loss_values_bad_table_name(self):
"""Test the ``get_loss_values`` errors if bad ``table_name`` provided."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)

# Run and Assert
error_msg = "Table 'bad_table' is not present in the metadata."
with pytest.raises(ValueError, match=error_msg):
instance.get_loss_values('bad_table')

def test_get_loss_values_unfitted_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._table_synthesizers['nesreca'] = CTGANSynthesizer(metadata.tables['nesreca'])

# Run and Assert
error_msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=error_msg):
instance.get_loss_values('nesreca')

def test_get_loss_values_unsupported_synthesizer_error(self):
"""Test the ``get_loss_values`` errors if synthesizer not GAN-based."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._table_synthesizers['nesreca']._fitted = True

# Run and Assert
msg = re.escape(
"Loss values are not available for table 'nesreca' "
'because the table does not use a GAN-based model.'
)
with pytest.raises(SynthesizerInputError, match=msg):
instance.get_loss_values('nesreca')

def test_add_constraint_warning(self):
"""Test a warning is raised when the synthesizer had already been fitted."""
# Setup
Expand Down
33 changes: 32 additions & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rdt.transformers import FloatFormatter, UnixTimestampEncoder

from sdv.data_processing.data_processor import DataProcessor
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv.errors import InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sequential.par import PARSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
Expand Down Expand Up @@ -505,6 +505,37 @@ def test__fit_without_sequence_key(self):
par._fit_context_model.assert_not_called()
par._fit_sequence_columns.assert_called_once_with(data)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``PARSynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = PARSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = PARSynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()

@patch('sdv.sequential.par.tqdm')
def test__sample_from_par(self, tqdm_mock):
"""Test that the method properly samples from the underlying ``PAR`` model.
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import numpy as np
import pandas as pd
import pytest

from sdv.errors import NotFittedError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer

Expand Down Expand Up @@ -226,6 +228,37 @@ def test__fit(self, mock_detect_discrete_columns, mock_ctgan):
discrete_columns=mock_detect_discrete_columns.return_value
)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``CTGANSynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = CTGANSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = CTGANSynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()


class TestTVAESynthesizer:

Expand Down Expand Up @@ -359,3 +392,34 @@ def test__fit(self, mock_detect_discrete_columns, mock_tvae):
processed_data,
discrete_columns=mock_detect_discrete_columns.return_value
)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``TVAESynthesizer."""
# Setup
mock_model = Mock()
loss_values = pd.DataFrame({
'Epoch': [0, 1, 2],
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
instance = TVAESynthesizer(metadata)
instance._model = mock_model
instance._fitted = True

# Run
actual_loss_values = instance.get_loss_values()

# Assert
pd.testing.assert_frame_equal(actual_loss_values, loss_values)

def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
instance = TVAESynthesizer(metadata)

# Run / Assert
msg = 'Loss values are not available yet. Please fit your synthesizer first.'
with pytest.raises(NotFittedError, match=msg):
instance.get_loss_values()

0 comments on commit ae7321f

Please sign in to comment.