diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 1e76d3f7e..1c0bbfe52 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -418,6 +418,34 @@ def get_learned_distributions(self, table_name): f"table because it uses the '{synthesizer.__class__.__name__}'." ) + def get_loss_values(self, table_name): + """Get the loss values from a model for a table. + + Return a pandas dataframe mapping of the loss values per epoch of GAN + based synthesizers + + Args: + table_name (str): + Table name for which the parameters should be retrieved. + + Returns: + pd.DataFrame: + Dataframe of loss values per epoch + """ + if table_name not in self._table_synthesizers: + raise ValueError( + f"Table '{table_name}' is not present in the metadata." + ) + + synthesizer = self._table_synthesizers[table_name] + if hasattr(synthesizer, 'get_loss_values'): + return synthesizer.get_loss_values() + + raise SynthesizerInputError( + f"Loss values are not available for table '{table_name}' " + 'because the table does not use a GAN-based model.' + ) + def _validate_constraints_to_be_added(self, constraints): for constraint_dict in constraints: if 'table_name' not in constraint_dict.keys(): diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 6233b0616..271a09c10 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -15,12 +15,13 @@ from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table import GaussianCopulaSynthesizer from sdv.single_table.base import BaseSynthesizer +from sdv.single_table.ctgan import LossValuesMixin from sdv.utils import cast_to_iterable, groupby_list LOGGER = logging.getLogger(__name__) -class PARSynthesizer(BaseSynthesizer): +class PARSynthesizer(LossValuesMixin, BaseSynthesizer): """Synthesizer for sequential data. This synthesizer uses the ``deepecho.models.par.PARModel`` class as the core model. diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index b471e82df..9600d707a 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -2,11 +2,32 @@ import numpy as np from ctgan import CTGAN, TVAE +from sdv.errors import NotFittedError from sdv.single_table.base import BaseSingleTableSynthesizer from sdv.single_table.utils import detect_discrete_columns -class CTGANSynthesizer(BaseSingleTableSynthesizer): +class LossValuesMixin: + """Mixin for accessing loss values from synthesizers.""" + + def get_loss_values(self): + """Get the loss values from the model. + + Raises: + - ``NotFittedError`` if synthesizer has not been fitted. + + Returns: + pd.DataFrame: + Dataframe containing the loss values per epoch. + """ + if not self._fitted: + err_msg = 'Loss values are not available yet. Please fit your synthesizer first.' + raise NotFittedError(err_msg) + + return self._model.loss_values.copy() + + +class CTGANSynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): """Model wrapping ``CTGAN`` model. Args: @@ -209,7 +230,7 @@ def _sample(self, num_rows, conditions=None): raise NotImplementedError("CTGANSynthesizer doesn't support conditional sampling.") -class TVAESynthesizer(BaseSingleTableSynthesizer): +class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): """Model wrapping ``TVAE`` model. Args: diff --git a/setup.py b/setup.py index 39b0dd92a..cf2eef61a 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "pandas>=1.5.0;python_version>='3.11'", 'tqdm>=4.15,<5', 'copulas>=0.9.0,<0.10', - 'ctgan>=0.7.4,<0.8', + 'ctgan>=0.8,<0.9', 'deepecho>=0.5,<0.6', 'rdt>=1.9.0,<2', 'sdmetrics>=0.12.1,<0.13', diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 64dd2ca8a..3846565c5 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -49,6 +49,9 @@ def test_par(): assert sampled.shape == data.shape assert (sampled.dtypes == data.dtypes).all() assert (sampled.notna().sum(axis=1) != 0).all() + loss_values = model.get_loss_values() + assert len(loss_values) == 1 + assert list(loss_values.columns) == ['Epoch', 'Loss'] def test_column_after_date_simple(): diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index a8bb1579a..2b5ec01de 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -105,6 +105,12 @@ def test_synthesize_table_ctgan(tmp_path): assert len(synthetic_data) == 500 for column in sensitive_columns: assert synthetic_data[column].isin(real_data[column]).sum() == 0 + loss_values = synthesizer.get_loss_values() + assert list(loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss'] + assert len(loss_values) == 300 + custom_loss_values = custom_synthesizer.get_loss_values() + assert list(custom_loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss'] + assert len(custom_loss_values) == 100 # Assert - evaluate assert quality_report.get_score() > 0 diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index cbcac3572..5c4ba737c 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -7,7 +7,7 @@ import pandas as pd import pytest -from sdv.errors import InvalidDataError, SynthesizerInputError +from sdv.errors import InvalidDataError, NotFittedError, SynthesizerInputError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer @@ -938,6 +938,44 @@ def test_get_learned_distributions_raises_non_parametric_error(self): with pytest.raises(SynthesizerInputError, match=msg): instance.get_learned_distributions('nesreca') + def test_get_loss_values_bad_table_name(self): + """Test the ``get_loss_values`` errors if bad ``table_name`` provided.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + + # Run and Assert + error_msg = "Table 'bad_table' is not present in the metadata." + with pytest.raises(ValueError, match=error_msg): + instance.get_loss_values('bad_table') + + def test_get_loss_values_unfitted_error(self): + """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance._table_synthesizers['nesreca'] = CTGANSynthesizer(metadata.tables['nesreca']) + + # Run and Assert + error_msg = 'Loss values are not available yet. Please fit your synthesizer first.' + with pytest.raises(NotFittedError, match=error_msg): + instance.get_loss_values('nesreca') + + def test_get_loss_values_unsupported_synthesizer_error(self): + """Test the ``get_loss_values`` errors if synthesizer not GAN-based.""" + # Setup + metadata = get_multi_table_metadata() + instance = BaseMultiTableSynthesizer(metadata) + instance._table_synthesizers['nesreca']._fitted = True + + # Run and Assert + msg = re.escape( + "Loss values are not available for table 'nesreca' " + 'because the table does not use a GAN-based model.' + ) + with pytest.raises(SynthesizerInputError, match=msg): + instance.get_loss_values('nesreca') + def test_add_constraint_warning(self): """Test a warning is raised when the synthesizer had already been fitted.""" # Setup diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 739b2c98f..d106b8c3d 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -7,7 +7,7 @@ from rdt.transformers import FloatFormatter, UnixTimestampEncoder from sdv.data_processing.data_processor import DataProcessor -from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError +from sdv.errors import InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError from sdv.metadata.single_table import SingleTableMetadata from sdv.sequential.par import PARSynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -505,6 +505,37 @@ def test__fit_without_sequence_key(self): par._fit_context_model.assert_not_called() par._fit_sequence_columns.assert_called_once_with(data) + def test_get_loss_values(self): + """Test the ``get_loss_values`` method from ``PARSynthesizer.""" + # Setup + mock_model = Mock() + loss_values = pd.DataFrame({ + 'Epoch': [0, 1, 2], + 'Loss': [0.8, 0.6, 0.5] + }) + mock_model.loss_values = loss_values + metadata = SingleTableMetadata() + instance = PARSynthesizer(metadata) + instance._model = mock_model + instance._fitted = True + + # Run + actual_loss_values = instance.get_loss_values() + + # Assert + pd.testing.assert_frame_equal(actual_loss_values, loss_values) + + def test_get_loss_values_error(self): + """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" + # Setup + metadata = SingleTableMetadata() + instance = PARSynthesizer(metadata) + + # Run / Assert + msg = 'Loss values are not available yet. Please fit your synthesizer first.' + with pytest.raises(NotFittedError, match=msg): + instance.get_loss_values() + @patch('sdv.sequential.par.tqdm') def test__sample_from_par(self, tqdm_mock): """Test that the method properly samples from the underlying ``PAR`` model. diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index f7c78bdab..d63e43c58 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -2,7 +2,9 @@ import numpy as np import pandas as pd +import pytest +from sdv.errors import NotFittedError from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer @@ -226,6 +228,37 @@ def test__fit(self, mock_detect_discrete_columns, mock_ctgan): discrete_columns=mock_detect_discrete_columns.return_value ) + def test_get_loss_values(self): + """Test the ``get_loss_values`` method from ``CTGANSynthesizer.""" + # Setup + mock_model = Mock() + loss_values = pd.DataFrame({ + 'Epoch': [0, 1, 2], + 'Loss': [0.8, 0.6, 0.5] + }) + mock_model.loss_values = loss_values + metadata = SingleTableMetadata() + instance = CTGANSynthesizer(metadata) + instance._model = mock_model + instance._fitted = True + + # Run + actual_loss_values = instance.get_loss_values() + + # Assert + pd.testing.assert_frame_equal(actual_loss_values, loss_values) + + def test_get_loss_values_error(self): + """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" + # Setup + metadata = SingleTableMetadata() + instance = CTGANSynthesizer(metadata) + + # Run / Assert + msg = 'Loss values are not available yet. Please fit your synthesizer first.' + with pytest.raises(NotFittedError, match=msg): + instance.get_loss_values() + class TestTVAESynthesizer: @@ -359,3 +392,34 @@ def test__fit(self, mock_detect_discrete_columns, mock_tvae): processed_data, discrete_columns=mock_detect_discrete_columns.return_value ) + + def test_get_loss_values(self): + """Test the ``get_loss_values`` method from ``TVAESynthesizer.""" + # Setup + mock_model = Mock() + loss_values = pd.DataFrame({ + 'Epoch': [0, 1, 2], + 'Loss': [0.8, 0.6, 0.5] + }) + mock_model.loss_values = loss_values + metadata = SingleTableMetadata() + instance = TVAESynthesizer(metadata) + instance._model = mock_model + instance._fitted = True + + # Run + actual_loss_values = instance.get_loss_values() + + # Assert + pd.testing.assert_frame_equal(actual_loss_values, loss_values) + + def test_get_loss_values_error(self): + """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" + # Setup + metadata = SingleTableMetadata() + instance = TVAESynthesizer(metadata) + + # Run / Assert + msg = 'Loss values are not available yet. Please fit your synthesizer first.' + with pytest.raises(NotFittedError, match=msg): + instance.get_loss_values()