From 0c6785f0d023fe85fae618ad80359d340e334644 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:40:37 +0800 Subject: [PATCH] =?UTF-8?q?[FA2]=20Release=20flash-attn-mma=20shared-kv/qk?= =?UTF-8?q?v=F0=9F=8E=89=20(#162)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update flash_attn_mma_share_kv.cu * Create flash_attn_mma_share_qkv.cu * Update flash_attn_mma_split_q.cu * Update flash_attn_mma_split_q.cu * Update flash_attn_mma_split_kv.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update flash_attn_mma_share_qkv.cu --- README.md | 101 ++- kernels/flash-attn/README.md | 196 +++-- kernels/flash-attn/flash_attn_mma.py | 31 +- .../flash-attn/mma/flash_attn_mma_share_kv.cu | 414 +++++---- .../mma/flash_attn_mma_share_qkv.cu | 814 ++++++++++++++++++ .../flash-attn/mma/flash_attn_mma_split_kv.cu | 6 + .../flash-attn/mma/flash_attn_mma_split_q.cu | 6 + kernels/flash-attn/pybind/flash_attn.cc | 27 +- 8 files changed, 1280 insertions(+), 315 deletions(-) create mode 100644 kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu diff --git a/README.md b/README.md index d0dd4f2c..3d7a7972 100644 --- a/README.md +++ b/README.md @@ -42,24 +42,49 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d |Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32| |✔️|✔️|✔️|✔️| -I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details. + +I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Fully Sahred QKV SMEM, Prefetch Q s2r, Collective Store, etc. Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. ![flash-attn-mma](https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2) +- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop) +```bash +python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop +------------------------------------------------------------------------------------------------------------------------ + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10 +------------------------------------------------------------------------------------------------------------------------ + B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10 + mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08 + mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31 + mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54 + mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37 + mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15 + mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01 + mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42 + (flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22 +----------------------------------------------------------------------------------------------------------------------- +``` + +However, for large-scale attention computations, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details. -|CUDA Cores|Sliced K (Loop over N/D)|Tile Block (Br, Bc, Bd)|MMA (m16n8k16)| +|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)| |:---:|:---:|:---:|:---:| |✔️|✔️|✔️|✔️| -|Pack LDST (128 bits)|SMEM Padding|Copy Async |Tile MMAs (More Threads) +|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads) |✔️|✔️|✔️|✔️| -|Tile Warps (More Values)|Multi Stages (1/2)| Collective Store (Shfl)| **Split KV/Q** | +|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**| |✔️|✔️|✔️|✔️| +|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle| +|✔️|✔️|✔️|?| -The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps). +The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps). -![flash-attn](https://github.com/user-attachments/assets/11490fbc-2a4a-4630-abe8-91a9d1251cba) + - 📚 Split KV (Basic, FlashAttention-1) +
```C++ // Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy. @@ -69,22 +94,6 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| -template< - const int kHeadDim, // Headdim, 32,64,128 - const int kMmaAtomM, // MMA Atom M, 16 - const int kMmaAtomN, // MMA Atom N, 8 - const int kMmaAtomK, // MMA Atom K, 16 - const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M - const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N - const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M - const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|... - const int kStage, // only support 1 or 2 now. - const int kPad // 0,8 - > __global__ void flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] half* K, // [B, H, D, N] K^T transposed @@ -94,6 +103,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] ``` - 📚 Split Q (Faster, FlashAttention-2) +
```C++ // Split Q across MMA(Warps) and keep access KV for all MMA(Warps), @@ -104,22 +114,6 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | -template< - const int kHeadDim, // Headdim, 32,64,128 - const int kMmaAtomM, // MMA Atom M, 16 - const int kMmaAtomN, // MMA Atom N, 8 - const int kMmaAtomK, // MMA Atom K, 16 - const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M - const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N - const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M - const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... - const int kStage, // only support 1 or 2 now. - const int kPad // 0,8 - > __global__ void flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] half* K, // [B, H, D, N] K^T transposed @@ -127,6 +121,33 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] half* O, // [B, H, N, D] int QKV_seqlen); ``` + +- 📚 Split Q + Shared KV SMEM (Faster+) +
+ +```C++ +// K, V shared the same shared memory, improve block occupancy. +__global__ void +flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen); +``` +- 📚 Split Q + Fully Shared QKV SMEM (Faster++) + +
+ +```C++ +// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy. +__global__ void +flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen); +``` + ## ©️Citations🎉🎉 ```BibTeX @@ -144,11 +165,13 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
-|📖 CUDA Kernel| 📖 Elem dtype| 📖 Acc dtype| 📖 Docs | 📖 Level | +|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level | |:---|:---|:---|:---|:---| | ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️| | ✔️ [flash_attn_mma_stages_split_kv*](./kernels/flash-attn/mma/flash_attn_mma_split_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️| | ✔️ [flash_attn_mma_stages_split_q*](./kernels/flash-attn/mma/flash_attn_mma_split_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️| +| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| +| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| | ✔️ [sgemm_naive_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️| | ✔️ [sgemm_sliced_k_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️| | ✔️ [sgemm_t_8x8_sliced_k_f32x4](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️| diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md index dcc4726d..5717d994 100644 --- a/kernels/flash-attn/README.md +++ b/kernels/flash-attn/README.md @@ -2,33 +2,56 @@ ![flash-attn-mma](https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2) -|CUDA Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc, Bd)|MMA (m16n8k16)| +|Tensor Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc)|MMA (m16n8k16)| |:---:|:---:|:---:|:---:| |✔️|✔️|✔️|✔️| |Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads) |✔️|✔️|✔️|✔️| |Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**| |✔️|✔️|✔️|✔️| +|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle| +|✔️|✔️|✔️|?| -This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention (SeqLen <= 4096), the flash-attention-mma implemented in this repository matches the performance of the official FA. However, for large-scale attention computations, there remains a significant performance gap. Performance optimizations are ongoing; stay tuned for updates. +This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates. + +- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop) +```bash +python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop +------------------------------------------------------------------------------------------------------------------------ + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10 +------------------------------------------------------------------------------------------------------------------------ + B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10 + mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08 + mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31 + mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54 + mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37 + mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15 + mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01 + mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42 + (flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22 +----------------------------------------------------------------------------------------------------------------------- +``` ## 📖 Contents -- [📖 Split KV](#mma-split-kv) -- [📖 Split Q](#mma) +- [📖 FlashAttetion MMA Kernels](#mma) + - [📚 Split KV](#mma-split-kv) + - [📚 Split Q ](#mma-split-q) + - [📚 Shared KV SMEM](#mma-share-kv) + - [📚 Fully Shared QKV SMEM](#mma-share-qkv) - [📖 Prerequisites](#prerequisites) - [📖 Installation](#install) - [📖 Performance](#perf) - [📖 Python Testing](#test) - + ## 📖 FlashAttetion MMA Kernels
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](.) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps) using a naive matmul (MMA) and Warp tiling policy, is slower compared to the `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps). - + +- 📚 Split KV (Basic, FlashAttention-1)
```C++ @@ -39,22 +62,6 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| -template< - const int kHeadDim, // Headdim, 32,64,128 - const int kMmaAtomM, // MMA Atom M, 16 - const int kMmaAtomN, // MMA Atom N, 8 - const int kMmaAtomK, // MMA Atom K, 16 - const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M - const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N - const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M - const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|... - const int kStage, // only support 1 or 2 now. - const int kPad // 0,8 - > __global__ void flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] half* K, // [B, H, D, N] K^T transposed @@ -63,7 +70,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] int QKV_seqlen); ``` -## 📖 Split Q (Faster, FlashAttention-2) +- 📚 Split Q (Faster, FlashAttention-2)
```C++ @@ -75,22 +82,6 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | -template< - const int kHeadDim, // Headdim, 32,64,128 - const int kMmaAtomM, // MMA Atom M, 16 - const int kMmaAtomN, // MMA Atom N, 8 - const int kMmaAtomK, // MMA Atom K, 16 - const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] - const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] - const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M - const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N - const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M - const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... - const int kStage, // only support 1 or 2 now. - const int kPad // 0,8 - > __global__ void flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] half* K, // [B, H, D, N] K^T transposed @@ -99,6 +90,31 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] int QKV_seqlen); ``` +- 📚 Split Q + Shared KV SMEM (Faster+) +
+ +```C++ +// K, V shared the same shared memory, improve block occupancy. +__global__ void +flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen); +``` +- 📚 Split Q + Fully Shared QKV SMEM (Faster++) + +
+ +```C++ +// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy. +__global__ void +flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen); +``` ## 📖 Prerequisites
@@ -117,7 +133,7 @@ pip install flash-attn --no-build-isolation # need offical flash-attention for c ## 📖 Performance
-Currently, for small-scale attention (SeqLen <= 4096), the flash-attention-mma implemented in this repository matches the performance of the official FA version. However, for large-scale attention computations, there remains a significant performance gap. Performance optimizations are ongoing; stay tuned for updates. +Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192), the flash-attention-mma implemented in this repository matches the performance of the official FA version. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates. ## 📖 Python Testing
@@ -134,61 +150,91 @@ python3 flash_attn_mma.py --D 64 # test all default settings for D=64 - B=2, H=2, N=4096, D=64 ```bash -python3 flash_attn_mma.py --B 2 --H 2 --D 64 --N 4096 # NVIDIA RTX 3080 Laptop +python3 flash_attn_mma.py --B 2 --H 2 --D 64 --N 4096 --iters 10 # NVIDIA RTX 3080 Laptop ------------------------------------------------------------------------------------------------------------------------ - B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 6827, Warmup: 2, Iters: 10 + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 9655, Warmup: 1, Iters: 10 ------------------------------------------------------------------------------------------------------------------------ - B=2, H=2, N=4096, D=64, Warmup: 2, Iters: 10 - mma(split-kv+stage1): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.767565ms, TFLOPS:22.82 - mma(split-kv+stage2): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.730205ms, TFLOPS:23.99 - mma(split-q+stage1): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.524163ms, TFLOPS:33.41 - mma(split-q+stage2): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.622582ms, TFLOPS:28.13 - (flash): ['-0.02687073 ', '0.03143311 ', '-0.03656006 '], time:0.610447ms, TFLOPS:28.69 + B=2, H=2, N=4096, D=64, Warmup: 1, Iters: 10 + mma(split-kv+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.765753ms, TFLOPS:22.87 + mma(split-kv+stage2): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.731516ms, TFLOPS:23.94 + mma(split-q+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.526834ms, TFLOPS:33.24 + mma(split-q+stage2): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.660753ms, TFLOPS:26.51 + mma(split-q+share-kv+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.460815ms, TFLOPS:38.01 + mma(split-q+share-qkv+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.465345ms, TFLOPS:37.64 + mma(split-q+share-qkv+stage2): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.474334ms, TFLOPS:36.92 + (flash): ['0.01904297 ', '-0.02037048 ', '-0.01724243 '], time:0.596189ms, TFLOPS:29.38 ------------------------------------------------------------------------------------------------------------------------ ``` - B=2, H=2, N=8192, D=64 ```bash -python3 flash_attn_mma.py --B 2 --H 2 --D 64 --N 8192 # NVIDIA RTX 3080 Laptop + python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop ------------------------------------------------------------------------------------------------------------------------ - B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1923, Warmup: 2, Iters: 10 + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 5669, Warmup: 1, Iters: 10 ------------------------------------------------------------------------------------------------------------------------ - B=2, H=2, N=8192, D=64, Warmup: 2, Iters: 10 - mma(split-kv+stage1): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:2.870488ms, TFLOPS:24.41 - mma(split-kv+stage2): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:2.599239ms, TFLOPS:26.95 - mma(split-q+stage1): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:1.826215ms, TFLOPS:38.36 - mma(split-q+stage2): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:2.142096ms, TFLOPS:32.71 - (flash): ['-0.01076508 ', '-0.0075798 ', '0.02301025 '], time:2.061176ms, TFLOPS:33.99 + B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10 + mma(split-kv+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:5.572367ms, TFLOPS:25.15 + mma(split-kv+stage2): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:5.295920ms, TFLOPS:26.46 + mma(split-q+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:3.607082ms, TFLOPS:38.85 + mma(split-q+stage2): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:4.600883ms, TFLOPS:30.45 + mma(split-q+share-kv+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:2.744508ms, TFLOPS:51.05 + mma(split-q+share-qkv+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:2.700114ms, TFLOPS:51.89 + mma(split-q+share-qkv+stage2): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:2.692103ms, TFLOPS:52.05 + (flash): ['-0.00882721 ', '0.01213074 ', '-0.01314545 '], time:3.778219ms, TFLOPS:37.09 ------------------------------------------------------------------------------------------------------------------------ ``` - B=1, H=8, N=8192, D=64 ```bash -python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 # NVIDIA RTX 3080 Laptop ------------------------------------------------------------------------------------------------------------------------- - B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 4374, Warmup: 2, Iters: 10 +python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop ------------------------------------------------------------------------------------------------------------------------ - B=1, H=8, N=8192, D=64, Warmup: 2, Iters: 10 - mma(split-kv+stage1): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:5.583835ms, TFLOPS:25.09 - mma(split-kv+stage2): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:5.325174ms, TFLOPS:26.31 - mma(split-q+stage1): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:3.675842ms, TFLOPS:38.12 - mma(split-q+stage2): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:4.370213ms, TFLOPS:32.06 - (flash): ['-0.01470184 ', '-0.01394653 ', '-0.02435303 '], time:3.680992ms, TFLOPS:38.07 + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10 ------------------------------------------------------------------------------------------------------------------------ + B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10 + mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08 + mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31 + mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54 + mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37 + mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15 + mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01 + mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42 + (flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22 +----------------------------------------------------------------------------------------------------------------------- ``` - B=1, H=48, N=8192, D=64 ```bash -python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 # NVIDIA RTX 3080 Laptop +python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop +------------------------------------------------------------------------------------------------------------------------ + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 4669, Warmup: 1, Iters: 10 +------------------------------------------------------------------------------------------------------------------------ + B=1, H=48, N=8192, D=64, Warmup: 1, Iters: 10 + mma(split-kv+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:42.534423ms, TFLOPS:19.77 + mma(split-kv+stage2): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:42.349815ms, TFLOPS:19.85 + mma(split-q+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:35.657477ms, TFLOPS:23.58 + mma(split-q+stage2): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:36.065412ms, TFLOPS:23.31 + mma(split-q+share-kv+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:23.619652ms, TFLOPS:35.59 + mma(split-q+share-qkv+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:23.893070ms, TFLOPS:35.19 + mma(split-q+share-qkv+stage2): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:23.590446ms, TFLOPS:35.64 + (flash): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:22.385812ms, TFLOPS:37.56 +------------------------------------------------------------------------------------------------------------------------ +``` + +- B=1, H=8, N=8192, D=32 +```bash +python3 flash_attn_mma.py --B 1 --H 8 --D 32 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop ------------------------------------------------------------------------------------------------------------------------ - B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 8331, Warmup: 2, Iters: 10 + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 2322, Warmup: 1, Iters: 10 ------------------------------------------------------------------------------------------------------------------------ - B=1, H=48, N=8192, D=64, Warmup: 2, Iters: 10 - mma(split-kv+stage1): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:42.588711ms, TFLOPS:19.74 - mma(split-kv+stage2): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:42.275143ms, TFLOPS:19.89 - mma(split-q+stage1): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:37.420964ms, TFLOPS:22.47 - mma(split-q+stage2): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:37.678123ms, TFLOPS:22.31 - (flash): ['-0.0150528 ', '0.00946045 ', '0.0368042 '], time:22.342849ms, TFLOPS:37.63 + B=1, H=8, N=8192, D=32, Warmup: 1, Iters: 10 + mma(split-kv+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:3.930807ms, TFLOPS:18.16 + mma(split-kv+stage2): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:3.901839ms, TFLOPS:18.30 + mma(split-q+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.839685ms, TFLOPS:38.81 + mma(split-q+stage2): ['-0.00607681 ', '-0.00229454 ', '0.02029419 '], time:1.511669ms, TFLOPS:47.23 + mma(split-q+share-kv+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.400948ms, TFLOPS:50.97 + mma(split-q+share-qkv+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.393318ms, TFLOPS:51.25 + mma(split-q+share-qkv+stage2): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.322961ms, TFLOPS:53.97 + (flash): ['-0.00617599 ', '-0.00231934 ', '0.02029419 '], time:1.810646ms, TFLOPS:39.43 ------------------------------------------------------------------------------------------------------------------------ ``` diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py index 9a1d1f40..9ae87ddb 100644 --- a/kernels/flash-attn/flash_attn_mma.py +++ b/kernels/flash-attn/flash_attn_mma.py @@ -47,8 +47,8 @@ def get_args(): parser.add_argument("--seed", type=int, default=None) parser.add_argument("--debug", action="store_true") parser.add_argument("--verbose", '--v', action="store_true") - parser.add_argument("--warmup", type=int, default=2) - parser.add_argument("--iters", type=int, default=10) + parser.add_argument("--warmup", type=int, default=1) + parser.add_argument("--iters", type=int, default=5) parser.add_argument("--range-k", '--gk', action="store_true") return parser.parse_args() @@ -62,6 +62,8 @@ def get_args(): sources=[ './mma/flash_attn_mma_split_kv.cu', './mma/flash_attn_mma_split_q.cu', + './mma/flash_attn_mma_share_kv.cu', + './mma/flash_attn_mma_share_qkv.cu', './pybind/flash_attn.cc' ], extra_cuda_cflags=[ @@ -94,7 +96,7 @@ def get_mha_tflops(B, H, N, D, T=1.0): flops_subtract_max = B * H * N * N # sub max flops_exp = B * H * N * N # pointwise exp flops_row_sum = B * H * N * (N - 1) # row sum - flops_normalization = B * H * N * N # 归一化 + flops_normalization = B * H * N * N # normalization flops_safe_softmax = flops_row_max + flops_subtract_max + flops_exp + flops_row_sum + flops_normalization @@ -118,7 +120,7 @@ def run_benchmark(perf_func: callable, v: torch.Tensor, tag: str, out: Optional[torch.Tensor] = None, - s: Optional[torch.Tensor] = None, # BUDEG + s: Optional[torch.Tensor] = None, # DEBUG stages: int = -1, warmup: int = args.warmup, iters: int = args.iters, @@ -173,7 +175,7 @@ def run_benchmark(perf_func: callable, out_val = out_val_first[:2] out_val.append(out_val_last[-1]) out_val = [f"{v:<12}" for v in out_val] - print(f"{out_info:>25}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}") + print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}") if show_all: print(out) time.sleep(0.05) @@ -234,7 +236,7 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor, Bs = [1, 2, 4] if not args.B else [args.B] Hs = [1, 4, 8] if not args.H else [args.H] -Ns = [1024, 2048] if not args.N else [args.N] +Ns = [1024, 2048, 4096] if not args.N else [args.N] Ds = [64, 128] if not args.D else [args.D] # batch_size, n_head, seq_len, head_dim (B,H,N,D) BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds] @@ -252,14 +254,17 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor, torch.cuda.synchronize() if args.run_torch_unfused: - out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)") - out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1) - out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2) - out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1) - out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2) - out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") + out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)") + out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1) + out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2) + out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1) + out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2) + out_mma_share_kv, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1) + out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1) + out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2) + out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") if args.run_torch_sdpa: - out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)") + out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)") print("-" * 120) torch.cuda.synchronize() diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu index b8e32226..5ad6bfa7 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu @@ -31,6 +31,17 @@ // | warp_QP 6 | MMA 6 ... MMA 6 (x16) | // | warp_QP 7 | MMA 7 ... MMA 7 (x16) | +// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps +// | 128x64 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x8) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x8) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x8) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x8) | +// | warp_QP 4 | MMA 4 ... MMA 4 (x8) | +// | warp_QP 5 | MMA 5 ... MMA 5 (x8) | +// | warp_QP 6 | MMA 6 ... MMA 6 (x8) | +// | warp_QP 7 | MMA 7 ... MMA 7 (x8) | + template< const int kHeadDim, // Headdim, 32,64,128 const int kMmaAtomM, // MMA Atom M, 16 @@ -56,19 +67,20 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, int QKV_seqlen) { // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 - static_assert(kMmaTileSeqLenQ == 4 && kMmaTileSeqLenK == 1); // Q@K^T - static_assert(kMmaTileSeqLenP == 4 && kMmaTileHeadDimV == 1); // P@V - static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK == 8); // Q@K^T + static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T + static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V + static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc. // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128. static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V - static_assert(kStage > 0 && kStage < 3); // 1,2 + // TODO: support stages for shared kv smem kernel. + static_assert(kStage < 3 && kStage > 0); static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16 constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 + static_assert(Br >= Bc); // for shared memory reuse. constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads - constexpr int kNumStoreBatchO = 2; // batch size for O collective store. // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc] const float scale = 1.0f / sqrt((float) kHeadDim); @@ -126,19 +138,18 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1 extern __shared__ half smem[]; constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M - constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M - constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M + constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M // K multi-stages: currently, only apply multi stages for K across seq_len. half* Q_tile_smem = smem; // 8M/16M half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M - half* V_tile_smem = K_tile_smem + kStage * K_tile_size; - // TODO: KV may shared same smem to reduce smem usage for kStage 1 - // stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M - // stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M =64M - // stage 2, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*2+32M =128M - // stage 1, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)+8M =24M - // stage 1, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*1+16M =48M - // stage 1, no shared KV smem, Br=Bc=32, d=256: 16M+(16M)*1+16M =48M + half* V_tile_smem = K_tile_smem; // KV shared the same smem + // NOTE: KV may shared same smem to reduce smem usage for kStage 1 + // stage 1, w shared KV smem, Br=Bc=64, d=64: 8M+(8M) =16M, +Pad(2M) = 18M + // stage 1, w shared KV smem, Br=Bc=128, d=64: 16M+16M =32M, +Pad(2M) = 34M + // stage 1, w shared KV smem, Br=Bc=64, d=128: 16M+16M =32M, +Pad(4M) = 36M + // stage 1, w shared KV smem, Br=Bc=128, d=128: 32M+32M =64M, +Pad(4M) = 68M + // stage 1, w shared KV smem, Br=Bc=32, d=256: 16M+16M =32M, +Pad(1M) = 34M uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem); uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem); @@ -153,7 +164,16 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. - uint32_t R_Q[kWarpTileSeqLenQ][ 4]; // [1][4] + // TODO: Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs. + // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. + // Then we can load Q from smem only once and reuse it for + // processes. This will reduce large io-access for Q smem while N is large. + // constexpr bool kCanPrefetchQs2r = false; // d <= 128 + // FIXME(DefTruth): why can not get good performance for headdim >= 64 ? + // Will enable it untill I have figure out the performance issues. + constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64); + constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1; + uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4] uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2] uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2] // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc] @@ -164,12 +184,9 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] - // Helper regs for collective store, [2][4]. may use smem to reduce regs? - uint32_t R_Z[kNumStoreBatchO][4]; // [2][4] fill_3D_regs(R_S, 0); fill_3D_regs(R_D, 0); fill_3D_regs(R_O, 0); - fill_2D_regs(R_Z, 0); // load Q from gmem -> smem, only load once. { @@ -184,33 +201,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, CP_ASYNC_COMMIT_GROUP(); } - // load K from gmem -> smem, (kStage - 1) K^T tiles, [d,Bc] - if constexpr (kStage > 1) { - #pragma unroll - for (int stage = 0; stage < (kStage - 1); ++stage) { - // update the offset of n according to stages - load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); - uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (stage * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); - #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { - CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); - } - CP_ASYNC_COMMIT_GROUP(); - } - } - - // wait Q and at least (kStage - 1) for K ready. - if constexpr (kStage > 1) { - CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2 - __syncthreads(); - } - // : for K^T[d,seqlen] with K^T_tile[d,Bc] // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] #pragma unroll 1 @@ -226,87 +216,61 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // and smem_sel_next will always equal 0, thus, we can not // prefetch KV from gmem to smem before tile_K_seqlen MMA done. - if constexpr (kStage > 1) { - // First, prefetch curr V tile_K_seqlen [Bc,d] (no stages) - { - load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; - int load_gmem_V_d = load_smem_V_d; - int load_gmem_V_addr = ( - V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); - uint32_t load_smem_V_ptr = ( - smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + - load_smem_V_d) * sizeof(half) - ); - #pragma unroll - for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { - CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); - } - CP_ASYNC_COMMIT_GROUP(); + // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages) + if (tile_K_seqlen == 0) { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } + CP_ASYNC_COMMIT_GROUP(); + } - // Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc] - if ((tile_K_seqlen + 1) < Tc) { - load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; - int load_gmem_K_addr = ( - K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); - uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel_next * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half) - ); - #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { - CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); - } - CP_ASYNC_COMMIT_GROUP(); - } - } else { - // If no stages, kStage = 1, we have to load current K tile - // from gmem to smem and have to wait it ready for Q@K^T MMA. - - // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages) - { - load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); - uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); - #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { - CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); - } - CP_ASYNC_COMMIT_GROUP(); - } + if constexpr (kCanPrefetchQs2r) { + // Wait Q ready and let K copy async, then prefetch Q from smem -> regs. + // NOTE: we only need to load Q once from smem -> regs, and then reuse it. + if (tile_K_seqlen == 0) { + // TODO: Full share QKV smem after Q is ready load to regs. + CP_ASYNC_WAIT_GROUP(1); + __syncthreads(); - // Then, prefetch curr K tile_K_seqlen [d,Bc] (no stages) - { - load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; - int load_gmem_V_d = load_smem_V_d; - int load_gmem_V_addr = ( - V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); - uint32_t load_smem_V_ptr = ( - smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + - load_smem_V_d) * sizeof(half) - ); #pragma unroll - for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { - CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs. + // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. + // Then we can load Q from smem only once and reuse it for + // processes. This will reduce large io-access for Q smem while N is large. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4] + } } - CP_ASYNC_COMMIT_GROUP(); - } - - // Wait curr Q and K tile ready and let curr V tile copy async. - CP_ASYNC_WAIT_GROUP(1); + } // end if tile_K_seqlen == 0 + // Now, we have to wait curr K tile ready for Q@K^T MMA. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } else { + // Wait curr Q and K tile ready. + CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } - + // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] // Matmul with NN layout, Q row major, K row major. // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] @@ -315,17 +279,19 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. // ldmatrix.x4 for Q_tile_smem. - #pragma unroll - for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] - int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; - int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 - int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 - uint32_t lane_smem_Q_ptr = ( - smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + - lane_smem_Q_d) * sizeof(half) - ); - LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], - lane_smem_Q_ptr); // now, R_Q + if constexpr (!kCanPrefetchQs2r) { + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3], + lane_smem_Q_ptr); // now, R_Q[1][1][4] + } } // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. @@ -342,21 +308,55 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, ); LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K } // end for kWarpTileSeqLenK - - // MMA compute - #pragma unroll - for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + + if constexpr (kCanPrefetchQs2r) { + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + } else { + // MMA compute #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { - HMMA16816(R_S[i][j][0], R_S[i][j][1], - R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], - R_K[j][0], R_K[j][1], - R_S[i][j][0], R_S[i][j][1]); + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } } } } // end loop over d, S=Q@K^T __syncthreads(); + // Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages), + // before rowmax and rowsum, load V from gmem -> smem. + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps // | 64x64 | warp_KV 0 | // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | @@ -402,9 +402,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // Warp level reduce max, warp_size = 4 // Each thread contains the maximum of 2 rows of Br, // and only the values of T0, T4, ..., T28 are used. - // Br, row_id = warp_QP<0~3> * 32 + i<0> * 16 + 0 * 8 + (lane / 4) <0~7> lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]); - // Br, row_id = warp_QP<0~3> * 32 + i<0> * 16 + 1 * 8 + (lane / 4) <8~15> lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]); } // end for kWarpTileSeqLenQ __syncthreads(); @@ -519,6 +517,23 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, } // end for V Bc. __syncthreads(); + // NOTE: Load next K tile async before rescale O + if ((tile_K_seqlen + 1) < Tc) { + load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + // Rescale O -> Update row sum Exp -> then, Update row max. #pragma unroll for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1 @@ -591,6 +606,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // Finaly, we still have to rescale O once more. // O_output(D) = ( 1/l_final ) * O_final (FA2 paper) + // NOTE: Here, we choose to reuse R_O as final output + // in order to reduce regs usage. #pragma unroll for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]); @@ -614,29 +631,57 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 #pragma unroll for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8 - R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4 - R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); - R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); - R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); - R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); - R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); - R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); - - // st.global.v4 128 bits. [Br,d] - if (lane_id % 4 == 0) { - // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 - int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; - int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 - // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) - int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; - int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) - int store_gmem_O_addr_0 = ( - O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); - int store_gmem_O_addr_1 = ( - O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); - LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]); - LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]); - } + + if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) { + // reuse R_Q[4/8][1][4] for collective store. + R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4 + R_Q[0][0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Q[0][0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Q[0][0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Q[1][0][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Q[1][0][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Q[1][0][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0][0]); + } + } else { + // we have to use new R_Z regs for collective store. + uint32_t R_Z[2][4]; + R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4 + R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]); + } + } // end if kCanPrefetchQs2r } // end for kWarpTileHeadDimV } // end for kWarpTileSeqLenQ } @@ -645,26 +690,32 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, template void launch_flash_attn_mma_stages_split_q_shared_kv( torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + // Tile BrxBc=128x64 constexpr int kMmaAtomM = 16; constexpr int kMmaAtomN = 8; constexpr int kMmaAtomK = 16; - constexpr int kMmaTileSeqLenQ = 4; + // constexpr int kMmaTileSeqLenQ = 4; + constexpr int kMmaTileSeqLenQ = 8; constexpr int kMmaTileSeqLenK = 1; - constexpr int kMmaTileSeqLenP = 4; + // constexpr int kMmaTileSeqLenP = 4; + constexpr int kMmaTileSeqLenP = 8; constexpr int kMmaTileHeadDimV = 1; constexpr int kWarpTileSeqLenQ = 1; constexpr int kWarpTileSeqLenK = 8; + // constexpr int kWarpTileSeqLenK = 16; constexpr int kWarpTileSeqLenP = 1; constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,.... constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads constexpr int kPad = 8; + + // static int kMaxSramPerBlock; + // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); - // Calculate SRAM size needed per block, Q,K,V smem size + // Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem. const int smem_max_size = ((Br * (kHeadDim + kPad)) + - (kStage * kHeadDim * (Bc + kPad)) + - (Bc * (kHeadDim + kPad))) * sizeof(half); + (kStage * kHeadDim * (Bc + kPad))) * sizeof(half); const int QKV_batch = Q.size(0); const int QKV_head = Q.size(1); @@ -692,6 +743,7 @@ void launch_flash_attn_mma_stages_split_q_shared_kv( kPad >, cudaFuncAttributeMaxDynamicSharedMemorySize, + // kMaxSramPerBlock 98304 ); @@ -730,25 +782,15 @@ void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q, CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] const int d = Q.size(3); // B, H, N, d - if (stages == 2) { - switch (d) - { - case 64: - launch_flash_attn_mma_stages_split_q_shared_kv<64, 2>(Q, K, V, O); - break; - case 96: - launch_flash_attn_mma_stages_split_q_shared_kv<96, 2>(Q, K, V, O); - break; - case 128: - launch_flash_attn_mma_stages_split_q_shared_kv<128, 2>(Q, K, V, O); - break; - default: - throw std::runtime_error("headdim not support!"); - break; - } + if (stages > 1) { + throw std::runtime_error( + "split_q_shared_kv not support stages>1 now!"); } else { switch (d) { + case 32: + launch_flash_attn_mma_stages_split_q_shared_kv<32, 1>(Q, K, V, O); + break; case 64: launch_flash_attn_mma_stages_split_q_shared_kv<64, 1>(Q, K, V, O); break; diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu new file mode 100644 index 00000000..aad3d5f2 --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu @@ -0,0 +1,814 @@ +#include "utils.h" + +// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction. +// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. +// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. + +// The FlashAttention-2 algorithm is described in the following paper: +// https://arxiv.org/pdf/2307.08691 + +// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d] +// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d] + +// Split Q across MMA(Warps) and keep access KV for all MMA(Warps), +// in order to reduce the comm between warps via smem and warp shuffle. + +// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps +// | 64x64 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x8) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x8) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x8) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + +// MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps +// | 128x128 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x16) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x16) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x16) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x16) | +// | warp_QP 4 | MMA 4 ... MMA 4 (x16) | +// | warp_QP 5 | MMA 5 ... MMA 5 (x16) | +// | warp_QP 6 | MMA 6 ... MMA 6 (x16) | +// | warp_QP 7 | MMA 7 ... MMA 7 (x16) | + +// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps +// | 128x64 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x8) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x8) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x8) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x8) | +// | warp_QP 4 | MMA 4 ... MMA 4 (x8) | +// | warp_QP 5 | MMA 5 ... MMA 5 (x8) | +// | warp_QP 6 | MMA 6 ... MMA 6 (x8) | +// | warp_QP 7 | MMA 7 ... MMA 7 (x8) | + +template< + const int kHeadDim, // Headdim, 32,64,128 + const int kMmaAtomM, // MMA Atom M, 16 + const int kMmaAtomN, // MMA Atom N, 8 + const int kMmaAtomK, // MMA Atom K, 16 + const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M + const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N + const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M + const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... + const int kStage, + const int kPad + > +__global__ void __launch_bounds__( + WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK) +flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen) { + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. + static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 + static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T + static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V + static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T + // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc. + // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128. + static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( + kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V + static_assert(kStage > 0 && kStage < 3); // 1 or 2 + static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16 + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 + static_assert(Br >= Bc); // for shared memory reuse. + constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads + // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. + const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc] + const float scale = 1.0f / sqrt((float) kHeadDim); + + // Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma) + const int QKV_batch_id = blockIdx.x; // Batch size, bx + const int QKV_head_id = blockIdx.y; // Head num, by + const int Q_tile_id = blockIdx.z; // Q tile_id, range [0, Tr), bz. + const int O_tile_id = Q_tile_id; // O tile_id, same as Q. + const int tid = threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_QP = warp_id; // 0,1,2,3 or 0~7 + const int warp_KV = 0; // 0 + // MMA Layout [Br,Bc]=[64,64], MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + // MMA Layout [Br,Bc]=[128,128], MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps + // | 128x128 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x16) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x16) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x16) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x16) | + // | warp_QP 4 | MMA 4 ... MMA 4 (x16) | + // | warp_QP 5 | MMA 5 ... MMA 5 (x16) | + // | warp_QP 6 | MMA 6 ... MMA 6 (x16) | + // | warp_QP 7 | MMA 7 ... MMA 7 (x16) | + const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] + const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) + + (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen] + const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] + const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] + + // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads. + int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64 + int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,... + // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads. + int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64 + int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,... + // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads. + int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... + // global Q row of current head for tile [Br,d] per block. + int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br; + if (load_gmem_Q_Br >= QKV_seqlen) return; + // KV tile gmem load index starts from 0 and increments with + // each iteration as we loop over seqlen. + int load_gmem_K_Bc_offset = 0; + int load_gmem_V_Bc_offset = 0; + + // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1 + extern __shared__ half smem[]; + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + // constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M + // constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M + constexpr int KV_tile_size = ( + ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ? + ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad)) + ); + // K multi-stages: currently, only apply multi stages for K across seq_len. + half* Q_tile_smem = smem; // 8M/16M + half* K_tile_smem = Q_tile_smem; // QKV shared the same smem + half* V_tile_smem = Q_tile_smem; // QKV shared the same smem + // NOTE: KV may shared same smem to reduce smem usage for kStage 1 + // stage 1, w shared KV smem, Br=Bc=64, d=64: 8M=8M, +Pad(1M) = 9M + // stage 1, w shared KV smem, Br=Bc=64, d=128: 16M=16M, +Pad(1M) = 17M + // stage 1, w shared KV smem, Br=Bc=128, d=64: 32M=32M, +Pad(2M) = 35M + + uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem); + uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem); + uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem); + + // --------------------- Registers/SMEM for thread block ------------------------- + // block m_old, l_old, store in lane, use float to keep precision. + float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2] + float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2] + fill_2D_regs(lane_block_row_max_old, -INFINITY); + fill_2D_regs(lane_block_row_sum_old, 0.0f); + + // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- + // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. + // TODO: Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs. + // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. + // Then we can load Q from smem only once and reuse it for + // processes. This will reduce large io-access for Q smem while N is large. + static_assert(kHeadDim <= 128, "shared_qkv only support headdim<=128"); + static_assert(kHeadDim >= 32, "shared_qkv only support headdim>=32"); + constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8); // always true. + // Use kStage and (Br / Bc) to control multi-stage policy for K g2s. + constexpr bool kCanPrefetchKg2s = ( + ((Q_tile_size / KV_tile_size) >= 2) && (kStage >= 2)); // for d<=64 is true. + constexpr int kPrefetchStageKg2s = kCanPrefetchKg2s ? 2 : 1; // only apply stage 2 for k prefetch. + constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1; + uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4] + uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2] + uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2] + // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc] + // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs. + uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2] + // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs. + // TODO: may reuse R_D as R_O? kWarpTileSeqLenP=kWarpTileSeqLenQ. + uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] + // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. + uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] + fill_3D_regs(R_S, 0); + fill_3D_regs(R_D, 0); + fill_3D_regs(R_O, 0); + + // load Q from gmem -> smem, only load once. + { + int load_gmem_Q_d = load_smem_Q_d; + int load_gmem_Q_addr = (Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d); + uint32_t load_smem_Q_ptr = (smem_Q_base_ptr + ( + load_smem_Q_Br * (kHeadDim + kPad) + load_smem_Q_d) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Br)); i += 8) { + CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // : for K^T[d,seqlen] with K^T_tile[d,Bc] + // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] + #pragma unroll 1 + for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { + // TODO: process last tile_K_seqlen ? pad to multiple of 8. + // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; + int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s; + // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; + int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s; + + // Wait Q ready and let K copy async, then prefetch Q from smem -> regs. + // NOTE: we only need to load Q once from smem -> regs, and then reuse it. + static_assert(kCanPrefetchQs2r); // always prefetch Q s2r. + if (tile_K_seqlen == 0) { + // TODO: Full share QKV smem after Q is ready load to regs. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + + #pragma unroll + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs. + // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. + // Then we can load Q from smem only once and reuse it for + // processes. This will reduce large io-access for Q smem while N is large. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4] + } + } + } + + // Load K tile from gmem -> smem + if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + if (tile_K_seqlen == 0) { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + // Now, we have to wait curr K tile ready for Q@K^T MMA. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + } else { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + // Now, we have to wait curr K tile ready for Q@K^T MMA. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] + // Matmul with NN layout, Q row major, K row major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + fill_3D_regs(R_S, 0); + #pragma unroll + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. + // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) + int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); + int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) + uint32_t lane_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * KV_tile_size + + lane_smem_K_d * (Bc + kPad) + + lane_smem_K_Bc) * sizeof(half) + ); + LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + } // end for kWarpTileSeqLenK + + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + } // end loop over d, S=Q@K^T + __syncthreads(); + + // Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages), + // before rowmax and rowsum, load V from gmem -> smem. + // TODO: Can we support stages 2 for V g2s? + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (smem_sel * KV_tile_size + + load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + if ((tile_K_seqlen + 1) < Tc) { + load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel_next * KV_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + } + + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + + // Online safe softmax, warp/block reduce max/sum, row wise + float lane_row_max_new[kWarpTileSeqLenQ][2]; // [1][2] + float lane_row_sum_new[kWarpTileSeqLenQ][2]; // [1][2] + fill_2D_regs(lane_row_max_new, -INFINITY); + fill_2D_regs(lane_row_sum_new, 0.0f); + + // Row max for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc. + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for C. (m16n8k16) + // Row\Col 0 1 2 3 4 5 6 7 + // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1} + // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1} + // 2 ... + // ... + // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1} + // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3} + // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3} + // 10 ... + // ... + // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3} + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // This should be the row max after S = (Q @ K^T) / sqrt(d) + float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale; + float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale; + lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0); + lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce max, warp_size = 4 + // Each thread contains the maximum of 2 rows of Br, + // and only the values of T0, T4, ..., T28 are used. + lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]); + lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]); + } // end for kWarpTileSeqLenQ + // __syncthreads(); + + // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Use latest global row max without update. + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; + float block_row_max_new_0 = lane_row_max_new[i][0]; + // Br 1, row_id, 8~15, 24~31, 40~47, 56~63; + float block_row_max_new_1 = lane_row_max_new[i][1]; + + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + // Apply m_new = max(m_old, m_new) here. + block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0); + block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1); + + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z; + t_reg_S_0.x = __expf(__fmaf_rn(t_reg_S_0.x, scale, - block_row_max_new_0)); + t_reg_S_0.y = __expf(__fmaf_rn(t_reg_S_0.y, scale, - block_row_max_new_0)); + t_reg_S_1.x = __expf(__fmaf_rn(t_reg_S_1.x, scale, - block_row_max_new_1)); + t_reg_S_1.y = __expf(__fmaf_rn(t_reg_S_1.y, scale, - block_row_max_new_1)); + lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y); + lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y); + // Update R_S for P[Br,Bc] = Exp(S-m), point wise. + HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0); + HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce sum, warp_size = 4 + lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]); + lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]); + } // end for kWarpTileSeqLenQ + + // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. + // Here, we have to wait V ready before compute O = P @ V + if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + if ((tile_K_seqlen + 1) < Tc) { + CP_ASYNC_WAIT_GROUP(1); + } else { + CP_ASYNC_WAIT_GROUP(0); + } + } else { + CP_ASYNC_WAIT_GROUP(0); + } + __syncthreads(); + + // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention. + // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major. + // Make sure to clear the states in R_O before MMA for P@V for each step. + + // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these + // registers for P(A) matrix directly ? How to do that ? + // according to the A matrix layout for MMA m16n8k16 instruction. + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for A matrix with .f16. + // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5} + // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5} + // 2 (dashed arrow pointing right) + // ... + // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5} + // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7} + // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7} + // 10 (dashed arrow pointing right) + // ... + // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7} + + fill_3D_regs(R_O, 0); + #pragma unroll + for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { + // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans. + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N + int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K + int lane_smem_V_d = warp_smem_V_d; // 0 + uint32_t lane_smem_V_ptr = ( + smem_V_base_ptr + (smem_sel * KV_tile_size + + lane_smem_V_Bc * (kHeadDim + kPad) + + lane_smem_V_d) * sizeof(half) + ); + LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V + } + + // For R_S[1][8][2], mapping the layout below of P matrix. + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + // tile_V_Bc = 0, all curr MMAs(0~4) need slice P[:, 0:16], 0, 1; stored in all MMAs. + // tile_V_Bc = 1, all curr MMAs(0~4) need slice P[:, 16:32], 2, 3; stored in all MMAs. + // tile_V_Bc = 2, all curr MMAs(0~4) need slice P[:, 32:48], 4, 5; stored in all MMAs. + // tile_V_Bc = 3, all curr MMAs(0~4) need slice P[:, 48:64], 6, 7; stored in all MMAs. + int w = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6 + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + HMMA16816(R_O[i][j][0], R_O[i][j][1], + R_S[i][w][0], R_S[i][w][1], R_S[i][w + 1][0], R_S[i][w + 1][1], + R_V[j][0], R_V[j][1], + R_O[i][j][0], R_O[i][j][1]); + } + } + } // end for V Bc. + __syncthreads(); + + // Rescale O -> Update row sum Exp -> then, Update row max. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1 + // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper) + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63 + float block_row_max_new_0 = lane_row_max_new[i][0]; + float block_row_max_new_1 = lane_row_max_new[i][1]; + float block_row_sum_new_0 = lane_row_sum_new[i][0]; + float block_row_sum_new_1 = lane_row_sum_new[i][1]; + + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + // NOTE: max(-inf, val) = val. + block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0); + block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1); + // Avoid inf value while using m_old for rescaling O. + block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 : + block_row_max_new_0); + block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 : + block_row_max_new_1); + + // rescale factor for O and l, exp(m_old - m) + float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0); + float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1); + // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old. + // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3} + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + // Note that the formula in the FA2 paper is incorrect; here, + // the inverse of the exp function should not be taken, as it + // would result in an error during rescaling, namely, you have + // use exp(m_old - m_new), not 1/(m_old - m_new). + // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V + t_reg_D_0.x = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.x, t_reg_O_0.x); + t_reg_D_0.y = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.y, t_reg_O_0.y); + t_reg_D_1.x = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.x, t_reg_O_1.x); + t_reg_D_1.y = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.y, t_reg_O_1.y); + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } // end for kWarpTileHeadDimV. + + // Now, we can update m, l after O has been scaled. + // 1. First, update block row sum Exp for each lane which + // need both m_new and m_old. + float block_row_sum_old_0 = lane_block_row_sum_old[i][0]; + float block_row_sum_old_1 = lane_block_row_sum_old[i][1]; + // Update l = exp(m_old - m_new) * l_old + row_sum(P). + lane_block_row_sum_old[i][0] = (__fmaf_rn( + rescale_o_factor_0, block_row_sum_old_0, block_row_sum_new_0)); + lane_block_row_sum_old[i][1] = (__fmaf_rn( + rescale_o_factor_1, block_row_sum_old_1, block_row_sum_new_1)); + // 2. Then, update block row max for each lane. + lane_block_row_max_old[i][0] = block_row_max_new_0; + lane_block_row_max_old[i][1] = block_row_max_new_1; + } + + if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) { + if ((tile_K_seqlen + 1) < Tc) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + } + + } // end loop over N + __syncthreads(); + + // Finaly, we still have to rescale O once more. + // O_output(D) = ( 1/l_final ) * O_final (FA2 paper) + // NOTE: Here, we choose to reuse R_O as final output + // in order to reduce regs usage. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]); + float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]); + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x; + t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y; + t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x; + t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y; + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } + } + + // Store O(D): Write O[Br,d] from regs -> gmem, collective store + // with reg reuse & warp shuffle. need R_Z[2][4]. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8 + + if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) { // always true for shared qkv kernel + // reuse R_Q[4/8][1][4] for collective store. + R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4 + R_Q[0][0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Q[0][0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Q[0][0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Q[1][0][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Q[1][0][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Q[1][0][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0][0]); + } + } else { + // we have to use new R_Z regs for collective store. + uint32_t R_Z[2][4]; + R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4 + R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]); + } + } // end if kCanPrefetchQs2r + } // end for kWarpTileHeadDimV + } // end for kWarpTileSeqLenQ +} + +// Launch kernel for flash_attn_mma_stages_split_q +template +void launch_flash_attn_mma_stages_split_q_shared_qkv( + torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + // Tile BrxBc=128x64 + constexpr int kMmaAtomM = 16; + constexpr int kMmaAtomN = 8; + constexpr int kMmaAtomK = 16; + // constexpr int kMmaTileSeqLenQ = 4; + constexpr int kMmaTileSeqLenQ = 8; + constexpr int kMmaTileSeqLenK = 1; + // constexpr int kMmaTileSeqLenP = 4; + constexpr int kMmaTileSeqLenP = 8; + constexpr int kMmaTileHeadDimV = 1; + constexpr int kWarpTileSeqLenQ = 1; + constexpr int kWarpTileSeqLenK = 8; + // constexpr int kWarpTileSeqLenK = 16; + constexpr int kWarpTileSeqLenP = 1; + constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,.... + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 + constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads + constexpr int kPad = 8; + if constexpr (kStage > 1) { + static_assert(((Br / Bc) >= 2)); + } + + // static int kMaxSramPerBlock; + // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); + + // Calculate SRAM size needed per block, QKV smem size, QKV fully shared the same smem. + const int smem_max_size = (Br * (kHeadDim + kPad)) * sizeof(half); // 128x(32/64/128)x2/1024=8/16/32M + + const int QKV_batch = Q.size(0); + const int QKV_head = Q.size(1); + const int QKV_seqlen = Q.size(2); // QKV_seqlen + assert(QKV_seqlen % Bc == 0); // multiple of Bc=64 + + dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br) + dim3 block(kNumThreads); // 4/8 warps per block + + cudaFuncSetAttribute( + flash_attn_mma_stages_split_q_shared_qkv_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + >, + cudaFuncAttributeMaxDynamicSharedMemorySize, + // kMaxSramPerBlock + 98304 + ); + + flash_attn_mma_stages_split_q_shared_qkv_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + ><<>>( + reinterpret_cast(Q.data_ptr()), + reinterpret_cast(K.data_ptr()), + reinterpret_cast(V.data_ptr()), + reinterpret_cast(O.data_ptr()), + QKV_seqlen + ); +} + +void flash_attn_mma_stages_split_q_shared_qkv(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages) { + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed. + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + const int d = Q.size(3); // B, H, N, d + + if (stages > 1) { + switch (d) + { + case 32: + launch_flash_attn_mma_stages_split_q_shared_qkv<32, 2>(Q, K, V, O); + break; + case 64: + launch_flash_attn_mma_stages_split_q_shared_qkv<64, 2>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages_split_q_shared_qkv<96, 2>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages_split_q_shared_qkv<128, 2>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } else { + switch (d) + { + case 32: + launch_flash_attn_mma_stages_split_q_shared_qkv<32, 1>(Q, K, V, O); + break; + case 64: + launch_flash_attn_mma_stages_split_q_shared_qkv<64, 1>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages_split_q_shared_qkv<96, 1>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages_split_q_shared_qkv<128, 1>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } +} diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu index 1d37916c..6fbebb8f 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu @@ -816,6 +816,9 @@ void flash_attn_mma_stages_split_kv(torch::Tensor Q, if (stages == 2) { switch (d) { + case 32: + launch_flash_attn_mma_stages_split_kv<32, 2>(Q, K, V, O); + break; case 64: launch_flash_attn_mma_stages_split_kv<64, 2>(Q, K, V, O); break; @@ -832,6 +835,9 @@ void flash_attn_mma_stages_split_kv(torch::Tensor Q, } else { switch (d) { + case 32: + launch_flash_attn_mma_stages_split_kv<32, 1>(Q, K, V, O); + break; case 64: launch_flash_attn_mma_stages_split_kv<64, 1>(Q, K, V, O); break; diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu index ede19899..a5f96c23 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu @@ -733,6 +733,9 @@ void flash_attn_mma_stages_split_q(torch::Tensor Q, if (stages == 2) { switch (d) { + case 32: + launch_flash_attn_mma_stages_split_q<32, 2>(Q, K, V, O); + break; case 64: launch_flash_attn_mma_stages_split_q<64, 2>(Q, K, V, O); break; @@ -749,6 +752,9 @@ void flash_attn_mma_stages_split_q(torch::Tensor Q, } else { switch (d) { + case 32: + launch_flash_attn_mma_stages_split_q<32, 1>(Q, K, V, O); + break; case 64: launch_flash_attn_mma_stages_split_q<64, 1>(Q, K, V, O); break; diff --git a/kernels/flash-attn/pybind/flash_attn.cc b/kernels/flash-attn/pybind/flash_attn.cc index a7343f21..cf5e030b 100644 --- a/kernels/flash-attn/pybind/flash_attn.cc +++ b/kernels/flash-attn/pybind/flash_attn.cc @@ -5,10 +5,33 @@ #define TORCH_BINDING_COMMON_EXTENSION(func) \ m.def(STRINGFY(func), &func, STRINGFY(func)); -void flash_attn_mma_stages_split_kv(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O, int stages); -void flash_attn_mma_stages_split_q(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O, int stages); +void flash_attn_mma_stages_split_kv(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages); + +void flash_attn_mma_stages_split_q(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages); + +void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages); + +void flash_attn_mma_stages_split_q_shared_qkv(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_kv) TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q) + TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_kv) + TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_qkv) }