From a633564c6d653cb769c069c676469e235e1e7ea8 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 25 Nov 2024 10:32:39 +0100 Subject: [PATCH] fix: replaced exception with warning when multiple series, retrain=True and data transformer defined with global_fit=True --- .../test_historical_forecasts.py | 63 +++++++++---------- darts/utils/historical_forecasts/utils.py | 13 ++-- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/darts/tests/utils/historical_forecasts/test_historical_forecasts.py b/darts/tests/utils/historical_forecasts/test_historical_forecasts.py index 0249b51146..9e54638192 100644 --- a/darts/tests/utils/historical_forecasts/test_historical_forecasts.py +++ b/darts/tests/utils/historical_forecasts/test_historical_forecasts.py @@ -2860,8 +2860,9 @@ def test_historical_forecasts_with_scaler(self, params): else: self.helper_compare_hf(hf_auto, [manual_hf_0, manual_hf_1]) - def test_historical_forecasts_with_scaler_errors(self): - """Check that the appropriate exception is raised when providing incorrect parameters.""" + def test_historical_forecasts_with_scaler_errors(self, caplog): + """Check that the appropriate exception is raised when providing incorrect parameters or the expected + warning is display in the corner cases.""" ocl = 2 hf_args = { "start": -ocl - 1, @@ -2874,7 +2875,7 @@ def test_historical_forecasts_with_scaler_errors(self): # retrain=False and unfitted data transformers with pytest.raises(ValueError) as err: - _ = model.historical_forecasts( + model.historical_forecasts( **hf_args, series=self.sine_univariate1, data_transformers={"series": Scaler()}, @@ -2886,7 +2887,7 @@ def test_historical_forecasts_with_scaler_errors(self): # retrain=False, multiple series not matching the fitted data transformers dimensions with pytest.raises(ValueError) as err: - _ = model.historical_forecasts( + model.historical_forecasts( **hf_args, series=[self.sine_univariate1] * 2, data_transformers={ @@ -2900,17 +2901,37 @@ def test_historical_forecasts_with_scaler_errors(self): ) # retrain=True, multiple series and unfitted data transformers with global_fit=True - with pytest.raises(ValueError) as err: - _ = model.historical_forecasts( + expected_warning = ( + "When `retrain=True` and multiple series are provided, the fittable `data_transformers` " + "are trained on each series independently (`global_fit=True` will be ignored)." + ) + with caplog.at_level(logging.WARNING): + model.historical_forecasts( **hf_args, series=[self.sine_univariate1, self.sine_univariate2], data_transformers={"series": Scaler(global_fit=True)}, retrain=True, ) - assert str(err.value).startswith( - "When `retrain=True` and multiple series are provided, all the fittable `data_transformers` " - "must be defined with `global_fit=False" + assert expected_warning in caplog.text + + # data transformer (global_fit=False) prefitted on several series but only series is forecasted + expected_warning = ( + "Provided only a single series, but at least one of the `data_transformers` " + "that use `global_fit=False` was fitted on multiple `TimeSeries`." ) + with caplog.at_level(logging.WARNING): + model.historical_forecasts( + **hf_args, + series=[self.sine_univariate2], + data_transformers={ + "series": Scaler(global_fit=False).fit([ + self.sine_univariate1, + self.sine_univariate2, + ]) + }, + retrain=False, + ) + assert expected_warning in caplog.text @pytest.mark.parametrize("params", product([True, False], [True, False])) def test_historical_forecasts_with_scaler_multiple_series(self, params): @@ -2942,29 +2963,6 @@ def get_scaler(fit: bool): else: return Scaler(global_fit=global_fit) - # global fit is not supported with retrain and multiple series - if retrain and global_fit: - expected_msg = ( - "When `retrain=True` and multiple series are provided, all the fittable `data_transformers` must " - "be defined with `global_fit=False`." - ) - with pytest.raises(ValueError) as err: - _ = model.historical_forecasts( - **hf_args, - series=series, - data_transformers={"series": get_scaler(fit=True)}, - ) - assert str(err.value) == expected_msg - - with pytest.raises(ValueError) as err: - _ = model.historical_forecasts( - **hf_args, - series=series, - data_transformers={"series": get_scaler(fit=False)}, - ) - assert str(err.value) == expected_msg - return - # using all the series used to fit the scaler hf = model.historical_forecasts( **hf_args, @@ -3033,6 +3031,7 @@ def get_scaler(fit: bool): ) self.helper_compare_hf(hf, [manual_hf_2]) + # data_transformers are not pre-fitted if retrain: hf = model.historical_forecasts( **hf_args, diff --git a/darts/utils/historical_forecasts/utils.py b/darts/utils/historical_forecasts/utils.py index 04ddc6f3ca..86a4ef64b9 100644 --- a/darts/utils/historical_forecasts/utils.py +++ b/darts/utils/historical_forecasts/utils.py @@ -233,13 +233,10 @@ def _historical_forecasts_general_checks(model, series, kwargs): if n.retrain: # if more than one series is passed and the pipelines are retrained, they cannot be global - if len(series) > 1 and len(global_fit_pipelines) > 0: - raise_log( - ValueError( - "When `retrain=True` and multiple series are provided, all the fittable " - "`data_transformers` must be defined with `global_fit=False`." - ), - logger, + if n.show_warnings and len(series) > 1 and len(global_fit_pipelines) > 0: + logger.warning( + "When `retrain=True` and multiple series are provided, the fittable `data_transformers` " + "are trained on each series independently (`global_fit=True` will be ignored)." ) else: # must already be fitted without retraining @@ -283,7 +280,7 @@ def _historical_forecasts_general_checks(model, series, kwargs): else: # at least one pipeline was fitted on several series with `global_fit=False` but only # one series was passed - if max(fitted_params_pipelines) > 1: + if n.show_warnings and max(fitted_params_pipelines) > 1: logger.warning( "Provided only a single series, but at least one of the `data_transformers` " "that use `global_fit=False` was fitted on multiple `TimeSeries`."