From f977de37fc803ddc958a3f85075429f993aa37ef Mon Sep 17 00:00:00 2001 From: xiaoyao0115 <1804647152@qq.com> Date: Tue, 19 Nov 2024 14:04:52 +0000 Subject: [PATCH] format the code Signed-off-by: xiaoyao0115 <1804647152@qq.com> --- .../pytorch/csrc/extensions/attention.cu | 278 ++++++++---------- 1 file changed, 118 insertions(+), 160 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index b4ae0770dd..98158600e9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1241,7 +1241,7 @@ __forceinline__ __device__ int binary_search(int target, int *array, int len) { **************************************************************************************************/ // format of out and lse, ignoring d as it’s always the last dimension. -enum QKVFormat { SBH, BSH,BHS,HBS, TH, HT}; +enum QKVFormat { SBH, BSH, BHS, HBS, TH, HT }; template struct TensorList { @@ -1251,100 +1251,77 @@ struct TensorList { }; // describe tensor format for simplified computation. -template -struct TensorFormat -{ +template +struct TensorFormat { // store the bsht order for simplified computation, where bsht corresponds to 0, 1, 2, 3, and store_format[3] marks whether bs is fused into t int8_t store_format[4]; - int* cu_seqlens_s; + int *cu_seqlens_s; // size of tensor, b s h t int size[4]; - __forceinline__ __device__ TensorFormat(int size_kernel[4],int* cu_seqlens=nullptr) - { - - for(int i=0;i<4;i++) - { - size[i]=size_kernel[i]; + __forceinline__ __device__ TensorFormat(int size_kernel[4], int *cu_seqlens = nullptr) { + for (int i = 0; i < 4; i++) { + size[i] = size_kernel[i]; } - if constexpr(format==QKVFormat::TH) - { - cu_seqlens_s=cu_seqlens; - store_format[0]=3; - store_format[1]=2; - store_format[3]=1; - } - else if constexpr(format==QKVFormat::HT) - { - cu_seqlens_s=cu_seqlens; - store_format[0]=2; - store_format[1]=3; - store_format[3]=1; - } - else if constexpr(format==QKVFormat::SBH) - { - store_format[0]=1; - store_format[1]=0; - store_format[2]=2; - store_format[3]=0; - } - else if constexpr(format==QKVFormat::HBS) - { - store_format[0]=2; - store_format[1]=0; - store_format[2]=1; - store_format[3]=0; - } - else if constexpr(format==QKVFormat::BSH) - { - store_format[0]=0; - store_format[1]=1; - store_format[2]=2; - store_format[3]=0; + if constexpr (format == QKVFormat::TH) { + cu_seqlens_s = cu_seqlens; + store_format[0] = 3; + store_format[1] = 2; + store_format[3] = 1; + } else if constexpr (format == QKVFormat::HT) { + cu_seqlens_s = cu_seqlens; + store_format[0] = 2; + store_format[1] = 3; + store_format[3] = 1; + size[3] = size[1]; + } else if constexpr (format == QKVFormat::SBH) { + store_format[0] = 1; + store_format[1] = 0; + store_format[2] = 2; + store_format[3] = 0; + } else if constexpr (format == QKVFormat::HBS) { + store_format[0] = 2; + store_format[1] = 0; + store_format[2] = 1; + store_format[3] = 0; + } else if constexpr (format == QKVFormat::BSH) { + store_format[0] = 0; + store_format[1] = 1; + store_format[2] = 2; + store_format[3] = 0; + } else if constexpr (format == QKVFormat::BHS) { + store_format[0] = 0; + store_format[1] = 2; + store_format[2] = 1; + store_format[3] = 0; } - else if constexpr(format==QKVFormat::BHS) - { - store_format[0]=0; - store_format[1]=2; - store_format[2]=1; - store_format[3]=0; - } - } // calculate address according to index - __forceinline__ __device__ int compute_address(int id[4]) - { + __forceinline__ __device__ int compute_address(int id[4]) { int address; - if(store_format[3]==1) - { - address=id[store_format[0]]*size[store_format[1]]+id[store_format[1]]; - } - else - { - address=id[store_format[0]]*size[store_format[1]]+id[store_format[1]]; - address=address*size[store_format[2]]+id[store_format[2]]; + if (store_format[3] == 1) { + address = id[store_format[0]] * size[store_format[1]] + id[store_format[1]]; + } else { + address = id[store_format[0]] * size[store_format[1]] + id[store_format[1]]; + address = address * size[store_format[2]] + id[store_format[2]]; } return address; } // compute half right index - __forceinline__ __device__ void compute_half_right(int id[4]) - { - if constexpr(format==QKVFormat::TH) - { - id[1]-=(cu_seqlens_s[id[0]+1]-cu_seqlens_s[id[0]])/2; - id[3]-=cu_seqlens_s[id[0]+1]/2; - } - else if constexpr(format==QKVFormat::BSH || format==QKVFormat::SBH) - { - id[1]-=size[1]/2; + __forceinline__ __device__ void compute_half_right(int id[4]) { + if constexpr (format == QKVFormat::TH) { + id[1] -= (cu_seqlens_s[id[0] + 1] - cu_seqlens_s[id[0]]) / 2; + id[3] -= cu_seqlens_s[id[0] + 1] / 2; + } else if constexpr (format == QKVFormat::BSH || format == QKVFormat::SBH) { + id[1] -= size[1] / 2; } } }; -template +template __global__ void fused_out_correction_kernel(dtype *out, TensorList tensors, float *lse, int *cu_seqlens, int batch, int num_heads, int dim_per_head, int lse_seqlen, int cp_size, int rank, @@ -1353,14 +1330,13 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList int full_num; int num_total_tokens; - if constexpr (out_format == QKVFormat::TH ) { + if constexpr (out_format == QKVFormat::TH) { for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i]; } __syncthreads(); num_total_tokens = cu_seqlens_s[batch]; - } else if constexpr (out_format == QKVFormat::SBH || out_format == QKVFormat::BSH) - { + } else if constexpr (out_format == QKVFormat::SBH || out_format == QKVFormat::BSH) { num_total_tokens = lse_seqlen * batch; } @@ -1370,8 +1346,8 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList full_num = start + tensors.start_tensor_this_launch; } - int size[4]={batch,lse_seqlen,num_heads,num_total_tokens}; - TensorFormat out_full(size,cu_seqlens_s); + int size[4] = {batch, lse_seqlen, num_heads, num_total_tokens}; + TensorFormat out_full(size, cu_seqlens_s); TensorFormat lse_full(size); int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; @@ -1382,44 +1358,39 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList size_t idx_out_full, idx_lse_full, idx_out_half, idx_lse_half; for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { - int head_id = blockIdx.y; int id[4]; if constexpr (out_format == QKVFormat::TH) { - id[0]=binary_search(token_id, cu_seqlens_s, batch + 1); - id[1]=token_id - cu_seqlens_s[id[0]]; - } - else if constexpr (out_format == QKVFormat::BSH) { - id[0]=token_id/lse_seqlen; - id[1]=token_id-id[0]*lse_seqlen; - } - else if constexpr (out_format == QKVFormat::SBH) { - id[1]=token_id/batch; - id[0]=token_id-id[1]*batch; + id[0] = binary_search(token_id, cu_seqlens_s, batch + 1); + id[1] = token_id - cu_seqlens_s[id[0]]; + } else if constexpr (out_format == QKVFormat::BSH) { + id[0] = token_id / lse_seqlen; + id[1] = token_id - id[0] * lse_seqlen; + } else if constexpr (out_format == QKVFormat::SBH) { + id[1] = token_id / batch; + id[0] = token_id - id[1] * batch; } - id[2]=head_id; - id[3]=token_id; + id[2] = head_id; + id[3] = token_id; - idx_out_full=out_full.compute_address(id); - idx_lse_full=lse_full.compute_address(id); - - dtype *cur_out = out + idx_out_full*dim_per_head; + idx_out_full = out_full.compute_address(id); + idx_lse_full = lse_full.compute_address(id); + + dtype *cur_out = out + idx_out_full * dim_per_head; float lse_temp = lse[idx_lse_full]; - int end=full_num; - - if (start + tensors.start_tensor_this_launch > full_num) - { - out_full.compute_half_right(id); - if(id[1]>=0) - { - int size_half[4]={batch,lse_seqlen/2,num_heads,num_total_tokens/2}; - TensorFormat out_half(size_half); - TensorFormat lse_half(size_half); - idx_out_half=out_half.compute_address(id); - idx_lse_half=lse_half.compute_address(id); - end=start + tensors.start_tensor_this_launch; - } + int end = full_num; + + if (start + tensors.start_tensor_this_launch > full_num) { + out_full.compute_half_right(id); + if (id[1] >= 0) { + int size_half[4] = {batch, lse_seqlen / 2, num_heads, num_total_tokens / 2}; + TensorFormat out_half(size_half); + TensorFormat lse_half(size_half); + idx_out_half = out_half.compute_address(id); + idx_lse_half = lse_half.compute_address(id); + end = start + tensors.start_tensor_this_launch; + } } for (int j = lane_id; j < num_loops_per_head; j += tile_size) { @@ -1429,20 +1400,16 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList dtype *p = reinterpret_cast(&data); for (int i = start; i < end; i++) { - - if (id[1]>=0 && start + tensors.start_tensor_this_launch > full_num && i>rank) - { - idx_out=idx_out_half; - idx_lse=idx_lse_half; - } - else - { - idx_out=idx_out_full; - idx_lse=idx_lse_full; + if (id[1] >= 0 && start + tensors.start_tensor_this_launch > full_num && i > rank) { + idx_out = idx_out_half; + idx_lse = idx_lse_half; + } else { + idx_out = idx_out_full; + idx_lse = idx_lse_full; } dtype *cur_out_per_step = - reinterpret_cast(tensors.addresses_out[i]) + idx_out*dim_per_head; + reinterpret_cast(tensors.addresses_out[i]) + idx_out * dim_per_head; float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; float lse_corrected_exp = exp(reinterpret_cast(tensors.addresses_lse[i])[idx_lse] - lse_temp); @@ -1451,19 +1418,17 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); } } - reinterpret_cast(cur_out)[j] = data; } } } - - template void fused_out_correction_helper(at::Tensor out, const std::vector &out_per_step, const at::Tensor &lse, const std::vector &lse_per_step, const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, - int rank, bool softmax_lse_in_packed_format, const at::Tensor *lse_ = nullptr) { + int rank, bool softmax_lse_in_packed_format, + const at::Tensor *lse_ = nullptr) { int lse_seqlen; int batch; int num_heads; @@ -1510,54 +1475,43 @@ void fused_out_correction_helper(at::Tensor out, const std::vector & tensors.addresses_lse[j] = lse_per_step[i + j].data_ptr(); } if (qkv_format == "sbhd") { - - if(softmax_lse_in_packed_format) - { - fused_out_correction_kernel + if (softmax_lse_in_packed_format) { + fused_out_correction_kernel <<>>( out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); - } - else - { - fused_out_correction_kernel + } else { + fused_out_correction_kernel <<>>( out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); } } else if (qkv_format == "bshd") { - - if(softmax_lse_in_packed_format) - { - fused_out_correction_kernel + if (softmax_lse_in_packed_format) { + fused_out_correction_kernel <<>>( out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); - } - else - { - fused_out_correction_kernel + } else { + fused_out_correction_kernel <<>>( out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); } - - } else if (qkv_format == "thd"){ + } else if (qkv_format == "thd") { if (softmax_lse_in_packed_format) { - fused_out_correction_kernel + fused_out_correction_kernel <<>>( out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); } else { - fused_out_correction_kernel + fused_out_correction_kernel <<>>( out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); @@ -1575,30 +1529,36 @@ void fused_out_correction(at::Tensor out, const std::vector &out_per using dtype = at::Half; if (causal) { fused_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, - qkv_format, cp_size, rank, softmax_lse_in_packed_format); + qkv_format, cp_size, rank, + softmax_lse_in_packed_format); } else { fused_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, - qkv_format, cp_size, rank, softmax_lse_in_packed_format); + qkv_format, cp_size, rank, + softmax_lse_in_packed_format); } } else if (out.scalar_type() == at::ScalarType::BFloat16) { using dtype = at::BFloat16; if (causal) { fused_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, - qkv_format, cp_size, rank, softmax_lse_in_packed_format); + qkv_format, cp_size, rank, + softmax_lse_in_packed_format); } else { fused_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, - qkv_format, cp_size, rank, softmax_lse_in_packed_format); + qkv_format, cp_size, rank, + softmax_lse_in_packed_format); } } else if (out.scalar_type() == at::ScalarType::Float) { using dtype = float; if (causal) { fused_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, - qkv_format, cp_size, rank, softmax_lse_in_packed_format); + qkv_format, cp_size, rank, + softmax_lse_in_packed_format); } else { fused_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, - qkv_format, cp_size, rank, softmax_lse_in_packed_format); + qkv_format, cp_size, rank, + softmax_lse_in_packed_format); } } else { NVTE_ERROR("Unsupported dtype of out\n"); @@ -1745,8 +1705,6 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } - - /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/