diff --git a/README.md b/README.md
index 0fb9d42b..afd23a68 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
-đ **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[Beginners]**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
+đ **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[B]eginners**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
@@ -17,7 +17,7 @@
- / = not supported now.
- âī¸ = known work and already supported now.
- â = in my plan, but not coming soon, maybe a few weeks later.
-- **workflow**: custom **CUDA** kernel impl -> **Torch** python binding -> Run tests.
+- **workflow**: custom **CUDA** kernel impl -> **PyTorch** python binding -> Run tests.
|đ cuda kernel| đ elem dtype| đ acc dtype| đ docs | đ level |
|:---|:---|:---|:---|:---|
@@ -75,6 +75,9 @@
| âī¸ [softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|âī¸âī¸|
| âī¸ [safe_softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|âī¸âī¸|
| âī¸ [safe_softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|âī¸âī¸|
+| âī¸ [safe_softmax_f16_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|âī¸âī¸|
+| âī¸ [safe_softmax_f16x2_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|âī¸âī¸|
+| âī¸ [safe_softmax_f16x8_pack_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|âī¸âī¸|
| âī¸ [layer_norm_f32(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|âī¸âī¸|
| âī¸ [layer_norm_f32x4(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|âī¸âī¸|
| âī¸ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|âī¸âī¸|
diff --git a/layer-norm/layer_norm.cu b/layer-norm/layer_norm.cu
index eda644de..14efa9f0 100644
--- a/layer-norm/layer_norm.cu
+++ b/layer-norm/layer_norm.cu
@@ -433,9 +433,12 @@ if(((T).options().dtype() != (th_type))) { \
throw std::runtime_error("values must be "#th_type); \
}
-#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
-if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \
- throw std::runtime_error("Tensor size mismatch!"); \
+#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
+assert((T1).dim() == (T2).dim()); \
+for (int i = 0; i < (T1).dim(); ++i) { \
+ if ((T2).size(i) != (T1).size(i)) { \
+ throw std::runtime_error("Tensor size mismatch!"); \
+ } \
}
// fp32
diff --git a/rms-norm/rms_norm.cu b/rms-norm/rms_norm.cu
index b32c1aba..b3faf206 100644
--- a/rms-norm/rms_norm.cu
+++ b/rms-norm/rms_norm.cu
@@ -382,9 +382,12 @@ if(((T).options().dtype() != (th_type))) { \
throw std::runtime_error("values must be "#th_type); \
}
-#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
-if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \
- throw std::runtime_error("Tensor size mismatch!"); \
+#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
+assert((T1).dim() == (T2).dim()); \
+for (int i = 0; i < (T1).dim(); ++i) { \
+ if ((T2).size(i) != (T1).size(i)) { \
+ throw std::runtime_error("Tensor size mismatch!"); \
+ } \
}
#define LANUCH_RMS_NORM_F32_KERNEL(K) \
diff --git a/softmax/README.md b/softmax/README.md
index 702313cb..1fb6d451 100755
--- a/softmax/README.md
+++ b/softmax/README.md
@@ -5,11 +5,14 @@
å
åĢäģĨä¸å
厚īŧ
- [X] softmax_f32_kernel (grid level memory fence)
-- [X] softmax_f32x4_kernel(grid level memory fence, float4åéåįæŦ)
+- [X] softmax_f32x4_kernel(grid level memory fence)
- [X] softmax_f32_per_token_kernel(per token)
-- [X] softmax_f32x4_per_token_kernel(per token, float4åéåįæŦ)
+- [X] softmax_f32x4_per_token_kernel(per token)
- [X] safe_softmax_f32_per_token_kernel(per token)
-- [X] safe_softmax_f32x4_per_token_kernel(per token, float4åéåįæŦ)
+- [X] safe_softmax_f32x4_per_token_kernel(per token)
+- [X] safe_softmax_f16_f32_per_token_kernel(per token)
+- [X] safe_softmax_f16x2_f32_per_token_kernel(per token)
+- [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
- [X] PyTorch bindings
@@ -24,25 +27,84 @@ python3 softmax.py
čžåē:
```bash
---------------------------------------------------------------------------------
- out_f32: [1.909e-05, 0.00023536, 0.00010881], time:0.01697016ms
- out_f32x4: [1.909e-05, 0.00023536, 0.00010881], time:0.01716042ms
- out_f32_th: [1.909e-05, 0.00023536, 0.00010881], time:0.00715089ms
---------------------------------------------------------------------------------
- out_f32(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01011539ms
- out_f32x4(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01006842ms
- out_f32_th(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.00547409ms
---------------------------------------------------------------------------------
- out_f32(per): [0.00569158, 0.00022239, 0.00137839], time:0.01047754ms
- out_f32x4(per): [0.00569158, 0.00022239, 0.00137839], time:0.01045704ms
- out_f32(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01054454ms
- out_f32x4(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01042986ms
- out_f32_th(per): [0.00569158, 0.00022239, 0.00137839], time:0.00741696ms
---------------------------------------------------------------------------------
- out_f32(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00419974ms
- out_f32x4(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00316834ms
- out_f32(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00603890ms
- out_f32x4(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00319862ms
- out_f32_th(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00577068ms
---------------------------------------------------------------------------------
-```
\ No newline at end of file
+----------------------------------------------------------------------------------------------------
+ N=16384
+----------------------------------------------------------------------------------------------------
+ out_f32(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01040053ms
+ out_f32x4(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01053643ms
+ out_f32_th: ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.00582504ms
+----------------------------------------------------------------------------------------------------
+ S=4096, H=256
+----------------------------------------------------------------------------------------------------
+ out_f32(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00627208ms
+ out_f32x4(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00394082ms
+ out_f32(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00941491ms
+ out_f32x4(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00413442ms
+ out_f32_th(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00602674ms
+----------------------------------------------------------------------------------------------------
+ out_f16f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00912046ms
+ out_f16x2f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00522232ms
+ out_f16x8packf32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00413895ms
+ out_f16_th(per): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00605321ms
+----------------------------------------------------------------------------------------------------
+----------------------------------------------------------------------------------------------------
+ S=4096, H=512
+----------------------------------------------------------------------------------------------------
+ out_f32(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01139641ms
+ out_f32x4(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00515914ms
+ out_f32(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01834297ms
+ out_f32x4(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00574923ms
+ out_f32_th(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00657558ms
+----------------------------------------------------------------------------------------------------
+ out_f16f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.01782560ms
+ out_f16x2f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00919509ms
+ out_f16x8packf32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00415683ms
+ out_f16_th(per): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00634599ms
+----------------------------------------------------------------------------------------------------
+----------------------------------------------------------------------------------------------------
+ S=4096, H=1024
+----------------------------------------------------------------------------------------------------
+ out_f32(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.03191805ms
+ out_f32x4(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.00862813ms
+ out_f32(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.04873967ms
+ out_f32x4(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01027441ms
+ out_f32_th(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01181388ms
+----------------------------------------------------------------------------------------------------
+ out_f16f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.04671884ms
+ out_f16x2f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01810408ms
+ out_f16x8packf32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.00601912ms
+ out_f16_th(per): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01047063ms
+----------------------------------------------------------------------------------------------------
+----------------------------------------------------------------------------------------------------
+ S=4096, H=2048
+----------------------------------------------------------------------------------------------------
+ out_f32x4(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.01605988ms
+ out_f32x4(safe): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.02089310ms
+ out_f32_th(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.06726241ms
+----------------------------------------------------------------------------------------------------
+ out_f16x2f32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.04824972ms
+ out_f16x8packf32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.01086283ms
+ out_f16_th(per): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.07232165ms
+----------------------------------------------------------------------------------------------------
+----------------------------------------------------------------------------------------------------
+ S=4096, H=4096
+----------------------------------------------------------------------------------------------------
+ out_f32x4(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18465948ms
+ out_f32x4(safe): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18565655ms
+ out_f32_th(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18744922ms
+----------------------------------------------------------------------------------------------------
+ out_f16x8packf32(safe): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.02254891ms
+ out_f16_th(per): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.08283138ms
+----------------------------------------------------------------------------------------------------
+----------------------------------------------------------------------------------------------------
+ S=4096, H=8192
+----------------------------------------------------------------------------------------------------
+ out_f16x8packf32(safe): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19313049ms
+ out_f16_th(per): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19356799ms
+----------------------------------------------------------------------------------------------------
+ S=8192, H=8192
+----------------------------------------------------------------------------------------------------
+ out_f16x8packf32(safe): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.39828229ms
+ out_f16_th(per): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.40599036ms
+----------------------------------------------------------------------------------------------------
+```
diff --git a/softmax/softmax.cu b/softmax/softmax.cu
index c99df618..a1096aab 100644
--- a/softmax/softmax.cu
+++ b/softmax/softmax.cu
@@ -13,6 +13,9 @@
#define WARP_SIZE 32
#define INT4(value) (reinterpret_cast(&(value))[0])
#define FLOAT4(value) (reinterpret_cast(&(value))[0])
+#define HALF2(value) (reinterpret_cast(&(value))[0])
+#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
+#define LDST128BITS(value) (reinterpret_cast(&(value))[0])
// -------------------------------------- FP32 --------------------------------------
// Warp Reduce Sum
@@ -100,17 +103,17 @@ __global__ void softmax_f32x4_kernel(float* x, float* y, float* total, int N) {
float4 reg_x = FLOAT4(x[idx]);
float4 reg_exp;
- reg_exp.x = (idx < N) ? expf(reg_x.x) : 0.0f;
- reg_exp.y = (idx < N) ? expf(reg_x.y) : 0.0f;
- reg_exp.z = (idx < N) ? expf(reg_x.z) : 0.0f;
- reg_exp.w = (idx < N) ? expf(reg_x.w) : 0.0f;
+ reg_exp.x = (idx + 0 < N) ? expf(reg_x.x) : 0.0f;
+ reg_exp.y = (idx + 1 < N) ? expf(reg_x.y) : 0.0f;
+ reg_exp.z = (idx + 2 < N) ? expf(reg_x.z) : 0.0f;
+ reg_exp.w = (idx + 3 < N) ? expf(reg_x.w) : 0.0f;
float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w);
float exp_sum = block_reduce_sum_f32(exp_val);
// get the total sum of all blocks.
if (tid == 0) atomicAdd(total, exp_sum);
__threadfence(); // grid level memory fence
// e^x_i/sum(e^x_0,...,e^x_n-1)
- if (idx < N) {
+ if (idx + 3 < N) {
float4 reg_y;
reg_y.x = reg_exp.x / (*total);
reg_y.y = reg_exp.y / (*total);
@@ -145,15 +148,15 @@ __global__ void softmax_f32x4_per_token_kernel(float* x, float* y, int N) {
float4 reg_x = FLOAT4(x[idx]);
float4 reg_exp;
- reg_exp.x = (idx < N) ? expf(reg_x.x) : 0.0f;
- reg_exp.y = (idx < N) ? expf(reg_x.y) : 0.0f;
- reg_exp.z = (idx < N) ? expf(reg_x.z) : 0.0f;
- reg_exp.w = (idx < N) ? expf(reg_x.w) : 0.0f;
+ reg_exp.x = (idx + 0 < N) ? expf(reg_x.x) : 0.0f;
+ reg_exp.y = (idx + 1 < N) ? expf(reg_x.y) : 0.0f;
+ reg_exp.z = (idx + 2 < N) ? expf(reg_x.z) : 0.0f;
+ reg_exp.w = (idx + 3 < N) ? expf(reg_x.w) : 0.0f;
float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w);
float exp_sum = block_reduce_sum_f32(exp_val); // block sum
// e^x_i/sum(e^x_0,...,e^x_n-1)
- if (idx < N) {
+ if (idx + 3 < N) {
float4 reg_y;
reg_y.x = reg_exp.x / (exp_sum);
reg_y.y = reg_exp.y / (exp_sum);
@@ -183,10 +186,10 @@ __global__ void safe_softmax_f32x4_per_token_kernel(float* x, float* y, int N) {
const int idx = (blockIdx.x * blockDim.x + tid) * 4;
float4 reg_x = FLOAT4(x[idx]);
- reg_x.x = (idx < N) ? reg_x.x : -FLT_MAX;
- reg_x.y = (idx < N) ? reg_x.y : -FLT_MAX;
- reg_x.z = (idx < N) ? reg_x.z : -FLT_MAX;
- reg_x.w = (idx < N) ? reg_x.w : -FLT_MAX;
+ reg_x.x = (idx + 0 < N) ? reg_x.x : -FLT_MAX;
+ reg_x.y = (idx + 1 < N) ? reg_x.y : -FLT_MAX;
+ reg_x.z = (idx + 2 < N) ? reg_x.z : -FLT_MAX;
+ reg_x.w = (idx + 3 < N) ? reg_x.w : -FLT_MAX;
float val = reg_x.x;
val = fmaxf(val, reg_x.y);
val = fmaxf(val, reg_x.z);
@@ -194,15 +197,15 @@ __global__ void safe_softmax_f32x4_per_token_kernel(float* x, float* y, int N) {
float max_val = block_reduce_max_f32(val); // block max
float4 reg_exp;
- reg_exp.x = (idx < N) ? expf(reg_x.x - max_val) : 0.0f;
- reg_exp.y = (idx < N) ? expf(reg_x.y - max_val) : 0.0f;
- reg_exp.z = (idx < N) ? expf(reg_x.z - max_val) : 0.0f;
- reg_exp.w = (idx < N) ? expf(reg_x.w - max_val) : 0.0f;
+ reg_exp.x = (idx + 0 < N) ? expf(reg_x.x - max_val) : 0.0f;
+ reg_exp.y = (idx + 1 < N) ? expf(reg_x.y - max_val) : 0.0f;
+ reg_exp.z = (idx + 2 < N) ? expf(reg_x.z - max_val) : 0.0f;
+ reg_exp.w = (idx + 3 < N) ? expf(reg_x.w - max_val) : 0.0f;
float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w);
float exp_sum = block_reduce_sum_f32(exp_val); // block sum
// e^x_i/sum(e^x_0,...,e^x_n-1)
- if (idx < N) {
+ if (idx + 3 < N) {
float4 reg_y;
reg_y.x = reg_exp.x / (exp_sum);
reg_y.y = reg_exp.y / (exp_sum);
@@ -212,72 +215,131 @@ __global__ void safe_softmax_f32x4_per_token_kernel(float* x, float* y, int N) {
}
}
+template
+__global__ void safe_softmax_f16_f32_per_token_kernel(half* x, half* y, int N) {
+ const int tid = threadIdx.x;
+ const int idx = blockIdx.x * blockDim.x + tid;
+
+ float val = (idx < N) ? __half2float(x[idx]) : -FLT_MAX;
+ float max_val = block_reduce_max_f32(val); // block max
+ float exp_val = (idx < N) ? expf(val - max_val) : 0.0f;
+ float exp_sum = block_reduce_sum_f32(exp_val); // block sum
+ // e^x_i/sum(e^x_0,...,e^x_n-1)
+ if (idx < N) y[idx] = __float2half_rn(exp_val / exp_sum);
+}
+
+template
+__global__ void safe_softmax_f16x2_f32_per_token_kernel(half* x, half* y, int N) {
+ const int tid = threadIdx.x;
+ const int idx = (blockIdx.x * blockDim.x + tid) * 2;
+
+ float2 reg_x = __half22float2(HALF2(x[idx]));
+ float max_val = -FLT_MAX;
+ max_val = ((idx + 0) < N) ? fmaxf(reg_x.x, max_val): -FLT_MAX;
+ max_val = ((idx + 1) < N) ? fmaxf(reg_x.y, max_val): -FLT_MAX;
+ max_val = block_reduce_max_f32(max_val); // block max
+
+ float2 reg_exp;
+ reg_exp.x = ((idx + 0) < N) ? expf(reg_x.x - max_val) : 0.0f;
+ reg_exp.y = ((idx + 1) < N) ? expf(reg_x.y - max_val) : 0.0f;
+
+ float exp_val = reg_exp.x + reg_exp.y;
+ float exp_sum = block_reduce_sum_f32(exp_val); // block sum
+
+ float2 reg_y;
+ reg_y.x = reg_exp.x / (exp_sum);
+ reg_y.y = reg_exp.y / (exp_sum);
+
+ // e^x_i/sum(e^x_0,...,e^x_n-1)
+ if ((idx + 1) < N) HALF2(y[idx]) = __float22half2_rn(reg_y);
+}
+
+template
+__global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, int N) {
+ const int tid = threadIdx.x;
+ const int idx = (blockIdx.x * blockDim.x + tid) * 8;
+ // temporary register(memory), .local space in ptx, addressable
+ half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
+ // reinterpret as float4 and load 128 bits in 1 memory issue.
+ LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits
+
+ float max_val = -FLT_MAX;
+ #pragma unroll
+ for (int i = 0; i < 8; ++i) {
+ max_val = fmaxf(__half2float(pack_x[i]), max_val);
+ }
+ max_val = block_reduce_max_f32(max_val); // block max
+
+ float exp_sum = 0.0f;
+ #pragma unroll
+ for (int i = 0; i < 8; ++i) {
+ float exp_val = expf(__half2float(pack_x[i]) - max_val);
+ exp_sum += (((idx + i) < N) ? exp_val : 0.0f);
+ }
+ exp_sum = block_reduce_sum_f32(exp_sum); // block sum
+
+ #pragma unroll
+ for (int i = 0; i < 8; ++i) {
+ // e^x_i/sum(e^x_0,...,e^x_n-1)
+ float exp_val = expf(__half2float(pack_x[i]) - max_val);
+ pack_y[i] = __float2half_rn(exp_val / exp_sum);
+ }
+ // reinterpret as float4 and store 128 bits in 1 memory issue.
+ if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
+ // TODO: support non 8-multiple K here
+}
+
// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
-// naive softmax
-#define TORCH_BINDING_SOFTMAX(packed_type, th_type, element_type, n_elements) \
-torch::Tensor softmax_##packed_type(torch::Tensor x) { \
- if((x.options().dtype() != (th_type))) { \
- std::cout << "x Tensor Info:" << x.options() << std::endl; \
- throw std::runtime_error("values must be "#th_type); \
- } \
- auto options = torch::TensorOptions().dtype((th_type)).device( \
- torch::kCUDA, 0); \
- const int N = x.size(0); \
- auto y = torch::zeros({N}, options); \
- auto total = torch::zeros({1}, options); \
- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
- const int NUM_BLOCKS = (N + 256 - 1) / 256; \
- dim3 block(NUM_THREADS_PER_BLOCK); \
- dim3 grid(NUM_BLOCKS); \
- softmax_##packed_type##_kernel<<>>( \
- reinterpret_cast(x.data_ptr()), \
- reinterpret_cast(y.data_ptr()), \
- reinterpret_cast(total.data_ptr()), N); \
- return y; \
+#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
+if(((T).options().dtype() != (th_type))) { \
+ std::cout << "Tensor Info:" << (T).options() << std::endl; \
+ throw std::runtime_error("values must be "#th_type); \
+}
+
+#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
+assert((T1).dim() == (T2).dim()); \
+for (int i = 0; i < (T1).dim(); ++i) { \
+ if ((T2).size(i) != (T1).size(i)) { \
+ throw std::runtime_error("Tensor size mismatch!"); \
+ } \
}
-#define TORCH_BINDING_SOFTMAX_V2(packed_type, th_type, element_type, n_elements) \
-void softmax_##packed_type##_v2(torch::Tensor x, torch::Tensor y) { \
- if((x.options().dtype() != (th_type))) { \
- std::cout << "x Tensor Info:" << x.options() << std::endl; \
- throw std::runtime_error("values must be "#th_type); \
- } \
- auto options = torch::TensorOptions().dtype((th_type)).device( \
- torch::kCUDA, 0); \
+// grid memory fence
+#define TORCH_BINDING_SOFTMAX(packed_type, th_type, element_type, n_elements) \
+void softmax_##packed_type(torch::Tensor x, torch::Tensor y) { \
+ CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
+ CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \
+ auto options = torch::TensorOptions().dtype((th_type)).device(torch::kCUDA, 0);\
const int N = x.size(0); \
- if (y.size(0) != N) {throw std::runtime_error("y size mismatch!"); } \
+ CHECK_TORCH_TENSOR_SHAPE(x, y) \
auto total = torch::zeros({1}, options); \
- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
- const int NUM_BLOCKS = (N + 256 - 1) / 256; \
- dim3 block(NUM_THREADS_PER_BLOCK); \
- dim3 grid(NUM_BLOCKS); \
- softmax_##packed_type##_kernel<<>>( \
+ dim3 block(256); \
+ dim3 grid(((N + 256 - 1) / 256) / (n_elements)); \
+ softmax_##packed_type##_kernel<256><<>>( \
reinterpret_cast(x.data_ptr()), \
reinterpret_cast(y.data_ptr()), \
reinterpret_cast(total.data_ptr()), N); \
}
-TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1)
-TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4)
-TORCH_BINDING_SOFTMAX_V2(f32, torch::kFloat32, float, 1)
-TORCH_BINDING_SOFTMAX_V2(f32x4, torch::kFloat32, float, 4)
-
// softmax per token
-#define LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(_H_) \
-softmax_f32_per_token_kernel<(_H_)><<>>( \
- reinterpret_cast(x.data_ptr()), \
- reinterpret_cast(y.data_ptr()), \
+#define LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \
+softmax_f32_per_token_kernel<(H)><<>>( \
+ reinterpret_cast(x.data_ptr()), \
+ reinterpret_cast(y.data_ptr()), \
N);
#define DISPATCH_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \
dim3 block((H)); \
- dim3 grid((S)); \
+ dim3 grid((S)); \
switch ((H)) \
{ \
+ case 32: \
+ LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \
+ break; \
case 64: \
LANUCH_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \
break; \
@@ -299,44 +361,54 @@ softmax_f32_per_token_kernel<(_H_)><<>>( \
break; \
}
-#define LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(_H_) \
-softmax_f32x4_per_token_kernel<(_H_)/4><<< \
- grid, block>>>( \
- reinterpret_cast(x.data_ptr()), \
- reinterpret_cast(y.data_ptr()), \
+#define LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \
+softmax_f32x4_per_token_kernel<(H)/4><<< \
+ grid, block>>>( \
+ reinterpret_cast(x.data_ptr()), \
+ reinterpret_cast(y.data_ptr()), \
N);
-#define DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) \
- dim3 block((H)/4); \
- dim3 grid((S)); \
- switch ((H)) \
- { \
- case 64: \
- LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(64) \
- break; \
- case 128: \
- LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(128) \
- break; \
- case 256: \
- LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(256) \
- break; \
- case 512: \
- LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(512) \
- break; \
- case 1024: \
- LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(1024) \
- break; \
- default: \
- throw std::runtime_error( \
- "only support H: 64/128/256/512/1024"); \
- break; \
+#define DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) \
+ const int NT = (H)/4; \
+ dim3 block(NT); \
+ dim3 grid((S)); \
+ switch (H) \
+ { \
+ case 32: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(32) \
+ break; \
+ case 64: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(64) \
+ break; \
+ case 128: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(128) \
+ break; \
+ case 256: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(256) \
+ break; \
+ case 512: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(512) \
+ break; \
+ case 1024: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(1024) \
+ break; \
+ case 2048: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(2048) \
+ break; \
+ case 4096: \
+ LANUCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(4096) \
+ break; \
+ default: \
+ throw std::runtime_error( \
+ "only support H: 64/128/.../1024*4"); \
+ break; \
}
// safe softmax per token
-#define LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(_H_) \
-safe_softmax_f32_per_token_kernel<(_H_)><<>>( \
- reinterpret_cast(x.data_ptr()), \
- reinterpret_cast(y.data_ptr()), \
+#define LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \
+safe_softmax_f32_per_token_kernel<(H)><<>>( \
+ reinterpret_cast(x.data_ptr()), \
+ reinterpret_cast(y.data_ptr()), \
N);
#define DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \
@@ -344,6 +416,9 @@ safe_softmax_f32_per_token_kernel<(_H_)><<>>( \
dim3 grid((S)); \
switch ((H)) \
{ \
+ case 32: \
+ LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \
+ break; \
case 64: \
LANUCH_SAFE_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \
break; \
@@ -365,18 +440,22 @@ safe_softmax_f32_per_token_kernel<(_H_)><<>>( \
break; \
}
-#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(_H_) \
-safe_softmax_f32x4_per_token_kernel<(_H_)/4><<< \
- grid, block>>>( \
- reinterpret_cast(x.data_ptr()), \
- reinterpret_cast(y.data_ptr()), \
+#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \
+safe_softmax_f32x4_per_token_kernel<(H)/4><<< \
+ grid, block>>>( \
+ reinterpret_cast(x.data_ptr()), \
+ reinterpret_cast(y.data_ptr()), \
N);
#define DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H) \
- dim3 block((H)/4); \
+ const int NT = (H)/4; \
+ dim3 block(NT); \
dim3 grid((S)); \
- switch ((H)) \
+ switch (H) \
{ \
+ case 32: \
+ LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(32) \
+ break; \
case 64: \
LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(64) \
break; \
@@ -392,154 +471,221 @@ safe_softmax_f32x4_per_token_kernel<(_H_)/4><<< \
case 1024: \
LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(1024) \
break; \
+ case 2048: \
+ LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(2048) \
+ break; \
+ case 4096: \
+ LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(4096) \
+ break; \
default: \
throw std::runtime_error( \
- "only support H: 64/128/256/512/1024"); \
+ "only support H: 64/128/.../1024*4"); \
break; \
}
-// softmax per token
-torch::Tensor softmax_f32_per_token(torch::Tensor x) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
- auto options = torch::TensorOptions().dtype((torch::kFloat32)).device(
- torch::kCUDA, 0);
- const int S = x.size(0); // seqlens
- const int H = x.size(1); // head size/kv_len
- const int N = S * H;
- auto y = torch::zeros({S, H}, options).contiguous(); // [S,H]
+#define LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(H) \
+safe_softmax_f16_f32_per_token_kernel<(H)><<>>( \
+ reinterpret_cast(x.data_ptr()), \
+ reinterpret_cast(y.data_ptr()), \
+ N);
- DISPATCH_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
- return y;
-}
+#define DISPATCH_SATE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(S, H) \
+ dim3 block((H)); \
+ dim3 grid((S)); \
+ switch ((H)) \
+ { \
+ case 32: \
+ LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(32) \
+ break; \
+ case 64: \
+ LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(64) \
+ break; \
+ case 128: \
+ LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(128) \
+ break; \
+ case 256: \
+ LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(256) \
+ break; \
+ case 512: \
+ LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(512) \
+ break; \
+ case 1024: \
+ LANUCH_SAFE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(1024) \
+ break; \
+ default: \
+ throw std::runtime_error( \
+ "only support H: 64/128/256/512/1024"); \
+ break; \
+ }
+
+#define LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(H) \
+safe_softmax_f16x2_f32_per_token_kernel<(H)/2><<>>( \
+ reinterpret_cast(x.data_ptr()), \
+ reinterpret_cast(y.data_ptr()), \
+ N);
+
+#define DISPATCH_SATE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(S, H) \
+ const int NT = (H)/2; \
+ dim3 block(NT); \
+ dim3 grid((S)); \
+ switch (H) \
+ { \
+ case 32: \
+ LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(32) \
+ break; \
+ case 64: \
+ LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(64) \
+ break; \
+ case 128: \
+ LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(128) \
+ break; \
+ case 256: \
+ LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(256) \
+ break; \
+ case 512: \
+ LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(512) \
+ break; \
+ case 1024: \
+ LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(1024) \
+ break; \
+ case 2048: \
+ LANUCH_SAFE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(2048) \
+ break; \
+ default: \
+ throw std::runtime_error( \
+ "only support H: 64/128/.../1024*2"); \
+ break; \
+ }
-// no copy for y Tensor
-void softmax_f32_per_token_v2(torch::Tensor x, torch::Tensor y) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
+#define LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(H) \
+safe_softmax_f16x8_pack_f32_per_token_kernel<(H)/8><<>>( \
+ reinterpret_cast(x.data_ptr()), \
+ reinterpret_cast(y.data_ptr()), \
+ N);
+
+#define DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H) \
+ const int NT = (H)/8; \
+ dim3 block(NT); \
+ dim3 grid((S)); \
+ switch (H) \
+ { \
+ case 32: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(32) \
+ break; \
+ case 64: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(64) \
+ break; \
+ case 128: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(128) \
+ break; \
+ case 256: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(256) \
+ break; \
+ case 512: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(512) \
+ break; \
+ case 1024: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(1024) \
+ break; \
+ case 2048: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(2048) \
+ break; \
+ case 4096: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(4096) \
+ break; \
+ case 8192: \
+ LANUCH_SAFE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(8192) \
+ break; \
+ default: \
+ throw std::runtime_error( \
+ "only support H: 64/128/.../1024*8"); \
+ break; \
+ }
+
+// per token fp32
+void softmax_f32_per_token(torch::Tensor x, torch::Tensor y) {
+ CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
+ CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
+ CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
const int H = x.size(1); // head size/kv_len
const int N = S * H;
- if ((y.size(0) != S) || (y.size(1) != H)) {
- throw std::runtime_error("y Tensor size mismatch!");
- }
-
DISPATCH_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
}
-torch::Tensor softmax_f32x4_per_token(torch::Tensor x) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
- auto options = torch::TensorOptions().dtype((torch::kFloat32)).device(
- torch::kCUDA, 0);
+void softmax_f32x4_per_token(torch::Tensor x, torch::Tensor y) {
+ CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
+ CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
+ CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
- const int H = x.size(1); // head size/kv_len
+ const int H = x.size(1); // head size/kv_len
const int N = S * H;
- auto y = torch::zeros({S, H}, options).contiguous(); // [S,H]
-
- DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H)
- return y;
+ DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H)
}
-// no copy for y Tensor
-void softmax_f32x4_per_token_v2(torch::Tensor x, torch::Tensor y) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
+void safe_softmax_f32_per_token(torch::Tensor x, torch::Tensor y) {
+ CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
+ CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
+ CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
const int H = x.size(1); // head size/kv_len
const int N = S * H;
- if ((y.size(0) != S) || (y.size(1) != H)) {
- throw std::runtime_error("y Tensor size mismatch!");
- }
-
- DISPATCH_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H)
+ DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
}
-// safe_softmax per token
-torch::Tensor safe_softmax_f32_per_token(torch::Tensor x) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
- auto options = torch::TensorOptions().dtype((torch::kFloat32)).device(
- torch::kCUDA, 0);
+void safe_softmax_f32x4_per_token(torch::Tensor x, torch::Tensor y) {
+ CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
+ CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
+ CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
const int H = x.size(1); // head size/kv_len
const int N = S * H;
- auto y = torch::zeros({S, H}, options).contiguous(); // [S,H]
-
- DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
- return y;
+ DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H)
}
-// no copy for y Tensor
-void safe_softmax_f32_per_token_v2(torch::Tensor x, torch::Tensor y) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
+// per token fp16
+void safe_softmax_f16_f32_per_token(torch::Tensor x, torch::Tensor y) {
+ CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
+ CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
const int H = x.size(1); // head size/kv_len
const int N = S * H;
- if ((y.size(0) != S) || (y.size(1) != H)) {
- throw std::runtime_error("y Tensor size mismatch!");
- }
-
- DISPATCH_SATE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
+ DISPATCH_SATE_SOFTMAX_F16_F32_PER_TOKEN_KERNEL(S, H)
}
-torch::Tensor safe_softmax_f32x4_per_token(torch::Tensor x) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
- auto options = torch::TensorOptions().dtype((torch::kFloat32)).device(
- torch::kCUDA, 0);
+void safe_softmax_f16x2_f32_per_token(torch::Tensor x, torch::Tensor y) {
+ CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
+ CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
- const int H = x.size(1); // head size/kv_len
+ const int H = x.size(1); // head size/kv_len
const int N = S * H;
- auto y = torch::zeros({S, H}, options).contiguous(); // [S,H]
-
- DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H)
- return y;
+ DISPATCH_SATE_SOFTMAX_F16x2_F32_PER_TOKEN_KERNEL(S, H)
}
-// no copy for y Tensor
-void safe_softmax_f32x4_per_token_v2(torch::Tensor x, torch::Tensor y) {
- if((x.options().dtype() != (torch::kFloat32))) {
- std::cout << "x Tensor Info:" << x.options() << std::endl;
- throw std::runtime_error("values must be torch::kFloat32");
- }
+void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) {
+ CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
+ CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
const int H = x.size(1); // head size/kv_len
const int N = S * H;
- if ((y.size(0) != S) || (y.size(1) != H)) {
- throw std::runtime_error("y Tensor size mismatch!");
- }
-
- DISPATCH_SATE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(S, H)
+ DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H)
}
+// grid memory fence fp32
+TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1)
+TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(softmax_f32)
- TORCH_BINDING_COMMON_EXTENSION(softmax_f32_v2)
TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4)
- TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4_v2)
TORCH_BINDING_COMMON_EXTENSION(softmax_f32_per_token)
- TORCH_BINDING_COMMON_EXTENSION(softmax_f32_per_token_v2)
TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4_per_token)
- TORCH_BINDING_COMMON_EXTENSION(softmax_f32x4_per_token_v2)
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32_per_token)
- TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32_per_token_v2)
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32x4_per_token)
- TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f32x4_per_token_v2)
+ TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16_f32_per_token)
+ TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x2_f32_per_token)
+ TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x8_pack_f32_per_token)
}
diff --git a/softmax/softmax.py b/softmax/softmax.py
index d5bc20f6..b2769161 100644
--- a/softmax/softmax.py
+++ b/softmax/softmax.py
@@ -50,40 +50,146 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
out_info = f"out_{tag}"
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
out_val = [round(v, 8) for v in out_val]
- print(f"{out_info:>20}: {out_val}, time:{mean_time:.8f}ms")
+ out_val = [f"{v:<12}" for v in out_val]
+ print(f"{out_info:>24}: {out_val}, time:{mean_time:.8f}ms")
if show_all: print(out)
return out, mean_time
+# grid memory fence
+print("-" * 100)
+N = 128 * 128
+print(" " * 45 + f"N={N}")
+print("-" * 100)
+x = torch.randn((N)).cuda().float()
+out = torch.zeros_like(x).cuda().float().contiguous()
+run_benchmark(lib.softmax_f32, x, "f32(fence)", out)
+run_benchmark(lib.softmax_f32x4, x, "f32x4(fence)", out)
+run_benchmark(partial(torch.softmax, dim=0, out=out), x, "f32_th")
+
+# per token softmax
+print("-" * 100)
+S, H = 4096, 256
+print(" " * 45 + f"S={S}, H={H}")
+print("-" * 100)
+x = torch.randn((S, H)).cuda().float().contiguous()
+out = torch.zeros_like(x).cuda().float().contiguous()
+run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
+run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
+run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
+run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
+run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
+
+print("-" * 100)
+x_f16 = x.half().contiguous()
+out_f16 = out.half().contiguous()
+run_benchmark(lib.safe_softmax_f16_f32_per_token, x_f16, "f16f32(safe)", out_f16)
+run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16)
+run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16)
+run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)")
+print("-" * 100)
-print("-" * 80)
-N_ELEMENTS = 256*48
-x = torch.randn((N_ELEMENTS)).cuda().float()
-run_benchmark(lib.softmax_f32, x, "f32")
-run_benchmark(lib.softmax_f32x4, x, "f32x4")
-run_benchmark(partial(torch.softmax, dim=0), x, "f32_th")
+# per token softmax
+print("-" * 100)
+S, H = 4096, 512
+print(" " * 45 + f"S={S}, H={H}")
+print("-" * 100)
+x = torch.randn((S, H)).cuda().float().contiguous()
+out = torch.zeros_like(x).cuda().float().contiguous()
+run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
+run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
+run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
+run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
+run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
+
+print("-" * 100)
+x_f16 = x.half().contiguous()
+out_f16 = out.half().contiguous()
+run_benchmark(lib.safe_softmax_f16_f32_per_token, x_f16, "f16f32(safe)", out_f16)
+run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16)
+run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16)
+run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)")
+print("-" * 100)
+
+# per token softmax
+print("-" * 100)
+S, H = 4096, 1024
+print(" " * 45 + f"S={S}, H={H}")
+print("-" * 100)
+x = torch.randn((S, H)).cuda().float().contiguous()
+out = torch.zeros_like(x).cuda().float().contiguous()
+run_benchmark(lib.softmax_f32_per_token, x, "f32(per)", out)
+run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
+run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
+run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
+run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
-print("-" * 80)
-# v2: no copy for out Tensor
+print("-" * 100)
+x_f16 = x.half().contiguous()
+out_f16 = out.half().contiguous()
+run_benchmark(lib.safe_softmax_f16_f32_per_token, x_f16, "f16f32(safe)", out_f16)
+run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16)
+run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16)
+run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)")
+print("-" * 100)
+
+# per token softmax
+print("-" * 100)
+S, H = 4096, 2048
+print(" " * 45 + f"S={S}, H={H}")
+print("-" * 100)
+x = torch.randn((S, H)).cuda().float().contiguous()
+out = torch.zeros_like(x).cuda().float().contiguous()
+run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
+run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
+run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
+
+print("-" * 100)
+x_f16 = x.half().contiguous()
+out_f16 = out.half().contiguous()
+run_benchmark(lib.safe_softmax_f16x2_f32_per_token, x_f16, "f16x2f32(safe)", out_f16)
+run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16)
+run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)")
+print("-" * 100)
+
+# per token softmax
+print("-" * 100)
+S, H = 4096, 4096
+print(" " * 45 + f"S={S}, H={H}")
+print("-" * 100)
+x = torch.randn((S, H)).cuda().float().contiguous()
+out = torch.zeros_like(x).cuda().float().contiguous()
+run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
+run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
+run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
+
+print("-" * 100)
+x_f16 = x.half().contiguous()
+out_f16 = out.half().contiguous()
+run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16)
+run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)")
+print("-" * 100)
+
+# per token softmax
+print("-" * 100)
+S, H = 4096, 8192
+print(" " * 45 + f"S={S}, H={H}")
+print("-" * 100)
+x = torch.randn((S, H)).cuda().float().contiguous()
out = torch.zeros_like(x).cuda().float().contiguous()
-run_benchmark(lib.softmax_f32_v2, x, "f32(v2)", out)
-run_benchmark(lib.softmax_f32x4_v2, x, "f32x4(v2)", out)
-run_benchmark(partial(torch.softmax, dim=0, out=out), x, "f32_th(v2)")
+x_f16 = x.half().contiguous()
+out_f16 = out.half().contiguous()
+run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16)
+run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)")
-print("-" * 80)
-S, H = 1024, 512
+# per token softmax
+print("-" * 100)
+S, H = 8192, 8192
+print(" " * 45 + f"S={S}, H={H}")
+print("-" * 100)
x = torch.randn((S, H)).cuda().float().contiguous()
-run_benchmark(lib.softmax_f32_per_token, x, "f32(per)")
-run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)")
-run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)")
-run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)")
-run_benchmark(partial(torch.softmax, dim=1), x, "f32_th(per)")
-
-print("-" * 80)
-# v2: no copy for out Tensor
out = torch.zeros_like(x).cuda().float().contiguous()
-run_benchmark(lib.softmax_f32_per_token_v2, x, "f32(per v2)", out)
-run_benchmark(lib.softmax_f32x4_per_token_v2, x, "f32x4(per v2)", out)
-run_benchmark(lib.safe_softmax_f32_per_token_v2, x, "f32(safe v2)", out)
-run_benchmark(lib.safe_softmax_f32x4_per_token_v2, x, "f32x4(safe v2)", out)
-run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per v2)")
-print("-" * 80)
+x_f16 = x.half().contiguous()
+out_f16 = out.half().contiguous()
+run_benchmark(lib.safe_softmax_f16x8_pack_f32_per_token, x_f16, "f16x8packf32(safe)", out_f16)
+run_benchmark(partial(torch.softmax, dim=1, out=out_f16), x_f16, "f16_th(per)")
+print("-" * 100)