Skip to content

Commit

Permalink
fix encoder support for conformal models (#2627)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader authored Dec 21, 2024
1 parent fc244ac commit 5116c38
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 40 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ series.plot()
* **Multivariate Support:** `TimeSeries` can be multivariate - i.e., contain multiple time-varying
dimensions/columns instead of a single scalar value. Many models can consume and produce multivariate series.

* **Multiple series training (global models):** All machine learning based models (incl. all neural networks)
* **Multiple Series Training (Global Models):** All machine learning based models (incl. all neural networks)
support being trained on multiple (potentially multivariate) series. This can scale to large datasets too.

* **Probabilistic Support:** `TimeSeries` objects can (optionally) represent stochastic
Expand All @@ -177,10 +177,10 @@ series.plot()
* **Conformal Prediction Support:** Our conformal prediction models allow to generate probabilistic forecasts with
calibrated quantile intervals for any pre-trained global forecasting model.

* **Past and Future Covariates support:** Many models in Darts support past-observed and/or future-known
* **Past and Future Covariates Support:** Many models in Darts support past-observed and/or future-known
covariate (external data) time series as inputs for producing forecasts.

* **Static Covariates support:** In addition to time-dependent data, `TimeSeries` can also contain
* **Static Covariates Support:** In addition to time-dependent data, `TimeSeries` can also contain
static data for each dimension, which can be exploited by some models.

* **Hierarchical Reconciliation:** Darts offers transformers to perform reconciliation.
Expand All @@ -189,7 +189,7 @@ series.plot()
* **Regression Models:** It is possible to plug-in any scikit-learn compatible model
to obtain forecasts as functions of lagged values of the target series and covariates.

* **Training with sample weights:** All global models support being trained with sample weights. They can be
* **Training with Sample Weights:** All global models support being trained with sample weights. They can be
applied to each observation, forecasted time step and target column.

* **Forecast Start Shifting:** All global models support training and prediction on a shifted output window.
Expand All @@ -198,7 +198,7 @@ series.plot()

* **Explainability:** Darts has the ability to *explain* some forecasting models using Shap values.

* **Data processing:** Tools to easily apply (and revert) common transformations on
* **Data Processing:** Tools to easily apply (and revert) common transformations on
time series data (scaling, filling missing values, differencing, boxcox, ...)

* **Metrics:** A variety of metrics for evaluating time series' goodness of fit;
Expand Down
38 changes: 4 additions & 34 deletions darts/models/forecasting/conformal_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,40 +308,6 @@ def predict(
If `series` is given and is a sequence of several time series, this function returns
a sequence where each element contains the corresponding `n` points forecasts.
"""
if series is None:
# then there must be a single TS, and that was saved in super().fit as self.training_series
if self.model.training_series is None:
raise_log(
ValueError(
"Input `series` must be provided. This is the result either from fitting on multiple series, "
"or from not having fit the model yet."
),
logger,
)
series = self.model.training_series

called_with_single_series = get_series_seq_type(series) == SeriesType.SINGLE

# guarantee that all inputs are either list of TimeSeries or None
series = series2seq(series)
if past_covariates is None and self.model.past_covariate_series is not None:
past_covariates = [self.model.past_covariate_series] * len(series)
if future_covariates is None and self.model.future_covariate_series is not None:
future_covariates = [self.model.future_covariate_series] * len(series)
past_covariates = series2seq(past_covariates)
future_covariates = series2seq(future_covariates)

super().predict(
n,
series,
past_covariates,
future_covariates,
num_samples,
verbose,
predict_likelihood_parameters,
show_warnings,
)

# call predict to verify that all series have required input times
_ = self.model.predict(
n=n,
Expand All @@ -355,6 +321,10 @@ def predict(
**kwargs,
)

series = series or self.model.training_series
called_with_single_series = get_series_seq_type(series) == SeriesType.SINGLE
series = series2seq(series)

# generate only the required forecasts for calibration (including the last forecast which is the output of
# `predict()`)
cal_start, cal_start_format = _get_calibration_hfc_start(
Expand Down
41 changes: 40 additions & 1 deletion darts/tests/models/forecasting/test_conformal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def test_too_short_input_predict(self, config):
else:
# if `past_covariates` are too short, then it raises error from the forecasting_model.predict()
assert str(exc.value).startswith(
"The `past_covariates` at list/sequence index 0 are not long enough."
"The `past_covariates` are not long enough."
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -1658,3 +1658,42 @@ def test_calibration_hfc_start_value_hist_fc(self, config):
start=start,
start_format="value",
) == (start_expected, "value")

def test_encoders(self):
"""Tests support of covariates encoders."""
n = OUT_LEN + 1
min_length = IN_LEN + n

# create non-overlapping train and val series
series = tg.linear_timeseries(length=min_length)
val_series = tg.linear_timeseries(
start=series.end_time() + series.freq, length=min_length
)

model = train_model(
series,
model_params={
"lags_future_covariates": (IN_LEN, OUT_LEN),
"add_encoders": {"datetime_attribute": {"future": ["hour"]}},
},
)

cp_model = ConformalNaiveModel(model, quantiles=q)
assert (
cp_model.model.encoders is not None
and cp_model.model.encoders.encoding_available
)
assert model.uses_future_covariates

# predict: encoders using stored train series must work
_ = cp_model.predict(n=n)
# predict: encoding of new series without train overlap must work
_ = cp_model.predict(n=n, series=val_series)

# check the same for hfc
_ = cp_model.historical_forecasts(
forecast_horizon=n, series=series, overlap_end=True
)
_ = cp_model.historical_forecasts(
forecast_horizon=n, series=val_series, overlap_end=True
)

0 comments on commit 5116c38

Please sign in to comment.