Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jasminerienecker committed Oct 22, 2024
1 parent 9ea94b7 commit 87bee95
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
4 changes: 3 additions & 1 deletion nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,15 @@
" datamodule_constructor = TimeSeriesDataModule\n",
" else:\n",
" datamodule_constructor = _DistributedTimeSeriesDataModule\n",
" \n",
" dataloader_kwargs = self.dataloader_kwargs if self.dataloader_kwargs is not None else {}\n",
" datamodule = datamodule_constructor(\n",
" dataset=dataset, \n",
" batch_size=batch_size,\n",
" valid_batch_size=valid_batch_size,\n",
" drop_last=self.drop_last_loader,\n",
" shuffle_train=shuffle_train,\n",
" **self.dataloader_kwargs\n",
" **dataloader_kwargs\n",
" )\n",
"\n",
" if self.val_check_steps > self.max_steps:\n",
Expand Down
6 changes: 5 additions & 1 deletion neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,17 @@ def _fit(
datamodule_constructor = TimeSeriesDataModule
else:
datamodule_constructor = _DistributedTimeSeriesDataModule

dataloader_kwargs = (
self.dataloader_kwargs if self.dataloader_kwargs is not None else {}
)
datamodule = datamodule_constructor(
dataset=dataset,
batch_size=batch_size,
valid_batch_size=valid_batch_size,
drop_last=self.drop_last_loader,
shuffle_train=shuffle_train,
**self.dataloader_kwargs,
**dataloader_kwargs,
)

if self.val_check_steps > self.max_steps:
Expand Down

0 comments on commit 87bee95

Please sign in to comment.