-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathoptimizers.py
21 lines (20 loc) · 991 Bytes
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import logging
def get_optimizer_parameters(config, module):
lr = config["params"]["lr"]
finetune_lr_multiplier = config["params"].get("finetune_lr_multiplier", 1)
logging.info(f"Using lr {lr} finetune_lr_multiplier {finetune_lr_multiplier}")
parameters = []
for name, submodule in module.named_children():
submodule_parameters = config["params"]
if name.startswith("attn_layer"):
submodule_parameters["lr"] = config["params"]["attention_lr"]
elif name.startswith("classifier") or name.startswith("beta"):
submodule_parameters["lr"] = lr
else:
submodule_parameters["lr"] = lr * finetune_lr_multiplier
if finetune_lr_multiplier == 0:
for p in submodule.parameters():
p.requires_grad = False
submodule_parameters = {**{"params": submodule.parameters()}, **submodule_parameters}
parameters.append(submodule_parameters)
return parameters