diff --git a/run_generate.py b/run_generate.py index f389770d..8408f80d 100644 --- a/run_generate.py +++ b/run_generate.py @@ -59,6 +59,7 @@ def get_args(): parser.add_argument("--dp", type=int, default=1) parser.add_argument("--pp", type=int, default=0) parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--ep", type=int, default=0) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") return parser.parse_args() @@ -79,6 +80,7 @@ def main(): pp_engine=OneForwardOneBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, + expert_parallel_size=args.ep or config.parallelism.expert_parallel_size, ) # Initialise all process groups @@ -86,6 +88,7 @@ def main(): data_parallel_size=parallel_config.dp, pipeline_parallel_size=parallel_config.pp, tensor_parallel_size=parallel_config.tp, + expert_parallel_size=parallel_config.expert_parallel_size, ) # Set log levels diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d5b9976f..6f19de0a 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -161,6 +161,7 @@ class GeneralArgs: Args: project: Name of the project (a project gather several runs in common tensorboard/hub-folders) + entity: Weights and bias entity name (optional) run: Name of the run step: Global step (updated when we save the checkpoint) consumed_train_samples: Number of samples consumed during training (should be actually just step*batch_size) @@ -168,6 +169,7 @@ class GeneralArgs: """ project: str + entity: Optional[str] = None run: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 57225243..df7eef22 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -51,6 +51,11 @@ class LlamaConfig: tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 + # MoE specific + moe_num_experts: int = 1 + num_experts_per_tok: int = 1 + moe_loss_weight: float = 0.01 + moe_z_loss_weight: float = 0.001 def __post_init__(self): # NOTE: user don't set self._init_method, ModelArgs will set it diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 6ab71fad..069743c6 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -257,7 +257,12 @@ def decode_text( sharded_logits = model( input_ids=state.new_input_ids, input_mask=state.new_input_mask, - ) + aux_loss=( + torch.zeros(1, device=state.new_input_ids.device) + if is_decoder_input_rank + else TensorPointer(decoder_input_rank) + ), + )["sharded_logits"] else: if isinstance(state.new_input_ids, torch.Tensor): batch_generated_ids = torch.cat(state.generation_ids, dim=-1) @@ -268,7 +273,12 @@ def decode_text( sharded_logits = model( input_ids=batch_generated_ids, input_mask=batch_generated_mask, - ) + aux_loss=( + torch.zeros(1, device=state.new_input_ids.device) + if is_decoder_input_rank + else TensorPointer(decoder_input_rank) + ), + )["sharded_logits"] if isinstance(sharded_logits, torch.Tensor) and logits_are_batch_first: sharded_logits = sharded_logits.transpose(0, 1) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca8894b9..c106e222 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -26,6 +26,9 @@ from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel +from nanotron.models.moe import ( + dMoE, +) from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext @@ -40,7 +43,10 @@ TensorParallelRowLinear, ) from nanotron.random import RandomStates -from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator +from nanotron.scaling.parametrization import ( + SpectralMupParametrizator, + StandardParametrizator, +) from nanotron.utils import checkpoint_method logger = logging.get_logger(__name__) @@ -173,7 +179,12 @@ def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] class CoreAttention(nn.Module): - def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + layer_idx: int, + ): super().__init__() # TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv` assert ( @@ -197,10 +208,28 @@ def forward( from flash_attn.flash_attn_interface import flash_attn_varlen_func # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) + cu_seqlens_q = torch.zeros( + (q_sequence_mask.shape[0] + 1), + dtype=torch.int32, + device=query_states.device, + ) + cu_seqlens_k = torch.zeros( + (kv_sequence_mask.shape[0] + 1), + dtype=torch.int32, + device=query_states.device, + ) + torch.cumsum( + q_sequence_mask.sum(-1, dtype=torch.int32), + dim=0, + dtype=torch.int32, + out=cu_seqlens_q[1:], + ) + torch.cumsum( + kv_sequence_mask.sum(-1, dtype=torch.int32), + dim=0, + dtype=torch.int32, + out=cu_seqlens_k[1:], + ) # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. @@ -604,6 +633,7 @@ def __init__( self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, tp_pg: dist.ProcessGroup, layer_idx: int, ): @@ -617,13 +647,23 @@ def __init__( ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + if config.moe_num_experts > 1: + self.mlp = dMoE( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + self._is_moe = True + else: + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self._is_moe = False def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + ) -> Dict[str, Union[torch.Tensor, TensorPointer, Dict[str, Union[torch.Tensor, TensorPointer]]],]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -633,24 +673,35 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] - hidden_states = hidden_states + residual + mlp_output = self.mlp(hidden_states=hidden_states) + hidden_states = mlp_output["hidden_states"] + residual + + if self._is_moe: + for key, value in mlp_output.items(): + if key != "hidden_states": + aux_losses[key] = aux_losses[key] + value return { "hidden_states": hidden_states, "sequence_mask": output["sequence_mask"], + "aux_losses": aux_losses, } class Embedding(nn.Module, AttachableStore): - def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): + def __init__( + self, + tp_pg: dist.ProcessGroup, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + ): super().__init__() self.token_embedding = TensorParallelEmbedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=config.pad_token_id, pg=tp_pg, - mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + mode=(parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE), ) self.pg = tp_pg @@ -669,7 +720,9 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ # Format input in `[seq_length, batch_size]` to support high TP with low batch_size input_ids = input_ids.transpose(0, 1) input_embeds = self.token_embedding(input_ids) - return {"input_embeds": input_embeds} + return { + "input_embeds": input_embeds, + } class LlamaModel(nn.Module): @@ -714,10 +767,11 @@ def __init__( "config": config, "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, + "parallel_context": parallel_context, "layer_idx": layer_idx, }, - module_input_keys={"hidden_states", "sequence_mask"}, - module_output_keys={"hidden_states", "sequence_mask"}, + module_input_keys={"hidden_states", "sequence_mask", "aux_losses"}, + module_output_keys={"hidden_states", "sequence_mask", "aux_losses"}, ) for layer_idx in range(config.num_hidden_layers) ] @@ -726,7 +780,10 @@ def __init__( self.final_layer_norm = PipelineBlock( p2p=self.p2p, module_builder=TritonRMSNorm, - module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_kwargs={ + "hidden_size": config.hidden_size, + "eps": config.rms_norm_eps, + }, module_input_keys={"input"}, module_output_keys={"hidden_states"}, ) # TODO @@ -760,22 +817,34 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + sharded_logits, hidden_states, aux_losses = self.forward_with_hidden_states( + input_ids=input_ids, + input_mask=input_mask, + aux_losses=aux_losses, + ) + return {"sharded_logits": sharded_logits, "aux_losses": aux_losses} def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], ): # all tensors are optional as most ranks don't need anything from the dataloader. - output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) + output = self.token_position_embeddings( + input_ids=input_ids, + input_mask=input_mask, + ) hidden_encoder_states = { "hidden_states": output["input_embeds"], "sequence_mask": input_mask, + "aux_losses": aux_losses, } + for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) @@ -785,17 +854,21 @@ def forward_with_hidden_states( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - return fp32_sharded_logits, hidden_states + return fp32_sharded_logits, hidden_states, hidden_encoder_states["aux_losses"] def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" model_config = self.config d_ff = model_config.intermediate_size d_qkv = model_config.hidden_size // model_config.num_attention_heads + attention_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + mlp_cost = 3 * d_ff * model_config.hidden_size + if model_config.moe_num_experts > 1: + mlp_cost *= model_config.num_experts_per_tok # active experts + mlp_cost += model_config.hidden_size * model_config.moe_num_experts # routing block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size - + 3 * d_ff * model_config.hidden_size, + LlamaDecoderLayer: attention_cost + mlp_cost, # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } @@ -846,7 +919,10 @@ def forward( # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 loss = sharded_cross_entropy( - sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + sharded_logits, + label_ids.transpose(0, 1).contiguous(), + group=self.tp_pg, + dtype=torch.float, ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. loss = masked_mean(loss, label_mask, dtype=torch.float) @@ -864,7 +940,11 @@ def __init__( random_states: Optional[RandomStates] = None, ): super().__init__() - self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + self.model = LlamaModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, @@ -887,16 +967,35 @@ def forward( label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - sharded_logits = self.model( + # aux_losses are used for load balancing in case of MoEs + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + output = self.model( input_ids=input_ids, input_mask=input_mask, + aux_losses=aux_losses, ) loss = self.loss( - sharded_logits=sharded_logits, + sharded_logits=output["sharded_logits"], label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + + # add all aux_losses to the main loss dictionary + if self.config.moe_num_experts > 1: + for key, value in output["aux_losses"].items(): + loss[key] = value + return loss @torch.no_grad() def init_model_randomly(self, config: Config): @@ -952,16 +1051,21 @@ def init_model_randomly(self, config: Config): initialized_parameters.add(full_param_name) assert initialized_parameters == { - param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - if param.is_tied - else name + ( + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + ) for name, param in model.named_parameters() }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" def get_embeddings_lm_head_tied_names(self): """Get the names of the tied embeddings and lm_head weights""" if self.config.tie_word_embeddings is True: - return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + return [ + "model.token_position_embeddings.pp_block.token_embedding.weight", + "model.lm_head.pp_block.weight", + ] else: return [] diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py new file mode 100644 index 00000000..75c9e3a8 --- /dev/null +++ b/src/nanotron/models/moe.py @@ -0,0 +1,688 @@ +""" LlaMa model with MoEs""" + +import warnings +from functools import partial +from typing import Optional, Tuple + +import numpy as np +import stk +import torch +import torch.nn.functional as F +from megablocks.layers import weight_parallel as wp +from megablocks.layers.activation_fn import act_fn +from torch import nn + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import LlamaConfig as Config +from nanotron.config import ParallelismArgs +from nanotron.nn.activations import ACT2FN +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.sharded_parameters import ( + SplitConfig, + mark_all_parameters_in_module_as_sharded, +) +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + +try: + import megablocks.ops as ops + from megablocks.layers.all_to_all import all_to_all +except ImportError: + warnings.warn("Please install megablocks to use MoEs: `pip install megablocks`") + + +logger = logging.get_logger(__name__) + + +def log_mean(x, dim): + return torch.logsumexp(x, dim=dim) - torch.log(torch.tensor(x.shape[dim], dtype=torch.float32)) + + +def load_balancing_loss(router_logits, tokens_per_expert, config: Config) -> torch.Tensor: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + logits: logits assigned to each expert per token. Shape: + [batch_size * sequence_length, num_experts]. + tokens_per_expert: [num_selected_experts] + + config: Config + + Returns: + The auxiliary loss. + """ + # tokens = batch_size * sequence_length + num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + moe_loss_weight = config.moe_loss_weight + num_experts_per_token = config.num_experts_per_tok + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert tokens_per_expert.ndim == 1 and tokens_per_expert.numel() == moe_num_experts + + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + + # compute router probability per expert in log space for numerical stability + logprobs = F.log_softmax(router_logits, dim=-1) + # take mean probability over batch + # shape [num_experts] + logprobs = log_mean(logprobs, dim=0) + expert_scores = torch.exp(logprobs) + + tokens_per_expert = tokens_per_expert.to(expert_scores.dtype) + + # Calculate the total scale across all factors. + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = moe_num_experts * moe_loss_weight + scale_denominator = num_hidden_layers * tokens * num_experts_per_token + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +def router_z_loss(router_logits, config: Config) -> torch.Tensor: + """ + The router z-loss was introduced in ST-MoE + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [batch_size * sequence_length, num_experts] + router logits + config: Config + + Returns: + Scalar router z-loss. + """ + num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + + z_loss_weight = config.moe_z_loss_weight + + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + + scale_numerator = z_loss_weight + scale_denominator = num_hidden_layers * tokens * moe_num_experts + scale = scale_numerator / scale_denominator + + return scale * z_loss.sum(dim=0) + + +class dMoE(torch.nn.Module): + def __init__( + self, + config: Config, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + if self.tp_mode == TensorParallelLinearMode.REDUCE_SCATTER: + logging.warn_once( + logger=logger, + msg="TensorParallelLinearMode.REDUCE_SCATTER is still experimental for MoEs. Use at your own risk.", + rank=0, + ) + + # Token router. + self.gate = LearnedRouter(config) + + # Expert computation helper. + self.experts = ParallelDroplessMLP( + config, + use_bias=False, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) + + def forward(self, hidden_states: torch.Tensor): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + """ + # Compute the expert scores and assignments. + # TODO: support sequence parallelism + batch_size, sequence_length, _ = hidden_states.size() + x = hidden_states.view(-1, self.config.hidden_size) + router_logits, expert_weights, top_experts = self.gate(x) + + # Compute the experts. + x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) + return { + "hidden_states": x.reshape(batch_size, sequence_length, -1), + "load_balancing_loss": lbl_loss, + "z_loss": z_loss, + } + + +# Adapted from megablocks.layers.router.LearnedRouter +class LearnedRouter(torch.nn.Module): + def __init__(self, config: Config): + super().__init__() + self.layer = torch.nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) + self.config = config + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + router_logits = self.layer(x) # (batch * sequence_length, n_experts) + scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse? + + if self.config.num_experts_per_tok == 1: + expert_weights, expert_indices = scores.max(dim=-1, keepdim=True) + else: + expert_weights, expert_indices = torch.topk(scores, self.config.num_experts_per_tok, dim=-1) + # IMPORTANT step to normalize, otherwise weights are very low + expert_weights = expert_weights / torch.norm( + expert_weights, + p=1, + dim=-1, + keepdim=True, + ) + return router_logits, expert_weights, expert_indices.int() + + +# Adapted from megablocks.layers.mlp.ParallelDroplessMLP +class ParallelDroplessMLP(torch.nn.Module): + def __init__( + self, + config: Config, + use_bias: bool, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.use_bias = use_bias + + self.expert_pg_size = parallel_context.expert_pg.size() + self.expert_parallel_group = parallel_context.expert_pg + + self.hidden_sharding_degree = self.expert_pg_size // min(self.expert_pg_size, self.config.moe_num_experts) + self.experts_per_rank = self.config.moe_num_experts // min(self.expert_pg_size, self.config.moe_num_experts) + + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = self.config.num_experts_per_tok + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + if use_bias: + self.bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + + # Select the forward function for the operating mode. + self.forward_fn = self.parallel_forward_once if self.expert_pg_size > 1 else self.forward_once + + self.blocking = 128 + + if self.experts_per_rank == 1: + self.mlp = MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) + else: + self.mlp = SparseGLU( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + + max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking + self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1) + + def indices_and_bins(self, top_expert): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_expert = top_expert.int() + bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, tokens_per_expert + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Calculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def forward_once(self, x, expert_weights, top_experts): # TODO: sparse + with torch.no_grad(): + ( + indices, + bin_ids, + bins, + padded_bins, + tokens_per_expert, + ) = self.indices_and_padded_bins(top_experts) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok) + + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.num_experts_per_tok, + -1, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x, expert_weights, top_experts): + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(top_experts) + repeated_tokens_per_expert = ops.repeat(tokens_per_expert, (self.hidden_sharding_degree,)) + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.expert_parallel_group, + async_op=True, + ) + + x = ops.gather(x, indices, bin_ids, bins, self.num_experts_per_tok) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + + # Reshape to [expert_pg_size, num_experts_per_rank]. + repeated_tokens_per_expert = repeated_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + parallel_tokens_per_expert = parallel_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + x = ops.repeat(x, (self.hidden_sharding_degree, 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, recv_counts, send_counts, self.expert_parallel_group, async_op=True + ) + + with torch.no_grad(): + replicate_bins = inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * self.hidden_sharding_degree, + dtype=torch.int32, + device=indices.device, + ), + self.experts_per_rank, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received + ).flatten() + + parallel_bin_ids, parallel_indices = ops.sort(parallel_top_expert, self.sort_end_bit) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum(dim=0, dtype=torch.int) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + num_experts_per_tok=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all(parallel_x, send_counts, recv_counts, self.expert_parallel_group) + + # Reduce along the hidden sharding to get the final outputs. + shape = (self.hidden_sharding_degree, -1, self.config.hidden_size) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + self.num_experts_per_tok, + ) + return x, tokens_per_expert.flatten() + + def forward(self, x, router_logits, expert_weights, top_experts): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + router_logits: tensor of shape [sequence_length * batch_size, n_experts] + expert_weights: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + """ + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) + if self.training: + lbl_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config) + z_loss = router_z_loss(router_logits, self.config) + else: + lbl_loss = torch.zeros(1, device=x.device) + z_loss = torch.zeros(1, device=x.device) + + if self.use_bias: + return x + self.bias + return x, lbl_loss, z_loss + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + num_experts_per_tok, + ): + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, num_experts_per_tok) + + # Perform the expert computation. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter(x, indices, bin_ids, expert_weights, bins, padded_bins, num_experts_per_tok) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + _, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit) + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + assert self.config.intermediate_size % self.blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.config.intermediate_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=x.dtype, + device="meta", + ) + shape = (padded_tokens, self.config.intermediate_size * self.experts_per_rank) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, row_indices, column_indices, offsets + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + +class ScaleGradient(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, x, scale): + ctx.scale = scale + return x + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, grad): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +class ExpertParallel(nn.Module): + """ + ExpertParallel serves to scale the gradients of the expert weights because unlike DP the gradients are not averaged across the expert parallel group. + """ + + def __init__(self, module, expert_parallel_size: int): + super().__init__() + self.module = module + self.expert_parallel_size = expert_parallel_size + + def forward(self, *args, **kwargs): + self.scale_gradients() + return self.module(*args, **kwargs) + + def scale_gradients(self): + scale_gradient(self.module, 1 / self.expert_parallel_size) + + +class SparseMLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__() + + self.expert_pg_size = parallel_config.expert_parallel_size if parallel_config is not None else 1 + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + self.tp_pg = parallel_context.tp_pg + + self.w1 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + self.w2 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + + if self.tp_pg.size() == 1: + self.w1.module.weight.data = self.w1.module.weight.data.T.contiguous() + + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + self.sdd = partial(wp.sdd_nt, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.sdd + self.dsd = partial(wp.dsd_nn, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.dsd + + def forward(self, x, topo): + self.w1.scale_gradients(), self.w2.scale_gradients() + x = self.sdd(x.contiguous(), self.w1.module.weight, topo) + activation_fn_out = act_fn(x, self.act) + return self.dsd(activation_fn_out, self.w2.module.weight) + + +class MLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.expert_pg_size = parallel_config.expert_parallel_size + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + + assert self.experts_per_rank == 1, "moe.MLP only supports 1 expert per rank, otherwise use moe.SparseMLP" + + self.w1 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) + + self.w2 = ExpertParallel( + TensorParallelRowLinear( + config.intermediate_size * self.experts_per_rank, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication + and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ), + expert_parallel_size=self.expert_pg_size, + ) + + self.w3 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + merged_states = self.w1(hidden_states) + hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states)) + return hidden_states + + +def inclusive_cumsum(x, dim): + scalar = ops.inclusive_cumsum(x, dim) + return scalar.view(1) if not len(scalar.size()) else scalar + + +class SparseGLU(SparseMLP): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__(config, parallel_config, parallel_context) + self.w3 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + if self.tp_pg.size() == 1: + self.w3.module.weight.data = self.w3.module.weight.data.T.contiguous() + + mark_all_parameters_in_module_as_sharded( + self, + pg=parallel_context.tp_and_expert_pg, + split_config=SplitConfig(split_dim=0), + ) + + def forward(self, x, topo): + # We need to scale gradients manually since we don't call the linears forward + self.w1.scale_gradients(), self.w2.scale_gradients(), self.w3.scale_gradients() + x = x.contiguous() + x1 = self.sdd(x, self.w1.module.weight, topo) + x2 = self.sdd(x, self.w3.module.weight, topo) + x = stk.ops.mul(act_fn(x1, self.act), x2) + return self.dsd(x, self.w2.module.weight) diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..d4859d35 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,18 +2,21 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup from nanotron.logging import log_rank from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd -from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model +from nanotron.parallel.pipeline_parallel.context_manager import ( + attach_pipeline_state_to_model, +) from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -48,13 +51,18 @@ def forward( output = {"loss": output} # We normalize our loss - if not isinstance(output["loss"], TensorPointer): - output["loss"] = output["loss"] / self.nb_microbatches - - # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): - assert output["loss"].requires_grad - state.register_activation_requiring_backward(output["loss"]) + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v / self.nb_microbatches + + # the outputs are either + # - token prediction loss ["loss"] + # - auxiliary losses ["load_balancing_loss", "z_loss"] + # that we need to backpropagate through, so register activations + for loss_key, output_tensor in output.items(): + if not isinstance(output_tensor, TensorPointer): + assert output_tensor.requires_grad + state.register_activation_requiring_backward(output_tensor) return output @staticmethod @@ -65,7 +73,10 @@ def _get_fwd_context(model: torch_nn.Module): return context def backward( - self, context: ContextManagers, state: PipelineTrainBatchState, grad_accumulator: Optional[GradientAccumulator] + self, + context: ContextManagers, + state: PipelineTrainBatchState, + grad_accumulator: Optional[GradientAccumulator], ): # Increment the number of backwards state.nb_backwards += 1 @@ -154,7 +165,7 @@ def validate_batch_iter( if not isinstance(output, dict): output = {"loss": output} - # Store the loss for each microbatch + # Store the loss(es) for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} outputs.append(output) @@ -269,8 +280,9 @@ def train_batch_iter( send_activation() # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) for micro_batch in batch: @@ -282,8 +294,9 @@ def train_batch_iter( output = {"loss": output} # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) # One backward diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..380ad460 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -34,6 +34,7 @@ class StandardParametrizator(Parametrizator): def __init__(self, config: ModelArgs): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { + nn.Linear: self._parametrize_column_linear, TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0eda00dc..6c6c0a6b 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -131,7 +131,9 @@ def __init__( super().__init__() self.config = get_config_from_file( - config_or_config_file, config_class=config_class, model_config_class=model_config_class + config_or_config_file, + config_class=config_class, + model_config_class=model_config_class, ) self.model_config = self.config.model.model_config if model_class is not None: @@ -212,7 +214,8 @@ def __init__( # Define iteration start state if self.init_checkpoint_path is not None: checkpoint_metadata = load_meta( - parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + parallel_context=self.parallel_context, + root_folder=self.init_checkpoint_path, ) assert isinstance(checkpoint_metadata.metas, TrainingMetadata) log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) @@ -276,7 +279,8 @@ def pre_training(self, *args, **kwargs): if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.init( project=self.config.general.project, - name=f"{current_time}_{self.config.general.run}", + name=f"{current_time}_{self.config.general.project}_{self.config.general.run}", + entity=self.config.general.entity, config={"nanotron_config": self.config.as_dict()}, ) @@ -307,7 +311,9 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da else: dataloader = dataloaders self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + dataloader=dataloader, + parallel_context=self.parallel_context, + config=self.config, ) return elif isinstance(dataloaders, Generator): @@ -385,13 +391,19 @@ def find_stage_idx_to_resume(): if dataloader is not None: self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + dataloader=dataloader, + parallel_context=self.parallel_context, + config=self.config, ) def train( self, dataloader_or_dls: Dict[ - str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] + str, + Union[ + Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], + Tuple[Iterator, ...], + ], ], **kwargs, ) -> None: @@ -424,7 +436,7 @@ def train( self._update_dataloader_based_on_training_stages(dataloader_or_dls) # Training step - outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + outputs, loss_avg, aux_losses = self.training_step(dataloader=self.current_dataloader) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -435,7 +447,11 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs( + outputs=outputs, + loss_avg=loss_avg, + aux_losses=aux_losses, + ) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -448,7 +464,12 @@ def train( def training_step( self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]] ) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]: - before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + before_tbi_sanity_checks( + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, + ) if self.iteration_step < 5: log_memory(logger=logger) @@ -464,7 +485,12 @@ def training_step( if self.iteration_step < 5: log_memory(logger=logger) - after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + after_tbi_sanity_checks( + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, + ) if isinstance(self.model, DistributedDataParallel) and self.grad_accumulator is not None: # Wait for fp32 grads allreduce to finish to make sure grads are synced across DP @@ -508,9 +534,13 @@ def training_step( ) before_optim_step_sanity_checks( - self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, ) + aux_losses = {} # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. @@ -518,7 +548,24 @@ def training_step( [output["loss"] for output in outputs] ).sum() # already divided by n_micro_batches_per_batch # sync loss across DP - handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + handle = dist.all_reduce( + loss_avg, + group=self.parallel_context.dp_pg, + async_op=True, + op=dist.ReduceOp.AVG, + ) + for k in outputs[0].keys(): + if k != "loss": + aux_losses[k] = torch.stack( + [output[k] for output in outputs] + ).sum() # already divided by n_micro_batches_per_batch + # sync loss across DP + handle = dist.all_reduce( + aux_losses[k], + group=self.parallel_context.dp_pg, + async_op=True, + op=dist.ReduceOp.AVG, + ) else: loss_avg = None handle = None @@ -530,14 +577,19 @@ def training_step( # Update the learning rate self.lr_scheduler.step() - after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + after_optim_step_sanity_checks( + self.config, + self.parallel_context, + self.unwrapped_model, + self.grad_accumulator, + ) if handle is not None: handle.wait() self.post_train_step() - return outputs, loss_avg + return outputs, loss_avg, aux_losses def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: outputs = self.pipeline_engine.validate_batch_iter( @@ -551,6 +603,7 @@ def train_step_logs( self, outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], + aux_losses: Optional[dict] = {}, ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() @@ -577,13 +630,20 @@ def train_step_logs( self.metadata.consumed_train_samples * self.config.tokens.sequence_length, "human_format", ), # , "12d"), - LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), + LogItem( + "elapsed_time_per_iteration_ms", + elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), LogItem( - "tokens_per_sec_per_gpu", tokens_per_sec / self.parallel_context.world_pg.size(), "human_format" + "tokens_per_sec_per_gpu", + tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", ), # , "1.6E"), LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"), LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"), + *[LogItem(k, v.item(), "human_format") for k, v in aux_losses.items()], LogItem("lr", lr, "human_format"), # , ".3E"), LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), @@ -598,10 +658,14 @@ def train_step_logs( log_entries.extend( [ LogItem( - "cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format" + "cuda_memory_allocated", + torch.cuda.memory_allocated(), + "human_format", ), # / 1024**2, ".2f"), LogItem( - "cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format" + "cuda_max_memory_reserved", + torch.cuda.max_memory_reserved(), + "human_format", ), # / 1024**2, ".2f"), LogItem("hd_total_memory_tb", total, "human_format"), # / (2**40), ".2f"), LogItem("hd_used_memory_tb", used, "human_format"), # / (2**40), ".2f"), @@ -691,9 +755,16 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: reloaded_from_checkpoint = False if self.init_checkpoint_path is not None: # Reload from a training checkpoint - log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + log_rank( + f"Loading weights from {self.init_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) self.param_shard_metadata = load_weights( - model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + model=unwrapped_model, + parallel_context=self.parallel_context, + root_folder=self.init_checkpoint_path, ) reloaded_from_checkpoint = True if not reloaded_from_checkpoint: @@ -756,17 +827,41 @@ def _init_model( module.init_rotary_embeddings() # Mark some parameters as tied - self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + self._mark_tied_parameters( + model=model, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) # count number of parameters num_params = sum(p.numel() for p in model.parameters()) size_params = sum(p.numel() * p.element_size() for p in model.parameters()) total_params = torch.tensor(num_params, device="cuda") total_size = torch.tensor(size_params, device="cuda") - dist.all_reduce(total_params, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) # TP - dist.all_reduce(total_params, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # PP - dist.all_reduce(total_size, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) - dist.all_reduce(total_size, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) + dist.all_reduce( + total_params, + group=parallel_context.tp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) # TP + dist.all_reduce( + total_params, + group=parallel_context.pp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) # PP + dist.all_reduce( + total_size, + group=parallel_context.tp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) + dist.all_reduce( + total_size, + group=parallel_context.pp_pg, + async_op=False, + op=dist.ReduceOp.SUM, + ) # TODO @nouamanetazi: better memory logs log_rank( @@ -872,7 +967,9 @@ def save_checkpoint(self) -> Path: config=self.config, ) save_random_states( - random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path + random_states=self.random_states, + parallel_context=self.parallel_context, + root_folder=checkpoint_path, ) with open(checkpoints_path / "latest.txt", mode="w") as fo: fo.write(f"{self.iteration_step}") @@ -893,11 +990,17 @@ def _mark_tied_parameters( parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None, ): - mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + mark_tied_parameters( + model=model, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) def mark_tied_parameters( - model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, ): # Tie embeddings embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names() @@ -917,7 +1020,10 @@ def mark_tied_parameters( for target in embeddings_lm_head_tied_names ] tie_parameters( - root_module=model, ties=shared_embeddings, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM + root_module=model, + ties=shared_embeddings, + parallel_context=parallel_context, + reduce_op=dist.ReduceOp.SUM, ) # Tie custom params @@ -932,7 +1038,9 @@ def mark_tied_parameters( def mark_unsharded_params_as_tied_across_tp( - model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs" + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: "ParallelismArgs", ): for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): @@ -968,12 +1076,17 @@ def mark_unsharded_params_as_tied_across_tp( reduce_op = dist.ReduceOp.SUM tie_parameters( - root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op + root_module=model, + ties=shared_weights, + parallel_context=parallel_context, + reduce_op=reduce_op, ) def mark_unsharded_params_as_tied_across_expert( - model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs" + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: "ParallelismArgs", ): for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): @@ -1002,5 +1115,8 @@ def mark_unsharded_params_as_tied_across_expert( reduce_op = None tie_parameters( - root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op + root_module=model, + ties=shared_weights, + parallel_context=parallel_context, + reduce_op=reduce_op, )