Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/onnx support #2620

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added more resampling methods to `TimeSeries.resample()`. This allows to aggregate values when down-sampling and to fill or keep the holes when up-sampling. [#2654](https://github.com/unit8co/darts/pull/2654) by [Jonas Blanc](https://github.com/jonasblanc)
- Added general function `darts.slice_intersect()` to intersect a sequence of `TimeSeries` along the time index. [#2592](https://github.com/unit8co/darts/pull/2592) by [Yoav Matzkevich](https://github.com/ymatzkevich).
- Added new time aggregated metric `wmape()` (Weighted Mean Absolute Percentage Error). [#2544](https://github.com/unit8co/darts/pull/2648) by [He Weilin](https://github.com/cnhwl).
- Added ONNX support for torch-based models, and an example of export and loading for inference in the User Guide. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou)
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

**Fixed**
- Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC).
Expand Down
55 changes: 45 additions & 10 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
When subclassing this class, please make sure to add the following methods with the given signatures:
- :func:`PLForecastingModule.__init__()`
- :func:`PLForecastingModule.forward()`
- :func:`PLForecastingModule._process_input_batch()`
- :func:`PLForecastingModule._produce_train_output()`
- :func:`PLForecastingModule._get_batch_prediction()`

Expand Down Expand Up @@ -583,6 +584,13 @@ def to_dtype(self, dtype):
logger,
)

def to_onnx(self, file_path, input_sample=None, **kwargs):
if not input_sample:
logger.warning(
"It is recommended to use `TorchForecastingModel.to_onnx` method instead."
)
super().to_onnx(file_path=file_path, input_sample=input_sample, **kwargs)

