From 4687e1d007971575dd0757d7ccb309277d38dc0e Mon Sep 17 00:00:00 2001
From: DefTruth <31974251+DefTruth@users.noreply.github.com>
Date: Thu, 19 Dec 2024 15:40:49 +0800
Subject: [PATCH] =?UTF-8?q?[FA2]=20flash-attn-mma=20get=20rid=20of=20trans?=
=?UTF-8?q?pose-k=E2=9C=94=EF=B8=8F=20(#169)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Update flash_attn_mma_split_kv.cu
* Update flash_attn_mma_split_q.cu
* Update flash_attn_mma_share_kv.cu
* Update flash_attn_mma_share_qkv.cu
* Update flash_attn_mma.py
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
---
README.md | 30 ++--
kernels/flash-attn/README.md | 30 ++--
kernels/flash-attn/flash_attn_mma.py | 26 ++--
.../flash-attn/mma/flash_attn_mma_share_kv.cu | 132 +++++++++---------
.../mma/flash_attn_mma_share_qkv.cu | 124 ++++++++--------
.../flash-attn/mma/flash_attn_mma_split_kv.cu | 115 ++++++++-------
.../flash-attn/mma/flash_attn_mma_split_q.cu | 123 ++++++++--------
7 files changed, 300 insertions(+), 280 deletions(-)
diff --git a/README.md b/README.md
index e22c6fab..d8368ce2 100644
--- a/README.md
+++ b/README.md
@@ -60,12 +60,12 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
```bash
-python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
+python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
- torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
+ torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
@@ -74,8 +74,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
- (flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
- (sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55
+ (flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
------------------------------------------------------------------------------------------------------------------------
```
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
@@ -93,7 +92,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
- half* K, // [B, H, D, N] K^T transposed
+ half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
@@ -113,7 +112,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
- half* K, // [B, H, D, N] K^T transposed
+ half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
@@ -125,10 +124,10 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
```C++
// K, V shared the same shared memory, improve block occupancy.
__global__ void
-flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
- half* K,
- half* V,
- half* O,
+flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D]
+ half* K, // [B, H, N, D]
+ half* V, // [B, H, N, D]
+ half* O, // [B, H, N, D]
int QKV_seqlen);
```
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
@@ -136,12 +135,13 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++
-// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
+// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
+// and reduce Q SMEM IO-Access.
__global__ void
-flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
- half* K,
- half* V,
- half* O,
+flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D]
+ half* K, // [B, H, N, D]
+ half* V, // [B, H, N, D]
+ half* O, // [B, H, N, D]
int QKV_seqlen);
```
diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md
index 7ad01e9f..6bb0b05b 100644
--- a/kernels/flash-attn/README.md
+++ b/kernels/flash-attn/README.md
@@ -16,12 +16,12 @@ This repository's implementation of FlashAttention is intended solely for learni
- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
```bash
-python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa # NVIDIA RTX 3080 Laptop
+python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
- torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
+ torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
@@ -30,8 +30,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
- (flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
- (sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55
+ (flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
------------------------------------------------------------------------------------------------------------------------
```
@@ -67,7 +66,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
- half* K, // [B, H, D, N] K^T transposed
+ half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
@@ -87,7 +86,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
- half* K, // [B, H, D, N] K^T transposed
+ half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
@@ -99,10 +98,10 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
```C++
// K, V shared the same shared memory, improve block occupancy.
__global__ void
-flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
- half* K,
- half* V,
- half* O,
+flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D]
+ half* K, // [B, H, N, D]
+ half* V, // [B, H, N, D]
+ half* O, // [B, H, N, D]
int QKV_seqlen);
```
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
@@ -110,12 +109,13 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++
-// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
+// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
+// and reduce Q SMEM IO-Access.
__global__ void
-flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
- half* K,
- half* V,
- half* O,
+flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D]
+ half* K, // [B, H, N, D]
+ half* V, // [B, H, N, D]
+ half* O, // [B, H, N, D]
int QKV_seqlen);
```
diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py
index 1435caaf..a578a15d 100644
--- a/kernels/flash-attn/flash_attn_mma.py
+++ b/kernels/flash-attn/flash_attn_mma.py
@@ -176,7 +176,7 @@ def run_benchmark(perf_func: callable,
out_val = out_val_first[:2]
out_val.append(out_val_last[-1])
out_val = [f"{v:<12}" for v in out_val]
- print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
+ print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
if show_all:
print(out)
time.sleep(args.sleep)
@@ -203,12 +203,12 @@ def get_qkvo(B, H, N, D):
v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
- tk = k.transpose(-2, -1).contiguous()
+ # transpose (H,N) -> (N,H) for FA2.
fq = q.transpose(1, 2).contiguous()
fk = k.transpose(1, 2).contiguous()
fv = v.transpose(1, 2).contiguous()
- return q, k, v, o, tk, fq, fk, fv
+ return q, k, v, o, fq, fk, fv
# un-fused naive attn
@@ -233,7 +233,7 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
print("-" * 120)
diff = torch.abs(out_flash.float() - out_mma.float())
all_close = str(torch.allclose(out_flash.float(), out_mma.float(), atol=1e-2))
- print(f"out_flash vs {tag:<20}, all close: {all_close:<6}, "
+ print(f"out_flash vs {tag:<18}, all close: {all_close:<6}, "
f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, "
f"mean diff: {diff.mean().item():.6f}")
@@ -254,19 +254,19 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
for (B, H, N, D) in BHNDs:
print("-" * 120)
print(" " * 30 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
- q, k, v, o, tk, fq, fk, fv = get_qkvo(B, H, N, D)
+ q, k, v, o, fq, fk, fv = get_qkvo(B, H, N, D)
torch.cuda.synchronize()
if args.run_torch_unfused:
out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)")
- out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1)
- out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
- out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
- out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
- out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1)
- out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage2)", o, stages=2)
- out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
- out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
+ out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage1)", o, stages=1)
+ out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage2)", o, stages=2)
+ out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage1)", o, stages=1)
+ out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage2)", o, stages=2)
+ out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
+ out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
+ out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage1)", o, stages=1)
+ out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2)
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
index 9b9d652e..a9a563e7 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
@@ -65,7 +65,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* V,
half* O,
int QKV_seqlen) {
- // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major.
+ // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
@@ -74,7 +75,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128.
static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V
- // TODO: support stages for shared kv smem kernel.
static_assert(kStage < 3 && kStage > 0);
static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
@@ -82,7 +82,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
static_assert(Br >= Bc); // for shared memory reuse.
constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
// Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen.
- const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc]
+ const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d]
const float scale = 1.0f / sqrt((float) kHeadDim);
// Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma)
@@ -113,17 +113,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
(QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d]
- const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) +
- (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen]
+ const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d]
const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d]
const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d]
// Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads.
int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64
int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,...
- // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads.
- int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64
- int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,...
+ // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads.
+ int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
// Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads.
int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
@@ -135,16 +135,11 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
int load_gmem_K_Bc_offset = 0;
int load_gmem_V_Bc_offset = 0;
- // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1
+ // Shared memory for Q,K,V, we don not need additional smem for O
+ // collective store which perform via registers reuse and warp shuffle.
extern __shared__ half smem[];
- constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
- // constexpr int KV_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
- // constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
- constexpr int KV_tile_size = (
- ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ?
- ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad))
- );
- // K multi-stages: currently, only apply multi stages for K across seq_len.
+ constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
+ constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // K[Bc,d]
half* Q_tile_smem = smem; // 8M/16M
half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M
half* V_tile_smem = K_tile_smem; // KV shared the same smem
@@ -172,7 +167,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
// Then we can load Q from smem only once and reuse it for
// processes. This will reduce large io-access for Q smem while N is large.
- // constexpr bool kCanPrefetchQs2r = false; // d <= 128
// FIXME(DefTruth): why can not get good performance for headdim >= 64 ?
// Will enable it untill I have figure out the performance issues.
constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64);
@@ -214,22 +208,25 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
// TODO: process last tile_K_seqlen ? pad to multiple of 8.
// Load K tile from gmem -> smem, always use smem part 0, send g2s
- // memory issues before Prefetch Q s2r to enable time overlap.
+ // memory issues before Prefetch Q s2r.
if constexpr (kCanPrefetchKVg2s) {
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
- CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
+
// Now, we have to wait curr K tile ready for Q@K^T MMA.
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
@@ -254,15 +251,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
}
} else {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
@@ -281,10 +280,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
- // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
- // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
- // Then we can load Q from smem only once and reuse it for
- // processes. This will reduce large io-access for Q smem while N is large.
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
@@ -304,8 +299,9 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
} // end if kCanPrefetchQs2r
// : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
- // Matmul with NN layout, Q row major, K row major.
- // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
+ // Matmul with NT layout, Q row major, K^T col major.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d]
//
fill_3D_regs(R_S, 0);
#pragma unroll
@@ -329,18 +325,20 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
}
// smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
- // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N]
- #pragma unroll
- for (int j = 0; j < kWarpTileSeqLenK; ++j) {
- int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N)
- int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
- int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
+ // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N]
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d].
+ // K[Bc,d] with row major means K^T[d,Bc] in col major.
+ int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN;
+ int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7
+ int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8
uint32_t lane_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- lane_smem_K_d * (Bc + kPad) +
- lane_smem_K_Bc) * sizeof(half)
+ lane_smem_K_Bc * (kHeadDim + kPad) +
+ lane_smem_K_d) * sizeof(half)
);
- LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
+ LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
} // end for kWarpTileSeqLenK
if constexpr (kCanPrefetchQs2r) {
@@ -396,15 +394,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
@@ -738,15 +738,12 @@ void launch_flash_attn_mma_stages_split_q_shared_kv(
constexpr int kMmaAtomM = 16;
constexpr int kMmaAtomN = 8;
constexpr int kMmaAtomK = 16;
- // constexpr int kMmaTileSeqLenQ = 4;
constexpr int kMmaTileSeqLenQ = 8;
constexpr int kMmaTileSeqLenK = 1;
- // constexpr int kMmaTileSeqLenP = 4;
constexpr int kMmaTileSeqLenP = 8;
constexpr int kMmaTileHeadDimV = 1;
constexpr int kWarpTileSeqLenQ = 1;
constexpr int kWarpTileSeqLenK = 8;
- // constexpr int kWarpTileSeqLenK = 16;
constexpr int kWarpTileSeqLenP = 1;
constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,....
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
@@ -754,13 +751,10 @@ void launch_flash_attn_mma_stages_split_q_shared_kv(
constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
constexpr int kPad = 8;
- // static int kMaxSramPerBlock;
- // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
+ static int kMaxSramPerBlock;
+ cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
// Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem.
- constexpr int KV_tile_size = (
- ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ?
- ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad))
- );
+ constexpr int KV_tile_size = (Bc * (kHeadDim + kPad));
const int smem_max_size = ((Br * (kHeadDim + kPad)) +
(kStage * KV_tile_size)) * sizeof(half);
@@ -791,8 +785,8 @@ void launch_flash_attn_mma_stages_split_q_shared_kv(
kPad
>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
- // kMaxSramPerBlock
- 98304
+ kMaxSramPerBlock
+ // 98304
);
flash_attn_mma_stages_split_q_shared_kv_kernel<
@@ -824,10 +818,10 @@ void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q,
torch::Tensor V,
torch::Tensor O,
int stages) {
- CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed.
- CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
const int d = Q.size(3); // B, H, N, d
if (stages > 1) {
diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
index c00eedcd..e946ec76 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
@@ -65,7 +65,8 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* V,
half* O,
int QKV_seqlen) {
- // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major.
+ // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
@@ -81,7 +82,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
static_assert(Br >= Bc); // for shared memory reuse.
constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
// Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen.
- const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc]
+ const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d]
const float scale = 1.0f / sqrt((float) kHeadDim);
// Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma)
@@ -112,17 +113,17 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
// | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
(QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d]
- const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) +
- (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen]
+ const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d]
const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d]
const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d]
// Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads.
int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64
int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,...
- // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads.
- int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64
- int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,...
+ // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads.
+ int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
// Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads.
int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
@@ -134,16 +135,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
int load_gmem_K_Bc_offset = 0;
int load_gmem_V_Bc_offset = 0;
- // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1
+ // Shared memory for Q,K,V, we don not need additional smem for O
+ // collective store which perform via registers reuse and warp shuffle.
extern __shared__ half smem[];
- constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
- // constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
- // constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M
- constexpr int KV_tile_size = (
- ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ?
- ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad))
- );
- // K multi-stages: currently, only apply multi stages for K across seq_len.
+ constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
+ constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // K[Bc,d]
half* Q_tile_smem = smem; // 8M/16M
half* K_tile_smem = Q_tile_smem; // QKV shared the same smem
half* V_tile_smem = Q_tile_smem; // QKV shared the same smem
@@ -249,18 +245,21 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
if constexpr (kCanPrefetchKVg2s) {
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
- CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
+
// Now, we have to wait curr K tile ready for Q@K^T MMA.
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
@@ -285,15 +284,17 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
}
} else {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
@@ -303,8 +304,9 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
}
// : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
- // Matmul with NN layout, Q row major, K row major.
- // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
+ // Matmul with NT layout, Q row major, K^T col major.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d]
//
fill_3D_regs(R_S, 0);
#pragma unroll
@@ -326,19 +328,22 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
lane_smem_Q_ptr); // now, R_Q[1][1][4]
}
}
+
// smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
- // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N]
+ // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N]
#pragma unroll
for (int j = 0; j < kWarpTileSeqLenK; ++j) {
- int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N)
- int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
- int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
+ // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d].
+ // K[Bc,d] with row major means K^T[d,Bc] in col major.
+ int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN;
+ int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7
+ int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8
uint32_t lane_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- lane_smem_K_d * (Bc + kPad) +
- lane_smem_K_Bc) * sizeof(half)
- );
- LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
+ lane_smem_K_Bc * (kHeadDim + kPad) +
+ lane_smem_K_d) * sizeof(half)
+ );
+ LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
} // end for kWarpTileSeqLenK
if constexpr (kCanPrefetchQs2r) {
@@ -394,21 +399,22 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
if constexpr (kCanPrefetchKVg2s) {
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
}
}
-
// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
// | 64x64 | warp_KV 0 |
// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
@@ -734,15 +740,12 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv(
constexpr int kMmaAtomM = 16;
constexpr int kMmaAtomN = 8;
constexpr int kMmaAtomK = 16;
- // constexpr int kMmaTileSeqLenQ = 4;
constexpr int kMmaTileSeqLenQ = 8;
constexpr int kMmaTileSeqLenK = 1;
- // constexpr int kMmaTileSeqLenP = 4;
constexpr int kMmaTileSeqLenP = 8;
constexpr int kMmaTileHeadDimV = 1;
constexpr int kWarpTileSeqLenQ = 1;
constexpr int kWarpTileSeqLenK = 8;
- // constexpr int kWarpTileSeqLenK = 16;
constexpr int kWarpTileSeqLenP = 1;
constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,....
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
@@ -753,15 +756,12 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv(
static_assert(((Br / Bc) >= 2));
}
- // static int kMaxSramPerBlock;
- // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
+ static int kMaxSramPerBlock;
+ cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
// Calculate SRAM size needed per block, QKV smem size, QKV fully shared the same smem.
- constexpr int KV_tile_size = (
- ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ?
- ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad))
- );
+ constexpr int KV_tile_size = Bc * (kHeadDim + kPad);
int smem_max_size = (Br * (kHeadDim + kPad)) * sizeof(half); // 128x(32/64/128)x2/1024=8/16/32M
- if constexpr (kStage > 1) {
+ if constexpr (kStage > 1) { // make sure kStage > 1 work
smem_max_size = smem_max_size > 2 * KV_tile_size * sizeof(half) ?
smem_max_size : 2 * KV_tile_size * sizeof(half);
}
@@ -793,8 +793,8 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv(
kPad
>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
- // kMaxSramPerBlock
- 98304
+ kMaxSramPerBlock
+ // 98304
);
flash_attn_mma_stages_split_q_shared_qkv_kernel<
@@ -826,10 +826,10 @@ void flash_attn_mma_stages_split_q_shared_qkv(torch::Tensor Q,
torch::Tensor V,
torch::Tensor O,
int stages) {
- CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed.
- CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
const int d = Q.size(3); // B, H, N, d
if (stages > 1) {
diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu
index fb07e72d..789a77c4 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu
@@ -33,7 +33,8 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
half* V,
half* O,
int QKV_seqlen) {
- // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major.
+ // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
static_assert(kMmaTileSeqLenQ == 2 && kMmaTileSeqLenK == 4); // Q@K^T
static_assert(kMmaTileSeqLenP == 2 && kMmaTileHeadDimV == 4); // P@V
@@ -73,17 +74,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
// gridDim.y = head_num, gridDim.z = N/Br = Tr.
const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
(QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d]
- const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) +
- (QKV_head_id * kHeadDim * QKV_seqlen)); // transpose K, [d,seqlen]
+ const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d]
const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d]
const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d]
// Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 256 threads.
int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 4, row 0~64
int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 4) * 16, 0,16,32,48
- // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 256 threads.
- int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 4, row 0~64
- int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 4) * 16, 0,16,32,48
+ // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads.
+ int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 4) * 16, 0,16,32,48
// Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 256 threads.
int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 4, row 0~64
int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 4) * 16, 0,16,32,48
@@ -95,17 +96,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
int load_gmem_K_Bc_offset = 0;
int load_gmem_V_Bc_offset = 0;
- // Shared memory for Q,K,V,O, d=64->24M, d=128=48M
+ // Shared memory for Q,K,V,S, we don not need additional smem for O
+ // collective store which perform via registers reuse and warp shuffle.
extern __shared__ half smem[];
- constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
- constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M
- constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M
- constexpr int S_tile_size = Br * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M
+ constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
+ constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
+ constexpr int S_tile_size = Br * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M
// K multi-stages: currently, only apply multi stages for K across seq_len.
half* Q_tile_smem = smem; // 8M/16M
half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M
- half* V_tile_smem = K_tile_smem + kStage * K_tile_size;
- half* S_tile_smem = V_tile_smem + V_tile_size; // for temp S=Q@K^T
+ half* V_tile_smem = K_tile_smem + kStage * KV_tile_size;
+ half* S_tile_smem = V_tile_smem + KV_tile_size; // for temp S=Q@K^T
// stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M, shared KV smem: 24M
// stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M=64M, shared KV smem: 48M
// stage 2, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*2+32M=128M, shared KV smem: 96M
@@ -164,15 +165,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
for (int stage = 0; stage < (kStage - 1); ++stage) {
// update the offset of n according to stages
load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d]
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (stage * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ smem_K_base_ptr + (stage * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
@@ -222,17 +225,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
// Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc]
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d]
int load_gmem_K_addr = (
- K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel_next * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half)
+ smem_K_base_ptr + (smem_sel_next * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
);
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
@@ -244,17 +247,20 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
// First, prefetch curr K tile_K_seqlen [d,Bc] (no stages)
{
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d]
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ smem_K_base_ptr + (smem_sel * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
+
CP_ASYNC_COMMIT_GROUP();
}
@@ -282,8 +288,9 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
}
// : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
- // Matmul with NN layout, Q row major, K row major.
- // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
+ // Matmul with NT layout, Q row major, K^T col major.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d]
fill_3D_regs(R_S, 0);
#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
@@ -303,18 +310,20 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
}
// smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
- // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N]
- #pragma unroll
- for (int j = 0; j < kWarpTileSeqLenK; ++j) {
- int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N)
- int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
- int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
+ // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N]
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d].
+ // K[Bc,d] with row major means K^T[d,Bc] in col major.
+ int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN;
+ int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7
+ int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8
uint32_t lane_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * K_tile_size +
- lane_smem_K_d * (Bc + kPad) +
- lane_smem_K_Bc) * sizeof(half)
+ smem_K_base_ptr + (smem_sel * KV_tile_size +
+ lane_smem_K_Bc * (kHeadDim + kPad) +
+ lane_smem_K_d) * sizeof(half)
);
- LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
+ LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
} // end for kWarpTileSeqLenK
// MMA compute
@@ -701,6 +710,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
for (int i = 0; i < kWarpTileSeqLenP; ++i) {
#pragma unroll
for (int j = 0; j < kWarpTileHeadDimV; ++j) {
+ static_assert(kWarpTileSeqLenQ >= 2);
R_Q[0][0] = R_D[i][j][0]; R_Q[1][0] = R_D[i][j][1]; // warp_size 4
R_Q[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
R_Q[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
@@ -733,6 +743,8 @@ flash_attn_mma_stages_split_kv_kernel(half* Q,
template
void launch_flash_attn_mma_stages_split_kv(
torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
+ // Now: fixed tile BrxBc=64x64
+ // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
constexpr int kMmaAtomM = 16;
constexpr int kMmaAtomN = 8;
constexpr int kMmaAtomK = 16;
@@ -748,9 +760,11 @@ void launch_flash_attn_mma_stages_split_kv(
constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64
constexpr int kPad = 8;
+ static int kMaxSramPerBlock;
+ cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
// Calculate SRAM size needed per block, Q,K,V,S smem size
const int smem_max_size = ((Br * (kHeadDim + kPad)) +
- (kStage * kHeadDim * (Bc + kPad)) +
+ (kStage * Bc * (kHeadDim + kPad)) +
(Bc * (kHeadDim + kPad)) +
(Br * (Bc + kPad))) * sizeof(half);
@@ -780,7 +794,8 @@ void launch_flash_attn_mma_stages_split_kv(
kPad
>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
- 98304
+ kMaxSramPerBlock
+ // 98304
);
flash_attn_mma_stages_split_kv_kernel<
@@ -812,10 +827,10 @@ void flash_attn_mma_stages_split_kv(torch::Tensor Q,
torch::Tensor V,
torch::Tensor O,
int stages) {
- CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed.
- CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
const int d = Q.size(3); // B, H, N, d
if (stages == 2) {
diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu
index 92c1cd2e..2cd3a7e6 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu
@@ -54,11 +54,12 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
half* V,
half* O,
int QKV_seqlen) {
- // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major.
+ // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
- static_assert(kMmaTileSeqLenQ == 4 && kMmaTileSeqLenK == 1); // Q@K^T
- static_assert(kMmaTileSeqLenP == 4 && kMmaTileHeadDimV == 1); // P@V
- static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK == 8); // Q@K^T
+ static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
+ static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
+ static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T
// kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc.
// e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128.
static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
@@ -68,7 +69,6 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
- constexpr int kNumStoreBatchO = 2; // batch size for O collective store.
// Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen.
const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc]
const float scale = 1.0f / sqrt((float) kHeadDim);
@@ -101,17 +101,17 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
// | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
(QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d]
- const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) +
- (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen]
+ const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d]
const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d]
const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d]
// Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads.
int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64
int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,...
- // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads.
- int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64
- int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,...
+ // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads.
+ int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
// Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads.
int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
@@ -123,15 +123,15 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
int load_gmem_K_Bc_offset = 0;
int load_gmem_V_Bc_offset = 0;
- // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1
+ // Shared memory for Q,K,V, we don not need additional smem for O
+ // collective store which perform via registers reuse and warp shuffle.
extern __shared__ half smem[];
- constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
- constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M
- constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M
+ constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
+ constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
// K multi-stages: currently, only apply multi stages for K across seq_len.
half* Q_tile_smem = smem; // 8M/16M
half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M
- half* V_tile_smem = K_tile_smem + kStage * K_tile_size;
+ half* V_tile_smem = K_tile_smem + kStage * KV_tile_size;
// TODO: KV may shared same smem to reduce smem usage for kStage 1
// stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M
// stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M =64M
@@ -164,12 +164,9 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
// registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs.
uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
- // Helper regs for collective store, [2][4]. may use smem to reduce regs?
- uint32_t R_Z[kNumStoreBatchO][4]; // [2][4]
fill_3D_regs(R_S, 0);
fill_3D_regs(R_D, 0);
fill_3D_regs(R_O, 0);
- fill_2D_regs(R_Z, 0);
// load Q from gmem -> smem, only load once.
{
@@ -190,15 +187,17 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
for (int stage = 0; stage < (kStage - 1); ++stage) {
// update the offset of n according to stages
load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d]
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (stage * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ smem_K_base_ptr + (stage * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
@@ -248,17 +247,17 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
// Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc]
if ((tile_K_seqlen + 1) < Tc) {
load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
- int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d]
int load_gmem_K_addr = (
- K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel_next * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half)
+ smem_K_base_ptr + (smem_sel_next * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
);
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
CP_ASYNC_COMMIT_GROUP();
@@ -270,17 +269,20 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
// First, prefetch curr K tile_K_seqlen [d,Bc] (no stages)
{
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
- int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
- int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
+ int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d]
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
uint32_t load_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * K_tile_size +
- load_smem_K_d * (Bc + kPad) +
- load_smem_K_Bc) * sizeof(half));
+ smem_K_base_ptr + (smem_sel * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
#pragma unroll
- for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
}
+
CP_ASYNC_COMMIT_GROUP();
}
@@ -308,8 +310,9 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
}
// : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
- // Matmul with NN layout, Q row major, K row major.
- // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
+ // Matmul with NT layout, Q row major, K^T col major.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d]
fill_3D_regs(R_S, 0);
#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
@@ -329,18 +332,20 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
}
// smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
- // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N]
- #pragma unroll
- for (int j = 0; j < kWarpTileSeqLenK; ++j) {
- int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N)
- int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K);
- int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N)
+ // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N]
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d].
+ // K[Bc,d] with row major means K^T[d,Bc] in col major.
+ int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN;
+ int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7
+ int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8
uint32_t lane_smem_K_ptr = (
- smem_K_base_ptr + (smem_sel * K_tile_size +
- lane_smem_K_d * (Bc + kPad) +
- lane_smem_K_Bc) * sizeof(half)
+ smem_K_base_ptr + (smem_sel * KV_tile_size +
+ lane_smem_K_Bc * (kHeadDim + kPad) +
+ lane_smem_K_d) * sizeof(half)
);
- LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
+ LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
} // end for kWarpTileSeqLenK
// MMA compute
@@ -620,6 +625,7 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
#pragma unroll
for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8
+ uint32_t R_Z[2][4]; // [2][4]
R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4
R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
@@ -651,6 +657,8 @@ flash_attn_mma_stages_split_q_kernel(half* Q,
template
void launch_flash_attn_mma_stages_split_q(
torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
+ // Now: fixed tile BrxBc=64x64
+ // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
constexpr int kMmaAtomM = 16;
constexpr int kMmaAtomN = 8;
constexpr int kMmaAtomK = 16;
@@ -667,9 +675,11 @@ void launch_flash_attn_mma_stages_split_q(
constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
constexpr int kPad = 8;
+ static int kMaxSramPerBlock;
+ cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
// Calculate SRAM size needed per block, Q,K,V smem size
const int smem_max_size = ((Br * (kHeadDim + kPad)) +
- (kStage * kHeadDim * (Bc + kPad)) +
+ (kStage * Bc * (kHeadDim + kPad)) +
(Bc * (kHeadDim + kPad))) * sizeof(half);
const int QKV_batch = Q.size(0);
@@ -678,7 +688,7 @@ void launch_flash_attn_mma_stages_split_q(
assert(QKV_seqlen % Bc == 0); // multiple of Bc=64
dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br)
- dim3 block(kNumThreads); // 4 warps per block
+ dim3 block(kNumThreads); // 4/8 warps per block
cudaFuncSetAttribute(
flash_attn_mma_stages_split_q_kernel<
@@ -698,7 +708,8 @@ void launch_flash_attn_mma_stages_split_q(
kPad
>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
- 98304
+ kMaxSramPerBlock
+ // 98304
);
flash_attn_mma_stages_split_q_kernel<
@@ -730,10 +741,10 @@ void flash_attn_mma_stages_split_q(torch::Tensor Q,
torch::Tensor V,
torch::Tensor O,
int stages) {
- CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed.
- CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
- CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
const int d = Q.size(3); // B, H, N, d
if (stages > 1) {