Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imp/Add new model StatsForecastAutoTBATS #2611

Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**

- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).

**Fixed**

**Dependencies**
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ on bringing more models and features.
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 |
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | [Prophet repo](https://github.com/facebook/prophet) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 |
Expand Down
3 changes: 3 additions & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA
from darts.models.forecasting.sf_auto_ces import StatsForecastAutoCES
from darts.models.forecasting.sf_auto_ets import StatsForecastAutoETS
from darts.models.forecasting.sf_auto_tbats import StatsForecastAutoTBATS
from darts.models.forecasting.sf_auto_theta import StatsForecastAutoTheta

except ImportError:
Expand All @@ -108,6 +109,7 @@
StatsForecastAutoCES = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoETS = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoTheta = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoTBATS = NotImportedModule(module_name="StatsForecast", warn=False)

try:
from darts.models.forecasting.xgboost import XGBModel
Expand Down Expand Up @@ -160,6 +162,7 @@
"StatsForecastAutoCES",
"StatsForecastAutoETS",
"StatsForecastAutoTheta",
"StatsForecastAutoTBATS",
"XGBModel",
"GaussianProcessFilter",
"KalmanFilter",
Expand Down
1 change: 1 addition & 0 deletions darts/models/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- :class:`~darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES`
- :class:`~darts.models.forecasting.tbats_model.BATS`
- :class:`~darts.models.forecasting.tbats_model.TBATS`
- :class:`~darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS`
- :class:`~darts.models.forecasting.theta.Theta`
- :class:`~darts.models.forecasting.theta.FourTheta`
- :class:`~darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta`
Expand Down
104 changes: 104 additions & 0 deletions darts/models/forecasting/sf_auto_tbats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
StatsForecastAutoTBATS
-----------
"""

from statsforecast.models import AutoTBATS as SFAutoTBATS

from darts import TimeSeries
from darts.models.components.statsforecast_utils import (
create_normal_samples,
one_sigma_rule,
unpack_sf_dict,
)
from darts.models.forecasting.forecasting_model import LocalForecastingModel


class StatsForecastAutoTBATS(LocalForecastingModel):
def __init__(self, *autoTBATS_args, **autoTBATS_kwargs):
"""Auto-TBATS based on `Statsforecasts package
<https://github.com/Nixtla/statsforecast>`_.

Automatically selects the best TBATS model from all feasible combinations of the parameters `use_boxcox`,
`use_trend`, `use_damped_trend`, and `use_arma_errors`. Selection is made using the AIC.
Default value for `use_arma_errors` is True since this enables the evaluation of models with
and without ARMA errors.
<https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=f3de25596ab60ef0e886366826bf58a02b35a44f>
<https://doi.org/10.4225/03/589299681de3d>

We refer to the `statsforecast AutoTBATS documentation
<https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats>`_
for the exhaustive documentation of the arguments.

Parameters
----------
autoTBATS_args
Positional arguments for ``statsforecasts.models.AutoTBATS``.
autoTBATS_kwargs
Keyword arguments for ``statsforecasts.models.AutoTBATS``.

Examples
--------
>>> from darts.datasets import AirPassengersDataset
>>> from darts.models import StatsForecastAutoTBATS
>>> series = AirPassengersDataset().load()
>>> # define StatsForecastAutoTBATS parameters
>>> model = StatsForecastAutoTBATS(season_length=12)
>>> model.fit(series)
>>> pred = model.predict(6)
>>> pred.values()
array([[450.79653684],
[472.09265790],
[497.76948306],
[510.74927369],
[520.92224557],
[570.33881522]])
"""
super().__init__()
self.model = SFAutoTBATS(*autoTBATS_args, **autoTBATS_kwargs)

def fit(self, series: TimeSeries):
super().fit(series)
self._assert_univariate(series)
series = self.training_series
self.model.fit(
series.values(copy=False).flatten(),
)
return self

def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
forecast_dict = self.model.predict(
h=n,
level=(one_sigma_rule,), # ask one std for the confidence interval.
)

mu, std = unpack_sf_dict(forecast_dict)
if num_samples > 1:
samples = create_normal_samples(mu, std, num_samples, n)
else:
samples = mu

return self._build_forecast_series(samples)

@property
def supports_multivariate(self) -> bool:
return False

@property
def min_train_series_length(self) -> int:
return 10

@property
def _supports_range_index(self) -> bool:
return True

@property
def supports_probabilistic_prediction(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
StatsForecastAutoARIMA,
StatsForecastAutoCES,
StatsForecastAutoETS,
StatsForecastAutoTBATS,
StatsForecastAutoTheta,
Theta,
)
Expand All @@ -57,6 +58,7 @@
(StatsForecastAutoTheta(season_length=12), 5.5),
(StatsForecastAutoCES(season_length=12, model="Z"), 7.3),
(StatsForecastAutoETS(season_length=12, model="AAZ"), 7.3),
(StatsForecastAutoTBATS(season_length=12), 10),
(Croston(version="classic"), 23),
(Croston(version="tsb", alpha_d=0.1, alpha_p=0.1), 23),
(Theta(), 11),
Expand Down
1 change: 1 addition & 0 deletions docs/userguide/covariates.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ GFMs are models that can be trained on multiple target (and covariate) time seri
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | | ✅ | |
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | | | |
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | | | |
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | | | |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | | | |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | | | |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | | ✅ | |
Expand Down
Loading