-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlr_scheduler.py
49 lines (40 loc) · 1.19 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
class NoamOpt:
"""
Copied from https://nlp.seas.harvard.edu/2018/04/03/attention.html#hardware-and-schedule
A wrapper class for the Adam optimizer (or others) that implements learning rate scheduling.
"""
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
"""
Update parameters and rate"
"""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""
Implement `lrate` above
"""
if step is None:
step = self._step
return self.factor * (
self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def get_std_opt(model):
return NoamOpt(
model.encoder.hidden_dim,
2,
4000,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9),
)