madtoinou marked this conversation as resolved.
Show resolved Hide resolved
@property
def epochs_trained(self):
current_epoch = self.current_epoch
Expand Down Expand Up @@ -632,17 +640,48 @@ def _produce_train_output(self, input_batch: tuple):
input_batch
``(past_target, past_covariates, static_covariates)``
"""
past_target, past_covariates, static_covariates = input_batch
return self(self._process_input_batch(input_batch))

def _process_input_batch(
self, input_batch: tuple
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Converts output of PastCovariatesDataset (training dataset) into an input/past- and
output/future chunk.

Parameters
----------
input_batch
``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.

Returns
-------
tuple
``(x_past, x_static)`` the input/past and output/future chunks.
"""
# because of future past covariates, the batch shape is different during training and prediction
if len(input_batch) == 3:
(
past_target,
past_covariates,
static_covariates,
) = input_batch
else:
(
past_target,
past_covariates,
future_past_covariates,
static_covariates,
) = input_batch
# Currently all our PastCovariates models require past target and covariates concatenated
inpt = (
return (
(
torch.cat([past_target, past_covariates], dim=2)
if past_covariates is not None
else past_target
),
static_covariates,
)
return self(inpt)

def _get_batch_prediction(
self, n: int, input_batch: tuple, roll_size: int
Expand Down Expand Up @@ -674,12 +713,9 @@ def _get_batch_prediction(
past_covariates.shape[dim_component] if past_covariates is not None else 0
)

input_past = torch.cat(
[ds for ds in [past_target, past_covariates] if ds is not None],
dim=dim_component,
)
input_past, input_static = self._process_input_batch(input_batch)

out = self._produce_predict_output(x=(input_past, static_covariates))[
out = self._produce_predict_output(x=(input_past, input_static))[
:, self.first_prediction_index :, :
]

Expand Down Expand Up @@ -796,7 +832,6 @@ def _process_input_batch(
future_covariates,
static_covariates,
) = input_batch
dim_variable = 2
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

x_past = torch.cat(
[
Expand All @@ -808,7 +843,7 @@ def _process_input_batch(
]
if tensor is not None
],
dim=dim_variable,
dim=2,
)
return x_past, future_covariates, static_covariates

Expand Down
9 changes: 7 additions & 2 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def forward(
pass

def _produce_train_output(self, input_batch: tuple) -> torch.Tensor:
# only return the forecast, not the hidden state
return self(self._process_input_batch(input_batch))[0]

def _process_input_batch(
self, input_batch: tuple
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
(
past_target,
historic_future_covariates,
Expand All @@ -112,15 +118,14 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor:
) = input_batch
# For the RNN we concatenate the past_target with the future_covariates
# (they have the same length because we enforce a Shift dataset for RNNs)
model_input = (
return (
(
torch.cat([past_target, future_covariates], dim=2)
if future_covariates is not None
else past_target
),
static_covariates,
)
return self(model_input)[0]

def _produce_predict_output(
self, x: tuple, last_hidden_state: Optional[torch.Tensor] = None
Expand Down
88 changes: 88 additions & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,94 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates):
logger=logger,
)

def to_onnx(
self,
path: Optional[str] = None,
input_sample: Optional[tuple] = None,
randomize_input_sample: bool = False,
**kwargs,
):
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
:func:`torch.onnx.export` method ()`official documentation <https://lightning.ai/docs/pytorch/
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
stable/common/lightning_module.html#to-onnx>`_).
Note: onnx library (optionnal dependency) must be installed in order to call this method
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Example for exporting a :class:`DLinearModel`:
.. highlight:: python
.. code-block:: python
from darts.models import DLinearModel
from darts import TimeSeries
import numpy as np
train_ts = TimeSeries.from_values(np.arange(0,100))
model = DLinearModel(input_chunk_length=4, output_chunk_length=1)
model.fit(train_ts, epochs=1)
model.to_onnx("my_model.onnx")
..
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
path
Path under which to save the model at its current state. If no path is specified, the model
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.onnx"``.
input_sample
Tuple of Tensor corresponding to the inputs of the model forward pass.
randomize_input_sample
Wether to randomize the values in the `input_sample` to avoid leaking data.
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
**kwargs
Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a
description of the model being exported to stdout.
For more information, read the `official documentation <https://pytorch.org/docs/master/
onnx.html#torch.onnx.export>`_.
"""
if not self._fit_called:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
raise_log(
ValueError("`fit()` needs to be called before `to_onnx()`."), logger
)

if not self.train_sample and not input_sample:
raise_log(
ValueError(
"Either the `input_sample` argument or the `train_sample` attribute must be provided."
),
logger,
)

if path is None:
path = self._default_save_path() + ".onnx"

if not input_sample:
# last dimension in train_sample_shape is the expected target
mock_batch = tuple(
torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None
for shape in self.model.train_sample_shape[:-1]
)
input_sample = self.model._process_input_batch(mock_batch)
elif randomize_input_sample:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
input_sample = tuple(
torch.rand(tensor.shape, dtype=self.model.dtype)
if tensor is not None
else None
for tensor in input_sample
)

# torch models necessarily use historic target values as features in current implementation
input_names = ["x_past"]
if self._uses_future_covariates:
input_names.append("x_future")
if self._uses_static_covariates:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
input_names.append("x_static")

self.model.to_onnx(
file_path=path,
input_sample=(input_sample,),
input_names=input_names,
**kwargs,
)

@random_method
def fit(
self,
Expand Down
82 changes: 82 additions & 0 deletions docs/userguide/torch_forecasting_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ We assume that you already know about covariates in Darts. If you're new to the
- [Manual saving / loading](#manual-saving--loading)
- [Train & save on GPU, load on CPU](#trainingsaving-on-gpu-and-loading-on-cpu)
- [Load pre-trained model for fine-tuning](#re-training-or-fine-tuning-a-pre-trained-model)
- [Exporting model to ONNX format for inference](#exporting-model-to-ONNX-format-for-inference)
- [Callbacks](#callbacks)
- [Early Stopping](#example-with-early-stopping)
- [Custom Callback](#example-of-custom-callback-to-store-losses)
Expand Down Expand Up @@ -350,6 +351,87 @@ model_finetune = SomeTorchForecastingModel(..., # use identical parameters & va
model_finetune.load_weights("/your/path/to/save/model.pt")
```

#### Exporting model to ONNX format for inference
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reading this I think again about how nice it would be if all our Datasets returned feature / target arrays at fixed positions (e.g. all return a tuple of past target, past cov, historic future cov, future cov, static cov, ... even if they do not support all covariate types) :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definitely, it would make a lot of things more intuitive and easier to tweak for users.


It is also possible to export the model weights to the ONNX format to run inference in a lightweight environment. This example assumes that the model is trained using a future covariates that extends far enough into the future. Note that the user must align and slice the series manually and it will not be possible to forecast `n > output_chunk_length` without implementing the auto-regression logic
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

```python
model = SomeTorchForecastingModel(...)
model.fit(...)

# make sure to have onnx installed
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
onnx_filename = "example_onnx.onnx"
model.to_onnx(onnx_filename, export_params=True)
```

Now, to load the model and predict steps after the end of the series:

```python
import onnx
import onnxruntime as ort

def prepare_onnx_inputs(
model,
series: TimeSeries,
past_covariates : Optional[TimeSeries] = None,
future_covariates : Optional[TimeSeries] = None,
) -> tuple[Optional[np.ndarray]]:
"""Helper function to slice and concatenate the input features"""
past_feats, future_feats, static_feats = None, None, None
# convert and concatenate the historic features (target, past and future covariates)
past_feats = series.values()[-model.input_chunk_length:]
if past_covariates:
past_feats = np.concatenate(
[
past_feats,
past_covariates.values()[-model.input_chunk_length:]
],
axis=1
)
if future_covariates:
past_feats = np.concatenate(
[
past_feats,
future_covariates.values()[-model.input_chunk_length:]
],
axis=1
)
past_feats = np.expand_dims(past_feats, axis=0)

# convert the future covariates
if model._uses_future_covariates:
if future_covariates:
future_feats = np.expand_dims(future_covariates.values()[
len(series):len(series)+model.output_chunk_length
], axis=0)

# convert static covariates
if series.has_static_covariates:
static_feats = np.expand_dims(series.static_covariates_values(), axis=0)

return past_feats, future_feats, static_feats

onnx_model = onnx.load(onnx_filename)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_filename)

# use helper function to extract the features from the series
past_feats, future_feats, static_feats = prepare_input_feats(
model=model,
series=series,
past_covariates = None,
future_covariates = ts_future,
)

# extract only the features expected by the model
ort_inputs = {
k:v for k, v in zip(['x_past', 'x_future', 'x_static'], [past_feats, future_feats, static_feats]) if k in [inp.name for inp in list(ort_session.get_inputs())]
}
ort_outs = ort_session.run(None, ort_inputs)
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
```

Note that the forecasts might be slightly different due to rounding errors. Also, due to its specificities, `RNNModel` requires different pre-processing of the series to obtain the input arrays (notably because of `training_length`).

### Callbacks

Callbacks are a powerful way to monitor or control the behavior of the model during the training process. Some examples:
Expand Down
Loading