Skip to content

Commit

Permalink
upgrade schedulers
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jul 20, 2024
1 parent 3a5f9e7 commit 6e97248
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 90 deletions.
12 changes: 6 additions & 6 deletions lightning_ir/lightning_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
WarmupLRScheduler,
)
from .schedulers import (
ConstantSchedulerWithLinearWarmup,
ConstantSchedulerWithQuadraticWarmup,
LinearSchedulerWithLinearWarmup,
GenericConstantSchedulerWithLinearWarmup,
GenericConstantSchedulerWithQuadraticWarmup,
GenericLinearSchedulerWithLinearWarmup,
)

__all__ = [
"ConstantLRSchedulerWithLinearWarmup",
"ConstantSchedulerWithQuadraticWarmup",
"ConstantSchedulerWithLinearWarmup",
"GenericConstantSchedulerWithQuadraticWarmup",
"GenericConstantSchedulerWithLinearWarmup",
"IndexCallback",
"LinearLRSchedulerWithLinearWarmup",
"LinearSchedulerWithLinearWarmup",
"GenericLinearSchedulerWithLinearWarmup",
"LR_SCHEDULERS",
"RankCallback",
"WarmupLRScheduler",
Expand Down
44 changes: 19 additions & 25 deletions lightning_ir/lightning_utils/lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,36 @@
from abc import ABC, abstractmethod

import torch

from .schedulers import ConstantSchedulerWithLinearWarmup, LambdaWarmupScheduler, LinearSchedulerWithLinearWarmup


class WarmupLRScheduler(torch.optim.lr_scheduler.LambdaLR, ABC):
class WarmupLRScheduler(LambdaWarmupScheduler, torch.optim.lr_scheduler.LambdaLR):
def __init__(
self,
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
*args,
verbose: bool = False,
**kwargs,
) -> None:
last_epoch = -1
self.interval = "step"
self.num_warmup_steps = num_warmup_steps
self.num_training_steps = num_training_steps
super().__init__(optimizer, self.lr_lambda, last_epoch, verbose)

@abstractmethod
def lr_lambda(self, current_step: int) -> float:
...


class LinearLRSchedulerWithLinearWarmup(WarmupLRScheduler):
def lr_lambda(self, current_step: int) -> float:
if current_step < self.num_warmup_steps:
return current_step / self.num_warmup_steps
return max(
0.0,
(self.num_training_steps - current_step) / (self.num_training_steps - self.num_warmup_steps),
super().__init__(
*args,
optimizer=optimizer,
lr_lambda=self.value_lambda,
num_warmup_steps=num_warmup_steps,
last_epoch=last_epoch,
verbose=verbose,
**kwargs,
)


class ConstantLRSchedulerWithLinearWarmup(WarmupLRScheduler):
def lr_lambda(self, current_step: int) -> float:
if current_step < self.num_warmup_steps:
return current_step / self.num_warmup_steps
return 1.0
class LinearLRSchedulerWithLinearWarmup(WarmupLRScheduler, LinearSchedulerWithLinearWarmup):
pass


class ConstantLRSchedulerWithLinearWarmup(WarmupLRScheduler, ConstantSchedulerWithLinearWarmup):
pass


LR_SCHEDULERS = [LinearLRSchedulerWithLinearWarmup, ConstantLRSchedulerWithLinearWarmup]
122 changes: 82 additions & 40 deletions lightning_ir/lightning_utils/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,81 @@
from ..base import LightningIRModule


class LambdaWarmupScheduler(Callback, ABC):
class LambdaWarmupScheduler(ABC):
def __init__(
self,
keys: Sequence[str],
num_warmup_steps: int,
num_delay_steps: int = 0,
num_training_steps: int = -1,
*args,
**kwargs,
) -> None:
super().__init__()
self.keys = keys
self.num_warmup_steps = num_warmup_steps
self.num_delay_steps = num_delay_steps
self.num_training_steps = num_training_steps
self.values: Dict[str, float] = {}
super().__init__(*args, **kwargs)

@abstractmethod
def value_lambda(self, current_step: int) -> float: ...

def check_delay(self, current_step: int) -> bool:
return current_step < self.num_delay_steps

def check_warmup(self, current_step: int) -> bool:
return current_step < self.num_warmup_steps + self.num_delay_steps


class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler):

def __init__(
self,
num_warmup_steps: int,
num_training_steps: int,
final_value: float = 0.0,
num_delay_steps: int = 0,
*args,
**kwargs,
) -> None:
self.num_training_steps = num_training_steps
self.final_value = final_value
super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)

