Skip to content

Commit

Permalink
Add period parameter to cosine scheduler (#1413)
Browse files Browse the repository at this point in the history
* Add optional 'period' parameter to cosine_schedule and CosineWarmupScheduler

---------

Co-authored-by: guarin <[email protected]>
  • Loading branch information
JannikWirtz and guarin authored Oct 3, 2023
1 parent 569bf2b commit 5f1cd69
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
32 changes: 29 additions & 3 deletions lightly/utils/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import warnings
from typing import Optional

import numpy as np
import torch


def cosine_schedule(
step: int, max_steps: int, start_value: float, end_value: float
step: int,
max_steps: int,
start_value: float,
end_value: float,
period: Optional[int] = None,
) -> float:
"""Use cosine decay to gradually modify start_value to reach target end_value during
iterations.
Expand All @@ -19,6 +24,9 @@ def cosine_schedule(
Starting value.
end_value:
Target value.
period (optional):
The number of steps over which the cosine function completes a full cycle.
If not provided, it defaults to max_steps.
Returns:
Cosine decay value.
Expand All @@ -28,13 +36,21 @@ def cosine_schedule(
raise ValueError("Current step number can't be negative")
if max_steps < 1:
raise ValueError("Total step number must be >= 1")
if step > max_steps:
if period is None and step > max_steps:
warnings.warn(
f"Current step number {step} exceeds max_steps {max_steps}.",
category=RuntimeWarning,
)
if period is not None and period <= 0:
raise ValueError("Period must be >= 1")

if max_steps == 1:
decay: float
if period is not None: # "cycle" based on period, if provided
decay = (
end_value
- (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
)
elif max_steps == 1:
# Avoid division by zero
decay = end_value
elif step == max_steps:
Expand Down Expand Up @@ -83,12 +99,14 @@ def __init__(
last_epoch: int = -1,
start_value: float = 1.0,
end_value: float = 0.001,
period: Optional[int] = None,
verbose: bool = False,
) -> None:
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.start_value = start_value
self.end_value = end_value
self.period = period
super().__init__(
optimizer=optimizer,
lr_lambda=self.scale_lr,
Expand All @@ -110,6 +128,14 @@ def scale_lr(self, epoch: int) -> float:
"""
if epoch < self.warmup_epochs:
return self.start_value * (epoch + 1) / self.warmup_epochs
elif self.period is not None:
return cosine_schedule(
step=epoch - self.warmup_epochs,
max_steps=1,
start_value=self.start_value,
end_value=self.end_value,
period=self.period,
)
else:
return cosine_schedule(
step=epoch - self.warmup_epochs,
Expand Down
8 changes: 8 additions & 0 deletions tests/utils/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def test_cosine_schedule(self) -> None:
):
cosine_schedule(11, 10, 0.0, 1.0)

def test_cosine_schedule__period(self) -> None:
self.assertAlmostEqual(cosine_schedule(0, 1, 0, 1.0, period=10), 0.0, 6)
self.assertAlmostEqual(cosine_schedule(3, 1, 0, 2.0, period=10), 1.30901706, 6)
self.assertAlmostEqual(cosine_schedule(10, 1, 0, 1.0, period=10), 0.0, 6)
self.assertAlmostEqual(cosine_schedule(15, 1, 0, 1.0, period=10), 1.0, 6)
with self.assertRaises(ValueError):
cosine_schedule(1, 10, 0.0, 1.0, period=-1)

def test_CosineWarmupScheduler(self) -> None:
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(
Expand Down

0 comments on commit 5f1cd69

Please sign in to comment.