Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NHITS max_steps reset to default during model fitting #1236

Open
elenanicora opened this issue Dec 27, 2024 · 0 comments
Open

NHITS max_steps reset to default during model fitting #1236

elenanicora opened this issue Dec 27, 2024 · 0 comments
Labels

Comments

@elenanicora
Copy link

What happened + What you expected to happen

I am using the NHITS model within a custom wrapper and I have problems when trying to set max_steps.

The trainer_kwargs are correctly initialized during model instantiation. However, during the fit process (upon calling the forward method), the trainer object reverts to its default settings, ignoring the user-specified trainer_kwargs.
I have encountered this problem also when setting different early stopping criteria or validation checks steps.
Despite setting different parameters, every execution completes 100 epochs with no early stopping callbacks.

This does not happen when using the NHITS model in combination with the NeuralForecast wrapper (NeuralForecast(model=NHITS(...), freq=...)), where the trainer_kwargs are passed and applied correctly, but unfortunately this solution is not suitable for my situation.

Has anyone ever had this problem? Is there a workaround?

Versions / Dependencies

  • neuralforecast==1.7.5
  • pytorch-lightning==2.4.0
  • python 3.10.15

Reproduction script

These are the methods involved, this script is not able to reproduce the issue.
The code is part of a larger framework that extends the Darts python library.
Model fitting is going fine but whatever kwargs I pass to the trainer, the forward method resets them to default.

def _create_model(self) -> NHITS:
    """
    Creates and returns the underlying Nixtla NHITS model.
    """
    model_kwargs = self.kwargs.copy()  
    model_kwargs['h'] = self.output_chunk_length
    model_kwargs['input_size'] = self.input_chunk_length
    model_kwargs['loss'] = self.train_criterion 
    model_kwargs['n_pool_kernel_size'] = [1, 1, 1]
    model_kwargs['max_steps']=5
    model_kwargs['early_stop_patience_steps'] = 3
    
    self.model = NHITS(**model_kwargs)
    
    return self.model

self.model.trainer_kwargs = {'max_steps': 5, 'callbacks': [<pytorch_lightning.callbacks.early_stopping.EarlyStopping object at 0x7f0c15edc160>], 'accelerator': 'gpu', 'devices': -1, 'enable_checkpointing': False}

def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass through the Nixtla NHiTS model.

    Parameters:
    - x: Input tensor of shape [batch_size, input_chunk_length, num_features]
    """

    return self.model(x)

self.model.trainer_kwargs = {'max_steps': 5, 'callbacks': [<pytorch_lightning.callbacks.early_stopping.EarlyStopping object at 0x7f0c15edc160>], 'accelerator': 'gpu', 'devices': -1, 'enable_checkpointing': False}
self.model.trainer = <pytorch_lightning.trainer.trainer.Trainer object at 0x7f0bfc1e20b0>
self.model.trainer.max_steps = -1
self.model.trainer.early_stopping_callbacks = None

Issue Severity

High: It blocks me from completing my task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant