Skip to content

Commit

Permalink
[Softmax][FP16] Pack f16x8 softmax kernel (#49)
Browse files Browse the repository at this point in the history
* Update README.md

* Update softmax.cu

* Update softmax.py

* Update README.md

* Update layer_norm.cu

* Update README.md

* Update rms_norm.cu
  • Loading branch information
DefTruth authored Sep 26, 2024
1 parent 93636df commit 5901796
Show file tree
Hide file tree
Showing 6 changed files with 589 additions and 266 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
</div>

🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[Beginners]**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[B]eginners**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).

<img width="1438" alt="image" src="https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe">

## 0x00 📖 CUDA Kernel目录 (面试常考题目)
- / = not supported now.
- ✔️ = known work and already supported now.
- ❔ = in my plan, but not coming soon, maybe a few weeks later.
- **workflow**: custom **CUDA** kernel impl -> **Torch** python binding -> Run tests.
- **workflow**: custom **CUDA** kernel impl -> **PyTorch** python binding -> Run tests.

|📖 cuda kernel| 📖 elem dtype| 📖 acc dtype| 📖 docs | 📖 level |
|:---|:---|:---|:---|:---|
Expand Down Expand Up @@ -75,6 +75,9 @@
| ✔️ [softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x2_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x8_pack_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [layer_norm_f32(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f32x4(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
Expand Down
9 changes: 6 additions & 3 deletions layer-norm/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,12 @@ if(((T).options().dtype() != (th_type))) { \
throw std::runtime_error("values must be "#th_type); \
}

#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \
throw std::runtime_error("Tensor size mismatch!"); \
#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
assert((T1).dim() == (T2).dim()); \
for (int i = 0; i < (T1).dim(); ++i) { \
if ((T2).size(i) != (T1).size(i)) { \
throw std::runtime_error("Tensor size mismatch!"); \
} \
}

// fp32
Expand Down
9 changes: 6 additions & 3 deletions rms-norm/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,12 @@ if(((T).options().dtype() != (th_type))) { \
throw std::runtime_error("values must be "#th_type); \
}

#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \
throw std::runtime_error("Tensor size mismatch!"); \
#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
assert((T1).dim() == (T2).dim()); \
for (int i = 0; i < (T1).dim(); ++i) { \
if ((T2).size(i) != (T1).size(i)) { \
throw std::runtime_error("Tensor size mismatch!"); \
} \
}

#define LANUCH_RMS_NORM_F32_KERNEL(K) \
Expand Down
112 changes: 87 additions & 25 deletions softmax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
包含以下内容:

- [X] softmax_f32_kernel (grid level memory fence)
- [X] softmax_f32x4_kernel(grid level memory fence, float4向量化版本)
- [X] softmax_f32x4_kernel(grid level memory fence)
- [X] softmax_f32_per_token_kernel(per token)
- [X] softmax_f32x4_per_token_kernel(per token, float4向量化版本)
- [X] softmax_f32x4_per_token_kernel(per token)
- [X] safe_softmax_f32_per_token_kernel(per token)
- [X] safe_softmax_f32x4_per_token_kernel(per token, float4向量化版本)
- [X] safe_softmax_f32x4_per_token_kernel(per token)
- [X] safe_softmax_f16_f32_per_token_kernel(per token)
- [X] safe_softmax_f16x2_f32_per_token_kernel(per token)
- [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
- [X] PyTorch bindings


Expand All @@ -24,25 +27,84 @@ python3 softmax.py
输出:

```bash
--------------------------------------------------------------------------------
out_f32: [1.909e-05, 0.00023536, 0.00010881], time:0.01697016ms
out_f32x4: [1.909e-05, 0.00023536, 0.00010881], time:0.01716042ms
out_f32_th: [1.909e-05, 0.00023536, 0.00010881], time:0.00715089ms
--------------------------------------------------------------------------------
out_f32(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01011539ms
out_f32x4(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01006842ms
out_f32_th(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.00547409ms
--------------------------------------------------------------------------------
out_f32(per): [0.00569158, 0.00022239, 0.00137839], time:0.01047754ms
out_f32x4(per): [0.00569158, 0.00022239, 0.00137839], time:0.01045704ms
out_f32(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01054454ms
out_f32x4(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01042986ms
out_f32_th(per): [0.00569158, 0.00022239, 0.00137839], time:0.00741696ms
--------------------------------------------------------------------------------
out_f32(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00419974ms
out_f32x4(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00316834ms
out_f32(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00603890ms
out_f32x4(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00319862ms
out_f32_th(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00577068ms
--------------------------------------------------------------------------------
```
----------------------------------------------------------------------------------------------------
N=16384
----------------------------------------------------------------------------------------------------
out_f32(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01040053ms
out_f32x4(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01053643ms
out_f32_th: ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.00582504ms
----------------------------------------------------------------------------------------------------
S=4096, H=256
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00627208ms
out_f32x4(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00394082ms
out_f32(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00941491ms
out_f32x4(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00413442ms
out_f32_th(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00602674ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00912046ms
out_f16x2f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00522232ms
out_f16x8packf32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00413895ms
out_f16_th(per): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00605321ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=512
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01139641ms
out_f32x4(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00515914ms
out_f32(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01834297ms
out_f32x4(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00574923ms
out_f32_th(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00657558ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.01782560ms
out_f16x2f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00919509ms
out_f16x8packf32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00415683ms
out_f16_th(per): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00634599ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=1024
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.03191805ms
out_f32x4(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.00862813ms
out_f32(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.04873967ms
out_f32x4(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01027441ms
out_f32_th(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01181388ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.04671884ms
out_f16x2f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01810408ms
out_f16x8packf32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.00601912ms
out_f16_th(per): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01047063ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=2048
----------------------------------------------------------------------------------------------------
out_f32x4(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.01605988ms
out_f32x4(safe): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.02089310ms
out_f32_th(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.06726241ms
----------------------------------------------------------------------------------------------------
out_f16x2f32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.04824972ms
out_f16x8packf32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.01086283ms
out_f16_th(per): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.07232165ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=4096
----------------------------------------------------------------------------------------------------
out_f32x4(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18465948ms
out_f32x4(safe): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18565655ms
out_f32_th(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18744922ms
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.02254891ms
out_f16_th(per): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.08283138ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=8192
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19313049ms
out_f16_th(per): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19356799ms
----------------------------------------------------------------------------------------------------
S=8192, H=8192
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.39828229ms
out_f16_th(per): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.40599036ms
----------------------------------------------------------------------------------------------------
```
Loading

0 comments on commit 5901796

Please sign in to comment.