From ae393e81ea5816bc0e53af47cf49acc588394ba8 Mon Sep 17 00:00:00 2001 From: buptzyb Date: Mon, 25 Nov 2024 16:01:57 +0800 Subject: [PATCH] Support CUDA Graph for MoE models (#1233) * Align RNG tracker with megatron Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Fix module_params order and warmup bug in cudagraph Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Add fp8_group argument and fix fp8 accuracy issue for cudagraph Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Add TE modules and weights filters to support MoE models Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Revert self.fp8 Signed-off-by: Robin Zhang * Use hooks to filter module params Signed-off-by: Robin Zhang * Filter all TE modules in hooks Signed-off-by: Robin Zhang Co-authored-by: Yifei Song * Format code Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update graph.py Signed-off-by: Xin Yao * Revert CudaRNGStatesTracker Signed-off-by: Robin Zhang * Format Update Signed-off-by: Yifei Song * Revert "Use hooks to filter module params" This reverts commit 73a22e2e8bcf43ec84c23bc844b8d16d06626e26. Signed-off-by: Yifei Song * Remove filtering module params Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Signed-off-by: Xin Yao Signed-off-by: Yifei Song Co-authored-by: Yifei Song Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/fp8.py | 12 +-- transformer_engine/pytorch/graph.py | 86 +++++++++++++++++-- .../pytorch/module/layernorm_linear.py | 5 +- .../pytorch/module/layernorm_mlp.py | 5 +- transformer_engine/pytorch/module/linear.py | 6 +- 5 files changed, 97 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 2a909dabc6..15f20c81e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -442,16 +442,16 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] + fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) + fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) + fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) + fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) + fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"]) @contextmanager diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index c47b792a95..6c33cc72b9 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -12,6 +12,7 @@ from torch._C import _graph_pool_handle from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.pytorch.constants import dist_group_type from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -173,11 +174,14 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): - per_callable_module_params.append( - tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - ) + for m_chunk in range(num_model_chunks): + for _ in range(num_microbatches): + for l_no in range(num_layers): + per_callable_module_params.append( + tuple(callables[m_chunk * num_layers + l_no].parameters()) + if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) + else () + ) assert len(per_callable_module_params) == len(flatten_sample_args) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] @@ -201,13 +205,55 @@ def _make_graphed_callables( # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work # from ending up in any captures. torch.cuda.synchronize() - with torch.cuda.stream(torch.cuda.Stream()): + + # Get warmup func and func_idx. + warmup_func_idx = [] + warmup_func = [] + if _order is None: for func_idx, func in enumerate(callables): + warmup_func_idx.append(func_idx) + warmup_func.append(func) + else: + fwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + m_chunk = c_id - 1 + for l_no in range(num_layers): + func = callables[m_chunk * num_layers + l_no] + func_idx = (m_chunk * num_microbatches * num_layers) + ( + fwd_idx[m_chunk] * num_layers + l_no + ) + warmup_func_idx.append(func_idx) + warmup_func.append(func) + fwd_idx[m_chunk] += 1 + assert len(warmup_func) == len( + sample_args + ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}." + assert len(warmup_func_idx) == len( + set(warmup_func_idx) + ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + + # Filter the TE modules that cudagraph can access. + visited_te_modules = set() + + def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument + if isinstance(module, TransformerEngineBaseModule): + visited_te_modules.add(module) + + # Run warmup and do the above filtering. + with torch.cuda.stream(torch.cuda.Stream()): + for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): + hooks = [] + for module in func.modules(): + hook = module.register_forward_hook(hook_fn) + hooks.append(hook) outputs, _ = _tree_flatten(func(*args, **kwargs)) + for hook in hooks: + hook.remove() grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -216,6 +262,11 @@ def _make_graphed_callables( allow_unused=allow_unused_input, ) del outputs, grad_inputs + # The following code is added specifically for MCore's special requirements, + # aimed at preventing warmup from altering the control flow. + for module in func.modules(): + if hasattr(module, "is_first_microbatch"): + module.is_first_microbatch = True torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -462,6 +513,19 @@ def new_fwd(*user_args, **user_kwargs): isinstance(m, TransformerEngineBaseModule) and FP8GlobalStateManager.is_fp8_enabled() ): + if m not in visited_te_modules: + # Only Set the FP8 meta for the modules included by forward + continue + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + from transformer_engine.pytorch.attention import DotProductAttention + + if ( + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + ): + # Don't need to update FP8 meta for non-FP8 DPA + continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( @@ -538,6 +602,7 @@ def make_graphed_callables( fp8_enabled: bool = False, fp8_calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, + fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, @@ -579,6 +644,9 @@ def make_graphed_callables( using a higher precision. fp8_recipe: recipe.DelayedScaling, default = `None` recipe used for FP8 training. + fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the fp8 tensors + are reduced at the end of each training step. fp8_weight_caching: bool, default = `False` Whether or not to cache FP8 weights across microbatches. if set to `True`, the `is_first_microbatch` boolean argument must be passed into the forward @@ -607,7 +675,11 @@ def wrap_autocast(block): def forward_func(*args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True + enabled=fp8_enabled, + calibrating=fp8_calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group, + _graph=True, ): outputs = old_forward(*args, **kwargs) return outputs diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fbf1b97704..92b37fcb07 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1152,7 +1152,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 64e8c9ce36..1a651474bf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1484,7 +1484,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fed467210..9492725f56 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -938,8 +938,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ - - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False