Skip to content

Commit

Permalink
Improving communication overlap for the case of multi kernel queue us…
Browse files Browse the repository at this point in the history
…age (#1308)

* draft implementation

Signed-off-by: Youngeun Kwon <[email protected]>

* compile error fix

Signed-off-by: Youngeun Kwon <[email protected]>

* fix compile error

Signed-off-by: Youngeun Kwon <[email protected]>

* remove print

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Edit comments

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* edit the bulk-overlap test case

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add version guard

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add runtime version guard

Signed-off-by: Youngeun Kwon <[email protected]>

* fix the version guard

Signed-off-by: Youngeun Kwon <[email protected]>

---------

Signed-off-by: Youngeun Kwon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and phu0ngng committed Dec 3, 2024
1 parent 1d1c3a6 commit 8f599ce
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 43 deletions.
34 changes: 27 additions & 7 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):


@pytest.mark.parametrize(
"comm_type,fp8",
"comm_type, fp8, connections",
[
("AG", False),
("RS", False),
("RS", True),
("AG", False, 1),
("RS", False, 1),
("RS", True, 1),
("AG", False, 8),
("RS", False, 8),
("RS", True, 8),
],
ids=[
"ALL-GATHER - BF16 - 1 connections",
"REDUCE-SCATTER - BF16 - 1 connections",
"REDUCE-SCATTER - FP8 - 1 connections",
"ALL-GATHER - BF16 - 8 connections",
"REDUCE-SCATTER - BF16 - 8 connections",
"REDUCE-SCATTER - FP8 - 8 connections",
],
ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "],
)
def test_bulk_overlaps(comm_type, fp8):
def test_bulk_overlaps(comm_type, fp8, connections):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
if connections == 8:
if torch.cuda.get_device_properties(0).major != 9:
pytest.skip(
"CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability"
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)


@pytest.mark.parametrize(
Expand Down
30 changes: 27 additions & 3 deletions transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,31 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);

/*
Defining the launcher order between the communication and GEMM kernels
using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1.
The event is used to schedule the communication kernel before the GEMM.
This is needed only for Hopper, which uses persistent CTA execution.
*/
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
} else {
_comm_launch_event = 0;
}
}

CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);

if (_atomic_gemm) cudaFree(_counter.dptr());

Expand Down Expand Up @@ -168,7 +186,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
Expand All @@ -178,13 +197,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm);
comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
}
}

assert(pre_gelu_out.numel() == 0);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if (_comm_launch_event)
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0));
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
stream_main);
Expand Down
116 changes: 90 additions & 26 deletions transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS)
cfg.attrs = attribute_ub; \
cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1;

#if (CUDART_VERSION >= 12030)
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \
attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event;
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3
#else
#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event)
#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2
#endif

#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \
ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH;

#define callranks_ag(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
Expand Down Expand Up @@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
}

void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
Expand All @@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8)
} else {
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
}
}

Expand All @@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con
}

void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
Expand All @@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8)
} else {
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
}
}
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream) {
cudaStream_t stream, cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
Expand All @@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8)
} else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
}
}
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream);
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream,
comm_launch_event);
}

template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream) {
communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_firstgpu =
Expand All @@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;

SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
}

template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);

template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);

template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) {
const int elements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event) {
reducescatter2_userbuff_stridedoutput_fp8<fp8type>(output, scale, handler, offset, elements, 1, 0,
comm, stream);
comm, stream, comm_launch_event);
}

template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream);
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream);
cudaStream_t stream,
cudaEvent_t comm_launch_event);

template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *

// for TP-parallelism, only single node is implemented
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
/*
each Rank input is
allgather2_userbuff_inplace: offset+myrank*elements
Expand All @@ -228,21 +229,26 @@ for(int slice=0;slice<ncslices;slice++)
allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream);
*/
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream = 0);
cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
const int colelements, const int strideelements,
communicator *comm, cudaStream_t stream = 0);
communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream = 0);
const int elements, communicator *comm, cudaStream_t stream = 0,
cudaEvent_t comm_launch_event = 0);
template <typename fp8type>
void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler,
const int offset, const int rowelements,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CommOverlapCore {
bool _ubuf_scale_inv_initialized{false};

std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;

public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
Expand Down

0 comments on commit 8f599ce

Please sign in to comment.