Skip to content

Commit

Permalink
Enable TP-AG overlap with return_layernorm_output (NVIDIA#727)
Browse files Browse the repository at this point in the history
* Enable TP-AG overlap with return_layernorm_output

Signed-off-by: Jaemin Choi <[email protected]>

* Use ub_overlap_ag

Signed-off-by: Jaemin Choi <[email protected]>

---------

Signed-off-by: Jaemin Choi <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
  • Loading branch information
minitu and Jaemin Choi authored Mar 22, 2024
1 parent 8e672ff commit c1a68f6
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,18 @@ def forward(

if ub_overlap_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
if tp_world_size == 1 or (not is_grad_enabled):
ub_overlap_ag = False
if ub_overlap_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub(ub_name+"_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if return_layernorm_output:
# First prepare LN output in higher precision,
# which will be later copied to a FP8 UB
ln_out = torch.empty_like(inputmat)
else:
ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
Expand All @@ -136,7 +141,8 @@ def forward(
ln_out_gathered = False
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
if not return_layernorm_output:
ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
Expand All @@ -153,12 +159,22 @@ def forward(
if return_layernorm_output:
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8:
ln_out = tex.cast_to_fp8(
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
if ub_overlap_ag:
ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0)
tex.cast_to_fp8(
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
out=ln_out_fp8)
ln_out = ln_out_fp8
else:
ln_out = tex.cast_to_fp8(
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)

if fp8:
bias_dtype = (
Expand Down

0 comments on commit c1a68f6

Please sign in to comment.