forked from oshindow/Transformer-Transducer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptim.py
64 lines (54 loc) · 1.65 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch.optim as optim
class Optimizer(object):
def __init__(self, parameters, config):
self.config = config
self.optimizer = build_optimizer(parameters, config)
self.global_step = 1
self.current_epoch = 0
self.lr = config.lr
self.decay_ratio = config.decay_ratio
self.epoch_decay_flag = False
def step(self):
self.global_step += 1
self.optimizer.step()
def epoch(self):
self.current_epoch += 1
def zero_grad(self):
self.optimizer.zero_grad()
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
def decay_lr(self):
self.lr *= self.decay_ratio
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
def build_optimizer(parameters, config):
if config.type == 'adam':
return optim.Adam(
parameters,
lr=config.lr,
betas=(0.9, 0.98),
eps=1e-08,
weight_decay=config.weight_decay
)
elif config.type == 'sgd':
return optim.SGD(
params=parameters,
lr=config.lr,
momentum=config.momentum,
nesterov=config.nesterov,
weight_decay=config.weight_decay
)
elif config.type == 'adadelta':
return optim.Adadelta(
params=parameters
)
# lr=config.lr,
# rho=config.rho
# )
# eps=config.eps,
# weight_decay=config.weight_decay
# )
else:
raise NotImplementedError