From 7cb19540334f6bf8ec00f854d393247a6e3c6a91 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 20 Nov 2024 21:49:02 +0000 Subject: [PATCH] fixed backward pass for te.Linear column-parallel with TP overlap, updated unit tests Signed-off-by: Alp Dener --- .../distributed/run_layer_with_overlap.py | 58 ++-- .../distributed/test_comm_gemm_overlap.py | 14 +- transformer_engine/pytorch/module/linear.py | 264 +++++++++++------- 3 files changed, 203 insertions(+), 133 deletions(-) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 025061c9d9..9abbb93d29 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -51,18 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): kwargs["ub_overlap_ag"] = not reference if config.layer_type is te.Linear: - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["parallel_mode"] = config.parallel_mode - kwargs["ub_overlap_rs_fprop"] = not reference and config.parallel_mode == "row" - kwargs["ub_overlap_ag_dgrad"] = not reference and config.parallel_mode == "row" - kwargs["ub_overlap_ag_fprop"] = not reference and config.parallel_mode == "column" - kwargs["ub_overlap_rs_dgrad"] = not reference and config.parallel_mode == "column" - kwargs["ub_name"] = "proj" if config.parallel_mode == "row" else "qkv" + if config.linear_parallel_mode == "row": + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["ub_overlap_rs"] = not reference + elif config.linear_parallel_mode == "column": + input_shape[0] = config.seq_length // tp_size + args.append(3 * hidden_size) + kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["parallel_mode"] = config.linear_parallel_mode + kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv" else: input_shape[0] = config.seq_length // tp_size - kwargs["ub_bulk_wgrad"] = not reference - kwargs["ub_bulk_dgrad"] = not reference + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference if config.layer_type is te.LayerNormLinear: args.append(3 * hidden_size) kwargs["parallel_mode"] = "column" @@ -135,6 +140,12 @@ def _parse_args(argv=None, namespace=None): choices=["row", "column"], help="Parallel mode for te.Linear." ) + parser.add_argument( + "--overlap-rs-dgrad", + action="store_true", + default=False, + help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps." + ) parser.add_argument( "--debug", action="store_true", @@ -240,12 +251,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Intialize userbuffers + ub_cfgs = None + if opts.overlap_rs_dgrad: + ub_cfgs = { + "proj_dgrad" : {"method" : "ring_exchange"}, + "qkv_dgrad" : {"method" : "ring_exchange"}, + } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], WORLD_SIZE, use_fp8=opts.fp8, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=ub_cfgs, ) # Initialize the Transformer Engine layer with overlap @@ -324,27 +342,29 @@ def run_fwd_bwd(model, x): ref_grads.append(ref_param.grad) # Make sure we have the same number of gradients - numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") if len(test_grads) != len(ref_grads): - numerics_failed[0] = 1 + num_grads_failed[0] = 1 numerics_info = ( "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + f"expected {len(ref_grads)} but got {len(test_grads)}." ) dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world) # Now validate accuracy - if not bool(numerics_failed.item()): + numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda") + if not bool(num_grads_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): rtol = 0.125 if opts.fp8 else 0.025 atol = 0.0625 if opts.fp8 else 0.00125 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[0] = int(grad_failed) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - if bool(numerics_failed.item()): - break + numerics_failed[i] = int(grad_failed) + return_code = torch.max(numerics_failed) + dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world) + else: + return_code = num_grads_failed te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) @@ -354,7 +374,7 @@ def run_fwd_bwd(model, x): if opts.debug and WORLD_RANK == 0: print("Exiting...\n", end="", flush=True) - return numerics_failed[0].item() + return return_code.item() if __name__ == "__main__": diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 5d7ec24125..4c757146a9 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -88,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, fp8, fp8_init): +def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -99,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", ] + if layer_type == te.Linear.__name__: + test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") if fp8: if not fp8_available: @@ -229,12 +231,12 @@ def test_bulk_overlaps(comm_type, fp8): @pytest.mark.parametrize( "layer_type,linear_parallel_mode", ( - list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))])) - + [(te.Linear, "row"), (te.Linear, "column")] + [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")] + + list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))])) ), ids=( - [(" " + layer.__name__ + " ") for layer in TE_LAYERS] - + [" te.Linear (row-parallel) ", " te.Linear (column-parallel) "] + [f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "] + + [(" " + layer.__name__ + " ") for layer in TE_LAYERS] ), ) @pytest.mark.parametrize( @@ -254,4 +256,4 @@ def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, fp8, fp8_init) + _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index da1313a880..92ceb3d82e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,7 +3,9 @@ # See LICENSE for license information. """Linear API""" -import warnings +import sys +from functools import reduce +from operator import mul as multiply_op from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -81,10 +83,12 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, - ub_overlap_ag_fprop: bool, ub_overlap_rs_fprop: bool, ub_overlap_ag_dgrad: bool, + ub_overlap_ag_fprop: bool, ub_overlap_rs_dgrad: bool, + ub_bulk_dgrad: bool, + ub_bulk_wgrad: bool, ub_name: str, fp8_output: bool, fsdp_group: Union[dist_group_type, None], @@ -102,8 +106,8 @@ def forward( assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_ag = False if tp_world_size == 1 else ub_overlap_ag - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs + ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop + ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop # Cast input to expected dtype inputmat = cast_if_needed(inputmat, activation_dtype) @@ -170,14 +174,14 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) if fp8_output: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], fp8_dtype_forward, torch.uint8, ) else: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( None, None, None, @@ -220,7 +224,7 @@ def forward( elif ub_overlap_ag_fprop: ub_obj = get_ub(ub_name + "_fprop") assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer." - ub_obj.copy_input_to_ubuf(inputmat_data, 0) + ub_obj.copy_input_to_ubuf(inputmat_data, True) ub_obj.set_ubuf_scale_inv(inputmat_scale_inv) if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P @@ -256,7 +260,7 @@ def forward( ub_algo=ub_algo, ub=ub_obj, extra_output_tensor=rs_out, - out_index=proj_out_index, + out_index=out_index, fp8_meta_tensor=meta_tensor, D_dtype=out_tedtype, ) @@ -304,7 +308,7 @@ def forward( elif ub_overlap_ag_fprop: ub_obj = get_ub(ub_name + "_fprop") - ub_obj.copy_input_to_ubuf(inputmat_total, 0) + ub_obj.copy_input_to_ubuf(inputmat_total, True) dim_size = list(inputmat_total.size()) dim_size[0] *= tp_size # all-gathered sequence length dim_size[1] = out_features @@ -383,7 +387,9 @@ def forward( ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.ub_overlap_ag = ub_overlap_ag_dgrad - ctx.ub_overlap_rs = ub_overlap_rs_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad @@ -396,12 +402,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if ub_overlap_rs: - out = rs_out - elif parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if parallel_mode == "row": + if ub_overlap_rs_fprop: + out = rs_out + elif sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp_shape[1:-1], out_features) @@ -440,36 +447,75 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight.main_grad = main_grad tp_world_size = get_distributed_world_size(ctx.tp_group) - ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag_dgrad - ctx.ub_overlap_rs = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad + ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag + ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad + ctx.ub_obj_gradout = None - ub_algo = None + ub_obj_wgrad = None + ub_algo_wgrad = None + ub_algo_dgrad = None rs_out = None - if ctx.ub_overlap_ag_dgrad: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] * tp_world_size - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + dgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + if ctx.ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + dgrad = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, + device=grad_output.device) + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - dim_size = list(inputmat_total.size()) - dim_size[0] /= tp_world_size # sequence-parallel - dim_size[1] = weight.shape[0] # out_features - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, - device=inputmat_total.device) + dgrad = ctx.ub_obj_gradout.get_ubuf_output(1) if ctx.ub_obj_gradout.is_p2p_overlap(): if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + rs_out = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, + device=grad_output.device) + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + inputmat_data = ( + inputmat._data + if isinstance(inputmat, Float8Tensor) + else inputmat + ) + ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True) + inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1) + if isinstance(inputmat, Float8Tensor): + inputmat._data = inputmat_ubuf + else: + inputmat = inputmat_ubuf + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + dgrad = ub_obj_wgrad.get_ubuf_output(1) + + if dgrad is None: + if ctx.parallel_mode == "column" and ctx.sequence_parallel: + dgrad_shape[0] = dgrad_shape[0] * tp_world_size + dgrad = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, + device=grad_output.device) ( grad_output, @@ -480,16 +526,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx, grad_output, ctx.parallel_mode == "row" ) - # Column Parallel Linear - # Overlap input AG with dgrad + # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers) inputmat_total = None inputmat_t_total = None - handle = None + inputmat_gather_handle = None if (weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel - and not ctx.ub_overlap_ag): - inputmat_total, handle = gather_along_first_dim( + and not ctx.ub_bulk_dgrad): + inputmat_total, inputmat_gather_handle = gather_along_first_dim( inputmat, ctx.tp_group, async_op=ctx.requires_dgrad ) else: @@ -509,13 +554,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - if ctx.is_input_fp8: + if (ctx.is_input_fp8 + or (ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf())): out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, ctx.fp8_meta["scaling_bwd"], fp8_dtype_backward, torch.uint8, ) + if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf(): + ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) else: out_index, meta_tensor, output_te_dtype, output_dtype = ( None, @@ -523,7 +571,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, ctx.activation_dtype, ) - dgrad, _ = fp8_gemm( + _ = fp8_gemm( weight_fp8.transpose_2d(), weight_fp8._scale_inv, 0, @@ -535,8 +583,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], output_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, + ub_algo=ub_algo_dgrad, ub=ctx.ub_obj_gradout, + out=dgrad, out_index=out_index, fp8_meta_tensor=meta_tensor, D_dtype=output_te_dtype, @@ -546,7 +595,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.ub_overlap_rs_dgrad: dgrad = rs_out - elif output_dtype == torch.uint8: + if output_dtype == torch.uint8: dgrad = Float8Tensor( data=dgrad, fp8_meta=ctx.fp8_meta, @@ -556,32 +605,35 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, ) else: - dgrad, _, _ = gemm( + _ = gemm( weight, grad_output, ctx.activation_dtype, get_workspace(), layout="NN", grad=True, - ub_algo=ub_algo, + ub_algo=ub_algo_dgrad, ub=ctx.ub_obj_gradout, + out=dgrad, extra_output_tensor=rs_out, ) if ctx.ub_overlap_rs_dgrad: dgrad = rs_out - # Overlap dgrad-RS/AR with wgrad - if (ctx.parallel_mode == "column" - and ctx.sequence_parallel - and not ctx.ub_overlap_rs_dgrad): - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter_along_first_dim( + if inputmat_gather_handle is not None: + inputmat_gather_handle.wait() + + # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers) + dgrad_reduce_handle = None + if ctx.requires_dgrad and ctx.parallel_mode == "column": + if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad): + dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim( dgrad, ctx.tp_group, async_op=True ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + elif ctx.tensor_parallel and not ctx.sequence_parallel: + dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True) + wgrad = None if weight.requires_grad: @@ -618,6 +670,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: wgrad, _, _ = gemm( @@ -629,6 +683,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: # WGRAD @@ -642,18 +698,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) + if ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_ubuf_output(0) + # Deallocate input tensor clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_t_total) - # Column Parallel Linear - if (ctx.parallel_mode == "column" - and (ctx.tensor_parallel - or (ctx.sequence_parallel and not ctx.ub_overlap_rs_dgrad)) - and handle is not None): - handle.wait() + # Wait for dgrad reduce-scatter or all-reduce + if dgrad_reduce_handle is not None: + dgrad_reduce_handle.wait() if not ctx.use_bias: grad_bias = None @@ -707,10 +765,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # activation_dtype None, # parallel_mode None, # is_grad_enabled - None, # ub_overlap_ag_fprop None, # ub_overlap_rs_fprop None, # ub_overlap_ag_dgrad + None, # ub_overlap_ag_fprop None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name None, # fp8_output None, # fsdp_group @@ -806,10 +866,8 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, ub_overlap_rs: bool = False, - ub_overlap_ag_fprop: bool = False, - ub_overlap_rs_fprop: bool = False, - ub_overlap_ag_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -845,49 +903,37 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - assert ( - not ((ub_overlap_rs or ub_overlap_rs_fprop) and ub_overlap_ag_fprop) - ), "Cannot enable both all-gather and reduce-scatter overlaps in a te.Linear forward pass." - assert ( - not ((ub_overlap_ag or ub_overlap_ag_dgrad) and ub_overlap_rs_dgrad) - ), "Cannot enable both all-gather and reduce-scatter overlaps in a te.Linear backward pass." - - if ub_overlap_rs: - warnings.warn( - ( - "`ub_overlap_rs=True` will only enable reduce-scatter overlap in the forward " - + "pass. To enable reduce-scatter overlap in the backward pass, please use " - + "the `ub_overlap_rs_dgrad=True` option instead." - ), - DeprecationWarning - ) - ub_overlap_rs_fprop = True - if ub_overlap_ag: - warnings.warn( - ( - "`ub_overlap_ag=True` will only enable all-gather overlap in the backward " - + "pass. To enable all-gather overlap in the forward pass, please use " - + "the `ub_overlap_rs_dgrad=True` option instead." - ), - DeprecationWarning - ) - ub_overlap_ag_dgrad = True - - self.ub_overlap_ag_fprop = ub_overlap_ag_fprop - self.ub_overlap_rs_fprop = ub_overlap_rs_fprop - self.ub_overlap_ag_dgrad = ub_overlap_ag_dgrad - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - - if ub_overlap_ag_fprop or ub_overlap_rs_fprop or ub_overlap_ag_dgrad or ub_overlap_rs_dgrad: - assert self.tp_size > 1, "TP overlap requires `tp_size > 1`." + # Column parallel TP overlap options + self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag + self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs + self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad + self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad + if self.ub_overlap_rs_dgrad: + self.ub_bulk_dgrad = False + self.ub_bulk_wgrad = False + + # Row parallel TP overlap options + self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs + self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag + + if any( + [ + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): assert ub_name is not None, f"Userbuffer name [string] is not set." self.ub_name = ub_name - if self.ub_overlap_ag_fprop or self.ub_overlap_rs_dgrad: - assert sequence_parallel and parallel_mode == "column", ( - "All-gather overlap in the forward pass or reduce-scatter overlap in the backward " - + "pass require `parallel_mode=\'column\'` and `sequence_parallel=True`." - ) + assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), "Internal TE error!" + assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), "Internal TE error!" + assert not (self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad)), ( + "Internal TE error!" + ) self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name @@ -1134,10 +1180,12 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), - self.ub_overlap_ag_fprop, self.ub_overlap_rs_fprop, self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.ub_name, fp8_output, self.fsdp_group,