diff --git a/nbs/common.base_model.ipynb b/nbs/common.base_model.ipynb index 0f1daeeb5..8334c759c 100644 --- a/nbs/common.base_model.ipynb +++ b/nbs/common.base_model.ipynb @@ -358,13 +358,15 @@ " datamodule_constructor = TimeSeriesDataModule\n", " else:\n", " datamodule_constructor = _DistributedTimeSeriesDataModule\n", + " \n", + " dataloader_kwargs = self.dataloader_kwargs if self.dataloader_kwargs is not None else {}\n", " datamodule = datamodule_constructor(\n", " dataset=dataset, \n", " batch_size=batch_size,\n", " valid_batch_size=valid_batch_size,\n", " drop_last=self.drop_last_loader,\n", " shuffle_train=shuffle_train,\n", - " **self.dataloader_kwargs\n", + " **dataloader_kwargs\n", " )\n", "\n", " if self.val_check_steps > self.max_steps:\n", diff --git a/neuralforecast/common/_base_model.py b/neuralforecast/common/_base_model.py index b17a90efa..fdad6184f 100644 --- a/neuralforecast/common/_base_model.py +++ b/neuralforecast/common/_base_model.py @@ -332,13 +332,17 @@ def _fit( datamodule_constructor = TimeSeriesDataModule else: datamodule_constructor = _DistributedTimeSeriesDataModule + + dataloader_kwargs = ( + self.dataloader_kwargs if self.dataloader_kwargs is not None else {} + ) datamodule = datamodule_constructor( dataset=dataset, batch_size=batch_size, valid_batch_size=valid_batch_size, drop_last=self.drop_last_loader, shuffle_train=shuffle_train, - **self.dataloader_kwargs, + **dataloader_kwargs, ) if self.val_check_steps > self.max_steps: