diff --git a/CHANGELOG.md b/CHANGELOG.md index 5febc3e7a1..fdf0bfe1a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Added more resampling methods to `TimeSeries.resample()`. This allows to aggregate values when down-sampling and to fill or keep the holes when up-sampling. [#2654](https://github.com/unit8co/darts/pull/2654) by [Jonas Blanc](https://github.com/jonasblanc) - Added general function `darts.slice_intersect()` to intersect a sequence of `TimeSeries` along the time index. [#2592](https://github.com/unit8co/darts/pull/2592) by [Yoav Matzkevich](https://github.com/ymatzkevich). - Added new time aggregated metric `wmape()` (Weighted Mean Absolute Percentage Error). [#2544](https://github.com/unit8co/darts/pull/2648) by [He Weilin](https://github.com/cnhwl). +- Added ONNX support for torch-based models, and an example of export and loading for inference in the User Guide. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou) **Fixed** - Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC). diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 9c572e562d..1ad0b60a23 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -93,6 +93,7 @@ def __init__( When subclassing this class, please make sure to add the following methods with the given signatures: - :func:`PLForecastingModule.__init__()` - :func:`PLForecastingModule.forward()` + - :func:`PLForecastingModule._process_input_batch()` - :func:`PLForecastingModule._produce_train_output()` - :func:`PLForecastingModule._get_batch_prediction()` @@ -583,6 +584,13 @@ def to_dtype(self, dtype): logger, ) + def to_onnx(self, file_path, input_sample=None, **kwargs): + if not input_sample: + logger.warning( + "It is recommended to use `TorchForecastingModel.to_onnx` method instead." + ) + super().to_onnx(file_path=file_path, input_sample=input_sample, **kwargs) + @property def epochs_trained(self): current_epoch = self.current_epoch @@ -632,9 +640,41 @@ def _produce_train_output(self, input_batch: tuple): input_batch ``(past_target, past_covariates, static_covariates)`` """ - past_target, past_covariates, static_covariates = input_batch + return self(self._process_input_batch(input_batch)) + + def _process_input_batch( + self, input_batch: tuple + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Converts output of PastCovariatesDataset (training dataset) into an input/past- and + output/future chunk. + + Parameters + ---------- + input_batch + ``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``. + + Returns + ------- + tuple + ``(x_past, x_static)`` the input/past and output/future chunks. + """ + # because of future past covariates, the batch shape is different during training and prediction + if len(input_batch) == 3: + ( + past_target, + past_covariates, + static_covariates, + ) = input_batch + else: + ( + past_target, + past_covariates, + future_past_covariates, + static_covariates, + ) = input_batch # Currently all our PastCovariates models require past target and covariates concatenated - inpt = ( + return ( ( torch.cat([past_target, past_covariates], dim=2) if past_covariates is not None @@ -642,7 +682,6 @@ def _produce_train_output(self, input_batch: tuple): ), static_covariates, ) - return self(inpt) def _get_batch_prediction( self, n: int, input_batch: tuple, roll_size: int @@ -674,12 +713,9 @@ def _get_batch_prediction( past_covariates.shape[dim_component] if past_covariates is not None else 0 ) - input_past = torch.cat( - [ds for ds in [past_target, past_covariates] if ds is not None], - dim=dim_component, - ) + input_past, input_static = self._process_input_batch(input_batch) - out = self._produce_predict_output(x=(input_past, static_covariates))[ + out = self._produce_predict_output(x=(input_past, input_static))[ :, self.first_prediction_index :, : ] @@ -796,7 +832,6 @@ def _process_input_batch( future_covariates, static_covariates, ) = input_batch - dim_variable = 2 x_past = torch.cat( [ @@ -808,7 +843,7 @@ def _process_input_batch( ] if tensor is not None ], - dim=dim_variable, + dim=2, ) return x_past, future_covariates, static_covariates diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index e7d55ac9c6..d51b0a3838 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -104,6 +104,12 @@ def forward( pass def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: + # only return the forecast, not the hidden state + return self(self._process_input_batch(input_batch))[0] + + def _process_input_batch( + self, input_batch: tuple + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ( past_target, historic_future_covariates, @@ -112,7 +118,7 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: ) = input_batch # For the RNN we concatenate the past_target with the future_covariates # (they have the same length because we enforce a Shift dataset for RNNs) - model_input = ( + return ( ( torch.cat([past_target, future_covariates], dim=2) if future_covariates is not None @@ -120,7 +126,6 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: ), static_covariates, ) - return self(model_input)[0] def _produce_predict_output( self, x: tuple, last_hidden_state: Optional[torch.Tensor] = None diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 89ca19401f..cbed45d86f 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -641,6 +641,94 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): logger=logger, ) + def to_onnx( + self, + path: Optional[str] = None, + input_sample: Optional[tuple] = None, + randomize_input_sample: bool = False, + **kwargs, + ): + """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's + :func:`torch.onnx.export` method ()`official documentation `_). + + Note: onnx library (optionnal dependency) must be installed in order to call this method + + Example for exporting a :class:`DLinearModel`: + + .. highlight:: python + .. code-block:: python + + from darts.models import DLinearModel + from darts import TimeSeries + import numpy as np + + train_ts = TimeSeries.from_values(np.arange(0,100)) + model = DLinearModel(input_chunk_length=4, output_chunk_length=1) + model.fit(train_ts, epochs=1) + model.to_onnx("my_model.onnx") + .. + + Parameters + ---------- + path + Path under which to save the model at its current state. If no path is specified, the model + is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.onnx"``. + input_sample + Tuple of Tensor corresponding to the inputs of the model forward pass. + randomize_input_sample + Wether to randomize the values in the `input_sample` to avoid leaking data. + **kwargs + Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a + description of the model being exported to stdout. + For more information, read the `official documentation `_. + """ + if not self._fit_called: + raise_log( + ValueError("`fit()` needs to be called before `to_onnx()`."), logger + ) + + if not self.train_sample and not input_sample: + raise_log( + ValueError( + "Either the `input_sample` argument or the `train_sample` attribute must be provided." + ), + logger, + ) + + if path is None: + path = self._default_save_path() + ".onnx" + + if not input_sample: + # last dimension in train_sample_shape is the expected target + mock_batch = tuple( + torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None + for shape in self.model.train_sample_shape[:-1] + ) + input_sample = self.model._process_input_batch(mock_batch) + elif randomize_input_sample: + input_sample = tuple( + torch.rand(tensor.shape, dtype=self.model.dtype) + if tensor is not None + else None + for tensor in input_sample + ) + + # torch models necessarily use historic target values as features in current implementation + input_names = ["x_past"] + if self._uses_future_covariates: + input_names.append("x_future") + if self._uses_static_covariates: + input_names.append("x_static") + + self.model.to_onnx( + file_path=path, + input_sample=(input_sample,), + input_names=input_names, + **kwargs, + ) + @random_method def fit( self, diff --git a/docs/userguide/torch_forecasting_models.md b/docs/userguide/torch_forecasting_models.md index 54aa6ebbc2..565c50b552 100644 --- a/docs/userguide/torch_forecasting_models.md +++ b/docs/userguide/torch_forecasting_models.md @@ -22,6 +22,7 @@ We assume that you already know about covariates in Darts. If you're new to the - [Manual saving / loading](#manual-saving--loading) - [Train & save on GPU, load on CPU](#trainingsaving-on-gpu-and-loading-on-cpu) - [Load pre-trained model for fine-tuning](#re-training-or-fine-tuning-a-pre-trained-model) + - [Exporting model to ONNX format for inference](#exporting-model-to-ONNX-format-for-inference) - [Callbacks](#callbacks) - [Early Stopping](#example-with-early-stopping) - [Custom Callback](#example-of-custom-callback-to-store-losses) @@ -350,6 +351,87 @@ model_finetune = SomeTorchForecastingModel(..., # use identical parameters & va model_finetune.load_weights("/your/path/to/save/model.pt") ``` +#### Exporting model to ONNX format for inference + +It is also possible to export the model weights to the ONNX format to run inference in a lightweight environment. This example assumes that the model is trained using a future covariates that extends far enough into the future. Note that the user must align and slice the series manually and it will not be possible to forecast `n > output_chunk_length` without implementing the auto-regression logic + +```python +model = SomeTorchForecastingModel(...) +model.fit(...) + +# make sure to have onnx installed +onnx_filename = "example_onnx.onnx" +model.to_onnx(onnx_filename, export_params=True) +``` + +Now, to load the model and predict steps after the end of the series: + +```python +import onnx +import onnxruntime as ort + +def prepare_onnx_inputs( + model, + series: TimeSeries, + past_covariates : Optional[TimeSeries] = None, + future_covariates : Optional[TimeSeries] = None, + ) -> tuple[Optional[np.ndarray]]: + """Helper function to slice and concatenate the input features""" + past_feats, future_feats, static_feats = None, None, None + # convert and concatenate the historic features (target, past and future covariates) + past_feats = series.values()[-model.input_chunk_length:] + if past_covariates: + past_feats = np.concatenate( + [ + past_feats, + past_covariates.values()[-model.input_chunk_length:] + ], + axis=1 + ) + if future_covariates: + past_feats = np.concatenate( + [ + past_feats, + future_covariates.values()[-model.input_chunk_length:] + ], + axis=1 + ) + past_feats = np.expand_dims(past_feats, axis=0) + + # convert the future covariates + if model._uses_future_covariates: + if future_covariates: + future_feats = np.expand_dims(future_covariates.values()[ + len(series):len(series)+model.output_chunk_length + ], axis=0) + + # convert static covariates + if series.has_static_covariates: + static_feats = np.expand_dims(series.static_covariates_values(), axis=0) + + return past_feats, future_feats, static_feats + +onnx_model = onnx.load(onnx_filename) +onnx.checker.check_model(onnx_model) +ort_session = ort.InferenceSession(onnx_filename) + +# use helper function to extract the features from the series +past_feats, future_feats, static_feats = prepare_input_feats( + model=model, + series=series, + past_covariates = None, + future_covariates = ts_future, +) + +# extract only the features expected by the model +ort_inputs = { + k:v for k, v in zip(['x_past', 'x_future', 'x_static'], [past_feats, future_feats, static_feats]) if k in [inp.name for inp in list(ort_session.get_inputs())] + } +ort_outs = ort_session.run(None, ort_inputs) +``` + +Note that the forecasts might be slightly different due to rounding errors. Also, due to its specificities, `RNNModel` requires different pre-processing of the series to obtain the input arrays (notably because of `training_length`). + ### Callbacks Callbacks are a powerful way to monitor or control the behavior of the model during the training process. Some examples: