Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Jan 31, 2024
1 parent 4dba291 commit c436cf3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
36 changes: 23 additions & 13 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import uuid
import warnings
from datetime import datetime

import numpy as np
import pandas as pd
Expand All @@ -15,10 +14,10 @@

from sdv.errors import SamplingError, SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sampling import Condition
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.single_table.ctgan import LossValuesMixin
from sdv.sampling import Condition
from sdv.utils import cast_to_iterable, groupby_list

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,7 +77,7 @@ def _get_context_metadata(self):

for column in context_columns:
context_columns_dict[column] = self.metadata.columns[column]

for column, column_metadata in self._extra_context_columns.items():
context_columns_dict[column] = column_metadata

Expand Down Expand Up @@ -178,12 +177,14 @@ def _preprocess(self, data):
self.auto_assign_transformers(data)

self.update_transformers(sequence_key_transformers)

preprocessed = super()._preprocess(data)

if self._sequence_index:
sequence_index = preprocessed[self._sequence_key + [self._sequence_index]]
sequence_index_context = sequence_index.groupby(self._sequence_key).agg('first')
sequence_index_context.rename(columns={self._sequence_index: f'{self._sequence_index}.context'}, inplace=True)
sequence_index_context = sequence_index_context.rename(
columns={self._sequence_index: f'{self._sequence_index}.context'}
)
if all(sequence_index[self._sequence_key].nunique() == 1):
sequence_index_sequence = sequence_index[[self._sequence_index]].diff().bfill()
else:
Expand All @@ -192,12 +193,19 @@ def _preprocess(self, data):
).droplevel(1).reset_index()

preprocessed[self._sequence_index] = sequence_index_sequence[self._sequence_index]
preprocessed = preprocessed.merge(sequence_index_context, left_on=self._sequence_key, right_index=True)
preprocessed = preprocessed.merge(
sequence_index_context,
left_on=self._sequence_key,
right_index=True)

self.extended_columns[self._sequence_index] = FloatFormatter(
enforce_min_max_values=True)
self.extended_columns[self._sequence_index].fit(
sequence_index_sequence, self._sequence_index)
self._extra_context_columns[f'{self._sequence_index}.context'] = {
'sdtype': 'numerical'
}

self.extended_columns[self._sequence_index] = FloatFormatter(enforce_min_max_values=True)
self.extended_columns[self._sequence_index].fit(sequence_index_sequence, self._sequence_index)
self._extra_context_columns[f'{self._sequence_index}.context'] = {'sdtype': 'numerical'}

return preprocessed

def update_transformers(self, column_name_to_transformer):
Expand All @@ -221,7 +229,8 @@ def _fit_context_model(self, transformed):
LOGGER.debug(f'Fitting context synthesizer {self._context_synthesizer.__class__.__name__}')
if self.context_columns or self._extra_context_columns:
context_cols = (
self._sequence_key + self.context_columns + list(self._extra_context_columns.keys())
self._sequence_key + self.context_columns +
list(self._extra_context_columns.keys())
)
context = transformed[context_cols]
else:
Expand All @@ -248,7 +257,8 @@ def _fit_sequence_columns(self, timeseries_data):
column
for column in timeseries_data.columns
if column not in (
self._sequence_key + self.context_columns + list(self._extra_context_columns.keys())
self._sequence_key + self.context_columns +
list(self._extra_context_columns.keys())
)
]

Expand Down Expand Up @@ -405,7 +415,7 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
"This synthesizer does not have any context columns. Please use 'sample()' "
'to sample new sequences.'
)

condition_columns = list(set.intersection(
set(context_columns.columns), set(self._context_synthesizer._model.columns)
))
Expand Down
23 changes: 14 additions & 9 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from sdv.data_processing.data_processor import DataProcessor
from sdv.errors import InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sampling import Condition
from sdv.sequential.par import PARSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from tests.utils import DataFrameMatcher
from sdv.sampling import Condition


class TestPARSynthesizer:
Expand Down Expand Up @@ -313,12 +313,12 @@ def test__fit_context_model_with_context_columns(self, gaussian_copula_mock):
par = PARSynthesizer(metadata, context_columns=['gender'])
initial_synthesizer = Mock()
context_metadata = SingleTableMetadata.load_from_dict({
"columns": {
"gender": {
"sdtype": "categorical"
'columns': {
'gender': {
'sdtype': 'categorical'
},
"name": {
"sdtype": "id"
'name': {
'sdtype': 'id'
}
}
})
Expand Down Expand Up @@ -434,7 +434,10 @@ def test__fit_sequence_columns_with_sequence_index(self, assemble_sequences_mock
)
sequences = [
{'context': np.array(['F'], dtype=object), 'data': [[1, 1], [55, 60], [1, 1]]},
{'context': np.array(['M'], dtype=object), 'data': [[2, 2, 3], [65, 65, 70], [3, 3, 3]]},
{
'context': np.array(['M'], dtype=object),
'data': [[2, 2, 3], [65, 65, 70], [3, 3, 3]]
},
]
assemble_sequences_mock.return_value = sequences

Expand Down Expand Up @@ -664,7 +667,8 @@ def test__sample_from_par_with_sequence_index(self, tqdm_mock):
model_mock = Mock()
par._model = model_mock
mock_transformer = Mock()
mock_transformer.reverse_transform.return_value = pd.DataFrame({'time': [1000, 2000, 2000]})
mock_transformer.reverse_transform.return_value = pd.DataFrame(
{'time': [1000, 2000, 2000]})
par.extended_columns = {'time': mock_transformer}
par._data_columns = ['time', 'measurement']
par._output_columns = ['time', 'gender', 'name', 'measurement']
Expand All @@ -673,7 +677,8 @@ def test__sample_from_par_with_sequence_index(self, tqdm_mock):
[1000, 2000, 2000],
[55, 60, 65]
]
context_columns = pd.DataFrame({'name': ['John'], 'gender': ['M'], 'time.context': [18000]})
context_columns = pd.DataFrame(
{'name': ['John'], 'gender': ['M'], 'time.context': [18000]})
tqdm_mock.tqdm.return_value = context_columns.set_index('name').iterrows()

# Run
Expand Down

0 comments on commit c436cf3

Please sign in to comment.