diff --git a/README.md b/README.md index 0fb9d42b..afd23a68 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ -🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[Beginners]**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc). +🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[B]eginners**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc). image @@ -17,7 +17,7 @@ - / = not supported now. - ✔ī¸ = known work and already supported now. - ❔ = in my plan, but not coming soon, maybe a few weeks later. -- **workflow**: custom **CUDA** kernel impl -> **Torch** python binding -> Run tests. +- **workflow**: custom **CUDA** kernel impl -> **PyTorch** python binding -> Run tests. |📖 cuda kernel| 📖 elem dtype| 📖 acc dtype| 📖 docs | 📖 level | |:---|:---|:---|:---|:---| @@ -75,6 +75,9 @@ | ✔ī¸ [softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐ī¸â­ī¸| | ✔ī¸ [safe_softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐ī¸â­ī¸| | ✔ī¸ [safe_softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐ī¸â­ī¸| +| ✔ī¸ [safe_softmax_f16_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐ī¸â­ī¸| +| ✔ī¸ [safe_softmax_f16x2_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐ī¸â­ī¸| +| ✔ī¸ [safe_softmax_f16x8_pack_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐ī¸â­ī¸| | ✔ī¸ [layer_norm_f32(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐ī¸â­ī¸| | ✔ī¸ [layer_norm_f32x4(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐ī¸â­ī¸| | ✔ī¸ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐ī¸â­ī¸| diff --git a/layer-norm/layer_norm.cu b/layer-norm/layer_norm.cu index eda644de..14efa9f0 100644 --- a/layer-norm/layer_norm.cu +++ b/layer-norm/layer_norm.cu @@ -433,9 +433,12 @@ if(((T).options().dtype() != (th_type))) { \ throw std::runtime_error("values must be "#th_type); \ } -#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ -if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \ - throw std::runtime_error("Tensor size mismatch!"); \ +#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ +assert((T1).dim() == (T2).dim()); \ +for (int i = 0; i < (T1).dim(); ++i) { \ + if ((T2).size(i) != (T1).size(i)) { \ + throw std::runtime_error("Tensor size mismatch!"); \ + } \ } // fp32 diff --git a/rms-norm/rms_norm.cu b/rms-norm/rms_norm.cu index b32c1aba..b3faf206 100644 --- a/rms-norm/rms_norm.cu +++ b/rms-norm/rms_norm.cu @@ -382,9 +382,12 @@ if(((T).options().dtype() != (th_type))) { \ throw std::runtime_error("values must be "#th_type); \ } -#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ -if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \ - throw std::runtime_error("Tensor size mismatch!"); \ +#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ +assert((T1).dim() == (T2).dim()); \ +for (int i = 0; i < (T1).dim(); ++i) { \ + if ((T2).size(i) != (T1).size(i)) { \ + throw std::runtime_error("Tensor size mismatch!"); \ + } \ } #define LANUCH_RMS_NORM_F32_KERNEL(K) \ diff --git a/softmax/README.md b/softmax/README.md index 702313cb..1fb6d451 100755 --- a/softmax/README.md +++ b/softmax/README.md @@ -5,11 +5,14 @@ 包åĢäģĨ下内厚īŧš - [X] softmax_f32_kernel (grid level memory fence) -- [X] softmax_f32x4_kernel(grid level memory fence, float4向量化į‰ˆæœŦ) +- [X] softmax_f32x4_kernel(grid level memory fence) - [X] softmax_f32_per_token_kernel(per token) -- [X] softmax_f32x4_per_token_kernel(per token, float4向量化į‰ˆæœŦ) +- [X] softmax_f32x4_per_token_kernel(per token) - [X] safe_softmax_f32_per_token_kernel(per token) -- [X] safe_softmax_f32x4_per_token_kernel(per token, float4向量化į‰ˆæœŦ) +- [X] safe_softmax_f32x4_per_token_kernel(per token) +- [X] safe_softmax_f16_f32_per_token_kernel(per token) +- [X] safe_softmax_f16x2_f32_per_token_kernel(per token) +- [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token) - [X] PyTorch bindings @@ -24,25 +27,84 @@ python3 softmax.py 输å‡ē: ```bash --------------------------------------------------------------------------------- - out_f32: [1.909e-05, 0.00023536, 0.00010881], time:0.01697016ms - out_f32x4: [1.909e-05, 0.00023536, 0.00010881], time:0.01716042ms - out_f32_th: [1.909e-05, 0.00023536, 0.00010881], time:0.00715089ms --------------------------------------------------------------------------------- - out_f32(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01011539ms - out_f32x4(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01006842ms - out_f32_th(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.00547409ms --------------------------------------------------------------------------------- - out_f32(per): [0.00569158, 0.00022239, 0.00137839], time:0.01047754ms - out_f32x4(per): [0.00569158, 0.00022239, 0.00137839], time:0.01045704ms - out_f32(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01054454ms - out_f32x4(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01042986ms - out_f32_th(per): [0.00569158, 0.00022239, 0.00137839], time:0.00741696ms --------------------------------------------------------------------------------- - out_f32(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00419974ms - out_f32x4(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00316834ms - out_f32(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00603890ms - out_f32x4(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00319862ms - out_f32_th(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00577068ms --------------------------------------------------------------------------------- -``` \ No newline at end of file +---------------------------------------------------------------------------------------------------- + N=16384 +---------------------------------------------------------------------------------------------------- + out_f32(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01040053ms + out_f32x4(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01053643ms + out_f32_th: ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.00582504ms +---------------------------------------------------------------------------------------------------- + S=4096, H=256 +---------------------------------------------------------------------------------------------------- + out_f32(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00627208ms + out_f32x4(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00394082ms + out_f32(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00941491ms + out_f32x4(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00413442ms + out_f32_th(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00602674ms +---------------------------------------------------------------------------------------------------- + out_f16f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00912046ms + out_f16x2f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00522232ms + out_f16x8packf32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00413895ms + out_f16_th(per): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00605321ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + S=4096, H=512 +---------------------------------------------------------------------------------------------------- + out_f32(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01139641ms + out_f32x4(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00515914ms + out_f32(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01834297ms + out_f32x4(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00574923ms + out_f32_th(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00657558ms +---------------------------------------------------------------------------------------------------- + out_f16f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.01782560ms + out_f16x2f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00919509ms + out_f16x8packf32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00415683ms + out_f16_th(per): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00634599ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + S=4096, H=1024 +---------------------------------------------------------------------------------------------------- + out_f32(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.03191805ms + out_f32x4(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.00862813ms + out_f32(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.04873967ms + out_f32x4(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01027441ms + out_f32_th(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01181388ms +---------------------------------------------------------------------------------------------------- + out_f16f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.04671884ms + out_f16x2f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01810408ms + out_f16x8packf32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.00601912ms + out_f16_th(per): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01047063ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + S=4096, H=2048 +---------------------------------------------------------------------------------------------------- + out_f32x4(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.01605988ms + out_f32x4(safe): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.02089310ms + out_f32_th(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.06726241ms +---------------------------------------------------------------------------------------------------- + out_f16x2f32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.04824972ms + out_f16x8packf32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.01086283ms + out_f16_th(per): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.07232165ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + S=4096, H=4096 +---------------------------------------------------------------------------------------------------- + out_f32x4(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18465948ms + out_f32x4(safe): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18565655ms + out_f32_th(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18744922ms +---------------------------------------------------------------------------------------------------- + out_f16x8packf32(safe): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.02254891ms + out_f16_th(per): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.08283138ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + S=4096, H=8192 +---------------------------------------------------------------------------------------------------- + out_f16x8packf32(safe): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19313049ms + out_f16_th(per): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19356799ms +---------------------------------------------------------------------------------------------------- + S=8192, H=8192 +---------------------------------------------------------------------------------------------------- + out_f16x8packf32(safe): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.39828229ms + out_f16_th(per): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.40599036ms +---------------------------------------------------------------------------------------------------- +``` diff --git a/softmax/softmax.cu b/softmax/softmax.cu index c99df618..a1096aab 100644 --- a/softmax/softmax.cu +++ b/softmax/softmax.cu @@ -13,6 +13,9 @@ #define WARP_SIZE 32 #define INT4(value) (reinterpret_cast(&(value))[0]) #define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) // -------------------------------------- FP32 -------------------------------------- // Warp Reduce Sum @@ -100,17 +103,17 @@ __global__ void softmax_f32x4_kernel(float* x, float* y, float* total, int N) { float4 reg_x = FLOAT4(x[idx]); float4 reg_exp; - reg_exp.x = (idx < N) ? expf(reg_x.x) : 0.0f; - reg_exp.y = (idx < N) ? expf(reg_x.y) : 0.0f; - reg_exp.z = (idx < N) ? expf(reg_x.z) : 0.0f; - reg_exp.w = (idx < N) ? expf(reg_x.w) : 0.0f; + reg_exp.x = (idx + 0 < N) ? expf(reg_x.x) : 0.0f; + reg_exp.y = (idx + 1 < N) ? expf(reg_x.y) : 0.0f; + reg_exp.z = (idx + 2 < N) ? expf(reg_x.z) : 0.0f; + reg_exp.w = (idx + 3 < N) ? expf(reg_x.w) : 0.0f; float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w); float exp_sum = block_reduce_sum_f32(exp_val); // get the total sum of all blocks. if (tid == 0) atomicAdd(total, exp_sum); __threadfence(); // grid level memory fence // e^x_i/sum(e^x_0,...,e^x_n-1) - if (idx < N) { + if (idx + 3 < N) { float4 reg_y; reg_y.x = reg_exp.x / (*total); reg_y.y = reg_exp.y / (*total); @@ -145,15 +148,15 @@ __global__ void softmax_f32x4_per_token_kernel(float* x, float* y, int N) { float4 reg_x = FLOAT4(x[idx]); float4 reg_exp; - reg_exp.x = (idx < N) ? expf(reg_x.x) : 0.0f; - reg_exp.y = (idx < N) ? expf(reg_x.y) : 0.0f; - reg_exp.z = (idx < N) ? expf(reg_x.z) : 0.0f; - reg_exp.w = (idx < N) ? expf(reg_x.w) : 0.0f; + reg_exp.x = (idx + 0 < N) ? expf(reg_x.x) : 0.0f; + reg_exp.y = (idx + 1 < N) ? expf(reg_x.y) : 0.0f; + reg_exp.z = (idx + 2 < N) ? expf(reg_x.z) : 0.0f; + reg_exp.w = (idx + 3 < N) ? expf(reg_x.w) : 0.0f; float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w); float exp_sum = block_reduce_sum_f32(exp_val); // block sum // e^x_i/sum(e^x_0,...,e^x_n-1) - if (idx < N) { + if (idx + 3 < N) { float4 reg_y; reg_y.x = reg_exp.x / (exp_sum); reg_y.y = reg_exp.y / (exp_sum); @@ -183,10 +186,10 @@ __global__ void safe_softmax_f32x4_per_token_kernel(float* x, float* y, int N) { const int idx = (blockIdx.x * blockDim.x + tid) * 4; float4 reg_x = FLOAT4(x[idx]); - reg_x.x = (idx < N) ? reg_x.x : -FLT_MAX; - reg_x.y = (idx < N) ? reg_x.y : -FLT_MAX; - reg_x.z = (idx < N) ? reg_x.z : -FLT_MAX; - reg_x.w = (idx < N) ? reg_x.w : -FLT_MAX; + reg_x.x = (idx + 0 < N) ? reg_x.x : -FLT_MAX; + reg_x.y = (idx + 1 < N) ? reg_x.y : -FLT_MAX; + reg_x.z = (idx + 2 < N) ? reg_x.z : -FLT_MAX; + reg_x.w = (idx + 3 < N) ? reg_x.w : -FLT_MAX; float val = reg_x.x; val = fmaxf(val, reg_x.y); val = fmaxf(val, reg_x.z); @@ -194,15 +197,15 @@ __global__ void safe_softmax_f32x4_per_token_kernel(float* x, float* y, int N) { float max_val = block_reduce_max_f32(val); // block max float4 reg_exp; - reg_exp.x = (idx < N) ? expf(reg_x.x - max_val) : 0.0f; - reg_exp.y = (idx < N) ? expf(reg_x.y - max_val) : 0.0f; - reg_exp.z = (idx < N) ? expf(reg_x.z - max_val) : 0.0f; - reg_exp.w = (idx < N) ? expf(reg_x.w - max_val) : 0.0f; + reg_exp.x = (idx + 0 < N) ? expf(reg_x.x - max_val) : 0.0f; + reg_exp.y = (idx + 1 < N) ? expf(reg_x.y - max_val) : 0.0f; + reg_exp.z = (idx + 2 < N) ? expf(reg_x.z - max_val) : 0.0f; + reg_exp.w = (idx + 3 < N) ? expf(reg_x.w - max_val) : 0.0f; float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w); float exp_sum = block_reduce_sum_f32(exp_val); // block sum // e^x_i/sum(e^x_0,...,e^x_n-1) - if (idx < N) { + if (idx + 3 < N) { float4 reg_y; reg_y.x = reg_exp.x / (exp_sum); reg_y.y = reg_exp.y / (exp_sum); @@ -212,72 +215,131 @@ __global__ void safe_softmax_f32x4_per_token_kernel(float* x, float* y, int N) { } } +template +__global__ void safe_softmax_f16_f32_per_token_kernel(half* x, half* y, int N) { + const int tid = threadIdx.x; + const int idx = blockIdx.x * blockDim.x + tid; + + float val = (idx < N) ? __half2float(x[idx]) : -FLT_MAX; + float max_val = block_reduce_max_f32(val); // block max + float exp_val = (idx < N) ? expf(val - max_val) : 0.0f; + float exp_sum = block_reduce_sum_f32(exp_val); // block sum + // e^x_i/sum(e^x_0,...,e^x_n-1) + if (idx < N) y[idx] = __float2half_rn(exp_val / exp_sum); +} + +template +__global__ void safe_softmax_f16x2_f32_per_token_kernel(half* x, half* y, int N) { + const int tid = threadIdx.x; + const int idx = (blockIdx.x * blockDim.x + tid) * 2; + + float2 reg_x = __half22float2(HALF2(x[idx])); + float max_val = -FLT_MAX; + max_val = ((idx + 0) < N) ? fmaxf(reg_x.x, max_val): -FLT_MAX; + max_val = ((idx + 1) < N) ? fmaxf(reg_x.y, max_val): -FLT_MAX; + max_val = block_reduce_max_f32(max_val); // block max + + float2 reg_exp; + reg_exp.x = ((idx + 0) < N) ? expf(reg_x.x - max_val) : 0.0f; + reg_exp.y = ((idx + 1) < N) ? expf(reg_x.y - max_val) : 0.0f; + + float exp_val = reg_exp.x + reg_exp.y; + float exp_sum = block_reduce_sum_f32(exp_val); // block sum + + float2 reg_y; + reg_y.x = reg_exp.x / (exp_sum); + reg_y.y = reg_exp.y / (exp_sum); + + // e^x_i/sum(e^x_0,...,e^x_n-1) + if ((idx + 1) < N) HALF2(y[idx]) = __float22half2_rn(reg_y); +} + +template +__global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, int N) { + const int tid = threadIdx.x; + const int idx = (blockIdx.x * blockDim.x + tid) * 8; + // temporary register(memory), .local space in ptx, addressable + half pack_x[8], pack_y[8]; // 8x16 bits=128 bits. + // reinterpret as float4 and load 128 bits in 1 memory issue. + LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits + + float max_val = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 8; ++i) { + max_val = fmaxf(__half2float(pack_x[i]), max_val); + } + max_val = block_reduce_max_f32(max_val); // block max + + float exp_sum = 0.0f; + #pragma unroll + for (int i = 0; i < 8; ++i) { + float exp_val = expf(__half2float(pack_x[i]) - max_val); + exp_sum += (((idx + i) < N) ? exp_val : 0.0f); + } + exp_sum = block_reduce_sum_f32(exp_sum); // block sum + + #pragma unroll + for (int i = 0; i < 8; ++i) { + // e^x_i/sum(e^x_0,...,e^x_n-1) + float exp_val = expf(__half2float(pack_x[i]) - max_val); + pack_y[i] = __float2half_rn(exp_val / exp_sum); + } + // reinterpret as float4 and store 128 bits in 1 memory issue. + if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); } + // TODO: support non 8-multiple K here +} + // --------------------- PyTorch bindings for custom kernel ----------------------- #define STRINGFY(str) #str #define TORCH_BINDING_COMMON_EXTENSION(func) \ m.def(STRINGFY(func), &func, STRINGFY(func)); -// naive softmax -#define TORCH_BINDING_SOFTMAX(packed_type, th_type, element_type, n_elements) \ -torch::Tensor softmax_##packed_type(torch::Tensor x) { \ - if((x.options().dtype() != (th_type))) { \ - std::cout << "x Tensor Info:" << x.options() << std::endl; \ - throw std::runtime_error("values must be "#th_type); \ - } \ - auto options = torch::TensorOptions().dtype((th_type)).device( \ - torch::kCUDA, 0); \ - const int N = x.size(0); \ - auto y = torch::zeros({N}, options); \ - auto total = torch::zeros({1}, options); \ - static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \ - const int NUM_BLOCKS = (N + 256 - 1) / 256; \ - dim3 block(NUM_THREADS_PER_BLOCK); \ - dim3 grid(NUM_BLOCKS); \ - softmax_##packed_type##_kernel<<>>( \ - reinterpret_cast(x.data_ptr()), \ - reinterpret_cast(y.data_ptr()), \ - reinterpret_cast(total.data_ptr()), N); \ - return y; \ +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ +assert((T1).dim() == (T2).dim()); \ +for (int i = 0; i < (T1).dim(); ++i) { \ + if ((T2).size(i) != (T1).size(i)) { \ + throw std::runtime_error("Tensor size mismatch!"); \ + } \ } -#define TORCH_BINDING_SOFTMAX_V2(packed_type, th_type, element_type, n_elements) \ -void softmax_##packed_type##_v2(torch::Tensor x, torch::Tensor y) { \ - if((x.options().dtype() != (th_type))) { \ - std::cout << "x Tensor Info:" << x.options() << std::endl; \ - throw std::runtime_error("values must be "#th_type); \ - } \ - auto options = torch::TensorOptions().dtype((th_type)).device( \ - torch::kCUDA, 0); \ +// grid memory fence +#define TORCH_BINDING_SOFTMAX(packed_type, th_type, element_type, n_elements) \ +void softmax_##packed_type(torch::Tensor x, torch::Tensor y) { \ + CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ + CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ + auto options = torch::TensorOptions().dtype((th_type)).device(torch::kCUDA, 0);\ const int N = x.size(0); \ - if (y.size(0) != N) {throw std::runtime_error("y size mismatch!"); } \ + CHECK_TORCH_TENSOR_SHAPE(x, y) \ auto total = torch::zeros({1}, options); \ - static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \ - const int NUM_BLOCKS = (N + 256 - 1) / 256; \ - dim3 block(NUM_THREADS_PER_BLOCK); \ - dim3 grid(NUM_BLOCKS); \ - softmax_##packed_type##_kernel<<>>( \ + dim3 block(256); \ + dim3 grid(((N + 256 - 1) / 256) / (n_elements)); \ + softmax_##packed_type##_kernel<256><<>>( \ reinterpret_cast(x.data_ptr()), \ reinterpret_cast(y.data_ptr()), \ reinterpret_cast(total.data_ptr()), N); \ } -TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1) -TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4) -TORCH_BINDING_SOFTMAX_V2(f32, torch::kFloat32, float, 1) -TORCH_BINDING_SOFTMAX_V2(f32x4, torch::kFloat32, float, 4) - // softmax per token -#define LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(_H_) \ -softmax_f32_per_token_kernel<(_H_)><<>>( \ - reinterpret_cast(x.data_ptr()), \ - reinterpret_cast(y.data_ptr()), \ +#define LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \ +softmax_f32_per_token_kernel<(H)><<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ N); #define DISPATCH_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \ dim3 block((H)); \ - dim3 grid((S)); \ + dim3 grid((S)); \ switch ((H)) \ { \ + case 32: \ + LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \ + break; \ case 64: \ LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \ break; \ @@ -299,44 +361,54 @@ softmax_f32_per_token_kernel<(_H_)><<>>( \ break; \ } -#define LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(_H_) \ -softmax_f32x4_per_token_kernel<(_H_)/4><<< \ - grid, block>>>( \ - reinterpret_cast(x.data_ptr()), \ - reinterpret_cast(y.data_ptr()), \ +#define LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \ +softmax_f32x4_per_token_kernel<(H)/4><<< \ + grid, block>>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ N); -#define DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) \ - dim3 block((H)/4); \ - dim3 grid((S)); \ - switch ((H)) \ - { \ - case 64: \ - LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(64) \ - break; \ - case 128: \ - LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(128) \ - break; \ - case 256: \ - LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(256) \ - break; \ - case 512: \ - LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(512) \ - break; \ - case 1024: \ - LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(1024) \ - break; \ - default: \ - throw std::runtime_error( \ - "only support H: 64/128/256/512/1024"); \ - break; \ +#define DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) \ + const int NT = (H)/4; \ + dim3 block(NT); \ + dim3 grid((S)); \ + switch (H) \ + { \ + case 32: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(32) \ + break; \ + case 64: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(64) \ + break; \ + case 128: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(128) \ + break; \ + case 256: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(256) \ + break; \ + case 512: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(512) \ + break; \ + case 1024: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(1024) \ + break; \ + case 2048: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(2048) \ + break; \ + case 4096: \ + LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(4096) \ + break; \ + default: \ + throw std::runtime_error( \ + "only support H: 64/128/.../1024*4"); \ + break; \ } // safe softmax per token -#define LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(_H_) \ -safe_softmax_f32_per_token_kernel<(_H_)><<>>( \ - reinterpret_cast(x.data_ptr()), \ - reinterpret_cast(y.data_ptr()), \ +#define LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \ +safe_softmax_f32_per_token_kernel<(H)><<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ N); #define DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \ @@ -344,6 +416,9 @@ safe_softmax_f32_per_token_kernel<(_H_)><<>>( \ dim3 grid((S)); \ switch ((H)) \ { \ + case 32: \ + LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \ + break; \ case 64: \ LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \ break; \ @@ -365,18 +440,22 @@ safe_softmax_f32_per_token_kernel<(_H_)><<>>( \ break; \ } -#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(_H_) \ -safe_softmax_f32x4_per_token_kernel<(_H_)/4><<< \ - grid, block>>>( \ - reinterpret_cast(x.data_ptr()), \ - reinterpret_cast(y.data_ptr()), \ +#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \ +safe_softmax_f32x4_per_token_kernel<(H)/4><<< \ + grid, block>>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ N); #define DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) \ - dim3 block((H)/4); \ + const int NT = (H)/4; \ + dim3 block(NT); \ dim3 grid((S)); \ - switch ((H)) \ + switch (H) \ { \ + case 32: \ + LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(32) \ + break; \ case 64: \ LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(64) \ break; \ @@ -392,154 +471,221 @@ safe_softmax_f32x4_per_token_kernel<(_H_)/4><<< \ case 1024: \ LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(1024) \ break; \ + case 2048: \ + LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(2048) \ + break; \ + case 4096: \ + LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(4096) \ + break; \ default: \ throw std::runtime_error( \ - "only support H: 64/128/256/512/1024"); \ + "only support H: 64/128/.../1024*4"); \ break; \ } -// softmax per token -torch::Tensor softmax_f32_per_token(torch::Tensor x) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } - auto options = torch::TensorOptions().dtype((torch::kFloat32)).device( - torch::kCUDA, 0); - const int S = x.size(0); // seqlens - const int H = x.size(1); // head size/kv_len - const int N = S * H; - auto y = torch::zeros({S, H}, options).contiguous(); // [S,H] +#define LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(H) \ +safe_softmax_f16_f32_per_token_kernel<(H)><<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ + N); - DISPATCH_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) - return y; -} +#define DISPATCH_SATE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(S, H) \ + dim3 block((H)); \ + dim3 grid((S)); \ + switch ((H)) \ + { \ + case 32: \ + LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(32) \ + break; \ + case 64: \ + LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(64) \ + break; \ + case 128: \ + LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(128) \ + break; \ + case 256: \ + LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(256) \ + break; \ + case 512: \ + LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(512) \ + break; \ + case 1024: \ + LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(1024) \ + break; \ + default: \ + throw std::runtime_error( \ + "only support H: 64/128/256/512/1024"); \ + break; \ + } + +#define LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(H) \ +safe_softmax_f16x2_f32_per_token_kernel<(H)/2><<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ + N); + +#define DISPATCH_SATE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(S, H) \ + const int NT = (H)/2; \ + dim3 block(NT); \ + dim3 grid((S)); \ + switch (H) \ + { \ + case 32: \ + LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(32) \ + break; \ + case 64: \ + LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(64) \ + break; \ + case 128: \ + LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(128) \ + break; \ + case 256: \ + LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(256) \ + break; \ + case 512: \ + LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(512) \ + break; \ + case 1024: \ + LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(1024) \ + break; \ + case 2048: \ + LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(2048) \ + break; \ + default: \ + throw std::runtime_error( \ + "only support H: 64/128/.../1024*2"); \ + break; \ + } -// no copy for y Tensor -void softmax_f32_per_token_v2(torch::Tensor x, torch::Tensor y) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } +#define LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(H) \ +safe_softmax_f16x8_pack_f32_per_token_kernel<(H)/8><<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ + N); + +#define DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H) \ + const int NT = (H)/8; \ + dim3 block(NT); \ + dim3 grid((S)); \ + switch (H) \ + { \ + case 32: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(32) \ + break; \ + case 64: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(64) \ + break; \ + case 128: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(128) \ + break; \ + case 256: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(256) \ + break; \ + case 512: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(512) \ + break; \ + case 1024: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(1024) \ + break; \ + case 2048: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(2048) \ + break; \ + case 4096: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(4096) \ + break; \ + case 8192: \ + LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(8192) \ + break; \ + default: \ + throw std::runtime_error( \ + "only support H: 64/128/.../1024*8"); \ + break; \ + } + +// per token fp32 +void softmax_f32_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32) + CHECK_TORCH_TENSOR_SHAPE(x, y) const int S = x.size(0); // seqlens const int H = x.size(1); // head size/kv_len const int N = S * H; - if ((y.size(0) != S) || (y.size(1) != H)) { - throw std::runtime_error("y Tensor size mismatch!"); - } - DISPATCH_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) } -torch::Tensor softmax_f32x4_per_token(torch::Tensor x) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } - auto options = torch::TensorOptions().dtype((torch::kFloat32)).device( - torch::kCUDA, 0); +void softmax_f32x4_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32) + CHECK_TORCH_TENSOR_SHAPE(x, y) const int S = x.size(0); // seqlens - const int H = x.size(1); // head size/kv_len + const int H = x.size(1); // head size/kv_len const int N = S * H; - auto y = torch::zeros({S, H}, options).contiguous(); // [S,H] - - DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) - return y; + DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) } -// no copy for y Tensor -void softmax_f32x4_per_token_v2(torch::Tensor x, torch::Tensor y) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } +void safe_softmax_f32_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32) + CHECK_TORCH_TENSOR_SHAPE(x, y) const int S = x.size(0); // seqlens const int H = x.size(1); // head size/kv_len const int N = S * H; - if ((y.size(0) != S) || (y.size(1) != H)) { - throw std::runtime_error("y Tensor size mismatch!"); - } - - DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) + DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) } -// safe_softmax per token -torch::Tensor safe_softmax_f32_per_token(torch::Tensor x) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } - auto options = torch::TensorOptions().dtype((torch::kFloat32)).device( - torch::kCUDA, 0); +void safe_softmax_f32x4_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32) + CHECK_TORCH_TENSOR_SHAPE(x, y) const int S = x.size(0); // seqlens const int H = x.size(1); // head size/kv_len const int N = S * H; - auto y = torch::zeros({S, H}, options).contiguous(); // [S,H] - - DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) - return y; + DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) } -// no copy for y Tensor -void safe_softmax_f32_per_token_v2(torch::Tensor x, torch::Tensor y) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } +// per token fp16 +void safe_softmax_f16_f32_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf) + CHECK_TORCH_TENSOR_SHAPE(x, y) const int S = x.size(0); // seqlens const int H = x.size(1); // head size/kv_len const int N = S * H; - if ((y.size(0) != S) || (y.size(1) != H)) { - throw std::runtime_error("y Tensor size mismatch!"); - } - - DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) + DISPATCH_SATE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(S, H) } -torch::Tensor safe_softmax_f32x4_per_token(torch::Tensor x) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } - auto options = torch::TensorOptions().dtype((torch::kFloat32)).device( - torch::kCUDA, 0); +void safe_softmax_f16x2_f32_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf) + CHECK_TORCH_TENSOR_SHAPE(x, y) const int S = x.size(0); // seqlens - const int H = x.size(1); // head size/kv_len + const int H = x.size(1); // head size/kv_len const int N = S * H; - auto y = torch::zeros({S, H}, options).contiguous(); // [S,H] - - DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) - return y; + DISPATCH_SATE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(S, H) } -// no copy for y Tensor -void safe_softmax_f32x4_per_token_v2(torch::Tensor x, torch::Tensor y) { - if((x.options().dtype() != (torch::kFloat32))) { - std::cout << "x Tensor Info:" << x.options() << std::endl; - throw std::runtime_error("values must be torch::kFloat32"); - } +void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf) + CHECK_TORCH_TENSOR_SHAPE(x, y) const int S = x.size(0); // seqlens const int H = x.size(1); // head size/kv_len const int N = S * H; - if ((y.size(0) != S) || (y.size(1) != H)) { - throw std::runtime_error("y Tensor size mismatch!"); - } - - DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) + DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H) } +// grid memory fence fp32 +TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1) +TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(softmax_f32) - TORCH_BINDING_COMMON_EXTENSION(softmax_f32_v2) TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4) - TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4_v2) TORCH_BINDING_COMMON_EXTENSION(softmax_f32_per_token) - TORCH_BINDING_COMMON_EXTENSION(softmax_f32_per_token_v2) TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4_per_token) - TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4_per_token_v2) TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32_per_token) - TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32_per_token_v2) TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32x4_per_token) - TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32x4_per_token_v2) + TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16_f32_per_token) + TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x2_f32_per_token) + TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x8_pack_f32_per_token) } diff --git a/softmax/softmax.py b/softmax/softmax.py index d5bc20f6..b2769161 100644 --- a/softmax/softmax.py +++ b/softmax/softmax.py @@ -50,40 +50,146 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, out_info = f"out_{tag}" out_val = out.flatten().detach().cpu().numpy().tolist()[:3] out_val = [round(v, 8) for v in out_val] - print(f"{out_info:>20}: {out_val}, time:{mean_time:.8f}ms") + out_val = [f"{v:<12}" for v in out_val] + print(f"{out_info:>24}: {out_val}, time:{mean_time:.8f}ms") if show_all: print(out) return out, mean_time +# grid memory fence +print("-" * 100) +N = 128 * 128 +print(" " * 45 + f"N={N}") +print("-" * 100) +x = torch.randn((N)).cuda().float() +out = torch.zeros_like(x).cuda().float().contiguous() +run_benchmark(lib.softmax_f32, x, "f32(fence)", out) +run_benchmark(lib.softmax_f32x4, x, "f32x4(fence)", out) +run_benchmark(partial(torch.softmax, dim=0, out=out), x, "f32_th") + +# per token softmax +print("-" * 100) +S, H = 4096, 256 +print(" " * 45 + f"S={S}, H={H}") +print("-" * 100) +x = torch.randn((S, H)).cuda().float().contiguous() +out = torch.zeros_like(x).cuda().float().contiguous() +run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out) +run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) +run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out) +run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") + +print("-" * 100) +x_f16 = x.half().contiguous() +out_f16 = out.half().contiguous() +run_benchmark(lib.safe_softmax_f16_f32_per_token, x_f16, "f16f32(safe)", out_f16) +run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16) +run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16) +run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)") +print("-" * 100) -print("-" * 80) -N_ELEMENTS = 256*48 -x = torch.randn((N_ELEMENTS)).cuda().float() -run_benchmark(lib.softmax_f32, x, "f32") -run_benchmark(lib.softmax_f32x4, x, "f32x4") -run_benchmark(partial(torch.softmax, dim=0), x, "f32_th") +# per token softmax +print("-" * 100) +S, H = 4096, 512 +print(" " * 45 + f"S={S}, H={H}") +print("-" * 100) +x = torch.randn((S, H)).cuda().float().contiguous() +out = torch.zeros_like(x).cuda().float().contiguous() +run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out) +run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) +run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out) +run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") + +print("-" * 100) +x_f16 = x.half().contiguous() +out_f16 = out.half().contiguous() +run_benchmark(lib.safe_softmax_f16_f32_per_token, x_f16, "f16f32(safe)", out_f16) +run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16) +run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16) +run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)") +print("-" * 100) + +# per token softmax +print("-" * 100) +S, H = 4096, 1024 +print(" " * 45 + f"S={S}, H={H}") +print("-" * 100) +x = torch.randn((S, H)).cuda().float().contiguous() +out = torch.zeros_like(x).cuda().float().contiguous() +run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out) +run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) +run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out) +run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") -print("-" * 80) -# v2: no copy for out Tensor +print("-" * 100) +x_f16 = x.half().contiguous() +out_f16 = out.half().contiguous() +run_benchmark(lib.safe_softmax_f16_f32_per_token, x_f16, "f16f32(safe)", out_f16) +run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16) +run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16) +run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)") +print("-" * 100) + +# per token softmax +print("-" * 100) +S, H = 4096, 2048 +print(" " * 45 + f"S={S}, H={H}") +print("-" * 100) +x = torch.randn((S, H)).cuda().float().contiguous() +out = torch.zeros_like(x).cuda().float().contiguous() +run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) +run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") + +print("-" * 100) +x_f16 = x.half().contiguous() +out_f16 = out.half().contiguous() +run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16) +run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16) +run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)") +print("-" * 100) + +# per token softmax +print("-" * 100) +S, H = 4096, 4096 +print(" " * 45 + f"S={S}, H={H}") +print("-" * 100) +x = torch.randn((S, H)).cuda().float().contiguous() +out = torch.zeros_like(x).cuda().float().contiguous() +run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) +run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") + +print("-" * 100) +x_f16 = x.half().contiguous() +out_f16 = out.half().contiguous() +run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16) +run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)") +print("-" * 100) + +# per token softmax +print("-" * 100) +S, H = 4096, 8192 +print(" " * 45 + f"S={S}, H={H}") +print("-" * 100) +x = torch.randn((S, H)).cuda().float().contiguous() out = torch.zeros_like(x).cuda().float().contiguous() -run_benchmark(lib.softmax_f32_v2, x, "f32(v2)", out) -run_benchmark(lib.softmax_f32x4_v2, x, "f32x4(v2)", out) -run_benchmark(partial(torch.softmax, dim=0, out=out), x, "f32_th(v2)") +x_f16 = x.half().contiguous() +out_f16 = out.half().contiguous() +run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16) +run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)") -print("-" * 80) -S, H = 1024, 512 +# per token softmax +print("-" * 100) +S, H = 8192, 8192 +print(" " * 45 + f"S={S}, H={H}") +print("-" * 100) x = torch.randn((S, H)).cuda().float().contiguous() -run_benchmark(lib.softmax_f32_per_token, x, "f32(per)") -run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)") -run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)") -run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)") -run_benchmark(partial(torch.softmax, dim=1), x, "f32_th(per)") - -print("-" * 80) -# v2: no copy for out Tensor out = torch.zeros_like(x).cuda().float().contiguous() -run_benchmark(lib.softmax_f32_per_token_v2, x, "f32(per v2)", out) -run_benchmark(lib.softmax_f32x4_per_token_v2, x, "f32x4(per v2)", out) -run_benchmark(lib.safe_softmax_f32_per_token_v2, x, "f32(safe v2)", out) -run_benchmark(lib.safe_softmax_f32x4_per_token_v2, x, "f32x4(safe v2)", out) -run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per v2)") -print("-" * 80) +x_f16 = x.half().contiguous() +out_f16 = out.half().contiguous() +run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16) +run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)") +print("-" * 100)