Skip to content

Commit

Permalink
[FA2] flash-attn-mma get rid of transpose-k✔️ (#169)
Browse files Browse the repository at this point in the history
* Update flash_attn_mma_split_kv.cu

* Update flash_attn_mma_split_q.cu

* Update flash_attn_mma_share_kv.cu

* Update flash_attn_mma_share_qkv.cu

* Update flash_attn_mma.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Dec 19, 2024
1 parent 9324ddf commit 4687e1d
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 280 deletions.
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):

```bash
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
Expand All @@ -74,8 +74,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
(flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
(sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55
(flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
------------------------------------------------------------------------------------------------------------------------
```
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
Expand All @@ -93,7 +92,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
Expand All @@ -113,7 +112,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
Expand All @@ -125,23 +124,24 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
```C++
// K, V shared the same shared memory, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* K,
half* V,
half* O,
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
```
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
<div id="mma-share-qkv"></div>
```C++
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
half* V,
half* O,
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
```

Expand Down
30 changes: 15 additions & 15 deletions kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ This repository's implementation of FlashAttention is intended solely for learni

- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
```bash
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa # NVIDIA RTX 3080 Laptop
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 805, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
torch(unfused): ['-0.00887299 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
torch(unfused): ['-0.0088729 ', '-0.00307083 ', '0.00674438 '], time:19.318247ms, TFLOPS:7.25
mma(split-kv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.330205ms, TFLOPS:26.29
mma(split-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:5.058098ms, TFLOPS:27.70
mma(split-q+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:3.639126ms, TFLOPS:38.50
Expand All @@ -30,8 +30,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch --sdpa
mma(split-q+share-kv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.584863ms, TFLOPS:54.21
mma(split-q+share-qkv+stage1): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.691698ms, TFLOPS:52.06
mma(split-q+share-qkv+stage2): ['-0.0089035 ', '-0.00307846 ', '0.00675964 '], time:2.569842ms, TFLOPS:54.52
(flash): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
(sdpa): ['-0.00886536 ', '-0.0030632 ', '0.00675201 '], time:3.542566ms, TFLOPS:39.55
(flash): ['-0.0088653 ', '-0.00307836 ', '0.00675201 '], time:3.734636ms, TFLOPS:37.52
------------------------------------------------------------------------------------------------------------------------
```

Expand Down Expand Up @@ -67,7 +66,7 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
Expand All @@ -87,7 +86,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
Expand All @@ -99,23 +98,24 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
```C++
// K, V shared the same shared memory, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* K,
half* V,
half* O,
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
```
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
<div id="mma-share-qkv"></div>
```C++
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy & reduce Q SMEM IO-Access.
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy
// and reduce Q SMEM IO-Access.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
half* V,
half* O,
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, N, D]
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
```

Expand Down
26 changes: 13 additions & 13 deletions kernels/flash-attn/flash_attn_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def run_benchmark(perf_func: callable,
out_val = out_val_first[:2]
out_val.append(out_val_last[-1])
out_val = [f"{v:<12}" for v in out_val]
print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
if show_all:
print(out)
time.sleep(args.sleep)
Expand All @@ -203,12 +203,12 @@ def get_qkvo(B, H, N, D):
v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()

o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
tk = k.transpose(-2, -1).contiguous()
# transpose (H,N) -> (N,H) for FA2.
fq = q.transpose(1, 2).contiguous()
fk = k.transpose(1, 2).contiguous()
fv = v.transpose(1, 2).contiguous()

return q, k, v, o, tk, fq, fk, fv
return q, k, v, o, fq, fk, fv


# un-fused naive attn
Expand All @@ -233,7 +233,7 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
print("-" * 120)
diff = torch.abs(out_flash.float() - out_mma.float())
all_close = str(torch.allclose(out_flash.float(), out_mma.float(), atol=1e-2))
print(f"out_flash vs {tag:<20}, all close: {all_close:<6}, "
print(f"out_flash vs {tag:<18}, all close: {all_close:<6}, "
f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, "
f"mean diff: {diff.mean().item():.6f}")

Expand All @@ -254,19 +254,19 @@ def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
for (B, H, N, D) in BHNDs:
print("-" * 120)
print(" " * 30 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
q, k, v, o, tk, fq, fk, fv = get_qkvo(B, H, N, D)
q, k, v, o, fq, fk, fv = get_qkvo(B, H, N, D)
torch.cuda.synchronize()

if args.run_torch_unfused:
out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)")
out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1)
out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage1)", o, stages=1)
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, tk, v, "mma(split-q+share-kv+stage2)", o, stages=2)
out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, tk, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage1)", o, stages=1)
out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage2)", o, stages=2)
out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage1)", o, stages=1)
out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage2)", o, stages=2)
out_mma_share_qkv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage1)", o, stages=1)
out_mma_share_qkv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_qkv, q, k, v, "mma(split-q+share-qkv+stage2)", o, stages=2)
out_mma_share_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage1)", o, stages=1)
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2)
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
Expand Down
Loading

0 comments on commit 4687e1d

Please sign in to comment.