diff --git a/python/text_utils/modules/optimizer.py b/python/text_utils/modules/optimizer.py index adeb642..5426994 100644 --- a/python/text_utils/modules/optimizer.py +++ b/python/text_utils/modules/optimizer.py @@ -5,8 +5,7 @@ def _select_params_and_modules( - modules: Iterator[Tuple[str, nn.Module]], - prefix: str + modules: Iterator[Tuple[str, nn.Module]], prefix: str ) -> Iterator[Tuple[str, nn.Module, nn.Parameter]]: for name, mod in modules: for p_name, param in mod.named_parameters(prefix=name, recurse=False): @@ -17,10 +16,9 @@ def _select_params_and_modules( def optimizer_from_config( model: nn.Module, cfg: Dict[str, Any], - additional_optimizer_fn: Optional[Callable[ - [nn.Module, Dict[str, Any]], - optim.Optimizer - ]] = None + additional_optimizer_fn: Optional[ + Callable[[nn.Module, Dict[str, Any]], optim.Optimizer] + ] = None, ) -> optim.Optimizer: cfg = copy.deepcopy(cfg) opt_type = cfg.pop("type") @@ -40,8 +38,7 @@ def optimizer_from_config( assert len(param_groups) > 0, "param_groups must be non-empty" weight_decay_modules: dict[str, list[str]] | str = cfg.pop( - "weight_decay_modules", - "all" + "weight_decay_modules", "all" ) all = set(name for name, p in model.named_parameters() if p.requires_grad) params = [] @@ -54,28 +51,25 @@ def optimizer_from_config( decay = set() param_dict = {} for name, mod, param in _select_params_and_modules( - model.named_modules(), - prefix + model.named_modules(), prefix ): if name not in all: # this should only happen for shared # or non-trainable parameters continue + if fix: param.requires_grad = False continue + names.add(name) param_dict[name] = param mod_name = mod.__class__.__name__ - if ( - weight_decay_modules == "all" - or ( - isinstance(weight_decay_modules, dict) - and mod_name in weight_decay_modules - and any( - name.endswith(suffix) - for suffix in weight_decay_modules[mod_name] - ) + if weight_decay_modules == "all" or ( + isinstance(weight_decay_modules, dict) + and mod_name in weight_decay_modules + and any( + name.endswith(suffix) for suffix in weight_decay_modules[mod_name] ) ): decay.add(name) @@ -86,18 +80,23 @@ def optimizer_from_config( assert len(param_dict.keys() - (decay | no_decay)) == 0 if len(decay) > 0: - params.append({ - "params": [param_dict[name] for name in sorted(list(decay))], - **(cfg | group) - }) + params.append( + { + "params": [param_dict[name] for name in sorted(list(decay))], + **(cfg | group), + } + ) if len(no_decay) > 0: - params.append({ - "params": [param_dict[name] for name in sorted(list(no_decay))], - **(cfg | group | {"weight_decay": 0.0}) - }) + params.append( + { + "params": [param_dict[name] for name in sorted(list(no_decay))], + **(cfg | group | {"weight_decay": 0.0}), + } + ) unused = all - names - assert len(unused) == 0, \ - f"parameter groups dont match trainable model parameters: {unused}" + assert ( + len(unused) == 0 + ), f"parameter groups dont match trainable model parameters: {unused}" return optim_cls(params, **cfg)