Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Dec 10, 2024
1 parent e4b2ea9 commit 67be7c1
Showing 1 changed file with 79 additions and 82 deletions.
161 changes: 79 additions & 82 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from tqdm import tqdm

from nanotron import distributed as dist
from nanotron import optim
from nanotron import logging, optim
from nanotron.constants import OPTIMIZER_CONFIG_FILE_NAME
from nanotron.logging import log_rank
from nanotron.optim.zero import (
ZeroDistributedOptimizer,
extract_parallel_ranks_from_shard_path,
Expand All @@ -23,6 +24,8 @@
from nanotron.serialize.metadata import TensorMetadata
from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors

logger = logging.get_logger(__name__)


def get_optimizer_filename(
tp_topology: Tuple[int, int],
Expand Down Expand Up @@ -360,23 +363,6 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
else:
# NOTE: if you resume from training

def round_robin_map(numbers, min_val, max_val):
"""
Maps a list of numbers to a round-robin pattern within a configurable range.
Args:
numbers (list): List of numbers to map.
min_val (int): Minimum value in the round-robin range.
max_val (int): Maximum value in the round-robin range.
Returns:
list: Mapped list of numbers.
"""
range_size = max_val - min_val + 1
return [(num - 1) % range_size + min_val for num in numbers]

# NOTE: since here we only load the optimizer states,
# then we shard it according to the current data parallel dimension
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
Expand All @@ -397,97 +383,108 @@ def round_robin_map(numbers, min_val, max_val):
map_location=map_location,
)

def create_merged_optim_states(param_shapes, map_location):
merged_states = {}
for name, p_shape in param_shapes.items():
p_shape = tuple(int(x) for x in p_shape)
merged_states[name] = {
"exp_avg": torch.zeros(p_shape).view(-1).to(map_location),
"exp_avg_sq": torch.zeros(p_shape).view(-1).to(map_location),
}
return merged_states

def create_merged_gradients(param_shapes, map_location):
merged_grads = {}
for name, p_shape in param_shapes.items():
p_shape = tuple(int(x) for x in p_shape)
merged_grads[name] = torch.zeros(p_shape).view(-1).to(map_location)
return merged_grads

def load_sharded_states(shard_paths, map_location, load_type="state"):
sharded_states = {}
for shard_path in shard_paths:
pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True)
checkpoint = torch.load(shard_path, map_location=map_location)
sharded_states[(tp_rank, dp_rank)] = checkpoint[load_type]
return sharded_states

def get_key_by_value(d, target_value):
return next((key for key, value in d.items() if value == target_value), None)

def apply_offsets(merged_tensor, sharded_states, param_name, offsets, tp_rank, state_keys=None):
if state_keys:
for key in state_keys:
p_idx = get_key_by_value(state_dict["names"], param_name)
merged_tensor[param_name][key][int(offsets[0]) : int(offsets[1])] = sharded_states[
(int(tp_rank), int(dp_rank))
][p_idx][key]
else:
merged_tensor[param_name][int(offsets[0]) : int(offsets[1])] = sharded_states[
(int(tp_rank), int(dp_rank))
][param_name]

if isinstance(optimizer, ZeroDistributedOptimizer):
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}_exp-*-of-{ckpt_expert_parallel_size}.pt"
)
)

# NOTE: data parallel agnostic loading for optimizer states
if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size():
# NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
if int(ckp_dp_size) != parallel_context.dp_pg.size():
log_rank(
f"[Optimizer Loading] Detect new data parallelism topology in ZeRO-1, resharding optimizer states and gradient accumulator's states", # noqa
logger=logger,
level=logging.INFO,
rank=0,
)

current_dp_rank = dist.get_rank(parallel_context.dp_pg)
tp_rank = dist.get_rank(parallel_context.tp_pg)
OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"]
param_shapes = ckp_optimizer_config["configs"]["orig_param_shapes"]

ckp_sharded_optim_states = {}
for shard_path in shard_paths:
pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True)
ckp_sharded_optim_states[(tp_rank, dp_rank)] = torch.load(shard_path, map_location=map_location)[
"state"
]

merged_optim_states = {}
for name, p_shape in ckp_optimizer_config["configs"]["orig_param_shapes"].items():
p_shape = tuple(int(x) for x in p_shape)
merged_optim_states[name] = {
"exp_avg": torch.zeros(p_shape).view(-1).to(map_location),
"exp_avg_sq": torch.zeros(p_shape).view(-1).to(map_location),
}
# Handle optimizer states
ckp_sharded_optim_states = load_sharded_states(shard_paths, map_location, "state")
merged_optim_states = create_merged_optim_states(param_shapes, map_location)

def get_key_by_value(d, target_value):
return next((key for key, value in d.items() if value == target_value), None)

tp_rank = dist.get_rank(parallel_context.tp_pg)
for p_name, offsets in ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"].items():
for dp_rank, offset in offsets.items():
# offset = [int(x) for x in offset]
for key in ["exp_avg", "exp_avg_sq"]:
p_idx = get_key_by_value(state_dict["names"], p_name)
merged_optim_states[p_name][key][int(offset[0]) : int(offset[1])] = ckp_sharded_optim_states[
(int(tp_rank), int(dp_rank))
][p_idx][key]

# NOTE: now merge optimizer states across data parallel dimension
apply_offsets(
merged_optim_states, ckp_sharded_optim_states, p_name, offset, tp_rank, OPTIMIZER_STATE_NAMES
)

# Update state dict with new sliced tensors
for param_index in state_dict["state"]:
param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
for state_name in OPTIMIZER_STATE_NAMES:
current_offsets = optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank]
sliced_tensor = get_sliced_tensor(
# param=state_dict["state"][param_index][state_name],
param=merged_optim_states[param_name][state_name],
start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0],
end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1],
start_offset=current_offsets[0],
end_offset=current_offsets[1],
)
assert sliced_tensor.numel() > 0

state_dict["state"][param_index][state_name] = sliced_tensor

# NOTE: reshard gradient_accumulator if different dp size from checkpoint
if int(ckp_dp_size) != parallel_context.dp_pg.size():
# Handle gradient accumulator if DP size changed
assert int(ckp_tp_size) == parallel_context.tp_pg.size(), "Don't support changing TP size for ZeRO-1"
ckp_sharded_grad_accum = {}
for shard_path in shard_paths:
pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True)
ckp_sharded_grad_accum[(tp_rank, dp_rank)] = torch.load(shard_path, map_location=map_location)[
"gradient_accumulator"
]

assert len(ckp_sharded_grad_accum) == len(shard_paths)

merged_grad_accumulator = {}
for name, p_shape in ckp_optimizer_config["configs"]["orig_param_shapes"].items():
p_shape = tuple(int(x) for x in p_shape)
merged_grad_accumulator[name] = torch.zeros(p_shape).view(-1).to(map_location)
ckp_sharded_grad_accum = load_sharded_states(shard_paths, map_location, "gradient_accumulator")
merged_grad_accumulator = create_merged_gradients(param_shapes, map_location)

# NOTE: start to merge dp ranks across the same tp shard
# ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"]
tp_rank = dist.get_rank(parallel_context.tp_pg)
for p_name, offsets in ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"].items():
for dp_rank, offset in offsets.items():
# offset = [int(x) for x in offset]
merged_grad_accumulator[p_name][int(offset[0]) : int(offset[1])] = ckp_sharded_grad_accum[
(int(tp_rank), int(dp_rank))
][p_name]

assert 1 == 1
for p_name in state_dict["gradient_accumulator"].keys():
new_offset = optimizer.param_name_to_dp_rank_offsets[p_name][int(dp_rank)]
assert state_dict["gradient_accumulator"][p_name].device == merged_grad_accumulator[p_name].device
state_dict["gradient_accumulator"][p_name] = merged_grad_accumulator[p_name][
int(new_offset[0]) : int(new_offset[1])
]

optimizer.load_state_dict(state_dict, map_location=map_location)
apply_offsets(merged_grad_accumulator, ckp_sharded_grad_accum, p_name, offset, tp_rank)

# Update gradient accumulator with new slices
for p_name in state_dict["gradient_accumulator"].keys():
new_offset = optimizer.param_name_to_dp_rank_offsets[p_name][int(dp_rank)]
assert state_dict["gradient_accumulator"][p_name].device == merged_grad_accumulator[p_name].device
state_dict["gradient_accumulator"][p_name] = merged_grad_accumulator[p_name][
int(new_offset[0]) : int(new_offset[1])
]

optimizer.load_state_dict(state_dict, map_location=map_location)


def load_lr_scheduler(
Expand Down

0 comments on commit 67be7c1

Please sign in to comment.