From d60ef8702b3a9520c43de3b04306bdbd9bebbf4b Mon Sep 17 00:00:00 2001 From: madtoinou <32447896+madtoinou@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:50:38 +0200 Subject: [PATCH] Fix/deprec nn (#2593) * Fix deprecated usage of torch.nn.utils.weight_norm The previous implementation in darts.darts.models.forecasting.tcn_mode was using `torch.nn.utils.weight_norm`, which is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`. This commit replaces two occurrences of `torch.nn.utils.weight_norm` with the recommended `torch.nn.utils.parametrizations.weight_norm` to resolve the deprecation warning. * Update torch_forecasting_model.py Corrected file saving process for checkpoint files (ckpt) to filter out occurrences of the string '.pt' from the previous file path." * fix: revert changes * update changelog --------- Co-authored-by: Saeed Foroutan --- CHANGELOG.md | 3 ++- darts/models/forecasting/tcn_model.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66f6f90538..4bf257e993 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - fixed failing docker deployment - removed `gradle` dependency in favor of native GitHub action plugins. - Updated ruff to v0.7.2 and target-version to python39, also fixed various typos [#2589](https://github.com/unit8co/darts/pull/2589) by [Greg DeVosNouri](https://github.com/gdevos010) and [Antoine Madrona](https://github.com/madtoinou). +- Replaced the deprecated `torch.nn.utils.weight_norm` function with `torch.nn.utils.parametrizations.weight_norm` [#2593](https://github.com/unit8co/darts/pull/2593) by [Saeed Foroutan](https://github.com/SaeedForoutan). ## [0.31.0](https://github.com/unit8co/darts/tree/0.31.0) (2024-10-13) @@ -40,7 +41,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Improvements to `metrics`: - Added support for computing metrics on one or multiple quantiles `q`, either from probabilistic or quantile forecasts. [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader). - - Added quantile interval metrics `miw` (Mean Interval Width, time aggregated) and `iw` (Interval Width, per time step / non-aggregated) which compute the width of quantile intervals `q_intervals` (expected to be a tuple or sequence of tuples with (lower quantile, upper quantile). [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader). + - Added quantile interval metrics `miw` (Mean Interval Width, time aggregated) and `iw` (Interval Width, per time step / non-aggregated) which compute the width of quantile intervals `q_intervals` (expected to be a tuple or sequence of tuples with (lower quantile, upper quantile)). [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader). - Improvements to `backtest()` and `residuals()`: - Added support for computing backtest and residuals on one or multiple quantiles `q` in the `metric_kwargs`, either from probabilistic or quantile forecasts. [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader). - Added support for parameters `enable_optimization` and `predict_likelihood_parameters`. [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader). diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index a8d7d93e9a..3d66b4613d 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -99,8 +99,8 @@ def __init__( ) if weight_norm: self.conv1, self.conv2 = ( - nn.utils.weight_norm(self.conv1), - nn.utils.weight_norm(self.conv2), + nn.utils.parametrizations.weight_norm(self.conv1), + nn.utils.parametrizations.weight_norm(self.conv2), ) if input_dim != output_dim: