From b855656b20bc1cb3df4327fc017acdbd240100c9 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Thu, 21 Mar 2024 12:02:53 -0700 Subject: [PATCH] TP-RS overlap with send/recv ring-exchange (#724) * TP-RS overlap with send/recv Atomic GEMM based TP-RS overlap with send/recv Signed-off-by: Sangkug Lym Specify userbuffer overlap method of each overlap instance Signed-off-by: Sangkug Lym P2P TP-RS overlap with fp8 GEMM outputs Signed-off-by: Sangkug Lym Fix TP-RS overlap with send/recv Signed-off-by: Sangkug Lym * cleanup Signed-off-by: Sangkug Lym * cleanup Signed-off-by: Sangkug Lym * linting Signed-off-by: Sangkug Lym * fix typo Signed-off-by: Sangkug Lym --------- Signed-off-by: Sangkug Lym Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 18 +- .../pytorch/cpp_extensions/gemm.py | 30 +- .../pytorch/csrc/comm_gemm_overlap.h | 267 ++++++++++++++++-- .../pytorch/csrc/extensions/pybind.cpp | 26 +- .../pytorch/csrc/userbuffers/userbuffers.cu | 31 ++ .../pytorch/csrc/userbuffers/userbuffers.h | 4 + transformer_engine/pytorch/module/base.py | 43 ++- .../pytorch/module/layernorm_linear.py | 52 ++-- .../pytorch/module/layernorm_mlp.py | 137 ++++----- transformer_engine/pytorch/module/linear.py | 122 ++++---- transformer_engine/pytorch/transformer.py | 35 +-- 11 files changed, 497 insertions(+), 268 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 26ab9f3283..924e2bb97d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3175,10 +3175,8 @@ def __init__( qkv_weight_interleaved: bool = True, ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, - ub_split_rs: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_rs: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_ag: bool = False, bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", @@ -3265,9 +3263,8 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, - ub_split_ag=ub_split_ag, + ub_overlap_ag=ub_overlap_ag, normalization=normalization, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, ub_name="qkv", **common_gemm_kwargs, ) @@ -3297,9 +3294,8 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, - ub_split_ag=ub_split_ag, + ub_overlap_ag=ub_overlap_ag, normalization=normalization, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, ub_name="qkv", **common_gemm_kwargs, ) @@ -3347,10 +3343,8 @@ def __init__( bias=bias, return_bias=return_bias, parallel_mode="row" if set_parallel_mode else None, - ub_split_rs=ub_split_rs, - ub_split_ag=ub_split_ag, - ub_atomic_gemm_rs=ub_atomic_gemm_rs, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, + ub_overlap_rs=ub_overlap_rs, + ub_overlap_ag=ub_overlap_ag, ub_name="proj", **common_gemm_kwargs, ) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 4ddab0e5a1..df571a0e6b 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -101,14 +101,14 @@ def fp8_gemm( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (0, extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: - fn = ub.split_overlap_ag + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG: - fn = ub.atomic_gemm_overlap_ag + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + fn = ub.atomic_gemm_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) @@ -119,12 +119,24 @@ def fp8_gemm( extra_output_tensor is not None ), 'SPLIT_PIPELINED_RS requires extra output tensor' args = tuple(args + (True, extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + fn = ub.split_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor' + args = tuple(args + (extra_output_tensor,)) elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: fn = ub.atomic_gemm_overlap_rs assert ( extra_output_tensor is not None ), 'ATOMIC_GEMM_RS requires extra output tensor' args = tuple(args + (True, extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: + fn = ub.atomic_gemm_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), 'ATOMIC_GEMM_RS_P2P requires extra output tensor' + args = tuple(args + (extra_output_tensor,)) _ = fn(*args) return out, gelu_input @@ -217,8 +229,8 @@ def gemm( elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap args = tuple(args + (0, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: - fn = ub.split_overlap_ag + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) @@ -229,6 +241,12 @@ def gemm( extra_output_tensor is not None ), 'SPLIT_PIPELINED_RS requires extra output tensor' args = tuple(args + (False, extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + fn = ub.split_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor' + args = tuple(args + (extra_output_tensor,)) _ = fn(*args) return out, grad_bias, gelu_input diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 5f8ccab334..817a3ef366 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -41,10 +41,12 @@ enum class COMM_TYPE { RS = 0, AG = 1 }; enum class UBOverlapAlgo { BULK_OVERLAP_AG = 0, BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG = 2, + SPLIT_PIPELINED_AG_P2P = 2, SPLIT_PIPELINED_RS = 3, - ATOMIC_GEMM_RS = 4, - ATOMIC_GEMM_AG = 5 + SPLIT_PIPELINED_RS_P2P = 4, + ATOMIC_GEMM_RS = 5, + ATOMIC_GEMM_AG_P2P = 6, + ATOMIC_GEMM_RS_P2P = 7 }; struct UbufBase { @@ -70,9 +72,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int comm_sms; int cga_size; int use_ce; + bool _atomic_gemm; UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, - int num_splits, bool set_sm_margin, int num_max_streams, + int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm, torch::Tensor empty_tensor) { // Initialize userbuf communicator if (!comm_created) { @@ -116,9 +119,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); output_tensor = torch::Tensor(); - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({num_splits * 2}, counter_options); - counter.index_put_({Slice(None, num_splits)}, 1); + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + counter = torch::zeros({num_splits * 2}, counter_options); + counter.index_put_({Slice(None, num_splits)}, 1); + } // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0); @@ -519,12 +525,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); return output_tensor; } + + bool is_atomic_gemm() { return _atomic_gemm; } + bool is_p2p_overlap() { return false; } }; // UbufCommOverlap struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int _tp_id; int _tp_size; - int _ub_reg; + int _ub_reg, _ub_reg2; int _next_rank, _prev_rank, _rank, _rank_round_tp; int _aggregate2; int _math_sms; @@ -533,18 +542,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor _ubuf; torch::Tensor counter; torch::Tensor _empty_tensor; + torch::Tensor _ubuf_scale_inv; + bool _ubuf_scale_inv_initialized; std::vector _ubufs; at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv; int use_ce; int sms; int cga_size; + bool _atomic_gemm; UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams, - torch::Tensor empty_tensor) { + bool is_reduce_scatter, bool atomic_gemm, torch::Tensor empty_tensor) { // Initialize userbuf communicator if (!comm_created) { if (rank == 0) { @@ -561,16 +573,25 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Create workspace tensor with userbuffer int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_chunk_bytes = ubuf_bytes / tp_size; + int num_ubuf_chunks = tp_size; + if (is_reduce_scatter) { + // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk + // outputs for reduction at the end of the pipelining. + ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); + num_ubuf_chunks = static_cast(tp_size * 2 - 1); + } _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, _ub_comm, true); if (rank == 0) { printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); } - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); + + _ubuf = torch::from_blob( + _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); - for (int i = 0; i < tp_size; i++) { + for (int i = 0; i < num_ubuf_chunks; i++) { torch::Tensor ubuf_chunk = torch::from_blob( ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); _ubufs.push_back(ubuf_chunk); @@ -599,30 +620,37 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { _rank_round_tp = (rank / tp_size) * tp_size; _next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp; _prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp; + _ubuf_scale_inv_initialized = false; - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({tp_size * 2}, counter_options); - counter.index_put_({Slice(None, tp_size)}, 1); - _self_chunk_id = _tp_id; - - const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); - if (rank == 0 && env_p != nullptr) { - if (env_p[0] == '1') { - printf("!!userbuffers_sendrecv_atomic\n"); - } else if (env_p[0] == '2') { - printf("!!userbuffers_sendrecv_multiatomic\n"); - } else if (env_p[0] == '3') { - printf("!!userbuffers_sendrecv_multiatomic_shuffle\n"); - _self_chunk_id = 0; - } else { - printf("!!userbuffers_sendrecv\n"); + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + counter = torch::zeros({tp_size * 2}, counter_options); + counter.index_put_({Slice(None, tp_size)}, 1); + _self_chunk_id = _tp_id; + + if (!is_reduce_scatter) { + const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); + if (rank == 0 && env_p != nullptr) { + if (env_p[0] == '1') { + printf("!!userbuffers_sendrecv_atomic\n"); + } else if (env_p[0] == '2') { + printf("!!userbuffers_sendrecv_multiatomic\n"); + } else if (env_p[0] == '3') { + printf("!!userbuffers_sendrecv_multiatomic_shuffle\n"); + _self_chunk_id = 0; + } else { + printf("!!userbuffers_sendrecv\n"); + } + } + counter.index_put_({_self_chunk_id}, 0); } } - counter.index_put_({_self_chunk_id}, 0); // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0); + cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_send, 0); cudaEventCreateWithFlags(&_stop_recv, 0); } @@ -758,7 +786,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); return D; - } // split_overlap_ag + } // atomic_gemm_overlap_ag + /* ** Split AllGather + GEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is @@ -948,6 +977,174 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { return D; } // split_overlap_ag +/* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = sms; + _ub_comm->cga_size = cga_size; + int k = A.size(1); + int n = B.size(0); + + // Get communication and GEMM input chunk sizes + int n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int *counter_ptr = reinterpret_cast(counter.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + // Catch up the main stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + + // Atomic GEMM + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + _ubuf, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, true, counter); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + + consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, send_rank, (cudaStream_t) _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, recv_rank, (cudaStream_t) _stream_recv); + } + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = sms; + _ub_comm->cga_size = cga_size; + int k = A.size(1); + int n = B.size(0); + + // Get communication and GEMM input chunk sizes + int n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + // Catch up the main stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0)); + } + + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + char* input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); + torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); + // Store the last GEMM chunk output to the recieve buffer. + torch::Tensor workspace_chunk = torch::from_blob( + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + if (i == _tp_size - 1) { + at::cuda::setCurrentCUDAStream(stream_main); + } else { + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + } + te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, + _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + if (i > 0) { + // P2P communication chunk + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + CHECK_CUDA(cudaEventRecord( + _start_comm, (cudaStream_t) _stream_compute[(i - 1) % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_send, _start_comm, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, send_rank, (cudaStream_t) _stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, recv_rank, (cudaStream_t) _stream_recv); + } + } + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, + _tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } + } + /* ** Copy input to _ubufs[0] */ @@ -970,6 +1167,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { (cudaStream_t)stream_main)); } } + torch::Tensor get_ubuf_output(int comm_type) { char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); COMM_TYPE _comm_type = static_cast(comm_type); @@ -981,6 +1179,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int output_c_dim1 = _ubuf.size(1); return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); } + + void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } + bool is_atomic_gemm() { return _atomic_gemm; } + bool is_p2p_overlap() { return true; } }; // UbufP2PCommOverlap } // namespace ubuf diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index b5aa10b150..328bf1dcb4 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -109,26 +109,36 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG) + .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P) + .value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P) .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) - .value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG); + .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) + .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); py::class_(m, "UbufCommOverlap") - .def(py::init()) + .def(py::init()) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv) .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs) .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); + .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output) + .def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm) + .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap); py::class_(m, "UbufP2PCommOverlap") - .def(py::init()) - .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) - .def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) + .def(py::init()) + .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag) + .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs) + .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) + .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); + .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output) + .def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf) + .def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm) + .def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap) + .def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv); #else // NVTE_WITH_USERBUFFERS m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 76e9453efc..0ec89b0bb7 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -3666,3 +3666,34 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 grid(1); consumer_kernel<<>>(atomic_ptr, chunk_i); } + +template +__global__ void __launch_bounds__(MAX_THREADS / 4) +reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, + const int num_inputs, const int input_size) { + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + fp8type *inputs_fp8 = reinterpret_cast(inputs); + float accum_buf = static_cast(inputs_fp8[tid]) * (*scale); + #pragma unroll + for (int i = 1; i < num_inputs; i++) { + accum_buf += static_cast(inputs_fp8[tid + input_size * i]) * (*scale); + } + half *output_half = reinterpret_cast(output); + output_half[tid] = (half) accum_buf; +} + +template +void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, + int input_size, cudaStream_t stream) { + size_t num_threads = MAX_THREADS / 4; + size_t num_blocks = (input_size +num_threads - 1) / num_threads; + dim3 block(num_threads); + dim3 grid(num_blocks); + reduce_fp8_in_bf16_out_cuda<<>>( + inputs, output, scale, num_inputs, input_size); +} + +template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>( + void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); +template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>( + void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index 2d030a1409..407f9479c3 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -305,4 +305,8 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0); void destroy_communicator(communicator *comm); +template +void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, + int input_size, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ab24789549..59e5949e06 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -129,13 +129,14 @@ def initialize_ub( "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" ] if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))): - fp8_buf.append ("proj_fprop") + fp8_buf += ["proj_fprop", "fc2_fprop"] # Default overlap methods for layers methods = { "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "pipeline":["proj_fprop", "fc2_fprop"], "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], } + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] def get_method(name): for method, names in methods.items(): @@ -151,7 +152,28 @@ def add_ub( set_sm_margin: int = 0, num_splits: int = 4, aggregate: int = 0, + atomic_gemm: int = 0, + is_reduce_scatter: int = 0, ) -> None: + if atomic_gemm: + warnings.warn( + "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." + ) + assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." + if is_reduce_scatter and method == "ring_exchange": + raise ValueError( + "Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method." + ) + if method == 'bulk': + warnings.warn( + "Atoimic GEMM not is supported for a bulk overlap." + "Defaulting to `atomic_gemm=False`." + ) + atomic_gemm = 0 + if not is_reduce_scatter and method == 'pipeline': + raise ValueError( + "`pipeline` overlap method is not supported for AllGather." + ) sample_buffer = torch.empty( shape, dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype, @@ -166,6 +188,8 @@ def add_ub( set_sm_margin, # Set SM margin aggregate, # Aggregate 2X GEMM chunks _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + is_reduce_scatter, # overlap with reduce scatter + atomic_gemm, # use a single GEMM with atomic-counters torch.Tensor(), # empty tensor to pass to counters ) else: @@ -178,6 +202,7 @@ def add_ub( num_splits, # Number of communication splits set_sm_margin, # Set SM margin _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + atomic_gemm, # use a single GEMM with atomic-counters torch.Tensor(), # empty tensor to pass to counters ) _ub_communicators[name] = ub_obj @@ -191,6 +216,8 @@ def add_ub( num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0 set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 + atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0 + is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0 add_ub( name, method, @@ -198,7 +225,9 @@ def add_ub( cga_size, set_sm_margin, num_splits, - aggregate + aggregate, + atomic_gemm, + is_reduce_scatter, ) else: method = get_method(name) @@ -632,12 +661,10 @@ def grad_output_preprocess( grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) gather_grad_output = row_parallel_mode and ctx.sequence_parallel - if gather_grad_output: - ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag # No-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: - if not ub_overlap_ag: + if not ctx.ub_overlap_ag: grad_output_mat, _ = gather_along_first_dim( grad_output_mat, ctx.tp_group ) @@ -656,7 +683,7 @@ def grad_output_preprocess( and ctx.fp8_meta["recipe"].override_linear_precision.wgrad ): assert ( - not ub_overlap_ag + not ctx.ub_overlap_ag ), "override_linear_precision.wgrad not supported with UB AG overlap" grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather @@ -665,7 +692,7 @@ def grad_output_preprocess( grad_bias = grad_output_mat.sum(dim=0) else: grad_bias = None - if ub_overlap_ag: + if ctx.ub_overlap_ag: grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) else: grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) @@ -676,7 +703,7 @@ def grad_output_preprocess( fp8_dtype_backward, out=grad_output_c, ) - if not ub_overlap_ag: + if not ctx.ub_overlap_ag: grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) else: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index eecd908e51..3711d9898f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -86,8 +86,7 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_split_ag: bool, - ub_atomic_gemm_ag: bool, + ub_overlap_ag: bool, ub_name: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -106,12 +105,11 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_split_ag or ub_atomic_gemm_ag: + if ub_overlap_ag: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_split_ag = False - ub_atomic_gemm_ag = False - if ub_split_ag or ub_atomic_gemm_ag: + ub_overlap_ag = False + if ub_overlap_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub(ub_name+"_fprop") @@ -119,8 +117,6 @@ def forward( else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) - if ub_atomic_gemm_ag: - assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -138,9 +134,13 @@ def forward( # Column Parallel Linear ln_out_gathered = False - if ub_split_ag or ub_atomic_gemm_ag: + if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) + if ub_obj_lnout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif parallel_mode == "column" and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -201,8 +201,6 @@ def forward( ) weight_t_fp8 = None - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo out, _ = tex.fp8_gemm( weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, @@ -217,9 +215,9 @@ def forward( bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo, - ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None, - extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None, + ub_algo=ub_algo if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, ) else: # Cast for native AMP @@ -243,9 +241,9 @@ def forward( get_workspace(), bias=bias, use_bias=use_bias, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, - ub=ub_obj_lnout if ub_split_ag else None, - extra_output_tensor=ln_out if ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, ) if is_grad_enabled: @@ -624,7 +622,6 @@ def backward( None, None, None, - None, ) @@ -737,8 +734,7 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_ag: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -758,23 +754,16 @@ def __init__( self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_split_ag = ub_split_ag - self.ub_atomic_gemm_ag = ub_atomic_gemm_ag - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag]): + self.ub_overlap_ag = ub_overlap_ag + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]): assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name - - if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag: + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]): assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." - if ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) - if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1098,8 +1087,7 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_split_ag, - self.ub_atomic_gemm_ag, + self.ub_overlap_ag, self.ub_name, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bf9e6fe558..7d86658260 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -117,10 +117,8 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_split_rs: bool, - ub_atomic_gemm_rs: bool, - ub_split_ag: bool, - ub_atomic_gemm_ag: bool, + ub_overlap_rs: bool, + ub_overlap_ag: bool, gemm_gelu_fusion: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -142,25 +140,17 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_split_ag or ub_atomic_gemm_ag: - tp_world_size = get_distributed_world_size(tp_group) + tp_world_size = get_distributed_world_size(tp_group) + if ub_overlap_ag: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_split_ag = False - ub_atomic_gemm_ag = False - ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag + ub_overlap_ag = False if ub_overlap_ag: ub_obj_lnout = get_ub("fc1_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) - if ub_split_rs or ub_atomic_gemm_rs: - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1: - ub_split_rs = False - ub_atomic_gemm_rs = False - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." + ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -181,6 +171,10 @@ def forward( if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) + if ub_obj_lnout.is_atomic_gemm(): + ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif set_parallel_mode and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -267,9 +261,6 @@ def forward( ) fc2_weight_t_fp8 = None - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo - # Perform FP8 GEMM fp8_gemm_args = [ fc1_weight_fp8._data, @@ -287,7 +278,7 @@ def forward( bias=fc1_bias, use_bias=use_fc1_bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo, + ub_algo=ub_algo_ag if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -321,13 +312,23 @@ def forward( fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( None, None, None, activation_dtype) - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(gelu_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + if ub_obj_fc2out.is_p2p_overlap(): + if ub_obj_fc2out.is_atomic_gemm(): + ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_fc2out.is_atomic_gemm(): + ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_fc2out.is_fp8_ubuf(): fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT @@ -340,8 +341,6 @@ def forward( dim_size[1] = fc2_weight.size(0) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = tex.fp8_gemm( fc2_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, @@ -357,9 +356,9 @@ def forward( use_bias=use_fc2_bias, use_split_accumulator=_2X_ACC_FPROP, out=fc2_out, - ub_algo=ub_algo, - ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None, - extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None, + ub_algo=ub_algo_rs if ub_overlap_rs else None, + ub=ub_obj_fc2out if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, out_index=fc2_out_index, fp8_meta_tensor = fc2_meta_tensor, D_dtype = fc2_te_type, @@ -395,9 +394,9 @@ def forward( bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, gelu=not bias_gelu_nvfusion and (activation == 'gelu'), - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, - ub=ub_obj_lnout if ub_split_ag else None, - extra_output_tensor=ln_out if ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, ) if not is_grad_enabled: clear_tensor_data(ln_out_total) @@ -427,13 +426,17 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ torch.max(-amin, amax).float() - if ub_split_rs: + if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(gelu_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + if ub_obj_fc2out.is_p2p_overlap(): + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) @@ -446,9 +449,9 @@ def forward( bias=fc2_bias, use_bias=use_fc2_bias, out=fc2_out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, - ub=ub_obj_fc2out if ub_split_rs else None, - extra_output_tensor=rs_out if ub_split_rs else None, + ub_algo=ub_algo_rs if ub_overlap_rs else None, + ub=ub_obj_fc2out if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, ) if not is_grad_enabled: clear_tensor_data(gelu_out) @@ -515,13 +518,12 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_split_ag = ub_split_ag - ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag + ctx.ub_overlap_ag = ub_overlap_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization # Row Parallel Linear - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: fc2_out = rs_out elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) @@ -590,18 +592,19 @@ def backward( dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("fc1_dgrad") ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag - if ub_overlap_ag: + if ctx.ub_overlap_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: - ctx.ub_split_ag = False ctx.ub_overlap_ag = False - ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag - if ub_overlap_ag: + if ctx.ub_overlap_ag: dim_size = list(grad_outputs[0].size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("fc2_dgrad") + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( @@ -645,8 +648,6 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( fc2_weight_t_fp8._data, @@ -660,10 +661,10 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ctx.ub_obj_gradout if ub_overlap_ag else None, + ub_algo=ub_algo if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) - if ub_overlap_ag: + if ctx.ub_overlap_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) clear_tensor_data(grad_output_c) @@ -801,8 +802,9 @@ def backward( gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == 'gelu'), grad=True, gelu_input=fc1_out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \ + if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) # FC2 WGRAD @@ -1070,8 +1072,6 @@ def backward( None, None, None, - None, - None, ) @@ -1194,10 +1194,8 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, - ub_split_rs: bool = False, - ub_atomic_gemm_rs: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_ag: bool = False, ) -> None: super().__init__() @@ -1218,29 +1216,18 @@ def __init__( self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_split_rs = ub_split_rs - self.ub_split_ag = ub_split_ag - self.ub_atomic_gemm_rs = ub_atomic_gemm_rs - self.ub_atomic_gemm_ag = ub_atomic_gemm_ag + self.ub_overlap_rs = ub_overlap_rs + self.ub_overlap_ag = ub_overlap_ag # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap - self.gemm_gelu_fusion = (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and - self.activation == 'gelu' and self.ub_split_ag) - - if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions - or ub_bulk_dgrad - or ub_split_rs - or ub_split_ag - or ub_atomic_gemm_rs - or ub_atomic_gemm_ag): + self.gemm_gelu_fusion = \ + (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and + self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm()) + + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag]): assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) - if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1490,10 +1477,8 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_split_rs, - self.ub_atomic_gemm_rs, - self.ub_split_ag, - self.ub_atomic_gemm_ag, + self.ub_overlap_rs, + self.ub_overlap_ag, self.gemm_gelu_fusion, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d22242abb4..1f7898a592 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,7 +3,6 @@ # See LICENSE for license information. """Linear API""" -import warnings from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -79,10 +78,8 @@ def forward( parallel_mode: Union[str, None], is_grad_enabled: bool, primary_weights_in_fp8: bool, - ub_split_rs: bool, - ub_split_ag: bool, - ub_atomic_gemm_rs: bool, - ub_atomic_gemm_ag: bool, + ub_overlap_rs: bool, + ub_overlap_ag: bool, ub_name: str ) -> torch.Tensor: # Make sure input dimensions are compatible @@ -94,14 +91,8 @@ def forward( assert_dim_for_fp8_exec(weight) update_fp8_weights = is_first_microbatch is None or is_first_microbatch - - if ub_split_rs or ub_atomic_gemm_rs: - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1: - ub_split_rs = False - ub_atomic_gemm_rs = False - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." + tp_world_size = get_distributed_world_size(tp_group) + ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs # Cast input to expected dtype inputmat = cast_if_needed(inputmat, activation_dtype) @@ -180,14 +171,23 @@ def forward( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( None, None, None, activation_dtype) - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: ub_obj_projout = get_ub(ub_name+"_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - + if ub_obj_projout.is_p2p_overlap(): + if ub_obj_projout.is_atomic_gemm(): + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_projout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_projout.is_fp8_ubuf(): proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] @@ -199,8 +199,6 @@ def forward( dim_size[1] = weight.size(0) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = fp8_gemm( weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, @@ -216,9 +214,9 @@ def forward( use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, out=out, - ub_algo=ub_algo, - ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None, - extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None, + ub_algo=ub_algo if ub_overlap_rs else None, + ub=ub_obj_projout if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, out_index=proj_out_index, fp8_meta_tensor = meta_tensor, D_dtype = proj_out_tetype, @@ -238,13 +236,17 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ torch.max(-amin, amax).float() - if ub_split_rs: - ub_obj_projout = get_ub("proj_fprop") + if ub_overlap_rs: + ub_obj_projout = get_ub(ub_name+"_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size + dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) dim_size[1] = weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + if ub_obj_projout.is_p2p_overlap(): + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(inputmat_total.size()) dim_size[1] = weight.size(0) @@ -258,9 +260,9 @@ def forward( bias=bias, use_bias=use_bias, out=out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, - ub=ub_obj_projout if ub_split_rs else None, - extra_output_tensor=rs_out if ub_split_rs else None, + ub_algo=ub_algo if ub_overlap_rs else None, + ub=ub_obj_projout if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, ) if is_grad_enabled: @@ -307,14 +309,13 @@ def forward( ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group - ctx.ub_split_ag = ub_split_ag - ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag + ctx.ub_overlap_ag = ub_overlap_ag ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: out = rs_out elif parallel_mode == "row" and sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) @@ -350,16 +351,16 @@ def backward( weight_t_fp8 = weight.transpose( update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", ) - - if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_split_ag = False - ctx.ub_atomic_gemm_ag = False - if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: + tp_world_size = get_distributed_world_size(ctx.tp_group) + ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag + if ctx.ub_overlap_ag: dim_size = list(grad_output.size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad") + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ( grad_output, grad_output_c, @@ -397,8 +398,6 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( @@ -413,8 +412,8 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None, + ub_algo=ub_algo if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) else: dgrad, _, _ = gemm( @@ -424,8 +423,9 @@ def backward( get_workspace(), layout="NN", grad=True, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \ + if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) # Overlap dgrad-RS/AR with wgrad @@ -442,7 +442,7 @@ def backward( if ctx.fp8: # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: + if ctx.ub_overlap_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) if inputmat_t_total is None: inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward) @@ -542,8 +542,6 @@ def backward( None, None, None, - None, - None, ) @@ -629,10 +627,8 @@ def __init__( parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, device: Union[torch.device, str] = "cuda", - ub_split_rs: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_rs: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_ag: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -645,28 +641,18 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() - self.ub_split_rs = ub_split_rs - self.ub_split_ag = ub_split_ag - self.ub_atomic_gemm_rs = ub_atomic_gemm_rs - self.ub_atomic_gemm_ag = ub_atomic_gemm_ag - if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]): + self.ub_overlap_rs = ub_overlap_rs + self.ub_overlap_ag = ub_overlap_ag + if ub_overlap_rs or ub_overlap_ag: assert ub_name is not None, "Userbuffer name [string] is not set." + assert ( + tex.userbuf_comm_available() + ), "Userbuffer communication backend not available." self.ub_name = ub_name self.get_rng_state_tracker = get_rng_state_tracker if device == 'meta': assert parameters_split is None, ("Cannot split module parameters " "on 'meta' device.") - - if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs: - assert ( - tex.userbuf_comm_available() - ), "Userbuffer communication backend not available." - - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) - if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -930,10 +916,8 @@ def forward( self.parallel_mode, torch.is_grad_enabled(), self.primary_weights_in_fp8, - self.ub_split_rs, - self.ub_split_ag, - self.ub_atomic_gemm_rs, - self.ub_atomic_gemm_ag, + self.ub_overlap_rs, + self.ub_overlap_ag, self.ub_name, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ab69f8e690..a0fd231913 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -259,10 +259,8 @@ def __init__( ub_tp_comm_overlap: bool = False, ub_bulk_wgrad: bool = True, ub_bulk_dgrad: bool = True, - ub_split_ag: bool = True, - ub_split_rs: bool = True, - ub_atomic_gemm_ag: bool = False, - ub_atomic_gemm_rs: bool = False, + ub_overlap_ag: bool = True, + ub_overlap_rs: bool = True, bias: bool = True, activation: str = 'gelu', normalization: str = "LayerNorm", @@ -282,21 +280,8 @@ def __init__( params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad - ub_split_ag = ub_tp_comm_overlap and ub_split_ag - ub_split_rs = ub_tp_comm_overlap and ub_split_rs - ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs - assert ( - not (ub_split_rs and ub_atomic_gemm_rs) - ), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled." - ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag - assert ( - not (ub_split_ag and ub_atomic_gemm_ag) - ), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled." - - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) + ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag + ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) self.layer_number = layer_number @@ -370,10 +355,8 @@ def __init__( "qkv_weight_interleaved" : qkv_weight_interleaved, "ub_bulk_wgrad" : ub_bulk_wgrad, "ub_bulk_dgrad" : ub_bulk_dgrad, - "ub_split_ag" : ub_split_ag, - "ub_split_rs" : ub_split_rs, - "ub_atomic_gemm_rs" : ub_atomic_gemm_rs, - "ub_atomic_gemm_ag" : ub_atomic_gemm_ag, + "ub_overlap_ag" : ub_overlap_ag, + "ub_overlap_rs" : ub_overlap_rs, "qkv_format" : self.attn_input_format, } @@ -427,10 +410,8 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, - ub_split_rs=ub_split_rs, - ub_split_ag=ub_split_ag, - ub_atomic_gemm_rs=ub_atomic_gemm_rs, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, + ub_overlap_rs=ub_overlap_rs, + ub_overlap_ag=ub_overlap_ag, activation=activation, normalization=normalization, device=device,