-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlr.py
135 lines (113 loc) · 5.19 KB
/
lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#######
## The local version of cyclic lr, which corrects the mistake in the official implementation
#######
import types
import math
from torch._six import inf
from functools import partial, wraps
import warnings
from bisect import bisect_right
class LocalCyclicLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self,
optimizer,
base_lr,
max_lr,
step_size_up=2000,
step_size_down=None,
mode='triangular',
gamma=1.,
scale_fn=None,
scale_mode='cycle',
cycle_momentum=True,
base_momentum=0.8,
max_momentum=0.9,
last_epoch=-1):
if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
base_lrs = self._format_param('base_lr', optimizer, base_lr)
if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups):
group['lr'] = lr
self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
step_size_up = float(step_size_up)
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size
if mode not in ['triangular', 'triangular2', 'exp_range'] \
and scale_fn is None:
raise ValueError('mode is invalid and scale_fn is None')
self.mode = mode
self.gamma = gamma
if scale_fn is None:
if self.mode == 'triangular':
self.scale_fn = self._triangular_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'triangular2':
self.scale_fn = self._triangular2_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'exp_range':
self.scale_fn = self._exp_range_scale_fn
self.scale_mode = 'iterations'
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.cycle_momentum = cycle_momentum
if cycle_momentum:
if 'momentum' not in optimizer.defaults:
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
if last_epoch == -1:
for momentum, group in zip(base_momentums, optimizer.param_groups):
group['momentum'] = momentum
self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
super(LocalCyclicLR, self).__init__(optimizer, last_epoch)
def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError("expected {} values for {}, got {}".format(
len(optimizer.param_groups), name, len(param)))
return param
else:
return [param] * len(optimizer.param_groups)
def _triangular_scale_fn(self, x):
return 1.
def _triangular2_scale_fn(self, x):
return 1 / (2. ** (x - 1))
def _exp_range_scale_fn(self, x):
return self.gamma**(x)
def get_lr(self):
"""Calculates the learning rate at batch index. This function treats
`self.last_epoch` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
cycle = math.floor(1 + self.last_epoch / self.total_size)
x = 1. + self.last_epoch / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)
lrs = []
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
base_height = (max_lr - base_lr) * scale_factor
if self.scale_mode == 'cycle':
lr = base_lr + base_height * self.scale_fn(cycle)
else:
lr = base_lr + base_height * self.scale_fn(self.last_epoch)
lrs.append(lr)
if self.cycle_momentum:
momentums = []
for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
base_height = (max_momentum - base_momentum) * scale_factor
if self.scale_mode == 'cycle':
momentum = max_momentum - base_height * self.scale_fn(cycle)
else:
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
momentums.append(momentum)
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['momentum'] = momentum
return lrs