From f387b4982bbb77a9323c3a99b35b706685c6ff9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20B=C3=B6hm?= Date: Fri, 5 Jan 2024 12:44:05 +0100 Subject: [PATCH] Set default values for gradient clipping --- tsimcne/tsimcne.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index 1047042..1fe8385 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -245,7 +245,13 @@ class TSimCNE: about it. :param dict | None trainer_kwargs: The keyword arguments to pass - to the Trainer, to use during training. + to the Trainer, to use during training. By default the keys + ``gradient_clip_val=4`` and + ``gradient_clip_algorithm="value"`` will be set, but can be + overridden by passing in a custom dict. The values will be + set regardless of whether you pass in a ``dict`` or not, so if + you want to disable gradient clipping you need to override the + values. :param int=8 num_workers: The number of workers for creating the dataloader. Will be passed to the pytorch DataLoader @@ -377,8 +383,13 @@ def _handle_parameters(self): "for how to set the learning rate when using multiple devices" ) + trainer_kwargs = dict( + gradient_clip_val=4, gradient_clip_algorithm="value" + ) if self.trainer_kwargs is None: - self.trainer_kwargs = dict() + self.trainer_kwargs = trainer_kwargs + else: + self.trainer_kwargs = trainer_kwargs.update(self.trainer_kwargs) @staticmethod def check_ffcv(use_ffcv):