Skip to content

Commit

Permalink
Merge pull request #2124 from fzyzcjy/patch-1
Browse files Browse the repository at this point in the history
Fix super tiny type error
  • Loading branch information
rwightman authored Apr 2, 2024
2 parents 67b0b3d + b44e4e4 commit 59b3d86
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 9 deletions.
3 changes: 2 additions & 1 deletion timm/scheduler/cosine_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math
import numpy as np
import torch
from typing import List

from .scheduler import Scheduler

Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(
else:
self.warmup_steps = [1 for _ in self.base_values]

def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
Expand Down
2 changes: 1 addition & 1 deletion timm/scheduler/multistep_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_curr_decay_steps(self, t):
# assumes self.decay_t is sorted
return bisect.bisect_right(self.decay_t, t + 1)

def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
Expand Down
3 changes: 2 additions & 1 deletion timm/scheduler/plateau_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from typing import List

from .scheduler import Scheduler

Expand Down Expand Up @@ -106,5 +107,5 @@ def _apply_noise(self, epoch):
param_group['lr'] = new_lr
self.restore_lr = restore_lr

def _get_lr(self, t: int) -> float:
def _get_lr(self, t: int) -> List[float]:
assert False, 'should not be called as step is overridden'
3 changes: 2 additions & 1 deletion timm/scheduler/poly_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import math
import logging
from typing import List

import torch

Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
else:
self.warmup_steps = [1 for _ in self.base_values]

def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
Expand Down
6 changes: 3 additions & 3 deletions timm/scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from abc import ABC
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import torch

Expand Down Expand Up @@ -65,10 +65,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)

@abc.abstractmethod
def _get_lr(self, t: int) -> float:
def _get_lr(self, t: int) -> List[float]:
pass

def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
if not proceed:
return None
Expand Down
4 changes: 3 additions & 1 deletion timm/scheduler/step_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""
import math
import torch
from typing import List


from .scheduler import Scheduler

Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(
else:
self.warmup_steps = [1 for _ in self.base_values]

def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
Expand Down
3 changes: 2 additions & 1 deletion timm/scheduler/tanh_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math
import numpy as np
import torch
from typing import List

from .scheduler import Scheduler

Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(
else:
self.warmup_steps = [1 for _ in self.base_values]

def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
Expand Down

0 comments on commit 59b3d86

Please sign in to comment.