diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 080e37dd..76ad6f83 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -2,14 +2,13 @@ from abc import ABC, abstractmethod from collections import OrderedDict from contextlib import contextmanager -from typing import Callable, Dict, Iterator, Optional, Tuple, cast +from typing import Callable, Dict, Iterator, Optional, Tuple import torch from torch.distributed import GradBucket import nanotron.distributed as dist from nanotron import logging -from nanotron.optim.zero import SlicedFlatTensor from nanotron.parallel.parameters import NanotronParameter from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage @@ -91,57 +90,20 @@ def __init__( if not param.requires_grad: continue - global_buffer_start_idx = length - global_buffer_end_idx = global_buffer_start_idx + param.numel() - + start = length + end_weight = start + param.numel() assert name not in segment_index - param = cast(SlicedFlatTensor, param) - segment_index[name] = ( - (global_buffer_start_idx, global_buffer_end_idx), - (param.start_offset, param.end_offset), - param, - ) - length = global_buffer_end_idx + segment_index[name] = (start, end_weight, param) + length = end_weight big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda") - - self.parameters = {} - for name, ( - (global_start_idx, global_end_idx), - (dp_weight_start_idx, dp_weight_end_idx), - param, - ) in segment_index.items(): - # if name == "model.final_layer_norm.pp_block.weight": - # assert 1 == 1 - - fp32_p = big_flat_buffer[global_start_idx:global_end_idx].view_as(param) - # sliced_fp32_p = get_sliced_tensor( - # fp32_p, - # start_offset=dp_weight_start_idx, - # end_offset=dp_weight_end_idx, - # is_sharded=True, - # ) - # assert ( - # sliced_fp32_p.numel() == param.numel() - # ), f"Expected {name} to have the same number of elements, dp_weight_start_idx: {dp_weight_start_idx}, dp_weight_end_idx: {dp_weight_end_idx}, param.numel(): {param.numel()}, sliced_fp32_p.numel(): {sliced_fp32_p.numel()}" - self.parameters[name] = { - "fp32": fp32_p, + self.parameters = { + name: { + "fp32": big_flat_buffer[start_weight:end_weight].view_as(param), "half": param, } - - # self.parameters = { - # name: { - # # "fp32": big_flat_buffer[global_start_idx:global_end_idx].view_as(param), - # # NOTE: save the way we shard stuff in dp for zero-1, so we can reshard it - # "fp32": get_sliced_tensor( - # big_flat_buffer[global_start_idx:global_end_idx].view_as(param), - # start_offset=dp_weight_start_idx, - # end_offset=dp_weight_end_idx, - # ), - # "half": param, - # } - # for name, ((global_start_idx, global_end_idx), (dp_weight_start_idx, dp_weight_end_idx), param) in segment_index.items() - # } + for name, (start_weight, end_weight, param) in segment_index.items() + } with torch.inference_mode(): for _, elt in self.parameters.items(): diff --git a/src/nanotron/optim/inherit_from_other_optimizer.py b/src/nanotron/optim/inherit_from_other_optimizer.py index b4afa555..d67cecdf 100644 --- a/src/nanotron/optim/inherit_from_other_optimizer.py +++ b/src/nanotron/optim/inherit_from_other_optimizer.py @@ -50,28 +50,6 @@ def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, to return self.optimizer.load_state_dict(state_dict, map_location=map_location) def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: - # NOTE: error: RuntimeError: params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout - # NOTE: add assert device, dtype, layout are the same - - params = [p for group in self.optimizer.param_groups for p in group["params"]] - [p.grad for p in params if p.grad is not None] - [state["exp_avg"] for state in self.optimizer.state_dict()["state"].values()] - [state["exp_avg_sq"] for state in self.optimizer.state_dict()["state"].values()] - - # Check if all required attributes have the same device, dtype, and layout - # ref_device = params[0].device - # ref_dtype = params[0].dtype - # ref_layout = params[0].layout - - # for attr_list, name in zip( - # [params, grads, exp_avgs, exp_avg_sqs], - # ["params", "grads", "exp_avgs", "exp_avg_sqs"] - # ): - # for idx, attr in enumerate(attr_list): - # assert attr.device == ref_device, f"{name}[{idx}] has device {attr.device}, expected {ref_device}" - # assert attr.dtype == ref_dtype, f"{name}[{idx}] has dtype {attr.dtype}, expected {ref_dtype}" - # assert attr.layout == ref_layout, f"{name}[{idx}] has layout {attr.layout}, expected {ref_layout}" - return self.optimizer.step(closure=closure) def get_base_optimizer(self): diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 62a79de8..f74517e6 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -23,13 +23,6 @@ from nanotron.serialize.metadata import TensorMetadata from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors -# TODO(xrsrke): take rank instead of parallel_context -# def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): -# if is_zero is True: -# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" -# else: -# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" - def get_optimizer_filename( tp_topology: Tuple[int, int], @@ -129,7 +122,6 @@ def convert_to_string(input_item): torch.save( optimizer.state_dict(), root_folder - # / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), / get_optimizer_filename( tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()), pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()), @@ -385,19 +377,9 @@ def round_robin_map(numbers, min_val, max_val): range_size = max_val - min_val + 1 return [(num - 1) % range_size + min_val for num in numbers] - # if int(ckp_dp_size) != int(parallel_context.dp_pg.size()): - # pass - # else: - - # TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely - # state_dict = torch.load( - # root_folder - # / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), - # map_location=map_location, - # ) # NOTE: since here we only load the optimizer states, # then we shard it according to the current data parallel dimension - + # TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely state_dict = torch.load( root_folder / get_optimizer_filename( @@ -416,24 +398,6 @@ def round_robin_map(numbers, min_val, max_val): ) if isinstance(optimizer, ZeroDistributedOptimizer): - - # NOTE: optimizer state topology-agnostic loading - # NOTE: only reshard after merging tp shards - # or we get a new dp_size - # if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size(): - # # NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension - # current_dp_rank = dist.get_rank(parallel_context.dp_pg) - # OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"] - # for param_index in state_dict["state"]: - # param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0] - # for state_name in OPTIMIZER_STATE_NAMES: - # sliced_tensor = get_sliced_tensor( - # param=state_dict["state"][param_index][state_name], - # start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0], - # end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1], - # ) - # state_dict["state"][param_index][state_name] = sliced_tensor - shard_paths = list( root_folder.glob( f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}_exp-*-of-{ckpt_expert_parallel_size}.pt" @@ -474,8 +438,6 @@ def get_key_by_value(d, target_value): (int(tp_rank), int(dp_rank)) ][p_idx][key] - assert 1 == 1 - # NOTE: now merge optimizer states across data parallel dimension for param_index in state_dict["state"]: param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0] @@ -493,27 +455,6 @@ def get_key_by_value(d, target_value): # NOTE: reshard gradient_accumulator if different dp size from checkpoint if int(ckp_dp_size) != parallel_context.dp_pg.size(): assert int(ckp_tp_size) == parallel_context.tp_pg.size(), "Don't support changing TP size for ZeRO-1" - # merged_grad_accumulator = {} - # for name, param in state_dict["gradient_accumulator"].items(): - # # NOTE: assume that we shard a parameter evenly across all DPs - # # TODO: ideally refactor a map between sharding and resharding, so - # # we don't have to assume things - # # merged_p = torch.zeros(param.numel()*int(ckp_dp_size), device="cuda") - # # merged_p = [torch.zeros_like(param) for _ in range(int(ckp_dp_size))] - # # dist.all_gather(merged_p, param.to("cuda"), group=parallel_context.dp_pg) - # # merged_grad_accumulator[name] = torch.cat(merged_p, dim=-1).to(map_location) - - # merged_p_shape = ckp_optimizer_config["configs"]["orig_param_shapes"][name] - # merged_p_shape = tuple(int(x) for x in merged_p_shape) - # merged_p = torch.zeros(merged_p_shape).view(-1) - # dp_rank = dist.get_rank(parallel_context.dp_pg) - # dp_offset = ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"][name][str(dp_rank)] - # merged_p[int(dp_offset[0]):int(dp_offset[1])] = param.view(-1) - # dist.all_reduce(merged_p, group=parallel_context.dp_pg) - # merged_p = merged_p.view(merged_p_shape) - # merged_grad_accumulator[name] = merged_p.to(map_location) - - assert 1 == 1 ckp_sharded_grad_accum = {} for shard_path in shard_paths: pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True) @@ -546,27 +487,8 @@ def get_key_by_value(d, target_value): int(new_offset[0]) : int(new_offset[1]) ] - # NOTE: reshard the gradient_accumulator - - try: - assert state_dict["state"][0]["exp_avg"].numel() > 0 - except: - assert 1 == 1 - optimizer.load_state_dict(state_dict, map_location=map_location) - try: - assert state_dict["state"][0]["exp_avg"].numel() > 0 - except: - assert 1 == 1 - - try: - assert optimizer.state_dict()["state"][0]["exp_avg"].numel() > 0 - except: - assert 1 == 1 - - assert 1 == 1 - def load_lr_scheduler( lr_scheduler, diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 9075c0c9..3cb0e4c7 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -173,21 +173,6 @@ def __init__( ) self.model = self.init_model() # Defines self.model - # from torch import nn - # def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]: - # """ - # Return all the leaf modules (modules without any child modules) in a PyTorch module. - # """ - # leaf_modules = [] - # for n, m in module.named_modules(): - # if not list(m.children()): - # leaf_modules.append((n, m)) - # return leaf_modules - - # leaf_modules = get_leaf_modules(self.model) - # for name, param in self.model.named_parameters(): - # print(name, param.shape) - self.unwrapped_model: NanotronModel = ( self.model.module if isinstance(self.model, DistributedDataParallel) else self.model ) diff --git a/tests/test_zero.py b/tests/test_zero.py index bdf2c216..f1127f94 100644 --- a/tests/test_zero.py +++ b/tests/test_zero.py @@ -543,46 +543,3 @@ def _test_sliced_flat_tensor(parallel_context: ParallelContext): assert not isinstance(c, SlicedFlatTensor) parallel_context.destroy() - - -@rerun_if_address_is_in_use() -def test_wrap_slice_tensor_around_a_sharded_tensor(): - init_distributed(1, 1, 1)(_test_wrap_slice_tensor_around_a_sharded_tensor)() - - -def _test_wrap_slice_tensor_around_a_sharded_tensor(parallel_context: ParallelContext): - a = torch.randn(2, 3, requires_grad=True) - grad = torch.randn(2, 3) - a.grad = grad - - start_offset, end_offset = 1, 5 - b = SlicedFlatTensor(a, start_offset=start_offset, end_offset=end_offset) - - torch.testing.assert_close(a.grad, grad, atol=0, rtol=0) - torch.testing.assert_close(b.grad, grad.view(-1)[start_offset:end_offset]) - - # Deallocate the gradient by setting it to None - a.grad = None - - assert a.grad is None - assert b.grad is None - - # Setting gradient to None on the sliced tensor works - a.grad = grad - assert a.grad is not None - assert b.grad is not None - b.grad = None - assert b.grad is None - assert a.grad is None - - with assert_fail_with(NotImplementedError): - b.grad = torch.randn(1, 5) - - with assert_fail_with(NotImplementedError): - del b.grad - - c = b[:3] - # It's important not to contaminate everyone. - assert not isinstance(c, SlicedFlatTensor) - - parallel_context.destroy()