Skip to content

Commit

Permalink
[FEAT] support providing DataLoader arguments to optimize GPU usage (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jasminerienecker authored Nov 8, 2024
1 parent d79cecf commit 5e3ad97
Show file tree
Hide file tree
Showing 75 changed files with 287 additions and 889 deletions.
14 changes: 13 additions & 1 deletion nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,25 @@
" datamodule_constructor = TimeSeriesDataModule\n",
" else:\n",
" datamodule_constructor = _DistributedTimeSeriesDataModule\n",
" \n",
" dataloader_kwargs = self.dataloader_kwargs if self.dataloader_kwargs is not None else {}\n",
" \n",
" if self.num_workers_loader != 0: # value is not at its default\n",
" warnings.warn(\n",
" \"The `num_workers_loader` argument is deprecated and will be removed in a future version. \"\n",
" \"Please provide num_workers through `dataloader_kwargs`, e.g. \"\n",
" f\"`dataloader_kwargs={{'num_workers': {self.num_workers_loader}}}`\",\n",
" category=FutureWarning,\n",
" )\n",
" dataloader_kwargs['num_workers'] = self.num_workers_loader\n",
"\n",
" datamodule = datamodule_constructor(\n",
" dataset=dataset, \n",
" batch_size=batch_size,\n",
" valid_batch_size=valid_batch_size,\n",
" num_workers=self.num_workers_loader,\n",
" drop_last=self.drop_last_loader,\n",
" shuffle_train=shuffle_train,\n",
" **dataloader_kwargs\n",
" )\n",
"\n",
" if self.val_check_steps > self.max_steps:\n",
Expand Down
2 changes: 2 additions & 0 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
Expand Down Expand Up @@ -173,6 +174,7 @@
"\n",
" # DataModule arguments\n",
" self.num_workers_loader = num_workers_loader\n",
" self.dataloader_kwargs = dataloader_kwargs\n",
" self.drop_last_loader = drop_last_loader\n",
" # used by on_validation_epoch_end hook\n",
" self.validation_step_outputs = []\n",
Expand Down
4 changes: 3 additions & 1 deletion nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
Expand Down Expand Up @@ -172,6 +173,7 @@
"\n",
" # DataModule arguments\n",
" self.num_workers_loader = num_workers_loader\n",
" self.dataloader_kwargs = dataloader_kwargs\n",
" self.drop_last_loader = drop_last_loader\n",
" # used by on_validation_epoch_end hook\n",
" self.validation_step_outputs = []\n",
Expand Down Expand Up @@ -536,7 +538,7 @@
" self._check_exog(dataset)\n",
" self._restart_seed(random_seed)\n",
" data_module_kwargs = self._set_quantile_for_iqloss(**data_module_kwargs)\n",
"\n",
" \n",
" if step_size > 1:\n",
" raise Exception('Recurrent models do not support step_size > 1')\n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
Expand Down Expand Up @@ -188,6 +189,7 @@
"\n",
" # DataModule arguments\n",
" self.num_workers_loader = num_workers_loader\n",
" self.dataloader_kwargs = dataloader_kwargs\n",
" self.drop_last_loader = drop_last_loader\n",
" # used by on_validation_epoch_end hook\n",
" self.validation_step_outputs = []\n",
Expand Down
9 changes: 7 additions & 2 deletions nbs/docs/tutorials/18_adding_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@
" step_size: int = 1,\n",
" scaler_type: str = 'identity',\n",
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" **trainer_kwargs):\n",
" # Inherit BaseWindows class\n",
Expand Down Expand Up @@ -415,7 +414,13 @@
]
}
],
"metadata": {},
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions nbs/models.autoformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -511,6 +512,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(Autoformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand Down Expand Up @@ -539,6 +541,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.bitcn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References**<br> \n",
Expand Down Expand Up @@ -224,6 +225,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BiTCN, self).__init__(\n",
" h=h,\n",
Expand Down Expand Up @@ -253,6 +255,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.deepar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References**<br>\n",
Expand Down Expand Up @@ -234,6 +235,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" if exclude_insample_y:\n",
Expand Down Expand Up @@ -276,6 +278,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" self.horizon_backup = self.h # Used because h=0 during training\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.deepnpts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References**<br>\n",
Expand Down Expand Up @@ -169,6 +170,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" if exclude_insample_y:\n",
Expand Down Expand Up @@ -208,6 +210,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" self.h = h\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.dilated_rnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br> \n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
" \"\"\"\n",
" # Class attributes\n",
Expand Down Expand Up @@ -433,6 +434,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(DilatedRNN, self).__init__(\n",
" h=h,\n",
Expand All @@ -458,6 +460,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.dlinear.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -206,6 +207,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(DLinear, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand Down Expand Up @@ -234,6 +236,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
" \n",
" # Architecture\n",
Expand Down
5 changes: 4 additions & 1 deletion nbs/models.fedformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
" \"\"\"\n",
Expand Down Expand Up @@ -503,6 +504,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(FEDformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand All @@ -529,7 +531,8 @@
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs, \n",
" **trainer_kwargs)\n",
" # Architecture\n",
" self.label_len = int(np.ceil(input_size * decoder_input_size_multiplier))\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.gru.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
" \"\"\"\n",
" # Class attributes\n",
Expand Down Expand Up @@ -168,6 +169,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(GRU, self).__init__(\n",
" h=h,\n",
Expand All @@ -193,6 +195,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.informer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -359,6 +360,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(Informer, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand Down Expand Up @@ -387,6 +389,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
Expand Down
Loading

0 comments on commit 5e3ad97

Please sign in to comment.