Skip to content

Commit

Permalink
Set default values for gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
Niklas Böhm committed Jan 5, 2024
1 parent f17d1d0 commit f387b49
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tsimcne/tsimcne.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ class TSimCNE:
about it.
:param dict | None trainer_kwargs: The keyword arguments to pass
to the Trainer, to use during training.
to the Trainer, to use during training. By default the keys
``gradient_clip_val=4`` and
``gradient_clip_algorithm="value"`` will be set, but can be
overridden by passing in a custom dict. The values will be
set regardless of whether you pass in a ``dict`` or not, so if
you want to disable gradient clipping you need to override the
values.
:param int=8 num_workers: The number of workers for creating the
dataloader. Will be passed to the pytorch DataLoader
Expand Down Expand Up @@ -377,8 +383,13 @@ def _handle_parameters(self):
"for how to set the learning rate when using multiple devices"
)

trainer_kwargs = dict(
gradient_clip_val=4, gradient_clip_algorithm="value"
)
if self.trainer_kwargs is None:
self.trainer_kwargs = dict()
self.trainer_kwargs = trainer_kwargs
else:
self.trainer_kwargs = trainer_kwargs.update(self.trainer_kwargs)

@staticmethod
def check_ffcv(use_ffcv):
Expand Down

0 comments on commit f387b49

Please sign in to comment.