diff --git a/lion_pytorch/__init__.py b/lion_pytorch/__init__.py deleted file mode 100644 index b3a7799d..00000000 --- a/lion_pytorch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from lion_pytorch.lion_pytorch import Lion diff --git a/lion_pytorch/foreach.py b/lion_pytorch/foreach.py deleted file mode 100644 index 50d60518..00000000 --- a/lion_pytorch/foreach.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations -from typing import Tuple, Callable - -import torch -from torch.optim.optimizer import Optimizer - -# functions - -def exists(val): - return val is not None - -# class - -class Lion(Optimizer): - def __init__( - self, - params, - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - decoupled_weight_decay: bool = False - ): - assert lr > 0. - assert all([0. <= beta <= 1. for beta in betas]) - assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' - - self._init_lr = lr - self.decoupled_wd = decoupled_weight_decay - - defaults = dict( - lr = lr, - betas = betas, - weight_decay = weight_decay - ) - - super().__init__(params, defaults) - - @torch.no_grad() - def step( - self, - closure: Callable | None = None - ): - - loss = None - if exists(closure): - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - - lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr - - # maybe decoupled weight decay - - if decoupled_wd: - wd /= init_lr - - # accumulate List[Tensor] for foreach inplace updates - - params = [] - grads = [] - exp_avgs = [] - - for p in filter(lambda p: exists(p.grad), group['params']): - - grad, state = p.grad, self.state[p] - - # init state - exponential moving average of gradient values - - if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) - - exp_avg = state['exp_avg'] - - params.append(p) - grads.append(grad) - exp_avgs.append(exp_avg) - - # stepweight decay - - torch._foreach_mul_(params, 1. - lr * wd) - - # weight update - - updates = [t.clone() for t in exp_avgs] - torch._foreach_lerp_(updates, grads, 1. - beta1) - torch._foreach_sign_(updates) - - torch._foreach_add_(params, updates, alpha = -lr) - - # decay momentum running average - - torch._foreach_lerp_(exp_avgs, grads, 1. - beta2) - - return loss diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py deleted file mode 100644 index b0d3a3f8..00000000 --- a/lion_pytorch/lion_pytorch.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations -from typing import Tuple, Callable, Union - -import torch -from torch.optim.optimizer import Optimizer - - -def exists(val): - return val is not None - - -def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): - # stepweight decay - p.data.mul_(1. - lr * wd) - - # weight update - update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_() - p.add_(update, alpha=-lr) - - # decay the momentum running average coefficient - exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2) - - -class Lion(Optimizer): - def __init__( - self, - params, - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - use_triton: bool = False, - decoupled_weight_decay: bool = False, - ): - assert lr > 0. - assert all([0. <= beta <= 1. for beta in betas]) - - self._init_lr = lr - self.decoupled_wd = decoupled_weight_decay - - defaults = dict( - lr=lr, - betas=betas, - weight_decay=weight_decay - ) - - super().__init__(params, defaults) - self.update_fn = update_fn - - if use_triton: - from lion_pytorch.triton import update_fn as triton_update_fn - self.update_fn = triton_update_fn - - @torch.no_grad() - def step( - self, - closure: Union[Callable, None] = None - ): - - loss = None - if exists(closure): - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in filter(lambda p: exists(p.grad), group['params']): - - # grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr - grad = p.grad - lr = group['lr'] - wd = group['weight_decay'] - beta1, beta2 = group['betas'] - state= self.state[p] - decoupled_wd = self.decoupled_wd - init_lr = self._init_lr - - # maybe decoupled weight decay - - if decoupled_wd: - wd /= init_lr - - # init state - exponential moving average of gradient values - if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) - - exp_avg = state['exp_avg'] - - self.update_fn( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2 - ) - - return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py deleted file mode 100644 index 1dd4696b..00000000 --- a/lion_pytorch/triton.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl -except ImportError as e: - print('triton is not installed, please install by running `pip install triton>=2.2.0`') - exit() - -# triton cuda kernel - -@triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), -], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) -@triton.jit -def update_fn_kernel( - p_ptr, - grad_ptr, - exp_avg_ptr, - lr, - wd, - beta1, - beta2, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis = 0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - # offsetted pointers - - offset_p_ptr = p_ptr + offsets - offset_grad_ptr = grad_ptr + offsets - offset_exp_avg_ptr = exp_avg_ptr + offsets - - # load - - p = tl.load(offset_p_ptr, mask = mask) - grad = tl.load(offset_grad_ptr, mask = mask) - exp_avg = tl.load(offset_exp_avg_ptr, mask = mask) - - # stepweight decay - - p = p * (1 - lr * wd) - - # diff between momentum running average and grad - - diff = exp_avg - grad - - # weight update - - update = diff * beta1 + grad - - # torch.sign - - can_update = update != 0 - update_sign = tl.where(update > 0, -lr, lr) - - p = p + update_sign * can_update - - # decay the momentum running average coefficient - - exp_avg = diff * beta2 + grad - - # store new params and momentum running average coefficient - - tl.store(offset_p_ptr, p, mask = mask) - tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) - -def update_fn( - p: torch.Tensor, - grad: torch.Tensor, - exp_avg: torch.Tensor, - lr: float, - wd: float, - beta1: float, - beta2: float -): - assert all([t.is_cuda for t in (p, grad, exp_avg)]) - n_elements = p.numel() - - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - - update_fn_kernel[grid]( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2, - n_elements - ) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 548c352b..a82f0294 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -44,8 +44,6 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata -from lion_pytorch import Lion - logger = logging.get_logger(__name__) @@ -329,7 +327,9 @@ def init_optimizer_and_grad_accumulator( # Basic optimizer builder def basic_optimizer_builder(named_param_groups): optimizer = None + if optimizer_args.optimizer_factory.name == "adamW": + def optimizer(param_groups): return torch.optim.AdamW( param_groups, @@ -339,20 +339,16 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) + elif optimizer_args.optimizer_factory.name == "sgd": + def optimizer(param_groups): return torch.optim.SGD( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, ) - elif optimizer_args.optimizer_factory.name == "lion": - def optimizer(param_groups): - return Lion( - param_groups, - lr=optimizer_args.learning_rate_scheduler.learning_rate, - weight_decay=optimizer_args.weight_decay, - ) + else: raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported")