From 551f9839a297f3abeb9d2ee61b881800ae59c850 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Wed, 17 May 2023 17:36:52 +0200 Subject: [PATCH 1/9] feat: wrapping around pl.to_onnx to export models to ONNX, still require testing --- darts/models/forecasting/rnn_model.py | 58 +++++++ .../forecasting/torch_forecasting_model.py | 143 ++++++++++++++++++ 2 files changed, 201 insertions(+) diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index e3996927ea..f48bd898df 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -470,6 +470,64 @@ def _verify_train_dataset_type(self, train_dataset: TrainingDataset): "RNNModel requires a shifted training dataset with shift=1.", ) + def to_onnx(self, path: Optional[str] = None, **kwargs): + """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's + :func:`torch.onnx.export` method ()`official documentation `_). + + Note: onnx library (optionnal dependency) must be installed in order to call this method + + Parameters + ---------- + path + Path under which to save the model at its current state. Please avoid path starting with "last-" or + "best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model + is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``. + **kwargs + Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a + description of the model being exported to stdout. + For more information, read the `official documentation `_. + """ + raise_if_not( + self._fit_called, "`fit()` needs to be called before `to_onnx()`.", logger + ) + + if path is None: + # default path + path = self._default_save_path() + ".onnx" + + # mimic preprocessing performed by RNNModule_._get_batch_prediction() + ( + past_target, + past_covariates, + historic_future_covariates, + future_covariates, + future_past_covariates, + ) = ( + torch.Tensor(x).unsqueeze(0) if x is not None else None + for x in self.train_sample + ) + + if historic_future_covariates is not None: + # RNNs need as inputs (target[t] and covariates[t+1]) so here we shift the covariates + all_covariates = torch.cat( + [historic_future_covariates[:, 1:, :], future_covariates], dim=1 + ) + cov_past, _ = ( + all_covariates[:, : past_target.shape[1], :], + all_covariates[:, past_target.shape[1] :, :], + ) + input_past = torch.cat([past_target, cov_past], dim=2) + else: + input_past = past_target + + input_sample = [ + input_past.double(), + future_covariates.double() if future_covariates is not None else None, + ] + self.model.to_onnx(file_path=path, input_sample=input_sample, **kwargs) + @property def min_train_series_length(self) -> int: return self.training_length + 1 diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index cf5f777f11..e213c3efef 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -590,6 +590,11 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): """ pass + @abstractmethod + def to_onnx(self, path: Optional[str] = None, **kwargs): + """In charge of generating a dummy input sample and exporting the model to ONNX.""" + pass + def _verify_static_covariates(self, static_covariates: Optional[pd.DataFrame]): """ Verify that all static covariates are numeric. @@ -2077,6 +2082,62 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): "support only past_covariates.", ) + def to_onnx(self, path: Optional[str] = None, **kwargs): + """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's + :func:`torch.onnx.export` method ()`official documentation `_). + + Note: onnx library (optionnal dependency) must be installed in order to call this method + + Example for exporting a :class:`TCNModel`: + + .. highlight:: python + .. code-block:: python + + from darts.models import TCNModel + from darts import TimeSeries + import numpy as np + + train_ts = TimeSeries.from_values(np.arange(0,100)) + model = TCNModel(input_chunk_length=4, output_chunk_length=1) + model.fit(train_ts, epochs=1) + model.to_onnx("my_model.onnx") + .. + + Parameters + ---------- + path + Path under which to save the model at its current state. Please avoid path starting with "last-" or + "best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model + is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``. + **kwargs + Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a + description of the model being exported to stdout. + For more information, read the `official documentation `_. + """ + raise_if_not( + self._fit_called, "`fit()` needs to be called before `to_onnx()`.", logger + ) + + if path is None: + # default path + path = self._default_save_path() + ".onnx" + + # mimic preprocessing performed by PLPastCovariatesModule._get_batch_prediction() + (past_target, past_covariates, future_past_covariates, static_covariates,) = ( + torch.Tensor(x).unsqueeze(0) if x is not None else None + for x in self.train_sample + ) + + input_past = torch.cat( + [tensor for tensor in [past_target, past_covariates] if tensor is not None], + dim=2, + ) + + input_sample = [input_past.double(), static_covariates.double()] + self.model.to_onnx(file_path=path, input_sample=input_sample, **kwargs) + @property def _model_encoder_settings( self, @@ -2175,6 +2236,9 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): "support only future_covariates.", ) + def to_onnx(self, path: Optional[str] = None, **kwargs): + raise NotImplementedError("TBD: Darts doesn't contain such a model yet.") + @property def _model_encoder_settings( self, @@ -2264,6 +2328,11 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): "support only future_covariates.", ) + def to_onnx(self, path: Optional[str] = None, **kwargs): + raise NotImplementedError( + "TBD: The only DualCovariatesModel is an RNN with a specific implementation." + ) + @property def _model_encoder_settings( self, @@ -2350,6 +2419,77 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): # both covariates are supported; do nothing pass + def to_onnx(self, path: Optional[str] = None, **kwargs): + """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's + :func:`torch.onnx.export` method ()`official documentation `_). + + Note: onnx library (optionnal dependency) must be installed in order to call this method + + Example for exporting a :class:`DLinearModel`: + + .. highlight:: python + .. code-block:: python + + from darts.models import DLinearModel + from darts import TimeSeries + import numpy as np + + train_ts = TimeSeries.from_values(np.arange(0,100)) + model = DLinearModel(input_chunk_length=4, output_chunk_length=1) + model.fit(train_ts, epochs=1) + model.to_onnx("my_model.onnx") + .. + + Parameters + ---------- + path + Path under which to save the model at its current state. Please avoid path starting with "last-" or + "best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model + is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``. + **kwargs + Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a + description of the model being exported to stdout. + For more information, read the `official documentation `_. + """ + raise_if_not( + self._fit_called, "`fit()` needs to be called before `to_onnx()`.", logger + ) + + if path is None: + # default path + path = self._default_save_path() + ".onnx" + + # mimic preprocessing performed by PLPastCovariatesModule._get_batch_prediction() + ( + past_target, + past_covariates, + historic_future_covariates, + future_covariates, + future_past_covariates, + static_covariates, + ) = ( + torch.Tensor(x).unsqueeze(0) if x is not None else None + for x in self.train_sample + ) + + input_past = torch.cat( + [ + tensor + for tensor in [past_target, past_covariates, historic_future_covariates] + if tensor is not None + ], + dim=2, + ) + + input_sample = [ + input_past.double(), + future_covariates.double() if future_covariates is not None else None, + static_covariates.double(), + ] + self.model.to_onnx(file_path=path, input_sample=input_sample, **kwargs) + @property def _model_encoder_settings( self, @@ -2437,6 +2577,9 @@ def _verify_predict_sample(self, predict_sample: Tuple): # TODO: we have to check both past and future covariates raise NotImplementedError() + def to_onnx(self, path: Optional[str] = None, **kwargs): + raise NotImplementedError("TBD: Darts doesn't contain such a model yet.") + @property def _model_encoder_settings( self, From 59a5849b6840fac49c0dc4fe254f1d391f67383e Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 16 Dec 2024 22:30:10 +0100 Subject: [PATCH 2/9] feat: cleaned implementation of the to_onnx method --- .../forecasting/pl_forecasting_module.py | 30 ++- .../forecasting/torch_forecasting_model.py | 221 ++++++------------ 2 files changed, 97 insertions(+), 154 deletions(-) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 9c572e562d..7ca114cb38 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -93,6 +93,7 @@ def __init__( When subclassing this class, please make sure to add the following methods with the given signatures: - :func:`PLForecastingModule.__init__()` - :func:`PLForecastingModule.forward()` + - :func:`PLForecastingModule._process_input_batch()` - :func:`PLForecastingModule._produce_train_output()` - :func:`PLForecastingModule._get_batch_prediction()` @@ -632,9 +633,28 @@ def _produce_train_output(self, input_batch: tuple): input_batch ``(past_target, past_covariates, static_covariates)`` """ + return self(self._process_input_batch(input_batch)) + + def _process_input_batch( + self, input_batch: tuple + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Converts output of PastCovariatesDataset (training dataset) into an input/past- and + output/future chunk. + + Parameters + ---------- + input_batch + ``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``. + + Returns + ------- + tuple + ``(x_past, x_static)`` the input/past and output/future chunks. + """ past_target, past_covariates, static_covariates = input_batch # Currently all our PastCovariates models require past target and covariates concatenated - inpt = ( + return ( ( torch.cat([past_target, past_covariates], dim=2) if past_covariates is not None @@ -642,7 +662,6 @@ def _produce_train_output(self, input_batch: tuple): ), static_covariates, ) - return self(inpt) def _get_batch_prediction( self, n: int, input_batch: tuple, roll_size: int @@ -674,12 +693,9 @@ def _get_batch_prediction( past_covariates.shape[dim_component] if past_covariates is not None else 0 ) - input_past = torch.cat( - [ds for ds in [past_target, past_covariates] if ds is not None], - dim=dim_component, - ) + input_past, input_static = self._process_input_batch(input_batch) - out = self._produce_predict_output(x=(input_past, static_covariates))[ + out = self._produce_predict_output(x=(input_past, input_static))[ :, self.first_prediction_index :, : ] diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 20b035eea9..fcd640919e 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -644,10 +644,80 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): logger=logger, ) - @abstractmethod - def to_onnx(self, path: Optional[str] = None, **kwargs): - """In charge of generating a dummy input sample and exporting the model to ONNX.""" - pass + def to_onnx( + self, + path: Optional[str] = None, + input_sample: Optional[tuple] = None, + randomize_input_sample: bool = False, + **kwargs, + ): + """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's + :func:`torch.onnx.export` method ()`official documentation `_). + + Note: onnx library (optionnal dependency) must be installed in order to call this method + + Example for exporting a :class:`DLinearModel`: + + .. highlight:: python + .. code-block:: python + + from darts.models import DLinearModel + from darts import TimeSeries + import numpy as np + + train_ts = TimeSeries.from_values(np.arange(0,100)) + model = DLinearModel(input_chunk_length=4, output_chunk_length=1) + model.fit(train_ts, epochs=1) + model.to_onnx("my_model.onnx") + .. + + Parameters + ---------- + path + Path under which to save the model at its current state. If no path is specified, the model + is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.onnx"``. + input_sample + Tuple of Tensor corresponding to the inputs of the model forward pass. + randomize_input_sample + Wether to randomize the values in the `input_sample` to avoid leaking data. + **kwargs + Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a + description of the model being exported to stdout. + For more information, read the `official documentation `_. + """ + if not self._fit_called: + raise_log( + ValueError("`fit()` needs to be called before `to_onnx()`."), logger + ) + + if not self.train_sample and not input_sample: + raise_log( + ValueError( + "Either the `input_sample` argument or the `train_sample` attribute must be provided." + ), + logger, + ) + + if path is None: + path = self._default_save_path() + ".onnx" + + if not input_sample: + mock_batch = tuple( + torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None + for shape in self.model.train_sample_shape[:-1] + ) + input_sample = self.model._process_input_batch(mock_batch) + elif randomize_input_sample: + input_sample = tuple( + torch.rand(tensor.shape, dtype=self.model.dtype) + if tensor is not None + else None + for tensor in input_sample + ) + # TODO: define input names, depending on the class of the model + self.model.to_onnx(file_path=path, input_sample=(input_sample,), **kwargs) @random_method def fit( @@ -2545,67 +2615,6 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): "support only past_covariates.", ) - def to_onnx(self, path: Optional[str] = None, **kwargs): - """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's - :func:`torch.onnx.export` method ()`official documentation `_). - - Note: onnx library (optionnal dependency) must be installed in order to call this method - - Example for exporting a :class:`TCNModel`: - - .. highlight:: python - .. code-block:: python - - from darts.models import TCNModel - from darts import TimeSeries - import numpy as np - - train_ts = TimeSeries.from_values(np.arange(0,100)) - model = TCNModel(input_chunk_length=4, output_chunk_length=1) - model.fit(train_ts, epochs=1) - model.to_onnx("my_model.onnx") - .. - - Parameters - ---------- - path - Path under which to save the model at its current state. Please avoid path starting with "last-" or - "best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model - is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``. - **kwargs - Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a - description of the model being exported to stdout. - For more information, read the `official documentation `_. - """ - raise_if_not( - self._fit_called, "`fit()` needs to be called before `to_onnx()`.", logger - ) - - if path is None: - # default path - path = self._default_save_path() + ".onnx" - - # mimic preprocessing performed by PLPastCovariatesModule._get_batch_prediction() - ( - past_target, - past_covariates, - future_past_covariates, - static_covariates, - ) = ( - torch.Tensor(x).unsqueeze(0) if x is not None else None - for x in self.train_sample - ) - - input_past = torch.cat( - [tensor for tensor in [past_target, past_covariates] if tensor is not None], - dim=2, - ) - - input_sample = [input_past.double(), static_covariates.double()] - self.model.to_onnx(file_path=path, input_sample=input_sample, **kwargs) - @property def _model_encoder_settings( self, @@ -2706,9 +2715,6 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): "support only future_covariates.", ) - def to_onnx(self, path: Optional[str] = None, **kwargs): - raise NotImplementedError("TBD: Darts doesn't contain such a model yet.") - @property def _model_encoder_settings( self, @@ -2810,11 +2816,6 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): "support only future_covariates.", ) - def to_onnx(self, path: Optional[str] = None, **kwargs): - raise NotImplementedError( - "TBD: The only DualCovariatesModel is an RNN with a specific implementation." - ) - @property def _model_encoder_settings( self, @@ -2913,77 +2914,6 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates): # both covariates are supported; do nothing pass - def to_onnx(self, path: Optional[str] = None, **kwargs): - """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's - :func:`torch.onnx.export` method ()`official documentation `_). - - Note: onnx library (optionnal dependency) must be installed in order to call this method - - Example for exporting a :class:`DLinearModel`: - - .. highlight:: python - .. code-block:: python - - from darts.models import DLinearModel - from darts import TimeSeries - import numpy as np - - train_ts = TimeSeries.from_values(np.arange(0,100)) - model = DLinearModel(input_chunk_length=4, output_chunk_length=1) - model.fit(train_ts, epochs=1) - model.to_onnx("my_model.onnx") - .. - - Parameters - ---------- - path - Path under which to save the model at its current state. Please avoid path starting with "last-" or - "best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model - is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``. - **kwargs - Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a - description of the model being exported to stdout. - For more information, read the `official documentation `_. - """ - raise_if_not( - self._fit_called, "`fit()` needs to be called before `to_onnx()`.", logger - ) - - if path is None: - # default path - path = self._default_save_path() + ".onnx" - - # mimic preprocessing performed by PLPastCovariatesModule._get_batch_prediction() - ( - past_target, - past_covariates, - historic_future_covariates, - future_covariates, - future_past_covariates, - static_covariates, - ) = ( - torch.Tensor(x).unsqueeze(0) if x is not None else None - for x in self.train_sample - ) - - input_past = torch.cat( - [ - tensor - for tensor in [past_target, past_covariates, historic_future_covariates] - if tensor is not None - ], - dim=2, - ) - - input_sample = [ - input_past.double(), - future_covariates.double() if future_covariates is not None else None, - static_covariates.double(), - ] - self.model.to_onnx(file_path=path, input_sample=input_sample, **kwargs) - @property def _model_encoder_settings( self, @@ -3079,9 +3009,6 @@ def _verify_predict_sample(self, predict_sample: tuple): # TODO: we have to check both past and future covariates raise NotImplementedError() - def to_onnx(self, path: Optional[str] = None, **kwargs): - raise NotImplementedError("TBD: Darts doesn't contain such a model yet.") - @property def _model_encoder_settings( self, From 0c98eba02e804a51964e548fdfe1cafdaa39c61f Mon Sep 17 00:00:00 2001 From: madtoinou Date: Tue, 17 Dec 2024 01:00:54 +0100 Subject: [PATCH 3/9] fix: generation of input name, shape of input_batch for PastCov torch module --- .../forecasting/pl_forecasting_module.py | 18 +++++++++++++++--- .../forecasting/torch_forecasting_model.py | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 7ca114cb38..4582bf0647 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -652,7 +652,20 @@ def _process_input_batch( tuple ``(x_past, x_static)`` the input/past and output/future chunks. """ - past_target, past_covariates, static_covariates = input_batch + # because of future past covariates, the batch shape is different during training and prediction + if len(input_batch) == 3: + ( + past_target, + past_covariates, + static_covariates, + ) = input_batch + else: + ( + past_target, + past_covariates, + future_past_covariates, + static_covariates, + ) = input_batch # Currently all our PastCovariates models require past target and covariates concatenated return ( ( @@ -812,7 +825,6 @@ def _process_input_batch( future_covariates, static_covariates, ) = input_batch - dim_variable = 2 x_past = torch.cat( [ @@ -824,7 +836,7 @@ def _process_input_batch( ] if tensor is not None ], - dim=dim_variable, + dim=2, ) return x_past, future_covariates, static_covariates diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index fcd640919e..f5c236c929 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -704,6 +704,7 @@ def to_onnx( path = self._default_save_path() + ".onnx" if not input_sample: + # last dimension in train_sample_shape is the expected target mock_batch = tuple( torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None for shape in self.model.train_sample_shape[:-1] @@ -716,8 +717,20 @@ def to_onnx( else None for tensor in input_sample ) - # TODO: define input names, depending on the class of the model - self.model.to_onnx(file_path=path, input_sample=(input_sample,), **kwargs) + + # torch models necessarily use historic target values as features in current implementation + input_names = ["x_past"] + if self._uses_future_covariates: + input_names.append("x_future") + if self._uses_static_covariates: + input_names.append("x_static") + + self.model.to_onnx( + file_path=path, + input_sample=(input_sample,), + input_names=input_names, + **kwargs, + ) @random_method def fit( From d569b1dfb7cefce11be06def7e09bba0e5b40d6b Mon Sep 17 00:00:00 2001 From: madtoinou Date: Tue, 17 Dec 2024 01:41:51 +0100 Subject: [PATCH 4/9] feat: adding example of onnx usage in userguide --- .../forecasting/pl_forecasting_module.py | 7 ++ docs/userguide/torch_forecasting_models.md | 90 +++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 4582bf0647..1ad0b60a23 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -584,6 +584,13 @@ def to_dtype(self, dtype): logger, ) + def to_onnx(self, file_path, input_sample=None, **kwargs): + if not input_sample: + logger.warning( + "It is recommended to use `TorchForecastingModel.to_onnx` method instead." + ) + super().to_onnx(file_path=file_path, input_sample=input_sample, **kwargs) + @property def epochs_trained(self): current_epoch = self.current_epoch diff --git a/docs/userguide/torch_forecasting_models.md b/docs/userguide/torch_forecasting_models.md index 54aa6ebbc2..cb3fa59d06 100644 --- a/docs/userguide/torch_forecasting_models.md +++ b/docs/userguide/torch_forecasting_models.md @@ -22,6 +22,7 @@ We assume that you already know about covariates in Darts. If you're new to the - [Manual saving / loading](#manual-saving--loading) - [Train & save on GPU, load on CPU](#trainingsaving-on-gpu-and-loading-on-cpu) - [Load pre-trained model for fine-tuning](#re-training-or-fine-tuning-a-pre-trained-model) + - [Exporting model to ONNX format for inference](#exporting-model-to-ONNX-format-for-inference) - [Callbacks](#callbacks) - [Early Stopping](#example-with-early-stopping) - [Custom Callback](#example-of-custom-callback-to-store-losses) @@ -350,6 +351,95 @@ model_finetune = SomeTorchForecastingModel(..., # use identical parameters & va model_finetune.load_weights("/your/path/to/save/model.pt") ``` +#### Exporting model to ONNX format for inference + +It is also possible to export the model weights to the ONNX format to run inference in a lightweight environment. This example assumes that the model is trained using a future covariates that extends far enough into the future. Note that the user must align and slice the series manually and it will not be possible to forecast `n > output_chunk_length` without implementing the auto-regression logic + +```python +model = SomeTorchForecastingModel(...) +model.fit(...) + +# make sure to have onnx installed +onnx_filename = "example_onnx.onnx" +model.to_onnx(onnx_filename, export_params=True) +``` + +Now, to load the model and predict steps after the end of the series: + +```python +import onnx +import onnxruntime as ort + +def prepare_onnx_inputs( + model, + series: TimeSeries, + past_covariates : Optional[TimeSeries] = None, + future_covariates : Optional[TimeSeries] = None, + ) -> tuple[Optional[np.ndarray]]: + """Helper function to slice and concatenate the input features""" + past_feats, future_feats, static_feats = None, None, None + if forecast_start_position > 0: + raise_log( + ValueError("`forecast_start_position` must be <= 0"), + logger=logger + ) + + # convert and concatenate the historic features (target, past and future covariates) + past_feats = series.values()[-model.input_chunk_length:] + if past_covariates: + past_feats = np.concatenate( + [ + past_feats, + past_covariates.values()[-model.input_chunk_length:] + ], + axis=1 + ) + if future_covariates: + past_feats = np.concatenate( + [ + past_feats, + future_covariates.values()[-model.input_chunk_length:] + ], + axis=1 + ) + past_feats = np.expand_dims(past_feats, axis=0) + + # convert the future covariates + if model._uses_future_covariates: + if future_covariates: + future_feats = np.expand_dims(future_covariates.values()[ + len(series):len(series)+model.output_chunk_length + ], axis=0) + else: + future_feats = None + + # convert static covariates + if series.has_static_covariates: + static_feats = np.expand_dims(series.static_covariates_values(), axis=0) + + return past_feats, future_feats, static_feats + +onnx_model = onnx.load(onnx_filename) +onnx.checker.check_model(onnx_model) +ort_session = ort.InferenceSession(onnx_filename) + +# use helper function to extract the features from the series +past_feats, future_feats, static_feats = prepare_input_feats( + model=model, + series=series, + past_covariates = None, + future_covariates = ts_future, +) + +# extract only the features expected by the model +ort_inputs = { + k:v for k, v in zip(['x_past', 'x_future', 'x_static'], [past_feats, future_feats, static_feats]) if k in [inp.name for inp in list(ort_session.get_inputs())] + } +ort_outs = ort_session.run(None, ort_inputs) +``` + +Note that the forecasts might be slightly different due to rounding errors. + ### Callbacks Callbacks are a powerful way to monitor or control the behavior of the model during the training process. Some examples: From c2c00f3aa1c2090af923f92bc10b765f969c1b80 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Tue, 17 Dec 2024 09:53:49 +0100 Subject: [PATCH 5/9] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2428a90127..e0f2ebcf6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader). - Added `data_transformers` argument to `historical_forecasts`, `backtest`, `residuals`, and `gridsearch` that allow to automatically apply `DataTransformer` and/or `Pipeline` to the input series without data-leakage (fit on historic window of input series, transform the input series, and inverse transform the forecasts). [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou) and [Jan Fidor](https://github.com/JanFidor) - Added `series_idx` argument to `DataTransformer` that allows users to use only a subset of the transformers when `global_fit=False` and severals series are used. [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou) +- Added ONNX support for torch-based models, and an example of export and loading for inference in the User Guide. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou) - Updated the Documentation URL of `Statsforecast` models. [#2610](https://github.com/unit8co/darts/pull/2610) by [He Weilin](https://github.com/cnhwl). **Fixed** From f376f09154f669f075ba299462f625b4784dba2c Mon Sep 17 00:00:00 2001 From: madtoinou Date: Tue, 17 Dec 2024 10:15:56 +0100 Subject: [PATCH 6/9] fix: revert some changes --- darts/models/forecasting/rnn_model.py | 58 ------------------- .../forecasting/torch_forecasting_model.py | 25 -------- 2 files changed, 83 deletions(-) diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index 20347a67be..e7d55ac9c6 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -587,64 +587,6 @@ def _verify_train_dataset_type(self, train_dataset: TrainingDataset): "RNNModel requires a shifted training dataset with shift=1.", ) - def to_onnx(self, path: Optional[str] = None, **kwargs): - """Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's - :func:`torch.onnx.export` method ()`official documentation `_). - - Note: onnx library (optionnal dependency) must be installed in order to call this method - - Parameters - ---------- - path - Path under which to save the model at its current state. Please avoid path starting with "last-" or - "best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model - is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt"``. - **kwargs - Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a - description of the model being exported to stdout. - For more information, read the `official documentation `_. - """ - raise_if_not( - self._fit_called, "`fit()` needs to be called before `to_onnx()`.", logger - ) - - if path is None: - # default path - path = self._default_save_path() + ".onnx" - - # mimic preprocessing performed by RNNModule_._get_batch_prediction() - ( - past_target, - past_covariates, - historic_future_covariates, - future_covariates, - future_past_covariates, - ) = ( - torch.Tensor(x).unsqueeze(0) if x is not None else None - for x in self.train_sample - ) - - if historic_future_covariates is not None: - # RNNs need as inputs (target[t] and covariates[t+1]) so here we shift the covariates - all_covariates = torch.cat( - [historic_future_covariates[:, 1:, :], future_covariates], dim=1 - ) - cov_past, _ = ( - all_covariates[:, : past_target.shape[1], :], - all_covariates[:, past_target.shape[1] :, :], - ) - input_past = torch.cat([past_target, cov_past], dim=2) - else: - input_past = past_target - - input_sample = [ - input_past.double(), - future_covariates.double() if future_covariates is not None else None, - ] - self.model.to_onnx(file_path=path, input_sample=input_sample, **kwargs) - @property def supports_multivariate(self) -> bool: return True diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index f5c236c929..274c12b7e9 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -2621,13 +2621,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset): def _verify_predict_sample(self, predict_sample: tuple): _basic_compare_sample(self.train_sample, predict_sample) - def _verify_past_future_covariates(self, past_covariates, future_covariates): - raise_if_not( - future_covariates is None, - "Some future_covariates have been provided to a PastCovariates model. These models " - "support only past_covariates.", - ) - @property def _model_encoder_settings( self, @@ -2721,13 +2714,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset): def _verify_predict_sample(self, predict_sample: tuple): _basic_compare_sample(self.train_sample, predict_sample) - def _verify_past_future_covariates(self, past_covariates, future_covariates): - raise_if_not( - past_covariates is None, - "Some past_covariates have been provided to a PastCovariates model. These models " - "support only future_covariates.", - ) - @property def _model_encoder_settings( self, @@ -2822,13 +2808,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset): def _verify_predict_sample(self, predict_sample: tuple): _basic_compare_sample(self.train_sample, predict_sample) - def _verify_past_future_covariates(self, past_covariates, future_covariates): - raise_if_not( - past_covariates is None, - "Some past_covariates have been provided to a DualCovariates Torch model. These models " - "support only future_covariates.", - ) - @property def _model_encoder_settings( self, @@ -2923,10 +2902,6 @@ def _verify_inference_dataset_type(self, inference_dataset: InferenceDataset): def _verify_predict_sample(self, predict_sample: tuple): _mixed_compare_sample(self.train_sample, predict_sample) - def _verify_past_future_covariates(self, past_covariates, future_covariates): - # both covariates are supported; do nothing - pass - @property def _model_encoder_settings( self, From bb766296f4510eec92b02167e9e71a86b480a6f0 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Tue, 17 Dec 2024 10:47:48 +0100 Subject: [PATCH 7/9] fix: export to onnx for RNNModel --- darts/models/forecasting/rnn_model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index e7d55ac9c6..d51b0a3838 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -104,6 +104,12 @@ def forward( pass def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: + # only return the forecast, not the hidden state + return self(self._process_input_batch(input_batch))[0] + + def _process_input_batch( + self, input_batch: tuple + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ( past_target, historic_future_covariates, @@ -112,7 +118,7 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: ) = input_batch # For the RNN we concatenate the past_target with the future_covariates # (they have the same length because we enforce a Shift dataset for RNNs) - model_input = ( + return ( ( torch.cat([past_target, future_covariates], dim=2) if future_covariates is not None @@ -120,7 +126,6 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor: ), static_covariates, ) - return self(model_input)[0] def _produce_predict_output( self, x: tuple, last_hidden_state: Optional[torch.Tensor] = None From 22ea061258c55b377f80ede90fa57335105b10bc Mon Sep 17 00:00:00 2001 From: madtoinou Date: Tue, 17 Dec 2024 11:10:52 +0100 Subject: [PATCH 8/9] feat: added a comment about RNNModel for onnx inference --- docs/userguide/torch_forecasting_models.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/docs/userguide/torch_forecasting_models.md b/docs/userguide/torch_forecasting_models.md index cb3fa59d06..565c50b552 100644 --- a/docs/userguide/torch_forecasting_models.md +++ b/docs/userguide/torch_forecasting_models.md @@ -378,12 +378,6 @@ def prepare_onnx_inputs( ) -> tuple[Optional[np.ndarray]]: """Helper function to slice and concatenate the input features""" past_feats, future_feats, static_feats = None, None, None - if forecast_start_position > 0: - raise_log( - ValueError("`forecast_start_position` must be <= 0"), - logger=logger - ) - # convert and concatenate the historic features (target, past and future covariates) past_feats = series.values()[-model.input_chunk_length:] if past_covariates: @@ -410,8 +404,6 @@ def prepare_onnx_inputs( future_feats = np.expand_dims(future_covariates.values()[ len(series):len(series)+model.output_chunk_length ], axis=0) - else: - future_feats = None # convert static covariates if series.has_static_covariates: @@ -438,7 +430,7 @@ ort_inputs = { ort_outs = ort_session.run(None, ort_inputs) ``` -Note that the forecasts might be slightly different due to rounding errors. +Note that the forecasts might be slightly different due to rounding errors. Also, due to its specificities, `RNNModel` requires different pre-processing of the series to obtain the input arrays (notably because of `training_length`). ### Callbacks From 2e1b9aff5345c05156407075877ce4b84f40bec6 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 30 Dec 2024 12:14:59 +0100 Subject: [PATCH 9/9] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ad1b44168..9481897187 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** - New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl). +- Added ONNX support for torch-based models, and an example of export and loading for inference in the User Guide. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou) **Fixed** @@ -47,7 +48,6 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Interval Coverage `ic()` (binary if observation is within the quantile interval), and Mean Interval Coverage `mic()` (time-aggregated) - Interval Non-Conformity Score for Quantile Regression `incs_qr()`, and Mean ... `mincs_qr()` (time-aggregated) ([source](https://arxiv.org/pdf/1905.03222)) - Added `series_idx` argument to `DataTransformer` that allows users to use only a subset of the transformers when `global_fit=False` and severals series are used. [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou) -- Added ONNX support for torch-based models, and an example of export and loading for inference in the User Guide. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou) - Updated the Documentation URL of `Statsforecast` models. [#2610](https://github.com/unit8co/darts/pull/2610) by [He Weilin](https://github.com/cnhwl). **Fixed**