From d2a307df39e0c90272a3fd283cef5e2043e32c58 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 9 Dec 2024 18:10:30 +0000 Subject: [PATCH] add saving sharding's metadata of master weights in grad_accu --- src/nanotron/optim/gradient_accumulator.py | 71 +++++++++++--- src/nanotron/optim/zero.py | 31 +++++-- src/nanotron/serialize/optimizer.py | 103 +++++++++++++++++++-- src/nanotron/trainer.py | 16 ++++ 4 files changed, 194 insertions(+), 27 deletions(-) diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 2e940744..ace88a2a 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -2,13 +2,14 @@ from abc import ABC, abstractmethod from collections import OrderedDict from contextlib import contextmanager -from typing import Callable, Dict, Iterator, Optional, Tuple +from typing import Callable, Dict, Iterator, Optional, Tuple, cast import torch from torch.distributed import GradBucket import nanotron.distributed as dist from nanotron import logging +from nanotron.optim.zero import SlicedFlatTensor, get_sliced_tensor from nanotron.parallel.parameters import NanotronParameter from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage @@ -68,7 +69,11 @@ def __init__( grad_buckets_named_params: The parameters to accumulate gradients for. If None it defaults to `named_parameters`. In case of Zero 1, this should be all the parameters in the model. Note: We use `grad_buckets_named_params` to keep grad buffers for all parameters even when Zero 1 is used. This is because we need to accumulate gradients for all parameters without having to reduce in every accumulation step. - Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator + Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator. + + "self.parameters" + - .fp32: the pointer to the full precision weights + - .half: the pointer to the half precision weights """ if grad_buckets_named_params is None: named_parameters = list(named_parameters) @@ -86,20 +91,57 @@ def __init__( if not param.requires_grad: continue - start = length - end_weight = start + param.numel() + global_buffer_start_idx = length + global_buffer_end_idx = global_buffer_start_idx + param.numel() + assert name not in segment_index - segment_index[name] = (start, end_weight, param) - length = end_weight + param = cast(SlicedFlatTensor, param) + segment_index[name] = ( + (global_buffer_start_idx, global_buffer_end_idx), + (param.start_offset, param.end_offset), + param, + ) + length = global_buffer_end_idx big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda") - self.parameters = { - name: { - "fp32": big_flat_buffer[start_weight:end_weight].view_as(param), + + self.parameters = {} + for name, ( + (global_start_idx, global_end_idx), + (dp_weight_start_idx, dp_weight_end_idx), + param, + ) in segment_index.items(): + if name == "model.final_layer_norm.pp_block.weight": + assert 1 == 1 + + fp32_p = big_flat_buffer[global_start_idx:global_end_idx].view_as(param) + sliced_fp32_p = get_sliced_tensor( + fp32_p, + start_offset=dp_weight_start_idx, + end_offset=dp_weight_end_idx, + is_sharded=True, + ) + assert ( + sliced_fp32_p.numel() == param.numel() + ), f"Expected {name} to have the same number of elements, dp_weight_start_idx: {dp_weight_start_idx}, dp_weight_end_idx: {dp_weight_end_idx}, param.numel(): {param.numel()}, sliced_fp32_p.numel(): {sliced_fp32_p.numel()}" + self.parameters[name] = { + "fp32": sliced_fp32_p, "half": param, } - for name, (start_weight, end_weight, param) in segment_index.items() - } + + # self.parameters = { + # name: { + # # "fp32": big_flat_buffer[global_start_idx:global_end_idx].view_as(param), + # # NOTE: save the way we shard stuff in dp for zero-1, so we can reshard it + # "fp32": get_sliced_tensor( + # big_flat_buffer[global_start_idx:global_end_idx].view_as(param), + # start_offset=dp_weight_start_idx, + # end_offset=dp_weight_end_idx, + # ), + # "half": param, + # } + # for name, ((global_start_idx, global_end_idx), (dp_weight_start_idx, dp_weight_end_idx), param) in segment_index.items() + # } with torch.inference_mode(): for _, elt in self.parameters.items(): @@ -108,6 +150,9 @@ def __init__( # Check that fp32 weights have the same memory representation as half precision weights assert fp32_param.stride() == half_param.stride() + assert ( + fp32_param.numel() == half_param.numel() + ), f"There is a size mismatch of {name}, fp32_param: {fp32_param.numel()}, half_param: {half_param.numel()}" # Copy weights from half precision to full precision fp32_param.copy_(half_param) @@ -289,6 +334,10 @@ def state_dict(self) -> Dict[str, torch.Tensor]: def load_state_dict(self, state_dict: Dict[str, torch.Tensor]): assert set(state_dict.keys()) == set(self.parameters.keys()) + # NOTE: double check if the dp size in the checkpoint + # is differ from the current dp size, then we merge the states + # and reshard them again + with torch.inference_mode(): for name, elt in self.parameters.items(): elt["fp32"].copy_(state_dict[name]) diff --git a/src/nanotron/optim/zero.py b/src/nanotron/optim/zero.py index cb61c8b7..70f3c1d7 100644 --- a/src/nanotron/optim/zero.py +++ b/src/nanotron/optim/zero.py @@ -62,6 +62,7 @@ def __init__( # partition model's params across DP ranks. # `self.param_name_to_dp_rank_offsets` sets mapping between each param inside self.named_params and its rank # NOTE: some param_groups may have no params in the current rank. we still keep them in self.optimizer.param_groups + # TODO: maybe not shard layernorm params in zero-1, because it is small anyway self.param_name_to_dp_rank_offsets = self._partition_parameters() current_dp_rank = dist.get_rank(self.dp_pg) @@ -171,6 +172,8 @@ def _partition_parameters(self) -> Dict[str, Dict[int, Tuple[int, int]]]: for name, param in named_params: # We assume parameter to be contiguous in order to have an easy way of sharding it. assert param.is_contiguous(), f"Parameter {name} is not contiguous" + if name == "model.final_layer_norm.pp_block.weight": + assert 1 == 1 numel = param.numel() padded_numel_per_dp = (numel - 1) // self.dp_pg.size() + 1 @@ -262,13 +265,18 @@ class SlicedFlatTensor(torch.Tensor): __torch_function__ = torch._C._disabled_torch_function_impl @staticmethod - def get_sliced_flat_tensor(data, start_offset, end_offset): - with torch.no_grad(): - return data.view(-1)[start_offset:end_offset] + def get_sliced_flat_tensor(data, start_offset: int, end_offset: int, is_sharded: bool): + if is_sharded is False: + with torch.no_grad(): + return data.view(-1)[start_offset:end_offset] + else: + return data @staticmethod - def __new__(cls, data, start_offset, end_offset): - sliced_tensor = cls.get_sliced_flat_tensor(data=data, start_offset=start_offset, end_offset=end_offset) + def __new__(cls, data, start_offset: int, end_offset: int, is_sharded: bool): + sliced_tensor = cls.get_sliced_flat_tensor( + data=data, start_offset=start_offset, end_offset=end_offset, is_sharded=is_sharded + ) result = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, @@ -283,11 +291,16 @@ def __new__(cls, data, start_offset, end_offset): ) return result - def __init__(self, data, start_offset, end_offset): + def __init__(self, data, start_offset: int, end_offset: int, is_sharded: bool): + """ + is_sharded: whether a tensor is sharded or not + Sometimes we already shard a tensor, and just want to wrap it in a `SlicedFlatTensor` + so we can save the sharding metadata cleanly. + """ super().__init__() # TODO @thomasw21: Make is so that you can never update this value self.sliced_flat_tensor = self.get_sliced_flat_tensor( - data=data, start_offset=start_offset, end_offset=end_offset + data=data, start_offset=start_offset, end_offset=end_offset, is_sharded=is_sharded ) self.orig_data = data self.start_offset = start_offset @@ -337,9 +350,9 @@ def data_ptr(self): grad = property(_get_grad, _set_grad, _del_grad) -def get_sliced_tensor(param: NanotronParameter, start_offset: int, end_offset: int): +def get_sliced_tensor(param: NanotronParameter, start_offset: int, end_offset: int, is_sharded: bool = False): # This allows us to create a leaf tensor, despite sharing the underlying storage - result = SlicedFlatTensor(data=param, start_offset=start_offset, end_offset=end_offset) + result = SlicedFlatTensor(data=param, start_offset=start_offset, end_offset=end_offset, is_sharded=is_sharded) return result diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index f9c65bcd..65b7df60 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -22,13 +22,39 @@ from nanotron.serialize.metadata import TensorMetadata from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors - # TODO(xrsrke): take rank instead of parallel_context -def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): +# def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): +# if is_zero is True: +# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" +# else: +# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" + + +def get_optimizer_filename( + tp_topology: Tuple[int, int], + pp_topology: Tuple[int, int], + dp_topology: Optional[Tuple[int, int]] = None, + exp_topology: Optional[Tuple[int, int]] = None, + is_zero: Optional[bool] = None, +): + """ + tp_topology: Tuple[int, int] = (rank, size) + pp_topology: Tuple[int, int] = (rank, size) + dp_topology: Tuple[int, int] = (rank, size) + + NOTE: sometimes we get the checkpoint from a different topology (not the current parallel_context) + """ + assert exp_topology is not None, "exp_topology is required" + assert is_zero is not None, "is_zero is required" + pp_rank, pp_size = pp_topology + tp_rank, tp_size = tp_topology + exp_rank, exp_size = exp_topology + if is_zero is True: - return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" + dp_rank, dp_size = dp_topology + return f"{ObjectType.OPTIMIZER.value}_pp-{pp_rank}-of-{pp_size}_dp-{dp_rank}-of-{dp_size}_tp-{tp_rank}-of-{tp_size}_exp-{exp_rank}-of-{exp_size}.pt" else: - return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" + return f"{ObjectType.OPTIMIZER.value}_pp-{pp_rank}-of-{pp_size}_tp-{tp_rank}-of-{tp_size}_exp-{exp_rank}-of-{exp_size}.pt" def lr_scheduler_filename(parallel_context: ParallelContext, is_zero: bool): @@ -102,7 +128,14 @@ def convert_to_string(input_item): torch.save( optimizer.state_dict(), root_folder - / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), + # / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), + / get_optimizer_filename( + tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()), + pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()), + dp_topology=(dist.get_rank(parallel_context.dp_pg), parallel_context.dp_pg.size()), + exp_topology=(dist.get_rank(parallel_context.expert_pg), parallel_context.expert_parallel_size), + is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer), + ), ) @@ -330,16 +363,58 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - new_optim_state_dict["names"] = new_optim_state_param_names state_dict = new_optim_state_dict else: + # NOTE: if you resume from training + + def round_robin_map(numbers, min_val, max_val): + """ + Maps a list of numbers to a round-robin pattern within a configurable range. + + Args: + numbers (list): List of numbers to map. + min_val (int): Minimum value in the round-robin range. + max_val (int): Maximum value in the round-robin range. + + Returns: + list: Mapped list of numbers. + """ + range_size = max_val - min_val + 1 + return [(num - 1) % range_size + min_val for num in numbers] + + # if int(ckp_dp_size) != int(parallel_context.dp_pg.size()): + # pass + # else: + # TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely + # state_dict = torch.load( + # root_folder + # / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), + # map_location=map_location, + # ) + # NOTE: since here we only load the optimizer states, + # then we shard it according to the current data parallel dimension + state_dict = torch.load( root_folder - / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), + / get_optimizer_filename( + tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()), + pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()), + # NOTE(xrsrke): suppose we initially have dp world size of 4, + # then we change to dp world size of 8, then we need to load the optimizer states + # now we do a round-robin mapping of the optimizer states to the new dp world size + # dp=8's ranks: [0, 1, 2, 3, 4, 5, 6, 7] + # maps to: [0, 1, 2, 3, 0, 1, 2, 3] + dp_topology=(int(dist.get_rank(parallel_context.pp_pg)) // int(ckp_dp_size), ckp_dp_size), + exp_topology=(dist.get_rank(parallel_context.expert_pg), parallel_context.expert_parallel_size), + is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer), + ), map_location=map_location, ) if isinstance(optimizer, ZeroDistributedOptimizer): + + # NOTE: optimizer state topology-agnostic loading # NOTE: only reshard after merging tp shards - # or we get a new dp_Size + # or we get a new dp_size if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size(): # NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension current_dp_rank = dist.get_rank(parallel_context.dp_pg) @@ -354,6 +429,20 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - ) state_dict["state"][param_index][state_name] = sliced_tensor + # NOTE: reshard gradient_accumulator if different dp size from checkpoint + if int(ckp_dp_size) != parallel_context.dp_pg.size(): + merged_grad_accumulator = {} + for name, param in state_dict["gradient_accumulator"].items(): + # NOTE: assume that we shard a parameter evenly across all DPs + # TODO: ideally refactor a map between sharding and resharding, so + # we don't have to assume things + # merged_p = torch.zeros(param.numel()*int(ckp_dp_size), device="cuda") + merged_p = [torch.zeros_like(param) for _ in range(int(ckp_dp_size))] + dist.all_gather(merged_p, param.to("cuda"), group=parallel_context.dp_pg) + merged_grad_accumulator[name] = torch.cat(merged_p, dim=-1).to(map_location) + + assert 1 == 1 + optimizer.load_state_dict(state_dict, map_location=map_location) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 94b03c6e..8c105a4b 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -172,6 +172,22 @@ def __init__( parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg ) self.model = self.init_model() # Defines self.model + + # from torch import nn + # def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]: + # """ + # Return all the leaf modules (modules without any child modules) in a PyTorch module. + # """ + # leaf_modules = [] + # for n, m in module.named_modules(): + # if not list(m.children()): + # leaf_modules.append((n, m)) + # return leaf_modules + + # leaf_modules = get_leaf_modules(self.model) + for name, param in self.model.named_parameters(): + print(name, param.shape) + self.unwrapped_model: NanotronModel = ( self.model.module if isinstance(self.model, DistributedDataParallel) else self.model )