diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index 1047042..1fe8385 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -245,7 +245,13 @@ class TSimCNE: about it. :param dict | None trainer_kwargs: The keyword arguments to pass - to the Trainer, to use during training. + to the Trainer, to use during training. By default the keys + ``gradient_clip_val=4`` and + ``gradient_clip_algorithm="value"`` will be set, but can be + overridden by passing in a custom dict. The values will be + set regardless of whether you pass in a ``dict`` or not, so if + you want to disable gradient clipping you need to override the + values. :param int=8 num_workers: The number of workers for creating the dataloader. Will be passed to the pytorch DataLoader @@ -377,8 +383,13 @@ def _handle_parameters(self): "for how to set the learning rate when using multiple devices" ) + trainer_kwargs = dict( + gradient_clip_val=4, gradient_clip_algorithm="value" + ) if self.trainer_kwargs is None: - self.trainer_kwargs = dict() + self.trainer_kwargs = trainer_kwargs + else: + self.trainer_kwargs = trainer_kwargs.update(self.trainer_kwargs) @staticmethod def check_ffcv(use_ffcv):