From a0fc09df36c39563a523ef31eb80608e33dd3b38 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 21:12:53 -0800 Subject: [PATCH] add cycle_length (#825) --- docs/Configuration-Guide.md | 60 ++++++++++++++++++------------- src/levanter/optim/config.py | 66 ++++++++++++++++++++++++---------- tests/test_optimizer_config.py | 63 +++++++++++++++++++++++++++++++- 3 files changed, 146 insertions(+), 43 deletions(-) diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index 0b00c0800..0df811d79 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -295,21 +295,21 @@ All optimizers in Levanter are based on the [levanter.optim.OptimizerConfig][] d which are common to all optimizers (and most have to do with learning rate scheduling): -| Parameter | Description | Default | -|-----------------|-----------------------------------------------------------------------|----------| -| `weight_decay` | The weight decay. | `0.0` | -| `learning_rate` | The learning rate. | `1e-4` | -| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` | -| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` | -| `warmup` | Warmup fraction or number of steps | `0.01` | -| `decay` | Decay fraction or number of steps | `None` | -| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` | -| `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` | - -By default, Levanter uses a cosine learning rate schedule with a warmup. The learning rate is decayed to +| Parameter | Description | Default | +|-----------------|-------------------------------------------------------------------------------|----------| +| `weight_decay` | The weight decay. | `0.0` | +| `learning_rate` | The learning rate. | `1e-4` | +| `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` | +| `min_lr_ratio` | The minimum learning rate ratio. | `0.1` | +| `warmup` | Warmup fraction or number of steps | `0.01` | +| `decay` | Decay fraction or number of steps | `None` | +| `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` | +| `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` | +| `cycle_length` | How long the cycles should be (as an int, fraction), or list of cycle lengths | `None` | + +By default, Levanter uses a cosine learning rate decay with warmup. The learning rate is decayed to `min_lr_ratio * learning_rate` over the course of the training run. This is a fairly standard default for LLM training. - #### Learning Rate Schedules The `lr_schedule` parameter specifies the learning rate schedule. The following schedules are supported: @@ -328,8 +328,11 @@ By default, there is only one cycle, and Levanter's LR schedule looks like this: [warmup] -> [stable] -> [decay] ``` -But you can specify more with the `cycles` parameter. If you specify an int for `cycles`, the -learning rate will cycle through the schedule `cycles` times. Levanter's LR schedule looks like this: +But you can specify more with either the `cycles` or `cycle_length` parameters. +If you want to use a learning rate schedule with cycles, you can specify the number of cycles with the `cycles` +or `cycle_length` parameters. The LR will be decayed to `min_lr_ratio * learning_rate` at the end of each cycle. +With cycles, Levanter's LR schedule looks like this: + ``` [warmup] -> [stable] -> [decay] -> {[rewarmup] -> [stable] -> [decay]} x (cycles - 1) @@ -348,27 +351,37 @@ Here's what the phases mean: * `decay`: The decay period. The LR will decay to `min_lr_ratio * learning_rate` over this period. * `rewarmup`: The re-warmup period. If using cycles, the LR will be re-warmed from the final value of the previous cycle back to the peak value of the next cycle. +Also note that if *rewarmup* is 0, there will be no rewarmup period, meaning the LR will jump +back to the max LR. This is the default, and works surprisingly well. In addition, the stable +and decay phase of the first cycle will generally be different from the stable and decay phase of the other cycles, +since rewarmup and warmup are typically different. + +`stable` cannot be specified directly. It is the period between `warmup` and `decay` in the first cycle, and the period +between `rewarmup` and `decay` in subsequent cycles. By default, there is no `stable` period. + All of these parameters can be specified in terms of a fraction of the total number of steps of a cycle or as an absolute number of steps. -If you want to use a learning rate schedule with cycles, you can specify the number of cycles with the `cycles` -parameter. The LR will be decayed to `min_lr_ratio * learning_rate` at the end of each cycle. +Here are what the `cycles` and `cycle_length` parameters mean: + +* `cycle_length`: If you specify an int or float for `cycle_length`, the learning rate will cycle through the +schedule with the specified length. This is equivalent to specifying `cycles` as `num_train_steps / cycle_length`. +If `cycle_length` is a float < 1.0, it is interpreted as a fraction of the total number of steps. +If you specify a list of ints, the learning rate will cycle through the schedule with the specified cycle lengths. +* `cycles`: If you specify an int for `cycles`, the learning rate will cycle through the schedule `cycles` times. +If you specify a list of ints, the learning rate will cycle through the schedule with the specified steps as the minima +of the cycles. +It is an error to specify both `cycles` and `cycle_length`. You can also specify `cycles` as a list, e.g. `[10000, 25000, 50000]`. In this case, `cycles` is interpreted as the minima for the cycles, with the first and final steps being cycle minima as well. `cycles` as an int is equivalent to list `cycles` with the low points evenly spaced at `[num_train_steps / (c + 1)]`. -Also note that if *rewarmup* is 0, there will be no rewarmup period, meaning the LR will jump -back to the max LR. This is the default. In addition, the stable -and decay phase of the first cycle will generally be different from the stable and decay phase of the other cycles, -since rewarmup and warmup are typically different. - See [our paper on WSD-S](https://arxiv.org/pdf/2410.05192) for more information on cyclic LR schedules for training LLMs with short or no rewarmup. - ### AdamConfig Additionally, [levanter.optim.AdamConfig][] has the following fields: @@ -381,7 +394,6 @@ Additionally, [levanter.optim.AdamConfig][] has the following fields: | `max_grad_norm` | The maximum gradient norm (for clipping). | `1.0` | - ## LM Model Config [levanter.models.lm_model.LmConfig][] is a Draccus "choice class" that acts as a base class for all autoregressive diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index 7b684efeb..40be5576d 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -8,6 +8,7 @@ import draccus import equinox as eqx import jax +import numpy as np import optax from jax import numpy as jnp @@ -24,16 +25,18 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): min_lr_ratio: float = 0.1 """The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]""" - warmup: float = 0.01 + warmup: int | float = 0.01 """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - decay: Optional[float] = None + decay: int | float | None = None """fraction of training steps to use as decay, or steps to use. None means full decay""" - rewarmup: float = 0.0 + rewarmup: int | float = 0.0 "If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup." cooldown: Optional[float] = None """Deprecated, as its semantics are confusing.""" - cycles: int | None | list[int] = None - """ Number of cycles to use. If None or 1, use a single cycle. Overriden by haps.""" + cycle_length: int | float | None | list[int] = None + """ Length of cycle. If <= 1, it is treated as a fraction of the total number of steps. None is equivalent to 1.0.""" + cycles: int | list[int] | None = None + """Number of cycles or a list of cycle endpoints. Can use at most one of cycle_length, cycles, or haps.""" lr_schedule: str = "cosine" # constant, cosine, linear haps: Optional[list[int]] = None @@ -145,16 +148,13 @@ def mask_fn(model): def lr_scheduler(self, num_train_steps): if self.cooldown is not None: warnings.warn("cooldown is deprecated. Just use the normal schedule.", DeprecationWarning) - cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) + cooldown_steps = _convert_frac_or_steps(self.cooldown, num_train_steps) else: cooldown_steps = 0 total_main_steps = num_train_steps - cooldown_steps cooldown_points = self._get_cycle_minima(total_main_steps) - cooldown_points.insert(0, 0) - cooldown_points.append(num_train_steps) - min_lr = self.learning_rate * self.min_lr_ratio schedules = [] @@ -165,9 +165,9 @@ def lr_scheduler(self, num_train_steps): for cycle, (start, end) in enumerate(zip(cooldown_points[:-1], cooldown_points[1:])): cycle_steps = end - start if cycle == 0: # warmup - warmup_steps = _convert_ratio_or_steps(self.warmup, cycle_steps) + warmup_steps = _convert_frac_or_steps(self.warmup, cycle_steps) else: - warmup_steps = _convert_ratio_or_steps(self.rewarmup, cycle_steps) + warmup_steps = _convert_frac_or_steps(self.rewarmup, cycle_steps) if warmup_steps != 0: warmup = optax.linear_schedule(previous_end, self.learning_rate, warmup_steps) @@ -175,7 +175,7 @@ def lr_scheduler(self, num_train_steps): boundaries.append(start + warmup_steps) lr_decay_steps = ( - _convert_ratio_or_steps(self.decay, cycle_steps) + _convert_frac_or_steps(self.decay, cycle_steps) if self.decay is not None else cycle_steps - warmup_steps ) @@ -218,7 +218,31 @@ def lr_scheduler(self, num_train_steps): return schedule def _get_cycle_minima(self, total_main_steps): - if self.haps is not None: + if self.cycle_length is not None: + if self.cycles is not None: + raise ValueError("Can't use both cycle_length and cycles.") + if self.haps is not None: + warnings.warn("haps is deprecated. Use cycles instead.", DeprecationWarning) + raise ValueError("Can't use both cycle_length and haps.") + + if isinstance(self.cycle_length, int | float): + cycle_length = _convert_frac_or_steps(self.cycle_length, total_main_steps) + cooldown_points = [i * cycle_length for i in range(1, total_main_steps // cycle_length)] + if total_main_steps % cycle_length != 0: + warnings.warn( + "Cycle length does not divide total number of steps. The last cycle will be shorter." + ) + + elif isinstance(self.cycle_length, list): + lengths = np.array(self.cycle_length) + steps = np.cumsum(lengths) + if steps[-1] > total_main_steps: + raise ValueError(f"Cycle lengths exceed total number of steps: {steps[-1]} > {total_main_steps}") + cooldown_points = steps.tolist() + else: + raise ValueError("Invalid cycle_length. Must be a fraction, number of steps, or a list of steps.") + + elif self.haps is not None: warnings.warn("haps is deprecated. Use cycles instead.", DeprecationWarning) cooldown_points = list(self.haps) elif isinstance(self.cycles, int): @@ -228,6 +252,9 @@ def _get_cycle_minima(self, total_main_steps): cooldown_points = list(self.cycles) else: cooldown_points = [] + + cooldown_points.insert(0, 0) + cooldown_points.append(total_main_steps) return cooldown_points @@ -247,11 +274,14 @@ def schedule(count): return schedule -def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): - if ratio_or_steps < 1.0: - return int(ratio_or_steps * num_train_steps) - else: - return int(ratio_or_steps) +def _convert_frac_or_steps(frac_or_steps: float | int, num_train_steps: int): + # if it's greater than 1, it must be a whole number of steps + if frac_or_steps < 0.0 or (frac_or_steps > 1.0 and frac_or_steps % 1 != 0): + raise ValueError(f"Invalid fraction {frac_or_steps}. Must be between 0 and 1. You can also use (whole) steps.") + if frac_or_steps <= 1.0: + return int(frac_or_steps * num_train_steps) + + return int(frac_or_steps) @dataclass diff --git a/tests/test_optimizer_config.py b/tests/test_optimizer_config.py index 70737df7c..8301152cd 100644 --- a/tests/test_optimizer_config.py +++ b/tests/test_optimizer_config.py @@ -122,7 +122,7 @@ def test_wsds_schedule(): assert np.isclose(sched_fn(659), 1e-3) assert sched_fn(661) < 1e-3 - # Thrid cycle + # Third cycle assert np.isclose(sched_fn(701), 1e-3) assert np.isclose(sched_fn(969), 1e-3) assert sched_fn(971) < 1e-3 @@ -182,3 +182,64 @@ def test_rewarmup_schedule(): # Final decay phase assert sched_fn(999 - 1) > sched_fn(999) assert np.isclose(sched_fn(999), 0.2e-2, atol=1e-4) # End of second decay + + +def test_linear_schedule_with_cycle_length(): + optimizer = AdamConfig( + learning_rate=5e-4, + weight_decay=0.0, + warmup=50, + min_lr_ratio=0.2, + lr_schedule="linear", + cycle_length=500, + ) + + sched_fn = optimizer.lr_scheduler(1000) + + # Warmup phase + assert np.isclose(sched_fn(0), 0.0) + assert np.isclose(sched_fn(50), 5e-4) + + num_main_steps = 1000 + + # First cycle decay + assert np.isclose(sched_fn(499), 0.2 * 5e-4, atol=1e-5) + + # Second cycle starts + assert np.isclose(sched_fn(500), 5e-4) + + # midway through second cycle + midpoint = 500 - 1 + num_main_steps // 4 + assert np.isclose(sched_fn(midpoint), (5e-4 + 0.2 * 5e-4) / 2, atol=1e-5) + + # Final value + assert np.isclose(sched_fn(999), 0.2 * 5e-4, atol=1e-5) + + +def test_wsds_schedule_with_cycle_points(): + optimizer = AdamConfig( + learning_rate=1e-3, + weight_decay=0.0, + warmup=0.0, + decay=0.1, + min_lr_ratio=0.1, + lr_schedule="cosine", + cycle_length=[300, 400], + ) + + sched_fn = optimizer.lr_scheduler(1000) + + # First cycle + assert np.isclose(sched_fn(0), 1e-3) + assert np.isclose(sched_fn(269), 1e-3) + assert sched_fn(271) < 1e-3 + + # Second cycle + assert np.isclose(sched_fn(300), 1e-3) + assert np.isclose(sched_fn(659), 1e-3) + assert sched_fn(661) < 1e-3 + + # Third cycle + assert np.isclose(sched_fn(701), 1e-3) + assert np.isclose(sched_fn(969), 1e-3) + assert sched_fn(971) < 1e-3