Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first Step in LR Schedulers #6597

Merged
merged 9 commits into from
Oct 14, 2024
38 changes: 19 additions & 19 deletions deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ def get_lr_from_config(config):
return lr_params[WARMUP_MAX_LR], ''


def update_lr(param_groups, lrs):
for param_group, lr in zip(param_groups, lrs):
param_group['lr'] = lr
return [group['lr'] for group in param_groups]


"""
Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
optimizer to see if requirement is satisfied.
Expand Down Expand Up @@ -328,7 +334,7 @@ def __init__(self,
self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval

if last_batch_iteration == -1:
self._update_optimizer(self.min_lr)
self._last_lr = update_lr(self.optimizer.param_groups, self.min_lr)

def _staircase_interval(self):
return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
Expand All @@ -349,16 +355,11 @@ def get_last_lr(self):
assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
return self._last_lr

def _update_optimizer(self, group_lrs):
for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
param_group['lr'] = lr

def step(self, batch_iteration=None):
if batch_iteration is None:
batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = batch_iteration
self._update_optimizer(self.get_lr())
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr())

def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration}
Expand Down Expand Up @@ -615,9 +616,7 @@ def step(self, batch_iteration=None):
batch_iteration = self.last_batch_iteration + 1

self.last_batch_iteration = batch_iteration
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr())

if self.cycle_momentum:
momentums = self.get_mom()
Expand Down Expand Up @@ -675,11 +674,14 @@ def __init__(self,
self.warmup_type = warmup_type
self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
self.last_batch_iteration = last_batch_iteration
# Initialize lr in optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are duplicated. Can we extract the code as a function and reuse it?

if last_batch_iteration == -1:
self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr())

def get_lr(self):
if self.last_batch_iteration < 0:
logger.warning("Attempting to get learning rate from scheduler before it has started")
return [0.0]
return self.min_lrs
gamma = self._get_gamma()
return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)]

Expand All @@ -693,9 +695,7 @@ def step(self, last_batch_iteration=None):
if last_batch_iteration is None:
last_batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = last_batch_iteration
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr())

def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration}
Expand Down Expand Up @@ -819,6 +819,10 @@ def __init__(self,
total_num_steps, warmup_num_steps))
self.org_lrs = [group['lr'] for group in self.optimizer.param_groups]

# Initialize lrs in optimizer groups
if last_batch_iteration == -1:
self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr())

def get_lr_ratio(self):
if self.last_batch_iteration < 0:
logger.warning("Attempting to get learning rate from scheduler before it has started")
Expand All @@ -844,11 +848,7 @@ def step(self, last_batch_iteration=None):
if last_batch_iteration is None:
last_batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = last_batch_iteration

lrs = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, lrs):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr())

def get_lr(self):
if self.last_batch_iteration < 0:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/runtime/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def _verify_staircase_increase(values, step_size):
(WARMUP_DECAY_LR, {
WARMUP_NUM_STEPS: 10,
TOTAL_NUM_STEPS: 20
}), (WARMUP_COSINE_LR, {
WARMUP_NUM_STEPS: 10,
TOTAL_NUM_STEPS: 20
}), (ONE_CYCLE, {
CYCLE_MIN_LR: 0,
CYCLE_MAX_LR: 0.1
Expand Down Expand Up @@ -71,6 +74,11 @@ def test(self, scheduler_type, params):
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)

true_lrs = lr_scheduler.get_lr()
for group, true_lr in zip(model.optimizer.param_groups, true_lrs):
assert group['lr'] == true_lr, f"True lr {true_lr}, optimizer lr {group['lr']}"

for n, batch in enumerate(data_loader):
# get lr before training starts
lr_scheduler.get_lr()
Expand Down
Loading