Skip to content

Commit

Permalink
add stable stage in scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 9, 2024
1 parent 2516d06 commit 127ad4e
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/levanter/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):

min_lr_ratio: float = 0.1
warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup
"""The lr scheduler operates on x stages: [warmup] - [stable] - [decay] - [cooldown]"""
warmup: float = 0.01
"""fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup"""
stable: float = 0.00
"""fraction of training steps to use as stable, or steps to use. 0.0 means no stable"""
cooldown: float = 0.0
"""fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown"""
lr_schedule: str = "cosine" # constant, cosine, linear
Expand Down Expand Up @@ -61,8 +64,9 @@ def mask_fn(model):

def lr_scheduler(self, num_train_steps):
warmup_steps = self._convert_warmup(num_train_steps)
stable_steps = _convert_ratio_or_steps(self.stable, num_train_steps)
cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps)
lr_decay_steps = num_train_steps - warmup_steps - cooldown_steps
lr_decay_steps = num_train_steps - warmup_steps - stable_steps - cooldown_steps
min_lr = self.learning_rate * self.min_lr_ratio

match self.lr_schedule:
Expand All @@ -71,7 +75,7 @@ def lr_scheduler(self, num_train_steps):
case "cosine":
schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio)
case "linear":
schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps - warmup_steps)
schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps)
case "inv_sqrt":
schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000)
case _:
Expand All @@ -85,6 +89,11 @@ def lr_scheduler(self, num_train_steps):
schedules.append(warmup)
boundaries.append(warmup_steps)

if stable_steps != 0:
stable = optax.constant_schedule(self.learning_rate)
schedules.append(stable)
boundaries.append(warmup_steps + stable_steps)

schedules.append(schedule)

if cooldown_steps != 0:
Expand Down

0 comments on commit 127ad4e

Please sign in to comment.