You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am using the NHITS model within a custom wrapper and I have problems when trying to set max_steps.
The trainer_kwargs are correctly initialized during model instantiation. However, during the fit process (upon calling the forward method), the trainer object reverts to its default settings, ignoring the user-specified trainer_kwargs.
I have encountered this problem also when setting different early stopping criteria or validation checks steps.
Despite setting different parameters, every execution completes 100 epochs with no early stopping callbacks.
This does not happen when using the NHITS model in combination with the NeuralForecast wrapper (NeuralForecast(model=NHITS(...), freq=...)), where the trainer_kwargs are passed and applied correctly, but unfortunately this solution is not suitable for my situation.
Has anyone ever had this problem? Is there a workaround?
Versions / Dependencies
neuralforecast==1.7.5
pytorch-lightning==2.4.0
python 3.10.15
Reproduction script
These are the methods involved, this script is not able to reproduce the issue.
The code is part of a larger framework that extends the Darts python library.
Model fitting is going fine but whatever kwargs I pass to the trainer, the forward method resets them to default.
What happened + What you expected to happen
I am using the NHITS model within a custom wrapper and I have problems when trying to set max_steps.
The trainer_kwargs are correctly initialized during model instantiation. However, during the fit process (upon calling the forward method), the trainer object reverts to its default settings, ignoring the user-specified trainer_kwargs.
I have encountered this problem also when setting different early stopping criteria or validation checks steps.
Despite setting different parameters, every execution completes 100 epochs with no early stopping callbacks.
This does not happen when using the NHITS model in combination with the NeuralForecast wrapper (NeuralForecast(model=NHITS(...), freq=...)), where the trainer_kwargs are passed and applied correctly, but unfortunately this solution is not suitable for my situation.
Has anyone ever had this problem? Is there a workaround?
Versions / Dependencies
Reproduction script
These are the methods involved, this script is not able to reproduce the issue.
The code is part of a larger framework that extends the Darts python library.
Model fitting is going fine but whatever kwargs I pass to the trainer, the forward method resets them to default.
self.model.trainer_kwargs = {'max_steps': 5, 'callbacks': [<pytorch_lightning.callbacks.early_stopping.EarlyStopping object at 0x7f0c15edc160>], 'accelerator': 'gpu', 'devices': -1, 'enable_checkpointing': False}
self.model.trainer_kwargs = {'max_steps': 5, 'callbacks': [<pytorch_lightning.callbacks.early_stopping.EarlyStopping object at 0x7f0c15edc160>], 'accelerator': 'gpu', 'devices': -1, 'enable_checkpointing': False}
self.model.trainer = <pytorch_lightning.trainer.trainer.Trainer object at 0x7f0bfc1e20b0>
self.model.trainer.max_steps = -1
self.model.trainer.early_stopping_callbacks = None
Issue Severity
High: It blocks me from completing my task.
The text was updated successfully, but these errors were encountered: