From 67be7c1e3621f17a97fb408ee6d04f19ee28f4af Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 10 Dec 2024 17:09:05 +0000 Subject: [PATCH] refactor --- src/nanotron/serialize/optimizer.py | 161 ++++++++++++++-------------- 1 file changed, 79 insertions(+), 82 deletions(-) diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index f74517e6..b4b3c978 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -9,8 +9,9 @@ from tqdm import tqdm from nanotron import distributed as dist -from nanotron import optim +from nanotron import logging, optim from nanotron.constants import OPTIMIZER_CONFIG_FILE_NAME +from nanotron.logging import log_rank from nanotron.optim.zero import ( ZeroDistributedOptimizer, extract_parallel_ranks_from_shard_path, @@ -23,6 +24,8 @@ from nanotron.serialize.metadata import TensorMetadata from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors +logger = logging.get_logger(__name__) + def get_optimizer_filename( tp_topology: Tuple[int, int], @@ -360,23 +363,6 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - new_optim_state_dict["names"] = new_optim_state_param_names state_dict = new_optim_state_dict else: - # NOTE: if you resume from training - - def round_robin_map(numbers, min_val, max_val): - """ - Maps a list of numbers to a round-robin pattern within a configurable range. - - Args: - numbers (list): List of numbers to map. - min_val (int): Minimum value in the round-robin range. - max_val (int): Maximum value in the round-robin range. - - Returns: - list: Mapped list of numbers. - """ - range_size = max_val - min_val + 1 - return [(num - 1) % range_size + min_val for num in numbers] - # NOTE: since here we only load the optimizer states, # then we shard it according to the current data parallel dimension # TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely @@ -397,6 +383,46 @@ def round_robin_map(numbers, min_val, max_val): map_location=map_location, ) + def create_merged_optim_states(param_shapes, map_location): + merged_states = {} + for name, p_shape in param_shapes.items(): + p_shape = tuple(int(x) for x in p_shape) + merged_states[name] = { + "exp_avg": torch.zeros(p_shape).view(-1).to(map_location), + "exp_avg_sq": torch.zeros(p_shape).view(-1).to(map_location), + } + return merged_states + + def create_merged_gradients(param_shapes, map_location): + merged_grads = {} + for name, p_shape in param_shapes.items(): + p_shape = tuple(int(x) for x in p_shape) + merged_grads[name] = torch.zeros(p_shape).view(-1).to(map_location) + return merged_grads + + def load_sharded_states(shard_paths, map_location, load_type="state"): + sharded_states = {} + for shard_path in shard_paths: + pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True) + checkpoint = torch.load(shard_path, map_location=map_location) + sharded_states[(tp_rank, dp_rank)] = checkpoint[load_type] + return sharded_states + + def get_key_by_value(d, target_value): + return next((key for key, value in d.items() if value == target_value), None) + + def apply_offsets(merged_tensor, sharded_states, param_name, offsets, tp_rank, state_keys=None): + if state_keys: + for key in state_keys: + p_idx = get_key_by_value(state_dict["names"], param_name) + merged_tensor[param_name][key][int(offsets[0]) : int(offsets[1])] = sharded_states[ + (int(tp_rank), int(dp_rank)) + ][p_idx][key] + else: + merged_tensor[param_name][int(offsets[0]) : int(offsets[1])] = sharded_states[ + (int(tp_rank), int(dp_rank)) + ][param_name] + if isinstance(optimizer, ZeroDistributedOptimizer): shard_paths = list( root_folder.glob( @@ -404,90 +430,61 @@ def round_robin_map(numbers, min_val, max_val): ) ) - # NOTE: data parallel agnostic loading for optimizer states - if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size(): - # NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension + if int(ckp_dp_size) != parallel_context.dp_pg.size(): + log_rank( + f"[Optimizer Loading] Detect new data parallelism topology in ZeRO-1, resharding optimizer states and gradient accumulator's states", # noqa + logger=logger, + level=logging.INFO, + rank=0, + ) + current_dp_rank = dist.get_rank(parallel_context.dp_pg) + tp_rank = dist.get_rank(parallel_context.tp_pg) OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"] + param_shapes = ckp_optimizer_config["configs"]["orig_param_shapes"] - ckp_sharded_optim_states = {} - for shard_path in shard_paths: - pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True) - ckp_sharded_optim_states[(tp_rank, dp_rank)] = torch.load(shard_path, map_location=map_location)[ - "state" - ] - - merged_optim_states = {} - for name, p_shape in ckp_optimizer_config["configs"]["orig_param_shapes"].items(): - p_shape = tuple(int(x) for x in p_shape) - merged_optim_states[name] = { - "exp_avg": torch.zeros(p_shape).view(-1).to(map_location), - "exp_avg_sq": torch.zeros(p_shape).view(-1).to(map_location), - } + # Handle optimizer states + ckp_sharded_optim_states = load_sharded_states(shard_paths, map_location, "state") + merged_optim_states = create_merged_optim_states(param_shapes, map_location) - def get_key_by_value(d, target_value): - return next((key for key, value in d.items() if value == target_value), None) - - tp_rank = dist.get_rank(parallel_context.tp_pg) for p_name, offsets in ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"].items(): for dp_rank, offset in offsets.items(): - # offset = [int(x) for x in offset] - for key in ["exp_avg", "exp_avg_sq"]: - p_idx = get_key_by_value(state_dict["names"], p_name) - merged_optim_states[p_name][key][int(offset[0]) : int(offset[1])] = ckp_sharded_optim_states[ - (int(tp_rank), int(dp_rank)) - ][p_idx][key] - - # NOTE: now merge optimizer states across data parallel dimension + apply_offsets( + merged_optim_states, ckp_sharded_optim_states, p_name, offset, tp_rank, OPTIMIZER_STATE_NAMES + ) + + # Update state dict with new sliced tensors for param_index in state_dict["state"]: param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0] for state_name in OPTIMIZER_STATE_NAMES: + current_offsets = optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank] sliced_tensor = get_sliced_tensor( - # param=state_dict["state"][param_index][state_name], param=merged_optim_states[param_name][state_name], - start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0], - end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1], + start_offset=current_offsets[0], + end_offset=current_offsets[1], ) assert sliced_tensor.numel() > 0 - state_dict["state"][param_index][state_name] = sliced_tensor - # NOTE: reshard gradient_accumulator if different dp size from checkpoint - if int(ckp_dp_size) != parallel_context.dp_pg.size(): + # Handle gradient accumulator if DP size changed assert int(ckp_tp_size) == parallel_context.tp_pg.size(), "Don't support changing TP size for ZeRO-1" - ckp_sharded_grad_accum = {} - for shard_path in shard_paths: - pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True) - ckp_sharded_grad_accum[(tp_rank, dp_rank)] = torch.load(shard_path, map_location=map_location)[ - "gradient_accumulator" - ] - - assert len(ckp_sharded_grad_accum) == len(shard_paths) - merged_grad_accumulator = {} - for name, p_shape in ckp_optimizer_config["configs"]["orig_param_shapes"].items(): - p_shape = tuple(int(x) for x in p_shape) - merged_grad_accumulator[name] = torch.zeros(p_shape).view(-1).to(map_location) + ckp_sharded_grad_accum = load_sharded_states(shard_paths, map_location, "gradient_accumulator") + merged_grad_accumulator = create_merged_gradients(param_shapes, map_location) - # NOTE: start to merge dp ranks across the same tp shard - # ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"] - tp_rank = dist.get_rank(parallel_context.tp_pg) for p_name, offsets in ckp_optimizer_config["configs"]["param_name_to_dp_rank_offsets"].items(): for dp_rank, offset in offsets.items(): - # offset = [int(x) for x in offset] - merged_grad_accumulator[p_name][int(offset[0]) : int(offset[1])] = ckp_sharded_grad_accum[ - (int(tp_rank), int(dp_rank)) - ][p_name] - - assert 1 == 1 - for p_name in state_dict["gradient_accumulator"].keys(): - new_offset = optimizer.param_name_to_dp_rank_offsets[p_name][int(dp_rank)] - assert state_dict["gradient_accumulator"][p_name].device == merged_grad_accumulator[p_name].device - state_dict["gradient_accumulator"][p_name] = merged_grad_accumulator[p_name][ - int(new_offset[0]) : int(new_offset[1]) - ] - - optimizer.load_state_dict(state_dict, map_location=map_location) + apply_offsets(merged_grad_accumulator, ckp_sharded_grad_accum, p_name, offset, tp_rank) + + # Update gradient accumulator with new slices + for p_name in state_dict["gradient_accumulator"].keys(): + new_offset = optimizer.param_name_to_dp_rank_offsets[p_name][int(dp_rank)] + assert state_dict["gradient_accumulator"][p_name].device == merged_grad_accumulator[p_name].device + state_dict["gradient_accumulator"][p_name] = merged_grad_accumulator[p_name][ + int(new_offset[0]) : int(new_offset[1]) + ] + + optimizer.load_state_dict(state_dict, map_location=map_location) def load_lr_scheduler(