Skip to content

Commit

Permalink
fix: replaced exception with warning when multiple series, retrain=Tr…
Browse files Browse the repository at this point in the history
…ue and data transformer defined with global_fit=True
  • Loading branch information
madtoinou committed Nov 25, 2024
1 parent 4daf863 commit a633564
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 40 deletions.
63 changes: 31 additions & 32 deletions darts/tests/utils/historical_forecasts/test_historical_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2860,8 +2860,9 @@ def test_historical_forecasts_with_scaler(self, params):
else:
self.helper_compare_hf(hf_auto, [manual_hf_0, manual_hf_1])

def test_historical_forecasts_with_scaler_errors(self):
"""Check that the appropriate exception is raised when providing incorrect parameters."""
def test_historical_forecasts_with_scaler_errors(self, caplog):
"""Check that the appropriate exception is raised when providing incorrect parameters or the expected
warning is display in the corner cases."""
ocl = 2
hf_args = {
"start": -ocl - 1,
Expand All @@ -2874,7 +2875,7 @@ def test_historical_forecasts_with_scaler_errors(self):

# retrain=False and unfitted data transformers
with pytest.raises(ValueError) as err:
_ = model.historical_forecasts(
model.historical_forecasts(
**hf_args,
series=self.sine_univariate1,
data_transformers={"series": Scaler()},
Expand All @@ -2886,7 +2887,7 @@ def test_historical_forecasts_with_scaler_errors(self):

# retrain=False, multiple series not matching the fitted data transformers dimensions
with pytest.raises(ValueError) as err:
_ = model.historical_forecasts(
model.historical_forecasts(
**hf_args,
series=[self.sine_univariate1] * 2,
data_transformers={
Expand All @@ -2900,17 +2901,37 @@ def test_historical_forecasts_with_scaler_errors(self):
)

# retrain=True, multiple series and unfitted data transformers with global_fit=True
with pytest.raises(ValueError) as err:
_ = model.historical_forecasts(
expected_warning = (
"When `retrain=True` and multiple series are provided, the fittable `data_transformers` "
"are trained on each series independently (`global_fit=True` will be ignored)."
)
with caplog.at_level(logging.WARNING):
model.historical_forecasts(
**hf_args,
series=[self.sine_univariate1, self.sine_univariate2],
data_transformers={"series": Scaler(global_fit=True)},
retrain=True,
)
assert str(err.value).startswith(
"When `retrain=True` and multiple series are provided, all the fittable `data_transformers` "
"must be defined with `global_fit=False"
assert expected_warning in caplog.text

# data transformer (global_fit=False) prefitted on several series but only series is forecasted
expected_warning = (
"Provided only a single series, but at least one of the `data_transformers` "
"that use `global_fit=False` was fitted on multiple `TimeSeries`."
)
with caplog.at_level(logging.WARNING):
model.historical_forecasts(
**hf_args,
series=[self.sine_univariate2],
data_transformers={
"series": Scaler(global_fit=False).fit([
self.sine_univariate1,
self.sine_univariate2,
])
},
retrain=False,
)
assert expected_warning in caplog.text

@pytest.mark.parametrize("params", product([True, False], [True, False]))
def test_historical_forecasts_with_scaler_multiple_series(self, params):
Expand Down Expand Up @@ -2942,29 +2963,6 @@ def get_scaler(fit: bool):
else:
return Scaler(global_fit=global_fit)

# global fit is not supported with retrain and multiple series
if retrain and global_fit:
expected_msg = (
"When `retrain=True` and multiple series are provided, all the fittable `data_transformers` must "
"be defined with `global_fit=False`."
)
with pytest.raises(ValueError) as err:
_ = model.historical_forecasts(
**hf_args,
series=series,
data_transformers={"series": get_scaler(fit=True)},
)
assert str(err.value) == expected_msg

with pytest.raises(ValueError) as err:
_ = model.historical_forecasts(
**hf_args,
series=series,
data_transformers={"series": get_scaler(fit=False)},
)
assert str(err.value) == expected_msg
return

# using all the series used to fit the scaler
hf = model.historical_forecasts(
**hf_args,
Expand Down Expand Up @@ -3033,6 +3031,7 @@ def get_scaler(fit: bool):
)
self.helper_compare_hf(hf, [manual_hf_2])

# data_transformers are not pre-fitted
if retrain:
hf = model.historical_forecasts(
**hf_args,
Expand Down
13 changes: 5 additions & 8 deletions darts/utils/historical_forecasts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,10 @@ def _historical_forecasts_general_checks(model, series, kwargs):

if n.retrain:
# if more than one series is passed and the pipelines are retrained, they cannot be global
if len(series) > 1 and len(global_fit_pipelines) > 0:
raise_log(
ValueError(
"When `retrain=True` and multiple series are provided, all the fittable "
"`data_transformers` must be defined with `global_fit=False`."
),
logger,
if n.show_warnings and len(series) > 1 and len(global_fit_pipelines) > 0:
logger.warning(
"When `retrain=True` and multiple series are provided, the fittable `data_transformers` "
"are trained on each series independently (`global_fit=True` will be ignored)."
)
else:
# must already be fitted without retraining
Expand Down Expand Up @@ -283,7 +280,7 @@ def _historical_forecasts_general_checks(model, series, kwargs):
else:
# at least one pipeline was fitted on several series with `global_fit=False` but only
# one series was passed
if max(fitted_params_pipelines) > 1:
if n.show_warnings and max(fitted_params_pipelines) > 1:
logger.warning(
"Provided only a single series, but at least one of the `data_transformers` "
"that use `global_fit=False` was fitted on multiple `TimeSeries`."
Expand Down

0 comments on commit a633564

Please sign in to comment.