From 4687e1d007971575dd0757d7ccb309277d38dc0e Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:40:49 +0800 Subject: [PATCH] =?UTF-8?q?[FA2]=20flash-attn-mma=20get=20rid=20of=20trans?= =?UTF-8?q?pose-k=E2=9C=94=EF=B8=8F=20(#169)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update flash_attn_mma_split_kv.cu * Update flash_attn_mma_split_q.cu * Update flash_attn_mma_share_kv.cu * Update flash_attn_mma_share_qkv.cu * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md --- README.md | 30 ++-- kernels/flash-attn/README.md | 30 ++-- kernels/flash-attn/flash_attn_mma.py | 26 ++-- .../flash-attn/mma/flash_attn_mma_share_kv.cu | 132 +++++++++--------- .../mma/flash_attn_mma_share_qkv.cu | 124 ++++++++-------- .../flash-attn/mma/flash_attn_mma_split_kv.cu | 115 ++++++++------- .../flash-attn/mma/flash_attn_mma_split_q.cu | 123 ++++++++-------- 7 files changed, 300 insertions(+), 280 deletions(-) diff --git a/README.md b/README.md index e22c6fab..d8368ce2 100644 --- a/README.md +++ b/README.md @@ -60,12 +60,12 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop): ```bash -python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa +python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop ------------------------------------------------------------------------------------------------------------------------ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10 ------------------------------------------------------------------------------------------------------------------------ B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10 - torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25 + torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25 mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29 mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70 mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50 @@ -74,8 +74,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21 mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06 mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52 - (flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52 - (sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55 + (flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52 ------------------------------------------------------------------------------------------------------------------------ ``` The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps). @@ -93,7 +92,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| __global__ void flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] - half* K, // [B, H, D, N] K^T transposed + half* K, // [B, H, N, D] half* V, // [B, H, N, D] half* O, // [B, H, N, D] int QKV_seqlen); @@ -113,7 +112,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | __global__ void flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] - half* K, // [B, H, D, N] K^T transposed + half* K, // [B, H, N, D] half* V, // [B, H, N, D] half* O, // [B, H, N, D] int QKV_seqlen); @@ -125,10 +124,10 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] ```C++ // K, V shared the same shared memory, improve block occupancy. __global__ void -flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, - half* K, - half* V, - half* O, +flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D] + half* K, // [B, H, N, D] + half* V, // [B, H, N, D] + half* O, // [B, H, N, D] int QKV_seqlen); ``` - 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2) @@ -136,12 +135,13 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++ -// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access. +// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy +// and reduce Q SMEM IO-Access. __global__ void -flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, - half* K, - half* V, - half* O, +flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D] + half* K, // [B, H, N, D] + half* V, // [B, H, N, D] + half* O, // [B, H, N, D] int QKV_seqlen); ``` diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md index 7ad01e9f..6bb0b05b 100644 --- a/kernels/flash-attn/README.md +++ b/kernels/flash-attn/README.md @@ -16,12 +16,12 @@ This repository's implementation of FlashAttention is intended solely for learni - Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop) ```bash -python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa # NVIDIA RTX 3080 Laptop +python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop ------------------------------------------------------------------------------------------------------------------------ B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10 ------------------------------------------------------------------------------------------------------------------------ B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10 - torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25 + torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25 mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29 mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70 mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50 @@ -30,8 +30,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21 mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06 mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52 - (flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52 - (sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55 + (flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52 ------------------------------------------------------------------------------------------------------------------------ ``` @@ -67,7 +66,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| __global__ void flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] - half* K, // [B, H, D, N] K^T transposed + half* K, // [B, H, N, D] half* V, // [B, H, N, D] half* O, // [B, H, N, D] int QKV_seqlen); @@ -87,7 +86,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D] // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | __global__ void flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] - half* K, // [B, H, D, N] K^T transposed + half* K, // [B, H, N, D] half* V, // [B, H, N, D] half* O, // [B, H, N, D] int QKV_seqlen); @@ -99,10 +98,10 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D] ```C++ // K, V shared the same shared memory, improve block occupancy. __global__ void -flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, - half* K, - half* V, - half* O, +flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D] + half* K, // [B, H, N, D] + half* V, // [B, H, N, D] + half* O, // [B, H, N, D] int QKV_seqlen); ``` - 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2) @@ -110,12 +109,13 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
```C++ -// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access. +// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy +// and reduce Q SMEM IO-Access. __global__ void -flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, - half* K, - half* V, - half* O, +flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D] + half* K, // [B, H, N, D] + half* V, // [B, H, N, D] + half* O, // [B, H, N, D] int QKV_seqlen); ``` diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py index 1435caaf..a578a15d 100644 --- a/kernels/flash-attn/flash_attn_mma.py +++ b/kernels/flash-attn/flash_attn_mma.py @@ -176,7 +176,7 @@ def run_benchmark(perf_func: callable, out_val = out_val_first[:2] out_val.append(out_val_last[-1]) out_val = [f"{v:<12}" for v in out_val] - print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}") + print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}") if show_all: print(out) time.sleep(args.sleep) @@ -203,12 +203,12 @@ def get_qkvo(B, H, N, D): v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous() - tk = k.transpose(-2, -1).contiguous() + # transpose (H,N) -> (N,H) for FA2. fq = q.transpose(1, 2).contiguous() fk = k.transpose(1, 2).contiguous() fv = v.transpose(1, 2).contiguous() - return q, k, v, o, tk, fq, fk, fv + return q, k, v, o, fq, fk, fv # un-fused naive attn @@ -233,7 +233,7 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor, print("-" * 120) diff = torch.abs(out_flash.float() - out_mma.float()) all_close = str(torch.allclose(out_flash.float(), out_mma.float(), atol=1e-2)) - print(f"out_flash vs {tag:<20}, all close: {all_close:<6}, " + print(f"out_flash vs {tag:<18}, all close: {all_close:<6}, " f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, " f"mean diff: {diff.mean().item():.6f}") @@ -254,19 +254,19 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor, for (B, H, N, D) in BHNDs: print("-" * 120) print(" " * 30 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}") - q, k, v, o, tk, fq, fk, fv = get_qkvo(B, H, N, D) + q, k, v, o, fq, fk, fv = get_qkvo(B, H, N, D) torch.cuda.synchronize() if args.run_torch_unfused: out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)") - out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1) - out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2) - out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1) - out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2) - out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1) - out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage2)", o, stages=2) - out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1) - out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2) + out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage1)", o, stages=1) + out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage2)", o, stages=2) + out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage1)", o, stages=1) + out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage2)", o, stages=2) + out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage1)", o, stages=1) + out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage2)", o, stages=2) + out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage1)", o, stages=1) + out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2) out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") if args.run_torch_sdpa: out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)") diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu index 9b9d652e..a9a563e7 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_share_kv.cu @@ -65,7 +65,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, half* V, half* O, int QKV_seqlen) { - // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V @@ -74,7 +75,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128. static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V - // TODO: support stages for shared kv smem kernel. static_assert(kStage < 3 && kStage > 0); static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16 constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 @@ -82,7 +82,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, static_assert(Br >= Bc); // for shared memory reuse. constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. - const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc] + const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d] const float scale = 1.0f / sqrt((float) kHeadDim); // Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma) @@ -113,17 +113,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // | warp_QP 7 | MMA 7 ... MMA 7 (x16) | const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] - const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) + - (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen] + const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d] const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads. int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64 int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,... - // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads. - int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64 - int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,... + // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads. + int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads. int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... @@ -135,16 +135,11 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, int load_gmem_K_Bc_offset = 0; int load_gmem_V_Bc_offset = 0; - // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1 + // Shared memory for Q,K,V, we don not need additional smem for O + // collective store which perform via registers reuse and warp shuffle. extern __shared__ half smem[]; - constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M - // constexpr int KV_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M - // constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M - constexpr int KV_tile_size = ( - ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ? - ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad)) - ); - // K multi-stages: currently, only apply multi stages for K across seq_len. + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // K[Bc,d] half* Q_tile_smem = smem; // 8M/16M half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M half* V_tile_smem = K_tile_smem; // KV shared the same smem @@ -172,7 +167,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. // Then we can load Q from smem only once and reuse it for // processes. This will reduce large io-access for Q smem while N is large. - // constexpr bool kCanPrefetchQs2r = false; // d <= 128 // FIXME(DefTruth): why can not get good performance for headdim >= 64 ? // Will enable it untill I have figure out the performance issues. constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64); @@ -214,22 +208,25 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // TODO: process last tile_K_seqlen ? pad to multiple of 8. // Load K tile from gmem -> smem, always use smem part 0, send g2s - // memory issues before Prefetch Q s2r to enable time overlap. + // memory issues before Prefetch Q s2r. if constexpr (kCanPrefetchKVg2s) { if (tile_K_seqlen == 0) { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { - CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); + // Now, we have to wait curr K tile ready for Q@K^T MMA. CP_ASYNC_WAIT_GROUP(0); __syncthreads(); @@ -254,15 +251,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, } } else { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); @@ -281,10 +280,6 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, #pragma unroll for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { - // Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs. - // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. - // Then we can load Q from smem only once and reuse it for - // processes. This will reduce large io-access for Q smem while N is large. #pragma unroll for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; @@ -304,8 +299,9 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, } // end if kCanPrefetchQs2r // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] - // Matmul with NN layout, Q row major, K row major. - // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + // Matmul with NT layout, Q row major, K^T col major. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d] // fill_3D_regs(R_S, 0); #pragma unroll @@ -329,18 +325,20 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, } // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. - // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] - #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { - int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) - int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); - int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) + // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d]. + // K[Bc,d] with row major means K^T[d,Bc] in col major. + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7 + int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8 uint32_t lane_smem_K_ptr = ( smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - lane_smem_K_d * (Bc + kPad) + - lane_smem_K_Bc) * sizeof(half) + lane_smem_K_Bc * (kHeadDim + kPad) + + lane_smem_K_d) * sizeof(half) ); - LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K } // end for kWarpTileSeqLenK if constexpr (kCanPrefetchQs2r) { @@ -396,15 +394,17 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, if constexpr (kCanPrefetchKVg2s) { if ((tile_K_seqlen + 1) < Tc) { load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); @@ -738,15 +738,12 @@ void launch_flash_attn_mma_stages_split_q_shared_kv( constexpr int kMmaAtomM = 16; constexpr int kMmaAtomN = 8; constexpr int kMmaAtomK = 16; - // constexpr int kMmaTileSeqLenQ = 4; constexpr int kMmaTileSeqLenQ = 8; constexpr int kMmaTileSeqLenK = 1; - // constexpr int kMmaTileSeqLenP = 4; constexpr int kMmaTileSeqLenP = 8; constexpr int kMmaTileHeadDimV = 1; constexpr int kWarpTileSeqLenQ = 1; constexpr int kWarpTileSeqLenK = 8; - // constexpr int kWarpTileSeqLenK = 16; constexpr int kWarpTileSeqLenP = 1; constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,.... constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 @@ -754,13 +751,10 @@ void launch_flash_attn_mma_stages_split_q_shared_kv( constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads constexpr int kPad = 8; - // static int kMaxSramPerBlock; - // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); + static int kMaxSramPerBlock; + cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); // Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem. - constexpr int KV_tile_size = ( - ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ? - ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad)) - ); + constexpr int KV_tile_size = (Bc * (kHeadDim + kPad)); const int smem_max_size = ((Br * (kHeadDim + kPad)) + (kStage * KV_tile_size)) * sizeof(half); @@ -791,8 +785,8 @@ void launch_flash_attn_mma_stages_split_q_shared_kv( kPad >, cudaFuncAttributeMaxDynamicSharedMemorySize, - // kMaxSramPerBlock - 98304 + kMaxSramPerBlock + // 98304 ); flash_attn_mma_stages_split_q_shared_kv_kernel< @@ -824,10 +818,10 @@ void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q, torch::Tensor V, torch::Tensor O, int stages) { - CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed. - CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] const int d = Q.size(3); // B, H, N, d if (stages > 1) { diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu index c00eedcd..e946ec76 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu @@ -65,7 +65,8 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* V, half* O, int QKV_seqlen) { - // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V @@ -81,7 +82,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, static_assert(Br >= Bc); // for shared memory reuse. constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. - const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc] + const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d] const float scale = 1.0f / sqrt((float) kHeadDim); // Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma) @@ -112,17 +113,17 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // | warp_QP 7 | MMA 7 ... MMA 7 (x16) | const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] - const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) + - (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen] + const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d] const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads. int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64 int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,... - // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads. - int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64 - int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,... + // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads. + int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads. int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... @@ -134,16 +135,11 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, int load_gmem_K_Bc_offset = 0; int load_gmem_V_Bc_offset = 0; - // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1 + // Shared memory for Q,K,V, we don not need additional smem for O + // collective store which perform via registers reuse and warp shuffle. extern __shared__ half smem[]; - constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M - // constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M - // constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV shared 8M - constexpr int KV_tile_size = ( - ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ? - ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad)) - ); - // K multi-stages: currently, only apply multi stages for K across seq_len. + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // K[Bc,d] half* Q_tile_smem = smem; // 8M/16M half* K_tile_smem = Q_tile_smem; // QKV shared the same smem half* V_tile_smem = Q_tile_smem; // QKV shared the same smem @@ -249,18 +245,21 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, if constexpr (kCanPrefetchKVg2s) { if (tile_K_seqlen == 0) { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { - CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); + // Now, we have to wait curr K tile ready for Q@K^T MMA. CP_ASYNC_WAIT_GROUP(0); __syncthreads(); @@ -285,15 +284,17 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, } } else { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); @@ -303,8 +304,9 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, } // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] - // Matmul with NN layout, Q row major, K row major. - // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + // Matmul with NT layout, Q row major, K^T col major. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d] // fill_3D_regs(R_S, 0); #pragma unroll @@ -326,19 +328,22 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, lane_smem_Q_ptr); // now, R_Q[1][1][4] } } + // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. - // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] + // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N] #pragma unroll for (int j = 0; j < kWarpTileSeqLenK; ++j) { - int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) - int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); - int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) + // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d]. + // K[Bc,d] with row major means K^T[d,Bc] in col major. + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7 + int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8 uint32_t lane_smem_K_ptr = ( smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - lane_smem_K_d * (Bc + kPad) + - lane_smem_K_Bc) * sizeof(half) - ); - LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + lane_smem_K_Bc * (kHeadDim + kPad) + + lane_smem_K_d) * sizeof(half) + ); + LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K } // end for kWarpTileSeqLenK if constexpr (kCanPrefetchQs2r) { @@ -394,21 +399,22 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, if constexpr (kCanPrefetchKVg2s) { if ((tile_K_seqlen + 1) < Tc) { load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); } } - // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps // | 64x64 | warp_KV 0 | // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | @@ -734,15 +740,12 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv( constexpr int kMmaAtomM = 16; constexpr int kMmaAtomN = 8; constexpr int kMmaAtomK = 16; - // constexpr int kMmaTileSeqLenQ = 4; constexpr int kMmaTileSeqLenQ = 8; constexpr int kMmaTileSeqLenK = 1; - // constexpr int kMmaTileSeqLenP = 4; constexpr int kMmaTileSeqLenP = 8; constexpr int kMmaTileHeadDimV = 1; constexpr int kWarpTileSeqLenQ = 1; constexpr int kWarpTileSeqLenK = 8; - // constexpr int kWarpTileSeqLenK = 16; constexpr int kWarpTileSeqLenP = 1; constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,.... constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 @@ -753,15 +756,12 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv( static_assert(((Br / Bc) >= 2)); } - // static int kMaxSramPerBlock; - // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); + static int kMaxSramPerBlock; + cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); // Calculate SRAM size needed per block, QKV smem size, QKV fully shared the same smem. - constexpr int KV_tile_size = ( - ((kHeadDim * (Bc + kPad)) > (Bc * (kHeadDim + kPad))) ? - ((kHeadDim * (Bc + kPad))) : (Bc * (kHeadDim + kPad)) - ); + constexpr int KV_tile_size = Bc * (kHeadDim + kPad); int smem_max_size = (Br * (kHeadDim + kPad)) * sizeof(half); // 128x(32/64/128)x2/1024=8/16/32M - if constexpr (kStage > 1) { + if constexpr (kStage > 1) { // make sure kStage > 1 work smem_max_size = smem_max_size > 2 * KV_tile_size * sizeof(half) ? smem_max_size : 2 * KV_tile_size * sizeof(half); } @@ -793,8 +793,8 @@ void launch_flash_attn_mma_stages_split_q_shared_qkv( kPad >, cudaFuncAttributeMaxDynamicSharedMemorySize, - // kMaxSramPerBlock - 98304 + kMaxSramPerBlock + // 98304 ); flash_attn_mma_stages_split_q_shared_qkv_kernel< @@ -826,10 +826,10 @@ void flash_attn_mma_stages_split_q_shared_qkv(torch::Tensor Q, torch::Tensor V, torch::Tensor O, int stages) { - CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed. - CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] const int d = Q.size(3); // B, H, N, d if (stages > 1) { diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu index fb07e72d..789a77c4 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_split_kv.cu @@ -33,7 +33,8 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, half* V, half* O, int QKV_seqlen) { - // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 static_assert(kMmaTileSeqLenQ == 2 && kMmaTileSeqLenK == 4); // Q@K^T static_assert(kMmaTileSeqLenP == 2 && kMmaTileHeadDimV == 4); // P@V @@ -73,17 +74,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // gridDim.y = head_num, gridDim.z = N/Br = Tr. const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] - const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) + - (QKV_head_id * kHeadDim * QKV_seqlen)); // transpose K, [d,seqlen] + const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d] const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 256 threads. int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 4, row 0~64 int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 4) * 16, 0,16,32,48 - // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 256 threads. - int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 4, row 0~64 - int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 4) * 16, 0,16,32,48 + // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads. + int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 4) * 16, 0,16,32,48 // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 256 threads. int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 4, row 0~64 int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 4) * 16, 0,16,32,48 @@ -95,17 +96,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, int load_gmem_K_Bc_offset = 0; int load_gmem_V_Bc_offset = 0; - // Shared memory for Q,K,V,O, d=64->24M, d=128=48M + // Shared memory for Q,K,V,S, we don not need additional smem for O + // collective store which perform via registers reuse and warp shuffle. extern __shared__ half smem[]; - constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M - constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M - constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M - constexpr int S_tile_size = Br * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int S_tile_size = Br * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M // K multi-stages: currently, only apply multi stages for K across seq_len. half* Q_tile_smem = smem; // 8M/16M half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M - half* V_tile_smem = K_tile_smem + kStage * K_tile_size; - half* S_tile_smem = V_tile_smem + V_tile_size; // for temp S=Q@K^T + half* V_tile_smem = K_tile_smem + kStage * KV_tile_size; + half* S_tile_smem = V_tile_smem + KV_tile_size; // for temp S=Q@K^T // stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M, shared KV smem: 24M // stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M=64M, shared KV smem: 48M // stage 2, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*2+32M=128M, shared KV smem: 96M @@ -164,15 +165,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, for (int stage = 0; stage < (kStage - 1); ++stage) { // update the offset of n according to stages load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen] int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d] + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (stage * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + smem_K_base_ptr + (stage * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); @@ -222,17 +225,17 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc] if ((tile_K_seqlen + 1) < Tc) { load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d] int load_gmem_K_addr = ( - K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel_next * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half) + smem_K_base_ptr + (smem_sel_next * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); @@ -244,17 +247,20 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages) { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d] + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + smem_K_base_ptr + (smem_sel * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } + CP_ASYNC_COMMIT_GROUP(); } @@ -282,8 +288,9 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, } // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] - // Matmul with NN layout, Q row major, K row major. - // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + // Matmul with NT layout, Q row major, K^T col major. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d] fill_3D_regs(R_S, 0); #pragma unroll for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { @@ -303,18 +310,20 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, } // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. - // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] - #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { - int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) - int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); - int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) + // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d]. + // K[Bc,d] with row major means K^T[d,Bc] in col major. + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7 + int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8 uint32_t lane_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + - lane_smem_K_d * (Bc + kPad) + - lane_smem_K_Bc) * sizeof(half) + smem_K_base_ptr + (smem_sel * KV_tile_size + + lane_smem_K_Bc * (kHeadDim + kPad) + + lane_smem_K_d) * sizeof(half) ); - LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K } // end for kWarpTileSeqLenK // MMA compute @@ -701,6 +710,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, for (int i = 0; i < kWarpTileSeqLenP; ++i) { #pragma unroll for (int j = 0; j < kWarpTileHeadDimV; ++j) { + static_assert(kWarpTileSeqLenQ >= 2); R_Q[0][0] = R_D[i][j][0]; R_Q[1][0] = R_D[i][j][1]; // warp_size 4 R_Q[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); R_Q[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); @@ -733,6 +743,8 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, template void launch_flash_attn_mma_stages_split_kv( torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + // Now: fixed tile BrxBc=64x64 + // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size. constexpr int kMmaAtomM = 16; constexpr int kMmaAtomN = 8; constexpr int kMmaAtomK = 16; @@ -748,9 +760,11 @@ void launch_flash_attn_mma_stages_split_kv( constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64 constexpr int kPad = 8; + static int kMaxSramPerBlock; + cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); // Calculate SRAM size needed per block, Q,K,V,S smem size const int smem_max_size = ((Br * (kHeadDim + kPad)) + - (kStage * kHeadDim * (Bc + kPad)) + + (kStage * Bc * (kHeadDim + kPad)) + (Bc * (kHeadDim + kPad)) + (Br * (Bc + kPad))) * sizeof(half); @@ -780,7 +794,8 @@ void launch_flash_attn_mma_stages_split_kv( kPad >, cudaFuncAttributeMaxDynamicSharedMemorySize, - 98304 + kMaxSramPerBlock + // 98304 ); flash_attn_mma_stages_split_kv_kernel< @@ -812,10 +827,10 @@ void flash_attn_mma_stages_split_kv(torch::Tensor Q, torch::Tensor V, torch::Tensor O, int stages) { - CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed. - CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] const int d = Q.size(3); // B, H, N, d if (stages == 2) { diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu index 92c1cd2e..2cd3a7e6 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_split_q.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_split_q.cu @@ -54,11 +54,12 @@ flash_attn_mma_stages_split_q_kernel(half* Q, half* V, half* O, int QKV_seqlen) { - // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 - static_assert(kMmaTileSeqLenQ == 4 && kMmaTileSeqLenK == 1); // Q@K^T - static_assert(kMmaTileSeqLenP == 4 && kMmaTileHeadDimV == 1); // P@V - static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK == 8); // Q@K^T + static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T + static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V + static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc. // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128. static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( @@ -68,7 +69,6 @@ flash_attn_mma_stages_split_q_kernel(half* Q, constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads - constexpr int kNumStoreBatchO = 2; // batch size for O collective store. // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc] const float scale = 1.0f / sqrt((float) kHeadDim); @@ -101,17 +101,17 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // | warp_QP 7 | MMA 7 ... MMA 7 (x16) | const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] - const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) + - (QKV_head_id * kHeadDim * QKV_seqlen)); // transposed K, [d,seqlen] + const int K_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d] const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads. int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64 int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,... - // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 128 threads. - int load_smem_K_d = (tid / (kNumThreads / kHeadDim)); // d 64, tid / 2, row 0~64 - int load_smem_K_Bc = (tid % (kNumThreads / kHeadDim)) * (Bc / (kNumThreads / kHeadDim)); // (tid % 2) * 32, 0,32,... + // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads. + int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads. int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... @@ -123,15 +123,15 @@ flash_attn_mma_stages_split_q_kernel(half* Q, int load_gmem_K_Bc_offset = 0; int load_gmem_V_Bc_offset = 0; - // Shared memory for Q,K,V,O, d=64->24M, d=128=48M, kStage 1 + // Shared memory for Q,K,V, we don not need additional smem for O + // collective store which perform via registers reuse and warp shuffle. extern __shared__ half smem[]; - constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M - constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M - constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M // K multi-stages: currently, only apply multi stages for K across seq_len. half* Q_tile_smem = smem; // 8M/16M half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M - half* V_tile_smem = K_tile_smem + kStage * K_tile_size; + half* V_tile_smem = K_tile_smem + kStage * KV_tile_size; // TODO: KV may shared same smem to reduce smem usage for kStage 1 // stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M // stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M =64M @@ -164,12 +164,9 @@ flash_attn_mma_stages_split_q_kernel(half* Q, uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] - // Helper regs for collective store, [2][4]. may use smem to reduce regs? - uint32_t R_Z[kNumStoreBatchO][4]; // [2][4] fill_3D_regs(R_S, 0); fill_3D_regs(R_D, 0); fill_3D_regs(R_O, 0); - fill_2D_regs(R_Z, 0); // load Q from gmem -> smem, only load once. { @@ -190,15 +187,17 @@ flash_attn_mma_stages_split_q_kernel(half* Q, for (int stage = 0; stage < (kStage - 1); ++stage) { // update the offset of n according to stages load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen] int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d] + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (stage * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + smem_K_base_ptr + (stage * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); @@ -248,17 +247,17 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc] if ((tile_K_seqlen + 1) < Tc) { load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] - int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d] int load_gmem_K_addr = ( - K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel_next * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half) + smem_K_base_ptr + (smem_sel_next * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } CP_ASYNC_COMMIT_GROUP(); @@ -270,17 +269,20 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages) { load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) - int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen - int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + int load_gmem_K_d = load_smem_K_d; // K [Bc,d] from [seqlen,d] + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); uint32_t load_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + - load_smem_K_d * (Bc + kPad) + - load_smem_K_Bc) * sizeof(half)); + smem_K_base_ptr + (smem_sel * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); #pragma unroll - for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) { + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); } + CP_ASYNC_COMMIT_GROUP(); } @@ -308,8 +310,9 @@ flash_attn_mma_stages_split_q_kernel(half* Q, } // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] - // Matmul with NN layout, Q row major, K row major. - // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + // Matmul with NT layout, Q row major, K^T col major. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d] fill_3D_regs(R_S, 0); #pragma unroll for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { @@ -329,18 +332,20 @@ flash_attn_mma_stages_split_q_kernel(half* Q, } // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. - // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] - #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { - int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) - int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); - int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) + // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d]. + // K[Bc,d] with row major means K^T[d,Bc] in col major. + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7 + int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8 uint32_t lane_smem_K_ptr = ( - smem_K_base_ptr + (smem_sel * K_tile_size + - lane_smem_K_d * (Bc + kPad) + - lane_smem_K_Bc) * sizeof(half) + smem_K_base_ptr + (smem_sel * KV_tile_size + + lane_smem_K_Bc * (kHeadDim + kPad) + + lane_smem_K_d) * sizeof(half) ); - LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K } // end for kWarpTileSeqLenK // MMA compute @@ -620,6 +625,7 @@ flash_attn_mma_stages_split_q_kernel(half* Q, for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 #pragma unroll for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8 + uint32_t R_Z[2][4]; // [2][4] R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4 R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); @@ -651,6 +657,8 @@ flash_attn_mma_stages_split_q_kernel(half* Q, template void launch_flash_attn_mma_stages_split_q( torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + // Now: fixed tile BrxBc=64x64 + // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size. constexpr int kMmaAtomM = 16; constexpr int kMmaAtomN = 8; constexpr int kMmaAtomK = 16; @@ -667,9 +675,11 @@ void launch_flash_attn_mma_stages_split_q( constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads constexpr int kPad = 8; + static int kMaxSramPerBlock; + cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); // Calculate SRAM size needed per block, Q,K,V smem size const int smem_max_size = ((Br * (kHeadDim + kPad)) + - (kStage * kHeadDim * (Bc + kPad)) + + (kStage * Bc * (kHeadDim + kPad)) + (Bc * (kHeadDim + kPad))) * sizeof(half); const int QKV_batch = Q.size(0); @@ -678,7 +688,7 @@ void launch_flash_attn_mma_stages_split_q( assert(QKV_seqlen % Bc == 0); // multiple of Bc=64 dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br) - dim3 block(kNumThreads); // 4 warps per block + dim3 block(kNumThreads); // 4/8 warps per block cudaFuncSetAttribute( flash_attn_mma_stages_split_q_kernel< @@ -698,7 +708,8 @@ void launch_flash_attn_mma_stages_split_q( kPad >, cudaFuncAttributeMaxDynamicSharedMemorySize, - 98304 + kMaxSramPerBlock + // 98304 ); flash_attn_mma_stages_split_q_kernel< @@ -730,10 +741,10 @@ void flash_attn_mma_stages_split_q(torch::Tensor Q, torch::Tensor V, torch::Tensor O, int stages) { - CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed. - CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] - CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] const int d = Q.size(3); // B, H, N, d if (stages > 1) {