From 0c6785f0d023fe85fae618ad80359d340e334644 Mon Sep 17 00:00:00 2001
From: DefTruth <31974251+DefTruth@users.noreply.github.com>
Date: Tue, 17 Dec 2024 11:40:37 +0800
Subject: [PATCH] =?UTF-8?q?[FA2]=20Release=20flash-attn-mma=20shared-kv/qk?=
=?UTF-8?q?v=F0=9F=8E=89=20(#162)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update flash_attn_mma_share_kv.cu
* Create flash_attn_mma_share_qkv.cu
* Update flash_attn_mma_split_q.cu
* Update flash_attn_mma_split_q.cu
* Update flash_attn_mma_split_kv.cu
* Update flash_attn.cc
* Update flash_attn_mma.py
* Update flash_attn_mma.py
* Update flash_attn_mma.py
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update flash_attn_mma_share_qkv.cu
---
README.md | 101 ++-
kernels/flash-attn/README.md | 196 +++--
kernels/flash-attn/flash_attn_mma.py | 31 +-
.../flash-attn/mma/flash_attn_mma_share_kv.cu | 414 +++++----
.../mma/flash_attn_mma_share_qkv.cu | 814 ++++++++++++++++++
.../flash-attn/mma/flash_attn_mma_split_kv.cu | 6 +
.../flash-attn/mma/flash_attn_mma_split_q.cu | 6 +
kernels/flash-attn/pybind/flash_attn.cc | 27 +-
8 files changed, 1280 insertions(+), 315 deletions(-)
create mode 100644 kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
diff --git a/README.md b/README.md
index d0dd4f2c..3d7a7972 100644
--- a/README.md
+++ b/README.md
@@ -42,24 +42,49 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
|✔️|✔️|✔️|✔️|
-I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.
+
+I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Fully Sahred QKV SMEM, Prefetch Q s2r, Collective Store, etc. Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop.
![flash-attn-mma](https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2)
+- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
+```bash
+python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
+------------------------------------------------------------------------------------------------------------------------
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
+------------------------------------------------------------------------------------------------------------------------
+ B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
+ mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08
+ mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31
+ mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54
+ mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37
+ mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15
+ mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01
+ mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42
+ (flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22
+-----------------------------------------------------------------------------------------------------------------------
+```
+
+However, for large-scale attention computations, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.
-|CUDA Cores|Sliced K (Loop over N/D)|Tile Block (Br, Bc, Bd)|MMA (m16n8k16)|
+|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
-|Pack LDST (128 bits)|SMEM Padding|Copy Async |Tile MMAs (More Threads)
+|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)
|✔️|✔️|✔️|✔️|
-|Tile Warps (More Values)|Multi Stages (1/2)| Collective Store (Shfl)| **Split KV/Q** |
+|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
+|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle|
+|✔️|✔️|✔️|?|
-The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
+The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
-![flash-attn](https://github.com/user-attachments/assets/11490fbc-2a4a-4630-abe8-91a9d1251cba)
+
- 📚 Split KV (Basic, FlashAttention-1)
+
```C++
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
@@ -69,22 +94,6 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
-template<
- const int kHeadDim, // Headdim, 32,64,128
- const int kMmaAtomM, // MMA Atom M, 16
- const int kMmaAtomN, // MMA Atom N, 8
- const int kMmaAtomK, // MMA Atom K, 16
- const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M
- const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N
- const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M
- const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|...
- const int kStage, // only support 1 or 2 now.
- const int kPad // 0,8
- >
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
@@ -94,6 +103,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
```
- 📚 Split Q (Faster, FlashAttention-2)
+
```C++
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
@@ -104,22 +114,6 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
-template<
- const int kHeadDim, // Headdim, 32,64,128
- const int kMmaAtomM, // MMA Atom M, 16
- const int kMmaAtomN, // MMA Atom N, 8
- const int kMmaAtomK, // MMA Atom K, 16
- const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
- const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
- const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
- const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
- const int kStage, // only support 1 or 2 now.
- const int kPad // 0,8
- >
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
@@ -127,6 +121,33 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
```
+
+- 📚 Split Q + Shared KV SMEM (Faster+)
+
+
+```C++
+// K, V shared the same shared memory, improve block occupancy.
+__global__ void
+flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
+ half* K,
+ half* V,
+ half* O,
+ int QKV_seqlen);
+```
+- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
+
+
+
+```C++
+// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
+__global__ void
+flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
+ half* K,
+ half* V,
+ half* O,
+ int QKV_seqlen);
+```
+
## ©️Citations🎉🎉
```BibTeX
@@ -144,11 +165,13 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
-|📖 CUDA Kernel| 📖 Elem dtype| 📖 Acc dtype| 📖 Docs | 📖 Level |
+|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
|:---|:---|:---|:---|:---|
| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️|
| ✔️ [flash_attn_mma_stages_split_kv*](./kernels/flash-attn/mma/flash_attn_mma_split_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages_split_q*](./kernels/flash-attn/mma/flash_attn_mma_split_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
+| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
+| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [sgemm_naive_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️|
| ✔️ [sgemm_sliced_k_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md
index dcc4726d..5717d994 100644
--- a/kernels/flash-attn/README.md
+++ b/kernels/flash-attn/README.md
@@ -2,33 +2,56 @@
![flash-attn-mma](https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2)
-|CUDA Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc, Bd)|MMA (m16n8k16)|
+|Tensor Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
+|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle|
+|✔️|✔️|✔️|?|
-This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention (SeqLen <= 4096), the flash-attention-mma implemented in this repository matches the performance of the official FA. However, for large-scale attention computations, there remains a significant performance gap. Performance optimizations are ongoing; stay tuned for updates.
+This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates.
+
+- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
+```bash
+python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
+------------------------------------------------------------------------------------------------------------------------
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
+------------------------------------------------------------------------------------------------------------------------
+ B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
+ mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08
+ mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31
+ mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54
+ mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37
+ mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15
+ mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01
+ mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42
+ (flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22
+-----------------------------------------------------------------------------------------------------------------------
+```
## 📖 Contents
-- [📖 Split KV](#mma-split-kv)
-- [📖 Split Q](#mma)
+- [📖 FlashAttetion MMA Kernels](#mma)
+ - [📚 Split KV](#mma-split-kv)
+ - [📚 Split Q ](#mma-split-q)
+ - [📚 Shared KV SMEM](#mma-share-kv)
+ - [📚 Fully Shared QKV SMEM](#mma-share-qkv)
- [📖 Prerequisites](#prerequisites)
- [📖 Installation](#install)
- [📖 Performance](#perf)
- [📖 Python Testing](#test)
-
+
## 📖 FlashAttetion MMA Kernels
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](.) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps) using a naive matmul (MMA) and Warp tiling policy, is slower compared to the `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
-
+
+- 📚 Split KV (Basic, FlashAttention-1)
```C++
@@ -39,22 +62,6 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
-template<
- const int kHeadDim, // Headdim, 32,64,128
- const int kMmaAtomM, // MMA Atom M, 16
- const int kMmaAtomN, // MMA Atom N, 8
- const int kMmaAtomK, // MMA Atom K, 16
- const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M
- const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N
- const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M
- const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|...
- const int kStage, // only support 1 or 2 now.
- const int kPad // 0,8
- >
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
@@ -63,7 +70,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
int QKV_seqlen);
```
-## 📖 Split Q (Faster, FlashAttention-2)
+- 📚 Split Q (Faster, FlashAttention-2)
```C++
@@ -75,22 +82,6 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
-template<
- const int kHeadDim, // Headdim, 32,64,128
- const int kMmaAtomM, // MMA Atom M, 16
- const int kMmaAtomN, // MMA Atom N, 8
- const int kMmaAtomK, // MMA Atom K, 16
- const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
- const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
- const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
- const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
- const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
- const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
- const int kStage, // only support 1 or 2 now.
- const int kPad // 0,8
- >
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
@@ -99,6 +90,31 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
int QKV_seqlen);
```
+- 📚 Split Q + Shared KV SMEM (Faster+)
+
+
+```C++
+// K, V shared the same shared memory, improve block occupancy.
+__global__ void
+flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
+ half* K,
+ half* V,
+ half* O,
+ int QKV_seqlen);
+```
+- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
+
+
+
+```C++
+// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
+__global__ void
+flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
+ half* K,
+ half* V,
+ half* O,
+ int QKV_seqlen);
+```
## 📖 Prerequisites
@@ -117,7 +133,7 @@ pip install flash-attn --no-build-isolation # need offical flash-attention for c
## 📖 Performance
-Currently, for small-scale attention (SeqLen <= 4096), the flash-attention-mma implemented in this repository matches the performance of the official FA version. However, for large-scale attention computations, there remains a significant performance gap. Performance optimizations are ongoing; stay tuned for updates.
+Currently, for small-scale attention (B<=4, H <=48, SeqLen <= 8192), the flash-attention-mma implemented in this repository matches the performance of the official FA version. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates.
## 📖 Python Testing
@@ -134,61 +150,91 @@ python3 flash_attn_mma.py --D 64 # test all default settings for D=64
- B=2, H=2, N=4096, D=64
```bash
-python3 flash_attn_mma.py --B 2 --H 2 --D 64 --N 4096 # NVIDIA RTX 3080 Laptop
+python3 flash_attn_mma.py --B 2 --H 2 --D 64 --N 4096 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
- B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 6827, Warmup: 2, Iters: 10
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 9655, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
- B=2, H=2, N=4096, D=64, Warmup: 2, Iters: 10
- mma(split-kv+stage1): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.767565ms, TFLOPS:22.82
- mma(split-kv+stage2): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.730205ms, TFLOPS:23.99
- mma(split-q+stage1): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.524163ms, TFLOPS:33.41
- mma(split-q+stage2): ['-0.02688599 ', '0.03140259 ', '-0.03656006 '], time:0.622582ms, TFLOPS:28.13
- (flash): ['-0.02687073 ', '0.03143311 ', '-0.03656006 '], time:0.610447ms, TFLOPS:28.69
+ B=2, H=2, N=4096, D=64, Warmup: 1, Iters: 10
+ mma(split-kv+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.765753ms, TFLOPS:22.87
+ mma(split-kv+stage2): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.731516ms, TFLOPS:23.94
+ mma(split-q+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.526834ms, TFLOPS:33.24
+ mma(split-q+stage2): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.660753ms, TFLOPS:26.51
+ mma(split-q+share-kv+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.460815ms, TFLOPS:38.01
+ mma(split-q+share-qkv+stage1): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.465345ms, TFLOPS:37.64
+ mma(split-q+share-qkv+stage2): ['0.01901245 ', '-0.02037048 ', '-0.01722717 '], time:0.474334ms, TFLOPS:36.92
+ (flash): ['0.01904297 ', '-0.02037048 ', '-0.01724243 '], time:0.596189ms, TFLOPS:29.38
------------------------------------------------------------------------------------------------------------------------
```
- B=2, H=2, N=8192, D=64
```bash
-python3 flash_attn_mma.py --B 2 --H 2 --D 64 --N 8192 # NVIDIA RTX 3080 Laptop
+ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
- B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1923, Warmup: 2, Iters: 10
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 5669, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
- B=2, H=2, N=8192, D=64, Warmup: 2, Iters: 10
- mma(split-kv+stage1): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:2.870488ms, TFLOPS:24.41
- mma(split-kv+stage2): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:2.599239ms, TFLOPS:26.95
- mma(split-q+stage1): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:1.826215ms, TFLOPS:38.36
- mma(split-q+stage2): ['-0.01074219 ', '-0.00759125 ', '0.02301025 '], time:2.142096ms, TFLOPS:32.71
- (flash): ['-0.01076508 ', '-0.0075798 ', '0.02301025 '], time:2.061176ms, TFLOPS:33.99
+ B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
+ mma(split-kv+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:5.572367ms, TFLOPS:25.15
+ mma(split-kv+stage2): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:5.295920ms, TFLOPS:26.46
+ mma(split-q+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:3.607082ms, TFLOPS:38.85
+ mma(split-q+stage2): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:4.600883ms, TFLOPS:30.45
+ mma(split-q+share-kv+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:2.744508ms, TFLOPS:51.05
+ mma(split-q+share-qkv+stage1): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:2.700114ms, TFLOPS:51.89
+ mma(split-q+share-qkv+stage2): ['-0.0087738 ', '0.012146 ', '-0.01319122 '], time:2.692103ms, TFLOPS:52.05
+ (flash): ['-0.00882721 ', '0.01213074 ', '-0.01314545 '], time:3.778219ms, TFLOPS:37.09
------------------------------------------------------------------------------------------------------------------------
```
- B=1, H=8, N=8192, D=64
```bash
-python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 # NVIDIA RTX 3080 Laptop
-------------------------------------------------------------------------------------------------------------------------
- B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 4374, Warmup: 2, Iters: 10
+python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
- B=1, H=8, N=8192, D=64, Warmup: 2, Iters: 10
- mma(split-kv+stage1): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:5.583835ms, TFLOPS:25.09
- mma(split-kv+stage2): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:5.325174ms, TFLOPS:26.31
- mma(split-q+stage1): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:3.675842ms, TFLOPS:38.12
- mma(split-q+stage2): ['-0.01475525 ', '-0.01394653 ', '-0.02441406 '], time:4.370213ms, TFLOPS:32.06
- (flash): ['-0.01470184 ', '-0.01394653 ', '-0.02435303 '], time:3.680992ms, TFLOPS:38.07
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
+ B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
+ mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08
+ mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31
+ mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54
+ mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37
+ mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15
+ mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01
+ mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42
+ (flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22
+-----------------------------------------------------------------------------------------------------------------------
```
- B=1, H=48, N=8192, D=64
```bash
-python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 # NVIDIA RTX 3080 Laptop
+python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
+------------------------------------------------------------------------------------------------------------------------
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 4669, Warmup: 1, Iters: 10
+------------------------------------------------------------------------------------------------------------------------
+ B=1, H=48, N=8192, D=64, Warmup: 1, Iters: 10
+ mma(split-kv+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:42.534423ms, TFLOPS:19.77
+ mma(split-kv+stage2): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:42.349815ms, TFLOPS:19.85
+ mma(split-q+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:35.657477ms, TFLOPS:23.58
+ mma(split-q+stage2): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:36.065412ms, TFLOPS:23.31
+ mma(split-q+share-kv+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:23.619652ms, TFLOPS:35.59
+ mma(split-q+share-qkv+stage1): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:23.893070ms, TFLOPS:35.19
+ mma(split-q+share-qkv+stage2): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:23.590446ms, TFLOPS:35.64
+ (flash): ['-0.01280212 ', '-0.02825928 ', '0.0146637 '], time:22.385812ms, TFLOPS:37.56
+------------------------------------------------------------------------------------------------------------------------
+```
+
+- B=1, H=8, N=8192, D=32
+```bash
+python3 flash_attn_mma.py --B 1 --H 8 --D 32 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
- B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 8331, Warmup: 2, Iters: 10
+ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 2322, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
- B=1, H=48, N=8192, D=64, Warmup: 2, Iters: 10
- mma(split-kv+stage1): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:42.588711ms, TFLOPS:19.74
- mma(split-kv+stage2): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:42.275143ms, TFLOPS:19.89
- mma(split-q+stage1): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:37.420964ms, TFLOPS:22.47
- mma(split-q+stage2): ['-0.01500702 ', '0.00946045 ', '0.03683472 '], time:37.678123ms, TFLOPS:22.31
- (flash): ['-0.0150528 ', '0.00946045 ', '0.0368042 '], time:22.342849ms, TFLOPS:37.63
+ B=1, H=8, N=8192, D=32, Warmup: 1, Iters: 10
+ mma(split-kv+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:3.930807ms, TFLOPS:18.16
+ mma(split-kv+stage2): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:3.901839ms, TFLOPS:18.30
+ mma(split-q+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.839685ms, TFLOPS:38.81
+ mma(split-q+stage2): ['-0.00607681 ', '-0.00229454 ', '0.02029419 '], time:1.511669ms, TFLOPS:47.23
+ mma(split-q+share-kv+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.400948ms, TFLOPS:50.97
+ mma(split-q+share-qkv+stage1): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.393318ms, TFLOPS:51.25
+ mma(split-q+share-qkv+stage2): ['-0.00616074 ', '-0.00230789 ', '0.02029419 '], time:1.322961ms, TFLOPS:53.97
+ (flash): ['-0.00617599 ', '-0.00231934 ', '0.02029419 '], time:1.810646ms, TFLOPS:39.43
------------------------------------------------------------------------------------------------------------------------
```
diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py
index 9a1d1f40..9ae87ddb 100644
--- a/kernels/flash-attn/flash_attn_mma.py
+++ b/kernels/flash-attn/flash_attn_mma.py
@@ -47,8 +47,8 @@ def get_args():
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--verbose", '--v', action="store_true")
- parser.add_argument("--warmup", type=int, default=2)
- parser.add_argument("--iters", type=int, default=10)
+ parser.add_argument("--warmup", type=int, default=1)
+ parser.add_argument("--iters", type=int, default=5)
parser.add_argument("--range-k", '--gk', action="store_true")
return parser.parse_args()
@@ -62,6 +62,8 @@ def get_args():
sources=[
'./mma/flash_attn_mma_split_kv.cu',
'./mma/flash_attn_mma_split_q.cu',
+ './mma/flash_attn_mma_share_kv.cu',
+ './mma/flash_attn_mma_share_qkv.cu',
'./pybind/flash_attn.cc'
],
extra_cuda_cflags=[
@@ -94,7 +96,7 @@ def get_mha_tflops(B, H, N, D, T=1.0):
flops_subtract_max = B * H * N * N # sub max
flops_exp = B * H * N * N # pointwise exp
flops_row_sum = B * H * N * (N - 1) # row sum
- flops_normalization = B * H * N * N # 归一化
+ flops_normalization = B * H * N * N # normalization
flops_safe_softmax = flops_row_max + flops_subtract_max + flops_exp + flops_row_sum + flops_normalization
@@ -118,7 +120,7 @@ def run_benchmark(perf_func: callable,
v: torch.Tensor,
tag: str,
out: Optional[torch.Tensor] = None,
- s: Optional[torch.Tensor] = None, # BUDEG
+ s: Optional[torch.Tensor] = None, # DEBUG
stages: int = -1,
warmup: int = args.warmup,
iters: int = args.iters,
@@ -173,7 +175,7 @@ def run_benchmark(perf_func: callable,
out_val = out_val_first[:2]
out_val.append(out_val_last[-1])
out_val = [f"{v:<12}" for v in out_val]
- print(f"{out_info:>25}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
+ print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
if show_all:
print(out)
time.sleep(0.05)
@@ -234,7 +236,7 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
Bs = [1, 2, 4] if not args.B else [args.B]
Hs = [1, 4, 8] if not args.H else [args.H]
-Ns = [1024, 2048] if not args.N else [args.N]
+Ns = [1024, 2048, 4096] if not args.N else [args.N]
Ds = [64, 128] if not args.D else [args.D]
# batch_size, n_head, seq_len, head_dim (B,H,N,D)
BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds]
@@ -252,14 +254,17 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
torch.cuda.synchronize()
if args.run_torch_unfused:
- out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)")
- out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1)
- out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
- out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
- out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
- out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
+ out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)")
+ out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1)
+ out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
+ out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
+ out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
+ out_mma_share_kv, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1)
+ out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
+ out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
+ out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
- out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
+ out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
print("-" * 120)
torch.cuda.synchronize()
diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
index b8e32226..5ad6bfa7 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
@@ -31,6 +31,17 @@
// | warp_QP 6 | MMA 6 ... MMA 6 (x16) |
// | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
+// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps
+// | 128x64 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+// | warp_QP 4 | MMA 4 ... MMA 4 (x8) |
+// | warp_QP 5 | MMA 5 ... MMA 5 (x8) |
+// | warp_QP 6 | MMA 6 ... MMA 6 (x8) |
+// | warp_QP 7 | MMA 7 ... MMA 7 (x8) |
+
template<
const int kHeadDim, // Headdim, 32,64,128
const int kMmaAtomM, // MMA Atom M, 16
@@ -56,19 +67,20 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
int QKV_seqlen) {
// Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major.
static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
- static_assert(kMmaTileSeqLenQ == 4 && kMmaTileSeqLenK == 1); // Q@K^T
- static_assert(kMmaTileSeqLenP == 4 && kMmaTileHeadDimV == 1); // P@V
- static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK == 8); // Q@K^T
+ static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
+ static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
+ static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T
// kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc.
// e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128.
static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V
- static_assert(kStage > 0 && kStage < 3); // 1,2
+ // TODO: support stages for shared kv smem kernel.
+ static_assert(kStage < 3 && kStage > 0);
static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
+ static_assert(Br >= Bc); // for shared memory reuse.
constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
- constexpr int kNumStoreBatchO = 2; // batch size for O collective store.
// Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen.
const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc]
const float scale = 1.0f / sqrt((float) kHeadDim);
@@ -126,19 +138,18 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1
extern __shared__ half smem[];
constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
- constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M
- constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M
+ constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
+ constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
// K multi-stages: currently, only apply multi stages for K across seq_len.
half* Q_tile_smem = smem; // 8M/16M
half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M
- half* V_tile_smem = K_tile_smem + kStage * K_tile_size;
- // TODO: KV may shared same smem to reduce smem usage for kStage 1
- // stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M
- // stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M =64M
- // stage 2, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*2+32M =128M
- // stage 1, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)+8M =24M
- // stage 1, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*1+16M =48M
- // stage 1, no shared KV smem, Br=Bc=32, d=256: 16M+(16M)*1+16M =48M
+ half* V_tile_smem = K_tile_smem; // KV shared the same smem
+ // NOTE: KV may shared same smem to reduce smem usage for kStage 1
+ // stage 1, w shared KV smem, Br=Bc=64, d=64: 8M+(8M) =16M, +Pad(2M) = 18M
+ // stage 1, w shared KV smem, Br=Bc=128, d=64: 16M+16M =32M, +Pad(2M) = 34M
+ // stage 1, w shared KV smem, Br=Bc=64, d=128: 16M+16M =32M, +Pad(4M) = 36M
+ // stage 1, w shared KV smem, Br=Bc=128, d=128: 32M+32M =64M, +Pad(4M) = 68M
+ // stage 1, w shared KV smem, Br=Bc=32, d=256: 16M+16M =32M, +Pad(1M) = 34M
uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem);
uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
@@ -153,7 +164,16 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// ---------------------- Registers for S=Q@K^T/O=P@V ----------------------------
// registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d].
- uint32_t R_Q[kWarpTileSeqLenQ][ 4]; // [1][4]
+ // TODO: Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs.
+ // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
+ // Then we can load Q from smem only once and reuse it for
+ // processes. This will reduce large io-access for Q smem while N is large.
+ // constexpr bool kCanPrefetchQs2r = false; // d <= 128
+ // FIXME(DefTruth): why can not get good performance for headdim >= 64 ?
+ // Will enable it untill I have figure out the performance issues.
+ constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64);
+ constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1;
+ uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4]
uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2]
uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2]
// registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc]
@@ -164,12 +184,9 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
// registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs.
uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
- // Helper regs for collective store, [2][4]. may use smem to reduce regs?
- uint32_t R_Z[kNumStoreBatchO][4]; // [2][4]
fill_3D_regs(R_S, 0);
fill_3D_regs(R_D, 0);
fill_3D_regs(R_O, 0);
- fill_2D_regs(R_Z, 0);
// load Q from gmem -> smem, only load once.
{
@@ -184,33 +201,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
CP_ASYNC_COMMIT_GROUP();
}
- // load K from gmem -> smem, (kStage - 1) K^T tiles, [d,Bc]
- if constexpr (kStage > 1) {
- #pragma unroll
- for (int stage = 0; stage < (kStage - 1); ++stage) {
- // update the offset of n according to stages
- load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
- uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (stage * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
- #pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
- CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
- }
- CP_ASYNC_COMMIT_GROUP();
- }
- }
-
- // wait Q and at least (kStage - 1) for K ready.
- if constexpr (kStage > 1) {
- CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2
- __syncthreads();
- }
-
// : for K^T[d,seqlen] with K^T_tile[d,Bc]
// tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc]
#pragma unroll 1
@@ -226,87 +216,61 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// and smem_sel_next will always equal 0, thus, we can not
// prefetch KV from gmem to smem before tile_K_seqlen MMA done.
- if constexpr (kStage > 1) {
- // First, prefetch curr V tile_K_seqlen [Bc,d] (no stages)
- {
- load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
- int load_gmem_V_d = load_smem_V_d;
- int load_gmem_V_addr = (
- V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
- uint32_t load_smem_V_ptr = (
- smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) +
- load_smem_V_d) * sizeof(half)
- );
- #pragma unroll
- for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
- CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
- }
- CP_ASYNC_COMMIT_GROUP();
+ // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages)
+ if (tile_K_seqlen == 0) {
+ load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel * K_tile_size +
+ load_smem_K_d * (Bc + kPad) +
+ load_smem_K_Bc) * sizeof(half));
+ #pragma unroll
+ for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
+ CP_ASYNC_COMMIT_GROUP();
+ }
- // Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc]
- if ((tile_K_seqlen + 1) < Tc) {
- load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
- int load_gmem_K_addr = (
- K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
- uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel_next * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half)
- );
- #pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
- CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
- }
- CP_ASYNC_COMMIT_GROUP();
- }
- } else {
- // If no stages, kStage = 1, we have to load current K tile
- // from gmem to smem and have to wait it ready for Q@K^T MMA.
-
- // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages)
- {
- load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
- uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
- #pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
- CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
- }
- CP_ASYNC_COMMIT_GROUP();
- }
+ if constexpr (kCanPrefetchQs2r) {
+ // Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
+ // NOTE: we only need to load Q once from smem -> regs, and then reuse it.
+ if (tile_K_seqlen == 0) {
+ // TODO: Full share QKV smem after Q is ready load to regs.
+ CP_ASYNC_WAIT_GROUP(1);
+ __syncthreads();
- // Then, prefetch curr K tile_K_seqlen [d,Bc] (no stages)
- {
- load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
- int load_gmem_V_d = load_smem_V_d;
- int load_gmem_V_addr = (
- V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
- uint32_t load_smem_V_ptr = (
- smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) +
- load_smem_V_d) * sizeof(half)
- );
#pragma unroll
- for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
- CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
+ // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
+ // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
+ // Then we can load Q from smem only once and reuse it for
+ // processes. This will reduce large io-access for Q smem while N is large.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
+ int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
+ int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
+ int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_Q_ptr = (
+ smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
+ lane_smem_Q_d) * sizeof(half)
+ );
+ LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
+ R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
+ lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
+ }
}
- CP_ASYNC_COMMIT_GROUP();
- }
-
- // Wait curr Q and K tile ready and let curr V tile copy async.
- CP_ASYNC_WAIT_GROUP(1);
+ } // end if tile_K_seqlen == 0
+ // Now, we have to wait curr K tile ready for Q@K^T MMA.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ } else {
+ // Wait curr Q and K tile ready.
+ CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
}
-
+
// : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
// Matmul with NN layout, Q row major, K row major.
// S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
@@ -315,17 +279,19 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
// smem -> reg, load m16k16 smem Q, offset d according tile_K_d.
// ldmatrix.x4 for Q_tile_smem.
- #pragma unroll
- for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
- int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
- int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
- int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
- uint32_t lane_smem_Q_ptr = (
- smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
- lane_smem_Q_d) * sizeof(half)
- );
- LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3],
- lane_smem_Q_ptr); // now, R_Q
+ if constexpr (!kCanPrefetchQs2r) {
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
+ int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
+ int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
+ int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_Q_ptr = (
+ smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
+ lane_smem_Q_d) * sizeof(half)
+ );
+ LDMATRIX_X4(R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3],
+ lane_smem_Q_ptr); // now, R_Q[1][1][4]
+ }
}
// smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
@@ -342,21 +308,55 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
);
LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
} // end for kWarpTileSeqLenK
-
- // MMA compute
- #pragma unroll
- for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+
+ if constexpr (kCanPrefetchQs2r) {
+ // MMA compute
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ HMMA16816(R_S[i][j][0], R_S[i][j][1],
+ R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
+ R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
+ R_K[j][0], R_K[j][1],
+ R_S[i][j][0], R_S[i][j][1]);
+ }
+ }
+ } else {
+ // MMA compute
#pragma unroll
- for (int j = 0; j < kWarpTileSeqLenK; ++j) {
- HMMA16816(R_S[i][j][0], R_S[i][j][1],
- R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3],
- R_K[j][0], R_K[j][1],
- R_S[i][j][0], R_S[i][j][1]);
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ HMMA16816(R_S[i][j][0], R_S[i][j][1],
+ R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3],
+ R_K[j][0], R_K[j][1],
+ R_S[i][j][0], R_S[i][j][1]);
+ }
}
}
} // end loop over d, S=Q@K^T
__syncthreads();
+ // Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages),
+ // before rowmax and rowsum, load V from gmem -> smem.
+ {
+ load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) +
+ load_smem_V_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+
// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// | 64x64 | warp_KV 0 |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
@@ -402,9 +402,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// Warp level reduce max, warp_size = 4
// Each thread contains the maximum of 2 rows of Br,
// and only the values of T0, T4, ..., T28 are used.
- // Br, row_id = warp_QP<0~3> * 32 + i<0> * 16 + 0 * 8 + (lane / 4) <0~7>
lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]);
- // Br, row_id = warp_QP<0~3> * 32 + i<0> * 16 + 1 * 8 + (lane / 4) <8~15>
lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]);
} // end for kWarpTileSeqLenQ
__syncthreads();
@@ -519,6 +517,23 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
} // end for V Bc.
__syncthreads();
+ // NOTE: Load next K tile async before rescale O
+ if ((tile_K_seqlen + 1) < Tc) {
+ load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel * K_tile_size +
+ load_smem_K_d * (Bc + kPad) +
+ load_smem_K_Bc) * sizeof(half));
+ #pragma unroll
+ for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+
// Rescale O -> Update row sum Exp -> then, Update row max.
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1
@@ -591,6 +606,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// Finaly, we still have to rescale O once more.
// O_output(D) = ( 1/l_final ) * O_final (FA2 paper)
+ // NOTE: Here, we choose to reuse R_O as final output
+ // in order to reduce regs usage.
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]);
@@ -614,29 +631,57 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
#pragma unroll
for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8
- R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4
- R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
- R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
- R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
- R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
- R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
- R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
-
- // st.global.v4 128 bits. [Br,d]
- if (lane_id % 4 == 0) {
- // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
- int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
- int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
- // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
- int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
- int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
- int store_gmem_O_addr_0 = (
- O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
- int store_gmem_O_addr_1 = (
- O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
- LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]);
- LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]);
- }
+
+ if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) {
+ // reuse R_Q[4/8][1][4] for collective store.
+ R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4
+ R_Q[0][0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
+ R_Q[0][0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
+ R_Q[0][0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
+ R_Q[1][0][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
+ R_Q[1][0][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
+ R_Q[1][0][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
+ // st.global.v4 128 bits. [Br,d]
+ if (lane_id % 4 == 0) {
+ // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
+ int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
+ int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
+ // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
+ int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
+ int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
+ int store_gmem_O_addr_0 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
+ int store_gmem_O_addr_1 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
+ LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0][0]);
+ LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0][0]);
+ }
+ } else {
+ // we have to use new R_Z regs for collective store.
+ uint32_t R_Z[2][4];
+ R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4
+ R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
+ R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
+ R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
+ R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
+ R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
+ R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
+ // st.global.v4 128 bits. [Br,d]
+ if (lane_id % 4 == 0) {
+ // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
+ int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
+ int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
+ // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
+ int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
+ int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
+ int store_gmem_O_addr_0 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
+ int store_gmem_O_addr_1 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
+ LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]);
+ LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]);
+ }
+ } // end if kCanPrefetchQs2r
} // end for kWarpTileHeadDimV
} // end for kWarpTileSeqLenQ
}
@@ -645,26 +690,32 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
template
void launch_flash_attn_mma_stages_split_q_shared_kv(
torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
+ // Tile BrxBc=128x64
constexpr int kMmaAtomM = 16;
constexpr int kMmaAtomN = 8;
constexpr int kMmaAtomK = 16;
- constexpr int kMmaTileSeqLenQ = 4;
+ // constexpr int kMmaTileSeqLenQ = 4;
+ constexpr int kMmaTileSeqLenQ = 8;
constexpr int kMmaTileSeqLenK = 1;
- constexpr int kMmaTileSeqLenP = 4;
+ // constexpr int kMmaTileSeqLenP = 4;
+ constexpr int kMmaTileSeqLenP = 8;
constexpr int kMmaTileHeadDimV = 1;
constexpr int kWarpTileSeqLenQ = 1;
constexpr int kWarpTileSeqLenK = 8;
+ // constexpr int kWarpTileSeqLenK = 16;
constexpr int kWarpTileSeqLenP = 1;
constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,....
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
constexpr int kPad = 8;
+
+ // static int kMaxSramPerBlock;
+ // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
- // Calculate SRAM size needed per block, Q,K,V smem size
+ // Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem.
const int smem_max_size = ((Br * (kHeadDim + kPad)) +
- (kStage * kHeadDim * (Bc + kPad)) +
- (Bc * (kHeadDim + kPad))) * sizeof(half);
+ (kStage * kHeadDim * (Bc + kPad))) * sizeof(half);
const int QKV_batch = Q.size(0);
const int QKV_head = Q.size(1);
@@ -692,6 +743,7 @@ void launch_flash_attn_mma_stages_split_q_shared_kv(
kPad
>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
+ // kMaxSramPerBlock
98304
);
@@ -730,25 +782,15 @@ void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q,
CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
const int d = Q.size(3); // B, H, N, d
- if (stages == 2) {
- switch (d)
- {
- case 64:
- launch_flash_attn_mma_stages_split_q_shared_kv<64, 2>(Q, K, V, O);
- break;
- case 96:
- launch_flash_attn_mma_stages_split_q_shared_kv<96, 2>(Q, K, V, O);
- break;
- case 128:
- launch_flash_attn_mma_stages_split_q_shared_kv<128, 2>(Q, K, V, O);
- break;
- default:
- throw std::runtime_error("headdim not support!");
- break;
- }
+ if (stages > 1) {
+ throw std::runtime_error(
+ "split_q_shared_kv not support stages>1 now!");
} else {
switch (d)
{
+ case 32:
+ launch_flash_attn_mma_stages_split_q_shared_kv<32, 1>(Q, K, V, O);
+ break;
case 64:
launch_flash_attn_mma_stages_split_q_shared_kv<64, 1>(Q, K, V, O);
break;
diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
new file mode 100644
index 00000000..aad3d5f2
--- /dev/null
+++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
@@ -0,0 +1,814 @@
+#include "utils.h"
+
+// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction.
+// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
+// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
+
+// The FlashAttention-2 algorithm is described in the following paper:
+// https://arxiv.org/pdf/2307.08691
+
+// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d]
+// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d]
+
+// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
+// in order to reduce the comm between warps via smem and warp shuffle.
+
+// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+// | 64x64 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+
+// MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps
+// | 128x128 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x16) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x16) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x16) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x16) |
+// | warp_QP 4 | MMA 4 ... MMA 4 (x16) |
+// | warp_QP 5 | MMA 5 ... MMA 5 (x16) |
+// | warp_QP 6 | MMA 6 ... MMA 6 (x16) |
+// | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
+
+// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps
+// | 128x64 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+// | warp_QP 4 | MMA 4 ... MMA 4 (x8) |
+// | warp_QP 5 | MMA 5 ... MMA 5 (x8) |
+// | warp_QP 6 | MMA 6 ... MMA 6 (x8) |
+// | warp_QP 7 | MMA 7 ... MMA 7 (x8) |
+
+template<
+ const int kHeadDim, // Headdim, 32,64,128
+ const int kMmaAtomM, // MMA Atom M, 16
+ const int kMmaAtomN, // MMA Atom N, 8
+ const int kMmaAtomK, // MMA Atom K, 16
+ const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
+ const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
+ const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
+ const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
+ const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
+ const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
+ const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
+ const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
+ const int kStage,
+ const int kPad
+ >
+__global__ void __launch_bounds__(
+ WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK)
+flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
+ half* K,
+ half* V,
+ half* O,
+ int QKV_seqlen) {
+ // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major.
+ static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
+ static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
+ static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
+ static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T
+ // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc.
+ // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128.
+ static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
+ kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V
+ static_assert(kStage > 0 && kStage < 3); // 1 or 2
+ static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16
+ constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
+ constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
+ static_assert(Br >= Bc); // for shared memory reuse.
+ constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
+ // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen.
+ const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc]
+ const float scale = 1.0f / sqrt((float) kHeadDim);
+
+ // Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma)
+ const int QKV_batch_id = blockIdx.x; // Batch size, bx
+ const int QKV_head_id = blockIdx.y; // Head num, by
+ const int Q_tile_id = blockIdx.z; // Q tile_id, range [0, Tr), bz.
+ const int O_tile_id = Q_tile_id; // O tile_id, same as Q.
+ const int tid = threadIdx.x; // within block
+ const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+ const int warp_QP = warp_id; // 0,1,2,3 or 0~7
+ const int warp_KV = 0; // 0
+ // MMA Layout [Br,Bc]=[64,64], MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+ // MMA Layout [Br,Bc]=[128,128], MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps
+ // | 128x128 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x16) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x16) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x16) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x16) |
+ // | warp_QP 4 | MMA 4 ... MMA 4 (x16) |
+ // | warp_QP 5 | MMA 5 ... MMA 5 (x16) |
+ // | warp_QP 6 | MMA 6 ... MMA 6 (x16) |
+ // | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
+ const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d]
+ const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) +
+ (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen]
+ const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d]
+ const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d]
+
+ // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads.
+ int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64
+ int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,...
+ // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads.
+ int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64
+ int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,...
+ // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads.
+ int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
+ // global Q row of current head for tile [Br,d] per block.
+ int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br;
+ if (load_gmem_Q_Br >= QKV_seqlen) return;
+ // KV tile gmem load index starts from 0 and increments with
+ // each iteration as we loop over seqlen.
+ int load_gmem_K_Bc_offset = 0;
+ int load_gmem_V_Bc_offset = 0;
+
+ // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1
+ extern __shared__ half smem[];
+ constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
+ // constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
+ // constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
+ constexpr int KV_tile_size = (
+ ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ?
+ ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad))
+ );
+ // K multi-stages: currently, only apply multi stages for K across seq_len.
+ half* Q_tile_smem = smem; // 8M/16M
+ half* K_tile_smem = Q_tile_smem; // QKV shared the same smem
+ half* V_tile_smem = Q_tile_smem; // QKV shared the same smem
+ // NOTE: KV may shared same smem to reduce smem usage for kStage 1
+ // stage 1, w shared KV smem, Br=Bc=64, d=64: 8M=8M, +Pad(1M) = 9M
+ // stage 1, w shared KV smem, Br=Bc=64, d=128: 16M=16M, +Pad(1M) = 17M
+ // stage 1, w shared KV smem, Br=Bc=128, d=64: 32M=32M, +Pad(2M) = 35M
+
+ uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem);
+ uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
+ uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
+
+ // --------------------- Registers/SMEM for thread block -------------------------
+ // block m_old, l_old, store in lane, use float to keep precision.
+ float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
+ float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
+ fill_2D_regs(lane_block_row_max_old, -INFINITY);
+ fill_2D_regs(lane_block_row_sum_old, 0.0f);
+
+ // ---------------------- Registers for S=Q@K^T/O=P@V ----------------------------
+ // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d].
+ // TODO: Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs.
+ // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
+ // Then we can load Q from smem only once and reuse it for
+ // processes. This will reduce large io-access for Q smem while N is large.
+ static_assert(kHeadDim <= 128, "shared_qkv only support headdim<=128");
+ static_assert(kHeadDim >= 32, "shared_qkv only support headdim>=32");
+ constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8); // always true.
+ // Use kStage and (Br / Bc) to control multi-stage policy for K g2s.
+ constexpr bool kCanPrefetchKg2s = (
+ ((Q_tile_size / KV_tile_size) >= 2) && (kStage >= 2)); // for d<=64 is true.
+ constexpr int kPrefetchStageKg2s = kCanPrefetchKg2s ? 2 : 1; // only apply stage 2 for k prefetch.
+ constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1;
+ uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4]
+ uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2]
+ uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2]
+ // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc]
+ // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs.
+ uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2]
+ // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs.
+ // TODO: may reuse R_D as R_O? kWarpTileSeqLenP=kWarpTileSeqLenQ.
+ uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
+ // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs.
+ uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
+ fill_3D_regs(R_S, 0);
+ fill_3D_regs(R_D, 0);
+ fill_3D_regs(R_O, 0);
+
+ // load Q from gmem -> smem, only load once.
+ {
+ int load_gmem_Q_d = load_smem_Q_d;
+ int load_gmem_Q_addr = (Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d);
+ uint32_t load_smem_Q_ptr = (smem_Q_base_ptr + (
+ load_smem_Q_Br * (kHeadDim + kPad) + load_smem_Q_d) * sizeof(half));
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Br)); i += 8) {
+ CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+
+ // : for K^T[d,seqlen] with K^T_tile[d,Bc]
+ // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc]
+ #pragma unroll 1
+ for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
+ // TODO: process last tile_K_seqlen ? pad to multiple of 8.
+ // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
+ int smem_sel = (tile_K_seqlen) % kPrefetchStageKg2s;
+ // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
+ int smem_sel_next = (tile_K_seqlen + (kPrefetchStageKg2s - 1)) % kPrefetchStageKg2s;
+
+ // Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
+ // NOTE: we only need to load Q once from smem -> regs, and then reuse it.
+ static_assert(kCanPrefetchQs2r); // always prefetch Q s2r.
+ if (tile_K_seqlen == 0) {
+ // TODO: Full share QKV smem after Q is ready load to regs.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+
+ #pragma unroll
+ for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
+ // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
+ // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
+ // Then we can load Q from smem only once and reuse it for
+ // processes. This will reduce large io-access for Q smem while N is large.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
+ int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
+ int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
+ int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_Q_ptr = (
+ smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
+ lane_smem_Q_d) * sizeof(half)
+ );
+ LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
+ R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
+ lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
+ }
+ }
+ }
+
+ // Load K tile from gmem -> smem
+ if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ if (tile_K_seqlen == 0) {
+ load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel * KV_tile_size +
+ load_smem_K_d * (Bc + kPad) +
+ load_smem_K_Bc) * sizeof(half));
+ #pragma unroll
+ for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ // Now, we have to wait curr K tile ready for Q@K^T MMA.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+ } else {
+ load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel * KV_tile_size +
+ load_smem_K_d * (Bc + kPad) +
+ load_smem_K_Bc) * sizeof(half));
+ #pragma unroll
+ for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ // Now, we have to wait curr K tile ready for Q@K^T MMA.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+
+ // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
+ // Matmul with NN layout, Q row major, K row major.
+ // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
+ fill_3D_regs(R_S, 0);
+ #pragma unroll
+ for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
+ // smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
+ // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N]
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N)
+ int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
+ int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
+ uint32_t lane_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel * KV_tile_size +
+ lane_smem_K_d * (Bc + kPad) +
+ lane_smem_K_Bc) * sizeof(half)
+ );
+ LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
+ } // end for kWarpTileSeqLenK
+
+ // MMA compute
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ HMMA16816(R_S[i][j][0], R_S[i][j][1],
+ R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
+ R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
+ R_K[j][0], R_K[j][1],
+ R_S[i][j][0], R_S[i][j][1]);
+ }
+ }
+ } // end loop over d, S=Q@K^T
+ __syncthreads();
+
+ // Then, async prefetch curr V tile_K_seqlen [Bc,d] (no stages),
+ // before rowmax and rowsum, load V from gmem -> smem.
+ // TODO: Can we support stages 2 for V g2s?
+ {
+ load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (smem_sel * KV_tile_size +
+ load_smem_V_Bc * (kHeadDim + kPad) +
+ load_smem_V_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+
+ if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ if ((tile_K_seqlen + 1) < Tc) {
+ load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel_next * KV_tile_size +
+ load_smem_K_d * (Bc + kPad) +
+ load_smem_K_Bc) * sizeof(half));
+ #pragma unroll
+ for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+ }
+
+ // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+
+ // Online safe softmax, warp/block reduce max/sum, row wise
+ float lane_row_max_new[kWarpTileSeqLenQ][2]; // [1][2]
+ float lane_row_sum_new[kWarpTileSeqLenQ][2]; // [1][2]
+ fill_2D_regs(lane_row_max_new, -INFINITY);
+ fill_2D_regs(lane_row_sum_new, 0.0f);
+
+ // Row max for [Br,Bc] tile, Thread -> Warp -> Block.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc.
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // The layout of the fragments held by different threads for C. (m16n8k16)
+ // Row\Col 0 1 2 3 4 5 6 7
+ // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1}
+ // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1}
+ // 2 ...
+ // ...
+ // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1}
+ // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3}
+ // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3}
+ // 10 ...
+ // ...
+ // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3}
+ float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3}
+ // This should be the row max after S = (Q @ K^T) / sqrt(d)
+ float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale;
+ float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale;
+ lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0);
+ lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1);
+ } // end for kWarpTileSeqLenK
+
+ // Warp level reduce max, warp_size = 4
+ // Each thread contains the maximum of 2 rows of Br,
+ // and only the values of T0, T4, ..., T28 are used.
+ lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]);
+ lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]);
+ } // end for kWarpTileSeqLenQ
+ // __syncthreads();
+
+ // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ // Use latest global row max without update.
+ // Br 0, row_id, 0~7, 16~23, 32~39, 48~55;
+ float block_row_max_new_0 = lane_row_max_new[i][0];
+ // Br 1, row_id, 8~15, 24~31, 40~47, 56~63;
+ float block_row_max_new_1 = lane_row_max_new[i][1];
+
+ float block_row_max_old_0 = lane_block_row_max_old[i][0];
+ float block_row_max_old_1 = lane_block_row_max_old[i][1];
+ // Apply m_new = max(m_old, m_new) here.
+ block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0);
+ block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1);
+
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3}
+ // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z;
+ t_reg_S_0.x = __expf(__fmaf_rn(t_reg_S_0.x, scale, - block_row_max_new_0));
+ t_reg_S_0.y = __expf(__fmaf_rn(t_reg_S_0.y, scale, - block_row_max_new_0));
+ t_reg_S_1.x = __expf(__fmaf_rn(t_reg_S_1.x, scale, - block_row_max_new_1));
+ t_reg_S_1.y = __expf(__fmaf_rn(t_reg_S_1.y, scale, - block_row_max_new_1));
+ lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y);
+ lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y);
+ // Update R_S for P[Br,Bc] = Exp(S-m), point wise.
+ HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0);
+ HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1);
+ } // end for kWarpTileSeqLenK
+
+ // Warp level reduce sum, warp_size = 4
+ lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]);
+ lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]);
+ } // end for kWarpTileSeqLenQ
+
+ // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention.
+ // Here, we have to wait V ready before compute O = P @ V
+ if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ if ((tile_K_seqlen + 1) < Tc) {
+ CP_ASYNC_WAIT_GROUP(1);
+ } else {
+ CP_ASYNC_WAIT_GROUP(0);
+ }
+ } else {
+ CP_ASYNC_WAIT_GROUP(0);
+ }
+ __syncthreads();
+
+ // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention.
+ // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major.
+ // Make sure to clear the states in R_O before MMA for P@V for each step.
+
+ // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these
+ // registers for P(A) matrix directly ? How to do that ?
+ // according to the A matrix layout for MMA m16n8k16 instruction.
+ // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // The layout of the fragments held by different threads for A matrix with .f16.
+ // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+ // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5}
+ // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5}
+ // 2 (dashed arrow pointing right)
+ // ...
+ // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5}
+ // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7}
+ // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7}
+ // 10 (dashed arrow pointing right)
+ // ...
+ // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}
+
+ fill_3D_regs(R_O, 0);
+ #pragma unroll
+ for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) {
+ // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans.
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) {
+ int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N
+ int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K
+ int lane_smem_V_d = warp_smem_V_d; // 0
+ uint32_t lane_smem_V_ptr = (
+ smem_V_base_ptr + (smem_sel * KV_tile_size +
+ lane_smem_V_Bc * (kHeadDim + kPad) +
+ lane_smem_V_d) * sizeof(half)
+ );
+ LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V
+ }
+
+ // For R_S[1][8][2], mapping the layout below of P matrix.
+ // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+ // tile_V_Bc = 0, all curr MMAs(0~4) need slice P[:, 0:16], 0, 1; stored in all MMAs.
+ // tile_V_Bc = 1, all curr MMAs(0~4) need slice P[:, 16:32], 2, 3; stored in all MMAs.
+ // tile_V_Bc = 2, all curr MMAs(0~4) need slice P[:, 32:48], 4, 5; stored in all MMAs.
+ // tile_V_Bc = 3, all curr MMAs(0~4) need slice P[:, 48:64], 6, 7; stored in all MMAs.
+ int w = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ HMMA16816(R_O[i][j][0], R_O[i][j][1],
+ R_S[i][w][0], R_S[i][w][1], R_S[i][w + 1][0], R_S[i][w + 1][1],
+ R_V[j][0], R_V[j][1],
+ R_O[i][j][0], R_O[i][j][1]);
+ }
+ }
+ } // end for V Bc.
+ __syncthreads();
+
+ // Rescale O -> Update row sum Exp -> then, Update row max.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1
+ // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper)
+ // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63
+ float block_row_max_new_0 = lane_row_max_new[i][0];
+ float block_row_max_new_1 = lane_row_max_new[i][1];
+ float block_row_sum_new_0 = lane_row_sum_new[i][0];
+ float block_row_sum_new_1 = lane_row_sum_new[i][1];
+
+ float block_row_max_old_0 = lane_block_row_max_old[i][0];
+ float block_row_max_old_1 = lane_block_row_max_old[i][1];
+ // NOTE: max(-inf, val) = val.
+ block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0);
+ block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1);
+ // Avoid inf value while using m_old for rescaling O.
+ block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 :
+ block_row_max_new_0);
+ block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 :
+ block_row_max_new_1);
+
+ // rescale factor for O and l, exp(m_old - m)
+ float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0);
+ float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1);
+ // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old.
+ // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3}
+ float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3}
+ // Note that the formula in the FA2 paper is incorrect; here,
+ // the inverse of the exp function should not be taken, as it
+ // would result in an error during rescaling, namely, you have
+ // use exp(m_old - m_new), not 1/(m_old - m_new).
+ // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V
+ t_reg_D_0.x = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.x, t_reg_O_0.x);
+ t_reg_D_0.y = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.y, t_reg_O_0.y);
+ t_reg_D_1.x = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.x, t_reg_O_1.x);
+ t_reg_D_1.y = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.y, t_reg_O_1.y);
+ HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0);
+ HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1);
+ } // end for kWarpTileHeadDimV.
+
+ // Now, we can update m, l after O has been scaled.
+ // 1. First, update block row sum Exp for each lane which
+ // need both m_new and m_old.
+ float block_row_sum_old_0 = lane_block_row_sum_old[i][0];
+ float block_row_sum_old_1 = lane_block_row_sum_old[i][1];
+ // Update l = exp(m_old - m_new) * l_old + row_sum(P).
+ lane_block_row_sum_old[i][0] = (__fmaf_rn(
+ rescale_o_factor_0, block_row_sum_old_0, block_row_sum_new_0));
+ lane_block_row_sum_old[i][1] = (__fmaf_rn(
+ rescale_o_factor_1, block_row_sum_old_1, block_row_sum_new_1));
+ // 2. Then, update block row max for each lane.
+ lane_block_row_max_old[i][0] = block_row_max_new_0;
+ lane_block_row_max_old[i][1] = block_row_max_new_1;
+ }
+
+ if constexpr (kCanPrefetchKg2s && kPrefetchStageKg2s > 1) {
+ if ((tile_K_seqlen + 1) < Tc) {
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+ }
+
+ } // end loop over N
+ __syncthreads();
+
+ // Finaly, we still have to rescale O once more.
+ // O_output(D) = ( 1/l_final ) * O_final (FA2 paper)
+ // NOTE: Here, we choose to reuse R_O as final output
+ // in order to reduce regs usage.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]);
+ float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]);
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3}
+ t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x;
+ t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y;
+ t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x;
+ t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y;
+ HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0);
+ HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1);
+ }
+ }
+
+ // Store O(D): Write O[Br,d] from regs -> gmem, collective store
+ // with reg reuse & warp shuffle. need R_Z[2][4].
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8
+
+ if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) { // always true for shared qkv kernel
+ // reuse R_Q[4/8][1][4] for collective store.
+ R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4
+ R_Q[0][0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
+ R_Q[0][0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
+ R_Q[0][0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
+ R_Q[1][0][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
+ R_Q[1][0][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
+ R_Q[1][0][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
+ // st.global.v4 128 bits. [Br,d]
+ if (lane_id % 4 == 0) {
+ // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
+ int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
+ int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
+ // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
+ int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
+ int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
+ int store_gmem_O_addr_0 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
+ int store_gmem_O_addr_1 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
+ LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0][0]);
+ LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0][0]);
+ }
+ } else {
+ // we have to use new R_Z regs for collective store.
+ uint32_t R_Z[2][4];
+ R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4
+ R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
+ R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
+ R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
+ R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
+ R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
+ R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
+ // st.global.v4 128 bits. [Br,d]
+ if (lane_id % 4 == 0) {
+ // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
+ int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
+ int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
+ // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
+ int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
+ int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
+ int store_gmem_O_addr_0 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
+ int store_gmem_O_addr_1 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
+ LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]);
+ LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]);
+ }
+ } // end if kCanPrefetchQs2r
+ } // end for kWarpTileHeadDimV
+ } // end for kWarpTileSeqLenQ
+}
+
+// Launch kernel for flash_attn_mma_stages_split_q
+template
+void launch_flash_attn_mma_stages_split_q_shared_qkv(
+ torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
+ // Tile BrxBc=128x64
+ constexpr int kMmaAtomM = 16;
+ constexpr int kMmaAtomN = 8;
+ constexpr int kMmaAtomK = 16;
+ // constexpr int kMmaTileSeqLenQ = 4;
+ constexpr int kMmaTileSeqLenQ = 8;
+ constexpr int kMmaTileSeqLenK = 1;
+ // constexpr int kMmaTileSeqLenP = 4;
+ constexpr int kMmaTileSeqLenP = 8;
+ constexpr int kMmaTileHeadDimV = 1;
+ constexpr int kWarpTileSeqLenQ = 1;
+ constexpr int kWarpTileSeqLenK = 8;
+ // constexpr int kWarpTileSeqLenK = 16;
+ constexpr int kWarpTileSeqLenP = 1;
+ constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,....
+ constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
+ constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
+ constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
+ constexpr int kPad = 8;
+ if constexpr (kStage > 1) {
+ static_assert(((Br / Bc) >= 2));
+ }
+
+ // static int kMaxSramPerBlock;
+ // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
+
+ // Calculate SRAM size needed per block, QKV smem size, QKV fully shared the same smem.
+ const int smem_max_size = (Br * (kHeadDim + kPad)) * sizeof(half); // 128x(32/64/128)x2/1024=8/16/32M
+
+ const int QKV_batch = Q.size(0);
+ const int QKV_head = Q.size(1);
+ const int QKV_seqlen = Q.size(2); // QKV_seqlen
+ assert(QKV_seqlen % Bc == 0); // multiple of Bc=64
+
+ dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br)
+ dim3 block(kNumThreads); // 4/8 warps per block
+
+ cudaFuncSetAttribute(
+ flash_attn_mma_stages_split_q_shared_qkv_kernel<
+ kHeadDim,
+ kMmaAtomM,
+ kMmaAtomN,
+ kMmaAtomK,
+ kMmaTileSeqLenQ,
+ kMmaTileSeqLenK,
+ kMmaTileSeqLenP,
+ kMmaTileHeadDimV,
+ kWarpTileSeqLenQ,
+ kWarpTileSeqLenK,
+ kWarpTileSeqLenP,
+ kWarpTileHeadDimV,
+ kStage,
+ kPad
+ >,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ // kMaxSramPerBlock
+ 98304
+ );
+
+ flash_attn_mma_stages_split_q_shared_qkv_kernel<
+ kHeadDim,
+ kMmaAtomM,
+ kMmaAtomN,
+ kMmaAtomK,
+ kMmaTileSeqLenQ,
+ kMmaTileSeqLenK,
+ kMmaTileSeqLenP,
+ kMmaTileHeadDimV,
+ kWarpTileSeqLenQ,
+ kWarpTileSeqLenK,
+ kWarpTileSeqLenP,
+ kWarpTileHeadDimV,
+ kStage,
+ kPad
+ ><<>>(
+ reinterpret_cast(Q.data_ptr()),
+ reinterpret_cast(K.data_ptr()),
+ reinterpret_cast(V.data_ptr()),
+ reinterpret_cast(O.data_ptr()),
+ QKV_seqlen
+ );
+}
+
+void flash_attn_mma_stages_split_q_shared_qkv(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages) {
+ CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed.
+ CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
+ const int d = Q.size(3); // B, H, N, d
+
+ if (stages > 1) {
+ switch (d)
+ {
+ case 32:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<32, 2>(Q, K, V, O);
+ break;
+ case 64:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<64, 2>(Q, K, V, O);
+ break;
+ case 96:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<96, 2>(Q, K, V, O);
+ break;
+ case 128:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<128, 2>(Q, K, V, O);
+ break;
+ default:
+ throw std::runtime_error("headdim not support!");
+ break;
+ }
+ } else {
+ switch (d)
+ {
+ case 32:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<32, 1>(Q, K, V, O);
+ break;
+ case 64:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<64, 1>(Q, K, V, O);
+ break;
+ case 96:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<96, 1>(Q, K, V, O);
+ break;
+ case 128:
+ launch_flash_attn_mma_stages_split_q_shared_qkv<128, 1>(Q, K, V, O);
+ break;
+ default:
+ throw std::runtime_error("headdim not support!");
+ break;
+ }
+ }
+}
diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu
index 1d37916c..6fbebb8f 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu
@@ -816,6 +816,9 @@ void flash_attn_mma_stages_split_kv(torch::Tensor Q,
if (stages == 2) {
switch (d)
{
+ case 32:
+ launch_flash_attn_mma_stages_split_kv<32, 2>(Q, K, V, O);
+ break;
case 64:
launch_flash_attn_mma_stages_split_kv<64, 2>(Q, K, V, O);
break;
@@ -832,6 +835,9 @@ void flash_attn_mma_stages_split_kv(torch::Tensor Q,
} else {
switch (d)
{
+ case 32:
+ launch_flash_attn_mma_stages_split_kv<32, 1>(Q, K, V, O);
+ break;
case 64:
launch_flash_attn_mma_stages_split_kv<64, 1>(Q, K, V, O);
break;
diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu
index ede19899..a5f96c23 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu
@@ -733,6 +733,9 @@ void flash_attn_mma_stages_split_q(torch::Tensor Q,
if (stages == 2) {
switch (d)
{
+ case 32:
+ launch_flash_attn_mma_stages_split_q<32, 2>(Q, K, V, O);
+ break;
case 64:
launch_flash_attn_mma_stages_split_q<64, 2>(Q, K, V, O);
break;
@@ -749,6 +752,9 @@ void flash_attn_mma_stages_split_q(torch::Tensor Q,
} else {
switch (d)
{
+ case 32:
+ launch_flash_attn_mma_stages_split_q<32, 1>(Q, K, V, O);
+ break;
case 64:
launch_flash_attn_mma_stages_split_q<64, 1>(Q, K, V, O);
break;
diff --git a/kernels/flash-attn/pybind/flash_attn.cc b/kernels/flash-attn/pybind/flash_attn.cc
index a7343f21..cf5e030b 100644
--- a/kernels/flash-attn/pybind/flash_attn.cc
+++ b/kernels/flash-attn/pybind/flash_attn.cc
@@ -5,10 +5,33 @@
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
-void flash_attn_mma_stages_split_kv(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O, int stages);
-void flash_attn_mma_stages_split_q(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O, int stages);
+void flash_attn_mma_stages_split_kv(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages);
+
+void flash_attn_mma_stages_split_q(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages);
+
+void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages);
+
+void flash_attn_mma_stages_split_q_shared_qkv(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_kv)
TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q)
+ TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_kv)
+ TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_qkv)
}