Skip to content

Commit

Permalink
Support CUDA Graph for MoE models (#1233)
Browse files Browse the repository at this point in the history
* Align RNG tracker with megatron

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>

* Fix module_params order and warmup bug in cudagraph

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>

* Add fp8_group argument and fix fp8 accuracy issue for cudagraph

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>

* Add TE modules and weights filters to support MoE models

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>

* Revert self.fp8

Signed-off-by: Robin Zhang <[email protected]>

* Use hooks to filter module params

Signed-off-by: Robin Zhang <[email protected]>

* Filter all TE modules in hooks

Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>

* Format code

Signed-off-by: Robin Zhang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update graph.py

Signed-off-by: Xin Yao <[email protected]>

* Revert CudaRNGStatesTracker

Signed-off-by: Robin Zhang <[email protected]>

* Format Update

Signed-off-by: Yifei Song <[email protected]>

* Revert "Use hooks to filter module params"

This reverts commit 73a22e2.

Signed-off-by: Yifei Song <[email protected]>

* Remove filtering module params

Signed-off-by: Robin Zhang <[email protected]>

---------

Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Yifei Song <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xin Yao <[email protected]>
Co-authored-by: Xin Yao <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
6 people authored Nov 25, 2024
1 parent 8952bc4 commit ae393e8
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 17 deletions.
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,16 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()

# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2])

@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"])


@contextmanager
Expand Down
86 changes: 79 additions & 7 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch._C import _graph_pool_handle

from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.constants import dist_group_type
from .fp8 import (
fp8_autocast,
FP8GlobalStateManager,
Expand Down Expand Up @@ -173,11 +174,14 @@ def _make_graphed_callables(
]
else:
per_callable_module_params = []
for c in callables:
for i in range(num_microbatches):
per_callable_module_params.append(
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
)
for m_chunk in range(num_model_chunks):
for _ in range(num_microbatches):
for l_no in range(num_layers):
per_callable_module_params.append(
tuple(callables[m_chunk * num_layers + l_no].parameters())
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module)
else ()
)
assert len(per_callable_module_params) == len(flatten_sample_args)
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
Expand All @@ -201,13 +205,55 @@ def _make_graphed_callables(
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):

# Get warmup func and func_idx.
warmup_func_idx = []
warmup_func = []
if _order is None:
for func_idx, func in enumerate(callables):
warmup_func_idx.append(func_idx)
warmup_func.append(func)
else:
fwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
m_chunk = c_id - 1
for l_no in range(num_layers):
func = callables[m_chunk * num_layers + l_no]
func_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)
warmup_func_idx.append(func_idx)
warmup_func.append(func)
fwd_idx[m_chunk] += 1
assert len(warmup_func) == len(
sample_args
), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}."
assert len(warmup_func_idx) == len(
set(warmup_func_idx)
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."

# Filter the TE modules that cudagraph can access.
visited_te_modules = set()

def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)

# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters):
hooks = []
for module in func.modules():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
Expand All @@ -216,6 +262,11 @@ def _make_graphed_callables(
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
Expand Down Expand Up @@ -462,6 +513,19 @@ def new_fwd(*user_args, **user_kwargs):
isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention

if (
isinstance(m, DotProductAttention)
and not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
):
# Don't need to update FP8 meta for non-FP8 DPA
continue
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
Expand Down Expand Up @@ -538,6 +602,7 @@ def make_graphed_callables(
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
Expand Down Expand Up @@ -579,6 +644,9 @@ def make_graphed_callables(
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
fp8_weight_caching: bool, default = `False`
Whether or not to cache FP8 weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
Expand Down Expand Up @@ -607,7 +675,11 @@ def wrap_autocast(block):

def forward_func(*args, **kwargs):
with fp8_autocast(
enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True
enabled=fp8_enabled,
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=True,
):
outputs = old_forward(*args, **kwargs)
return outputs
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,10 @@ def forward(
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,10 @@ def forward(
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,10 @@ def forward(
first microbatch (since it is the first gradient being
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down

0 comments on commit ae393e8

Please sign in to comment.