From 15c74cf71733854b69d1429de27a62a8a828f000 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 1 Oct 2024 23:31:22 +0000 Subject: [PATCH 1/5] adding lr init in schedulers --- deepspeed/runtime/lr_schedules.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index d7f7e15a4dbd..c610b706ee13 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -675,6 +675,11 @@ def __init__(self, self.warmup_type = warmup_type self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps) self.last_batch_iteration = last_batch_iteration + # Initialize lr in optimizer + if last_batch_iteration == -1: + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def get_lr(self): if self.last_batch_iteration < 0: @@ -818,6 +823,12 @@ def __init__(self, logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format( total_num_steps, warmup_num_steps)) self.org_lrs = [group['lr'] for group in self.optimizer.param_groups] + + # Initialize lrs in optimizer groups + if last_batch_iteration == -1: + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def get_lr_ratio(self): if self.last_batch_iteration < 0: From 92bb0ff8f51b89ed3db811eed2ab14dff8c1ca5b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 21:23:35 +0000 Subject: [PATCH 2/5] adding step init check to unittest --- deepspeed/runtime/lr_schedules.py | 2 +- tests/unit/runtime/test_lr_schedulers.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index c610b706ee13..8aa23619a6a6 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -684,7 +684,7 @@ def __init__(self, def get_lr(self): if self.last_batch_iteration < 0: logger.warning("Attempting to get learning rate from scheduler before it has started") - return [0.0] + return self.min_lrs gamma = self._get_gamma() return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)] diff --git a/tests/unit/runtime/test_lr_schedulers.py b/tests/unit/runtime/test_lr_schedulers.py index bcfc485f2b8f..ea77b86e9b35 100644 --- a/tests/unit/runtime/test_lr_schedulers.py +++ b/tests/unit/runtime/test_lr_schedulers.py @@ -37,7 +37,12 @@ def _verify_staircase_increase(values, step_size): (WARMUP_DECAY_LR, { WARMUP_NUM_STEPS: 10, TOTAL_NUM_STEPS: 20 - }), (ONE_CYCLE, { + }), + (WARMUP_COSINE_LR, { + WARMUP_NUM_STEPS: 10, + TOTAL_NUM_STEPS: 20 + }), + (ONE_CYCLE, { CYCLE_MIN_LR: 0, CYCLE_MAX_LR: 0.1 }), (LR_RANGE_TEST, {})]) @@ -71,6 +76,11 @@ def test(self, scheduler_type, params): hidden_dim=hidden_dim, device=model.device, dtype=torch.float) + + true_lrs = lr_scheduler.get_lr() + for group, true_lr in zip(model.optimizer.param_groups, true_lrs): + assert group['lr'] == true_lr, f"True lr {true_lr}, optimizer lr {group['lr']}" + for n, batch in enumerate(data_loader): # get lr before training starts lr_scheduler.get_lr() From 180ea0cee8a6bbb0ae89313fc441341fdeaea766 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 2 Oct 2024 21:24:50 +0000 Subject: [PATCH 3/5] run pre-commit for formatting --- deepspeed/runtime/lr_schedules.py | 2 +- tests/unit/runtime/test_lr_schedulers.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 8aa23619a6a6..8ea1bca23b9d 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -823,7 +823,7 @@ def __init__(self, logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format( total_num_steps, warmup_num_steps)) self.org_lrs = [group['lr'] for group in self.optimizer.param_groups] - + # Initialize lrs in optimizer groups if last_batch_iteration == -1: for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): diff --git a/tests/unit/runtime/test_lr_schedulers.py b/tests/unit/runtime/test_lr_schedulers.py index ea77b86e9b35..47734c0cd864 100644 --- a/tests/unit/runtime/test_lr_schedulers.py +++ b/tests/unit/runtime/test_lr_schedulers.py @@ -37,12 +37,10 @@ def _verify_staircase_increase(values, step_size): (WARMUP_DECAY_LR, { WARMUP_NUM_STEPS: 10, TOTAL_NUM_STEPS: 20 - }), - (WARMUP_COSINE_LR, { + }), (WARMUP_COSINE_LR, { WARMUP_NUM_STEPS: 10, TOTAL_NUM_STEPS: 20 - }), - (ONE_CYCLE, { + }), (ONE_CYCLE, { CYCLE_MIN_LR: 0, CYCLE_MAX_LR: 0.1 }), (LR_RANGE_TEST, {})]) From fef345a08a25d23f46ead9a4cdded10cedcf1cbc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 8 Oct 2024 18:04:17 +0000 Subject: [PATCH 4/5] functions for replicated code --- deepspeed/runtime/lr_schedules.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 8ea1bca23b9d..0c86a011d1c4 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -677,9 +677,7 @@ def __init__(self, self.last_batch_iteration = last_batch_iteration # Initialize lr in optimizer if last_batch_iteration == -1: - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self.update_lr() def get_lr(self): if self.last_batch_iteration < 0: @@ -694,13 +692,17 @@ def get_last_lr(self): assert getattr(self, '_last_lr', None) is not None, "need to call step() first" return self._last_lr + def update_lr(self): + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + return + def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self.update_lr() def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -826,9 +828,7 @@ def __init__(self, # Initialize lrs in optimizer groups if last_batch_iteration == -1: - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self.update_lr() def get_lr_ratio(self): if self.last_batch_iteration < 0: @@ -851,15 +851,17 @@ def get_lr_ratio(self): ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio) return ratio + def update_lr(self): + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + return + def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration - - lrs = self.get_lr() - for param_group, lr in zip(self.optimizer.param_groups, lrs): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self.update_lr() def get_lr(self): if self.last_batch_iteration < 0: From 137541f2427522c4683effa9831cacd9cdd61e4e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 14 Oct 2024 17:58:56 +0000 Subject: [PATCH 5/5] cosolidating function --- deepspeed/runtime/lr_schedules.py | 39 +++++++++++-------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 0c86a011d1c4..f25a19e8e499 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -247,6 +247,12 @@ def get_lr_from_config(config): return lr_params[WARMUP_MAX_LR], '' +def update_lr(param_groups, lrs): + for param_group, lr in zip(param_groups, lrs): + param_group['lr'] = lr + return [group['lr'] for group in param_groups] + + """ Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped optimizer to see if requirement is satisfied. @@ -328,7 +334,7 @@ def __init__(self, self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval if last_batch_iteration == -1: - self._update_optimizer(self.min_lr) + self._last_lr = update_lr(self.optimizer.param_groups, self.min_lr) def _staircase_interval(self): return math.floor(float(self.last_batch_iteration + 1) / self.step_size) @@ -349,16 +355,11 @@ def get_last_lr(self): assert getattr(self, '_last_lr', None) is not None, "need to call step() first" return self._last_lr - def _update_optimizer(self, group_lrs): - for param_group, lr in zip(self.optimizer.param_groups, group_lrs): - param_group['lr'] = lr - def step(self, batch_iteration=None): if batch_iteration is None: batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration - self._update_optimizer(self.get_lr()) - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -615,9 +616,7 @@ def step(self, batch_iteration=None): batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) if self.cycle_momentum: momentums = self.get_mom() @@ -677,7 +676,7 @@ def __init__(self, self.last_batch_iteration = last_batch_iteration # Initialize lr in optimizer if last_batch_iteration == -1: - self.update_lr() + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def get_lr(self): if self.last_batch_iteration < 0: @@ -692,17 +691,11 @@ def get_last_lr(self): assert getattr(self, '_last_lr', None) is not None, "need to call step() first" return self._last_lr - def update_lr(self): - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - return - def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration - self.update_lr() + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -828,7 +821,7 @@ def __init__(self, # Initialize lrs in optimizer groups if last_batch_iteration == -1: - self.update_lr() + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def get_lr_ratio(self): if self.last_batch_iteration < 0: @@ -851,17 +844,11 @@ def get_lr_ratio(self): ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio) return ratio - def update_lr(self): - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - return - def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration - self.update_lr() + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) def get_lr(self): if self.last_batch_iteration < 0: