From ebea115e9f730b8ee25105d361cda38df8ed261d Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 11 Jan 2025 11:53:59 +0000 Subject: [PATCH] add --- src/nanotron/config/config.py | 1 + src/nanotron/fp8/distributed.py | 16 ++++++ src/nanotron/helpers.py | 63 +++++++++++++--------- src/nanotron/optim/gradient_accumulator.py | 46 ++++++++-------- src/nanotron/trainer.py | 35 ++++++++---- 5 files changed, 103 insertions(+), 58 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c5bd2d99..6689cf56 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -316,6 +316,7 @@ class OptimizerArgs: clip_grad: Optional[float] accumulate_grad_in_fp32: bool learning_rate_scheduler: LRSchedulerArgs + master_weight_dtype: torch.dtype = torch.float32 @dataclass diff --git a/src/nanotron/fp8/distributed.py b/src/nanotron/fp8/distributed.py index c5aacb98..3ca10ce0 100644 --- a/src/nanotron/fp8/distributed.py +++ b/src/nanotron/fp8/distributed.py @@ -9,6 +9,22 @@ from nanotron.parallel.parameters import NanotronParameter +def post_scaling_all_reduce_mean( + tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None, async_op=False +) -> torch.Tensor: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op) + tensor.div_(dist.get_world_size()) + return tensor + + +def post_scaling_all_reduce_coalesced_mean( + tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None, async_op=False +) -> torch.Tensor: + dist.all_reduce_coalesced(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op) + tensor.div_(dist.get_world_size()) + return tensor + + def all_reduce( tensor: Union[torch.Tensor, NanotronParameter], op: dist.ReduceOp = dist.ReduceOp.SUM, diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index c2b1d768..1135ae62 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -20,7 +20,6 @@ from nanotron import logging from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs from nanotron.distributed import ProcessGroup -from nanotron.fp8.optim import FP8AdamW from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel from nanotron.optim.base import BaseOptimizer, Optimizer @@ -295,6 +294,7 @@ def merge_named_param_groups( def init_optimizer_and_grad_accumulator( parametrization_method: ParametrizationMethod, model: nn.Module, + master_weight_dtype: torch.dtype, optimizer_args: OptimizerArgs, parallel_context: ParallelContext, ) -> Tuple[BaseOptimizer, GradientAccumulator]: @@ -327,34 +327,44 @@ def basic_optimizer_builder(named_param_groups): optimizer = None if optimizer_args.optimizer_factory.name == "adamW": - from nanotron import constants def optimizer(param_groups): # if has_fp8_params(param_groups) is False: - if constants.CONFIG.model.dtype != torch.int8: - return torch.optim.AdamW( - param_groups, - lr=optimizer_args.learning_rate_scheduler.learning_rate, - weight_decay=optimizer_args.weight_decay, - eps=optimizer_args.optimizer_factory.adam_eps, - betas=( - optimizer_args.optimizer_factory.adam_beta1, - optimizer_args.optimizer_factory.adam_beta2, - ), - fused=optimizer_args.optimizer_factory.torch_adam_is_fused, - ) - else: - return FP8AdamW( - param_groups, - lr=optimizer_args.learning_rate_scheduler.learning_rate, - weight_decay=optimizer_args.weight_decay, - eps=optimizer_args.optimizer_factory.adam_eps, - betas=( - optimizer_args.optimizer_factory.adam_beta1, - optimizer_args.optimizer_factory.adam_beta2, - ), - recipe=constants.CONFIG.fp8.optim, - ) + # if constants.CONFIG.model.dtype != torch.int8: + # return torch.optim.AdamW( + # param_groups, + # lr=optimizer_args.learning_rate_scheduler.learning_rate, + # weight_decay=optimizer_args.weight_decay, + # eps=optimizer_args.optimizer_factory.adam_eps, + # betas=( + # optimizer_args.optimizer_factory.adam_beta1, + # optimizer_args.optimizer_factory.adam_beta2, + # ), + # fused=optimizer_args.optimizer_factory.torch_adam_is_fused, + # ) + # else: + # return FP8AdamW( + # param_groups, + # lr=optimizer_args.learning_rate_scheduler.learning_rate, + # weight_decay=optimizer_args.weight_decay, + # eps=optimizer_args.optimizer_factory.adam_eps, + # betas=( + # optimizer_args.optimizer_factory.adam_beta1, + # optimizer_args.optimizer_factory.adam_beta2, + # ), + # recipe=constants.CONFIG.fp8.optim, + # ) + return torch.optim.AdamW( + param_groups, + lr=optimizer_args.learning_rate_scheduler.learning_rate, + weight_decay=optimizer_args.weight_decay, + eps=optimizer_args.optimizer_factory.adam_eps, + betas=( + optimizer_args.optimizer_factory.adam_beta1, + optimizer_args.optimizer_factory.adam_beta2, + ), + fused=optimizer_args.optimizer_factory.torch_adam_is_fused, + ) elif optimizer_args.optimizer_factory.name == "sgd": @@ -384,6 +394,7 @@ def grad_optimizer_builder(named_param_groups): gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator( named_parameters=named_params, grad_buckets_named_params=named_parameters, + master_dtype=master_weight_dtype, ), named_params_or_groups=named_param_groups, optimizer_builder=basic_optimizer_builder, diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index c5dc9990..ee532eb1 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -62,6 +62,7 @@ def __init__( self, named_parameters: Iterator[Tuple[str, NanotronParameter]], grad_buckets_named_params: Optional[Iterator[Tuple[str, NanotronParameter]]] = None, + master_dtype: torch.dtype = torch.float32, ): """Create a gradient accumulator that will accumulate gradients in fp32. @@ -84,14 +85,14 @@ def __init__( # Assign big buffer for weights + grad in fp32 segment_index = {} length = 0 - fp32_params = [] + master_params = [] for name, param in named_parameters: # NOTE: FP8 Parameter by default has requires_grad=False, # because we want to do the backward ourself, so here we only skip # if the parameter isn't fp8, and doesn't require grad if self._is_not_required_master_weights(param): - fp32_params.append((name, param)) + master_params.append((name, param)) continue start = length @@ -100,10 +101,10 @@ def __init__( segment_index[name] = (start, end_weight, param) length = end_weight - big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda") + big_flat_buffer = torch.empty(length, dtype=master_dtype, device="cuda") self.parameters = { name: { - "fp32": big_flat_buffer[start_weight:end_weight].view_as(param), + "master": big_flat_buffer[start_weight:end_weight].view_as(param), "half": param, } for name, (start_weight, end_weight, param) in segment_index.items() @@ -111,10 +112,10 @@ def __init__( self.parameters.update( { name: { - "fp32": param, + "master": param, "half": param, } - for name, param in fp32_params + for name, param in master_params } ) @@ -124,15 +125,15 @@ def __init__( with torch.inference_mode(): for _, elt in self.parameters.items(): - fp32_param = elt["fp32"] + master_param = elt["master"] half_param = elt["half"] # Check that fp32 weights have the same memory representation as half precision weights - assert fp32_param.stride() == half_param.stride() + assert master_param.stride() == half_param.stride() # Copy weights from half precision to full precision if not isinstance(half_param.data, FP8Tensor): - fp32_param.copy_(half_param) + master_param.copy_(half_param) else: from nanotron import constants @@ -147,13 +148,13 @@ def find_param_name(param, named_parameters): p_data = constants.CPU_WEIGHTS[p_name] assert p_data.dtype == torch.float32, f"Expected {p_name} to be float32, but got {p_data.dtype}" - fp32_param.copy_(constants.CPU_WEIGHTS[p_name]) + master_param.copy_(constants.CPU_WEIGHTS[p_name]) del constants.CPU_WEIGHTS[p_name] del p_name # Set requires_grad=True - fp32_param.requires_grad = True + master_param.requires_grad = True self._is_accumulation_sync_step = False # We need the last allreduce handle to make sure it finishes before the optimizer step @@ -298,7 +299,7 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: # In the case an optimizer decides to set it to None, we need to re-assign previous buffer if name in self.parameters: - fp32_param = self.parameters[name]["fp32"] + master_param = self.parameters[name]["master"] if hasattr(self, "param_name_to_offsets"): if name not in self.param_name_to_offsets: # When `name` isn't in `param_name_to_offsets` it means the slice is empty. @@ -307,8 +308,8 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: grad = fp32_grad.view(-1)[start_offset:end_offset] else: grad = fp32_grad - fp32_param.grad = grad - assert is_overflow_underflow_nan(fp32_param.grad) is False + master_param.grad = grad + assert is_overflow_underflow_nan(master_param.grad) is False @contextmanager def no_sync(self): @@ -331,15 +332,15 @@ def step(self): We need to update only the parameters that were updated by the optimizer. """ for name in self.parameters.keys(): - fp32_param = self.parameters[name]["fp32"] + master_param = self.parameters[name]["master"] half_param = self.parameters[name]["half"] # TODO @nouamane: should we use a fused kernel to copy? # Copy weights from full precision to half precision if half_param.data.__class__ == FP8Tensor: - half_param.data.set_data(fp32_param, sync=False) + half_param.data.set_data(master_param, sync=False) else: - half_param.copy_(fp32_param) + half_param.copy_(master_param) def zero_grad(self): # Full precision gradients are reset to zero/none after the underlying `optimiser.step`, so no need to reset. @@ -355,22 +356,22 @@ def zero_grad(self): self._contiguous_fp32_grad_buffer.zero_() def get_parameter_for_optimizer(self, name: str) -> NanotronParameter: - return self.parameters[name]["fp32"] + return self.parameters[name]["master"] def get_grad_buffer(self, name: str) -> torch.Tensor: """Returns the gradient of the parameter from the appropriate grad bucket.""" return self.fp32_grad_buffers[name]["fp32_grad"] def state_dict(self) -> Dict[str, torch.Tensor]: - # We consider `fp32` parameters as a state of the gradient accumulator - return {name: elt["fp32"] for name, elt in self.parameters.items()} + # We consider master parameters as a state of the gradient accumulator + return {name: elt["master"] for name, elt in self.parameters.items()} def load_state_dict(self, state_dict: Dict[str, torch.Tensor]): assert set(state_dict.keys()) == set(self.parameters.keys()) with torch.inference_mode(): for name, elt in self.parameters.items(): - elt["fp32"].copy_(state_dict[name]) + elt["master"].copy_(state_dict[name]) @dataclasses.dataclass @@ -405,6 +406,9 @@ def get_fp32_accum_hook( # s = torch.cuda.Stream() def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]: + import pydevd + + pydevd.settrace(suspend=True, trace_only_current_thread=True) # nonlocal s # DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation. # See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details. diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 9b20f615..44ef559f 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -36,7 +36,6 @@ ) from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.dataloader import sanity_check_dataloader -from nanotron.fp8.parallel import DistributedDataParallel as FP8DistributedDataParallel from nanotron.fp8.tensor import FP8Tensor from nanotron.fp8.utils import convert_model_to_fp8 from nanotron.helpers import ( @@ -240,9 +239,17 @@ def __init__( self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator( parametrization_method=parametrization_method, model=self.model, + master_weight_dtype=self.config.optimizer.master_weight_dtype, optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) + # NOTE: quantize optimizer states + # add hook to dequantize optimizer states before .step() + # add hook step to recompute lr + # add post_step hook to quantize optimizer states + + assert 1 == 1 + if self.init_checkpoint_path is not None: load_optimizer( optimizer=self.optimizer, @@ -883,16 +890,22 @@ def _init_model( # Check that the model has at least one grad. Necessary for DDP check_model_has_grad(model=model, parallel_context=parallel_context) # TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway) - if self.config.model.dtype == torch.int8: - raise NotImplementedError - model = FP8DistributedDataParallel(model, self.parallel_context) - else: - model = DistributedDataParallel( - model, - process_group=parallel_context.dp_pg, - broadcast_buffers=False, - bucket_cap_mb=config.model.ddp_bucket_cap_mb, - ) + # if self.config.model.dtype == torch.int8: + # raise NotImplementedError + # model = FP8DistributedDataParallel(model, self.parallel_context) + # else: + # model = DistributedDataParallel( + # model, + # process_group=parallel_context.dp_pg, + # broadcast_buffers=False, + # bucket_cap_mb=config.model.ddp_bucket_cap_mb, + # ) + model = DistributedDataParallel( + model, + process_group=parallel_context.dp_pg, + broadcast_buffers=False, + bucket_cap_mb=config.model.ddp_bucket_cap_mb, + ) # Sanity check the model, all parameters must be NanotronParameter (either tied or sharded) sanity_check(root_module=model)