Skip to content

Commit

Permalink
TP-RS overlap with send/recv ring-exchange (NVIDIA#724)
Browse files Browse the repository at this point in the history
* TP-RS overlap with send/recv

Atomic GEMM based TP-RS overlap with send/recv

Signed-off-by: Sangkug Lym <[email protected]>

Specify userbuffer overlap method of each overlap instance

Signed-off-by: Sangkug Lym <[email protected]>

P2P TP-RS overlap with fp8 GEMM outputs

Signed-off-by: Sangkug Lym <[email protected]>

Fix TP-RS overlap with send/recv

Signed-off-by: Sangkug Lym <[email protected]>

* cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* linting

Signed-off-by: Sangkug Lym <[email protected]>

* fix typo

Signed-off-by: Sangkug Lym <[email protected]>

---------

Signed-off-by: Sangkug Lym <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
erhoo82 and timmoon10 authored Mar 21, 2024
1 parent 59bfc17 commit b855656
Show file tree
Hide file tree
Showing 11 changed files with 497 additions and 268 deletions.
18 changes: 6 additions & 12 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3175,10 +3175,8 @@ def __init__(
qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
Expand Down Expand Up @@ -3265,9 +3263,8 @@ def __init__(
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
Expand Down Expand Up @@ -3297,9 +3294,8 @@ def __init__(
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
Expand Down Expand Up @@ -3347,10 +3343,8 @@ def __init__(
bias=bias,
return_bias=return_bias,
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
ub_name="proj",
**common_gemm_kwargs,
)
Expand Down
30 changes: 24 additions & 6 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def fp8_gemm(
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
Expand All @@ -119,12 +119,24 @@ def fp8_gemm(
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)

return out, gelu_input
Expand Down Expand Up @@ -217,8 +229,8 @@ def gemm(
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
Expand All @@ -229,6 +241,12 @@ def gemm(
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)

return out, grad_bias, gelu_input
Loading

0 comments on commit b855656

Please sign in to comment.