def value_lambda(self, current_step: int) -> float:
...
if self.check_delay(current_step):
return 0.0
if self.check_warmup(current_step):
return (current_step - self.num_delay_steps) / self.num_warmup_steps
current_step = current_step - self.num_delay_steps - self.num_warmup_steps
remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps
step_size = (1 - self.final_value) / remaining_steps
return max(self.final_value, 1 - step_size * current_step)

def step(self, key: str, current_step: int) -> float:
value = self.values[key]
return value * self.value_lambda(current_step)

class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
def value_lambda(self, current_step: int) -> float:
if self.check_delay(current_step):
return 0.0
if self.check_warmup(current_step):
return (current_step - self.num_delay_steps) / self.num_warmup_steps
return 1.0


class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
def value_lambda(self, current_step: int) -> float:
if self.check_delay(current_step):
return 0.0
if self.check_warmup(current_step):
return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2
return 1.0


class GenericScheduler(Callback, ABC):

def __init__(self, keys: Sequence[str], *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.keys = keys
self.values: Dict[str, float] = {}

@abstractmethod
def step(self, key: str, current_step: int) -> float: ...

def get_value(self, sub_keys: Sequence[str], obj: object) -> object:
for sub_key in sub_keys:
Expand All @@ -42,8 +95,6 @@ def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None:
setattr(obj, sub_keys[-1], value)

def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
if self.num_training_steps == -1:
self.num_training_steps = trainer.estimated_stepping_batches
for key in self.keys:
sub_keys = key.split(".")
self.values[key] = float(self.get_value(sub_keys, pl_module))
Expand All @@ -55,35 +106,26 @@ def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, bat
sub_keys = key.split(".")
self.set_value(sub_keys, pl_module, value)

def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
for key in self.keys:
value = self.values[key]
sub_keys = key.split(".")
self.set_value(sub_keys, pl_module, value)

class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler):
def value_lambda(self, current_step: int) -> float:
if current_step <= self.num_delay_steps:
return 0.0
current_step -= self.num_delay_steps
if current_step <= self.num_warmup_steps:
return current_step / self.num_warmup_steps
return max(
0.0,
(self.num_training_steps - current_step) / (self.num_training_steps - self.num_warmup_steps),
)

class GenericLinearSchedulerWithLinearWarmup(GenericScheduler, LinearSchedulerWithLinearWarmup):
def step(self, key: str, current_step: int) -> float:
value = self.values[key]
return value * self.value_lambda(current_step)

class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
def value_lambda(self, current_step: int) -> float:
if current_step <= self.num_delay_steps:
return 0.0
current_step -= self.num_delay_steps
if current_step <= self.num_warmup_steps:
return current_step / self.num_warmup_steps
return 1.0

class GenericConstantSchedulerWithLinearWarmup(GenericScheduler, ConstantSchedulerWithLinearWarmup):
def step(self, key: str, current_step: int) -> float:
value = self.values[key]
return value * self.value_lambda(current_step)

class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
def value_lambda(self, current_step: int) -> float:
if current_step <= self.num_delay_steps:
return 0.0
current_step -= self.num_delay_steps
if current_step <= self.num_warmup_steps:
return (current_step / self.num_warmup_steps) ** 2
return 1.0

class GenericConstantSchedulerWithQuadraticWarmup(GenericScheduler, ConstantSchedulerWithQuadraticWarmup):
def step(self, key: str, current_step: int) -> float:
value = self.values[key]
return value * self.value_lambda(current_step)
Loading

0 comments on commit 6e97248

Please sign in to comment.