Skip to content

Commit

Permalink
add saving sharding's metadata of master weights in grad_accu
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Dec 9, 2024
1 parent 2cde8f6 commit d2a307d
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 27 deletions.
71 changes: 60 additions & 11 deletions src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from typing import Callable, Dict, Iterator, Optional, Tuple
from typing import Callable, Dict, Iterator, Optional, Tuple, cast

import torch
from torch.distributed import GradBucket

import nanotron.distributed as dist
from nanotron import logging
from nanotron.optim.zero import SlicedFlatTensor, get_sliced_tensor
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage

Expand Down Expand Up @@ -68,7 +69,11 @@ def __init__(
grad_buckets_named_params: The parameters to accumulate gradients for. If None it defaults to `named_parameters`. In case of Zero 1, this should be all the parameters in the model.
Note: We use `grad_buckets_named_params` to keep grad buffers for all parameters even when Zero 1 is used. This is because we need to accumulate gradients for all parameters without having to reduce in every accumulation step.
Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator
Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator.
"self.parameters"
- .fp32: the pointer to the full precision weights
- .half: the pointer to the half precision weights
"""
if grad_buckets_named_params is None:
named_parameters = list(named_parameters)
Expand All @@ -86,20 +91,57 @@ def __init__(
if not param.requires_grad:
continue

start = length
end_weight = start + param.numel()
global_buffer_start_idx = length
global_buffer_end_idx = global_buffer_start_idx + param.numel()

assert name not in segment_index
segment_index[name] = (start, end_weight, param)
length = end_weight
param = cast(SlicedFlatTensor, param)
segment_index[name] = (
(global_buffer_start_idx, global_buffer_end_idx),
(param.start_offset, param.end_offset),
param,
)
length = global_buffer_end_idx

big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda")
self.parameters = {
name: {
"fp32": big_flat_buffer[start_weight:end_weight].view_as(param),

self.parameters = {}
for name, (
(global_start_idx, global_end_idx),
(dp_weight_start_idx, dp_weight_end_idx),
param,
) in segment_index.items():
if name == "model.final_layer_norm.pp_block.weight":
assert 1 == 1

fp32_p = big_flat_buffer[global_start_idx:global_end_idx].view_as(param)
sliced_fp32_p = get_sliced_tensor(
fp32_p,
start_offset=dp_weight_start_idx,
end_offset=dp_weight_end_idx,
is_sharded=True,
)
assert (
sliced_fp32_p.numel() == param.numel()
), f"Expected {name} to have the same number of elements, dp_weight_start_idx: {dp_weight_start_idx}, dp_weight_end_idx: {dp_weight_end_idx}, param.numel(): {param.numel()}, sliced_fp32_p.numel(): {sliced_fp32_p.numel()}"
self.parameters[name] = {
"fp32": sliced_fp32_p,
"half": param,
}
for name, (start_weight, end_weight, param) in segment_index.items()
}

# self.parameters = {
# name: {
# # "fp32": big_flat_buffer[global_start_idx:global_end_idx].view_as(param),
# # NOTE: save the way we shard stuff in dp for zero-1, so we can reshard it
# "fp32": get_sliced_tensor(
# big_flat_buffer[global_start_idx:global_end_idx].view_as(param),
# start_offset=dp_weight_start_idx,
# end_offset=dp_weight_end_idx,
# ),
# "half": param,
# }
# for name, ((global_start_idx, global_end_idx), (dp_weight_start_idx, dp_weight_end_idx), param) in segment_index.items()
# }

with torch.inference_mode():
for _, elt in self.parameters.items():
Expand All @@ -108,6 +150,9 @@ def __init__(

# Check that fp32 weights have the same memory representation as half precision weights
assert fp32_param.stride() == half_param.stride()
assert (
fp32_param.numel() == half_param.numel()
), f"There is a size mismatch of {name}, fp32_param: {fp32_param.numel()}, half_param: {half_param.numel()}"

# Copy weights from half precision to full precision
fp32_param.copy_(half_param)
Expand Down Expand Up @@ -289,6 +334,10 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
assert set(state_dict.keys()) == set(self.parameters.keys())

# NOTE: double check if the dp size in the checkpoint
# is differ from the current dp size, then we merge the states
# and reshard them again

with torch.inference_mode():
for name, elt in self.parameters.items():
elt["fp32"].copy_(state_dict[name])
Expand Down
31 changes: 22 additions & 9 deletions src/nanotron/optim/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
# partition model's params across DP ranks.
# `self.param_name_to_dp_rank_offsets` sets mapping between each param inside self.named_params and its rank
# NOTE: some param_groups may have no params in the current rank. we still keep them in self.optimizer.param_groups
# TODO: maybe not shard layernorm params in zero-1, because it is small anyway
self.param_name_to_dp_rank_offsets = self._partition_parameters()

current_dp_rank = dist.get_rank(self.dp_pg)
Expand Down Expand Up @@ -171,6 +172,8 @@ def _partition_parameters(self) -> Dict[str, Dict[int, Tuple[int, int]]]:
for name, param in named_params:
# We assume parameter to be contiguous in order to have an easy way of sharding it.
assert param.is_contiguous(), f"Parameter {name} is not contiguous"
if name == "model.final_layer_norm.pp_block.weight":
assert 1 == 1

numel = param.numel()
padded_numel_per_dp = (numel - 1) // self.dp_pg.size() + 1
Expand Down Expand Up @@ -262,13 +265,18 @@ class SlicedFlatTensor(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl

@staticmethod
def get_sliced_flat_tensor(data, start_offset, end_offset):
with torch.no_grad():
return data.view(-1)[start_offset:end_offset]
def get_sliced_flat_tensor(data, start_offset: int, end_offset: int, is_sharded: bool):
if is_sharded is False:
with torch.no_grad():
return data.view(-1)[start_offset:end_offset]
else:
return data

@staticmethod
def __new__(cls, data, start_offset, end_offset):
sliced_tensor = cls.get_sliced_flat_tensor(data=data, start_offset=start_offset, end_offset=end_offset)
def __new__(cls, data, start_offset: int, end_offset: int, is_sharded: bool):
sliced_tensor = cls.get_sliced_flat_tensor(
data=data, start_offset=start_offset, end_offset=end_offset, is_sharded=is_sharded
)

result = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
Expand All @@ -283,11 +291,16 @@ def __new__(cls, data, start_offset, end_offset):
)
return result

def __init__(self, data, start_offset, end_offset):
def __init__(self, data, start_offset: int, end_offset: int, is_sharded: bool):
"""
is_sharded: whether a tensor is sharded or not
Sometimes we already shard a tensor, and just want to wrap it in a `SlicedFlatTensor`
so we can save the sharding metadata cleanly.
"""
super().__init__()
# TODO @thomasw21: Make is so that you can never update this value
self.sliced_flat_tensor = self.get_sliced_flat_tensor(
data=data, start_offset=start_offset, end_offset=end_offset
data=data, start_offset=start_offset, end_offset=end_offset, is_sharded=is_sharded
)
self.orig_data = data
self.start_offset = start_offset
Expand Down Expand Up @@ -337,9 +350,9 @@ def data_ptr(self):
grad = property(_get_grad, _set_grad, _del_grad)


def get_sliced_tensor(param: NanotronParameter, start_offset: int, end_offset: int):
def get_sliced_tensor(param: NanotronParameter, start_offset: int, end_offset: int, is_sharded: bool = False):
# This allows us to create a leaf tensor, despite sharing the underlying storage
result = SlicedFlatTensor(data=param, start_offset=start_offset, end_offset=end_offset)
result = SlicedFlatTensor(data=param, start_offset=start_offset, end_offset=end_offset, is_sharded=is_sharded)
return result


Expand Down
103 changes: 96 additions & 7 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,39 @@
from nanotron.serialize.metadata import TensorMetadata
from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors


# TODO(xrsrke): take rank instead of parallel_context
def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
# def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
# if is_zero is True:
# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
# else:
# return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def get_optimizer_filename(
tp_topology: Tuple[int, int],
pp_topology: Tuple[int, int],
dp_topology: Optional[Tuple[int, int]] = None,
exp_topology: Optional[Tuple[int, int]] = None,
is_zero: Optional[bool] = None,
):
"""
tp_topology: Tuple[int, int] = (rank, size)
pp_topology: Tuple[int, int] = (rank, size)
dp_topology: Tuple[int, int] = (rank, size)
NOTE: sometimes we get the checkpoint from a different topology (not the current parallel_context)
"""
assert exp_topology is not None, "exp_topology is required"
assert is_zero is not None, "is_zero is required"
pp_rank, pp_size = pp_topology
tp_rank, tp_size = tp_topology
exp_rank, exp_size = exp_topology

if is_zero is True:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
dp_rank, dp_size = dp_topology
return f"{ObjectType.OPTIMIZER.value}_pp-{pp_rank}-of-{pp_size}_dp-{dp_rank}-of-{dp_size}_tp-{tp_rank}-of-{tp_size}_exp-{exp_rank}-of-{exp_size}.pt"
else:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
return f"{ObjectType.OPTIMIZER.value}_pp-{pp_rank}-of-{pp_size}_tp-{tp_rank}-of-{tp_size}_exp-{exp_rank}-of-{exp_size}.pt"


def lr_scheduler_filename(parallel_context: ParallelContext, is_zero: bool):
Expand Down Expand Up @@ -102,7 +128,14 @@ def convert_to_string(input_item):
torch.save(
optimizer.state_dict(),
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
# / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
/ get_optimizer_filename(
tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()),
pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()),
dp_topology=(dist.get_rank(parallel_context.dp_pg), parallel_context.dp_pg.size()),
exp_topology=(dist.get_rank(parallel_context.expert_pg), parallel_context.expert_parallel_size),
is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer),
),
)


Expand Down Expand Up @@ -330,16 +363,58 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
else:
# NOTE: if you resume from training

def round_robin_map(numbers, min_val, max_val):
"""
Maps a list of numbers to a round-robin pattern within a configurable range.
Args:
numbers (list): List of numbers to map.
min_val (int): Minimum value in the round-robin range.
max_val (int): Maximum value in the round-robin range.
Returns:
list: Mapped list of numbers.
"""
range_size = max_val - min_val + 1
return [(num - 1) % range_size + min_val for num in numbers]

# if int(ckp_dp_size) != int(parallel_context.dp_pg.size()):
# pass
# else:

# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
# state_dict = torch.load(
# root_folder
# / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
# map_location=map_location,
# )
# NOTE: since here we only load the optimizer states,
# then we shard it according to the current data parallel dimension

state_dict = torch.load(
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
/ get_optimizer_filename(
tp_topology=(dist.get_rank(parallel_context.tp_pg), parallel_context.tp_pg.size()),
pp_topology=(dist.get_rank(parallel_context.pp_pg), parallel_context.pp_pg.size()),
# NOTE(xrsrke): suppose we initially have dp world size of 4,
# then we change to dp world size of 8, then we need to load the optimizer states
# now we do a round-robin mapping of the optimizer states to the new dp world size
# dp=8's ranks: [0, 1, 2, 3, 4, 5, 6, 7]
# maps to: [0, 1, 2, 3, 0, 1, 2, 3]
dp_topology=(int(dist.get_rank(parallel_context.pp_pg)) // int(ckp_dp_size), ckp_dp_size),
exp_topology=(dist.get_rank(parallel_context.expert_pg), parallel_context.expert_parallel_size),
is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer),
),
map_location=map_location,
)

if isinstance(optimizer, ZeroDistributedOptimizer):

# NOTE: optimizer state topology-agnostic loading
# NOTE: only reshard after merging tp shards
# or we get a new dp_Size
# or we get a new dp_size
if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size():
# NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
current_dp_rank = dist.get_rank(parallel_context.dp_pg)
Expand All @@ -354,6 +429,20 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
)
state_dict["state"][param_index][state_name] = sliced_tensor

# NOTE: reshard gradient_accumulator if different dp size from checkpoint
if int(ckp_dp_size) != parallel_context.dp_pg.size():
merged_grad_accumulator = {}
for name, param in state_dict["gradient_accumulator"].items():
# NOTE: assume that we shard a parameter evenly across all DPs
# TODO: ideally refactor a map between sharding and resharding, so
# we don't have to assume things
# merged_p = torch.zeros(param.numel()*int(ckp_dp_size), device="cuda")
merged_p = [torch.zeros_like(param) for _ in range(int(ckp_dp_size))]
dist.all_gather(merged_p, param.to("cuda"), group=parallel_context.dp_pg)
merged_grad_accumulator[name] = torch.cat(merged_p, dim=-1).to(map_location)

assert 1 == 1

optimizer.load_state_dict(state_dict, map_location=map_location)


Expand Down
16 changes: 16 additions & 0 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ def __init__(
parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg
)
self.model = self.init_model() # Defines self.model

# from torch import nn
# def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]:
# """
# Return all the leaf modules (modules without any child modules) in a PyTorch module.
# """
# leaf_modules = []
# for n, m in module.named_modules():
# if not list(m.children()):
# leaf_modules.append((n, m))
# return leaf_modules

# leaf_modules = get_leaf_modules(self.model)
for name, param in self.model.named_parameters():
print(name, param.shape)

self.unwrapped_model: NanotronModel = (
self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
)
Expand Down

0 comments on commit d2a307d

Please sign in to comment.