Skip to content

Commit

Permalink
[FA2] Update flash-attn-mma shared-qkv🎉 (#168)
Browse files Browse the repository at this point in the history
* Update flash_attn_mma_share_qkv.cu

* Update flash_attn_mma_share_kv.cu

* Update flash_attn_mma_share_kv.cu

* Update flash_attn_mma_share_qkv.cu

* Update flash_attn_mma.py

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Dec 19, 2024
1 parent db8b8e8 commit 9324ddf
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 43 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
int QKV_seqlen);
```

- 📚 Split Q + Shared KV SMEM (Faster+)
- 📚 Split Q + Shared KV SMEM (**1/2 SRAM** vs FA2)
<div id="mma-share-kv"></div>

```C++
Expand All @@ -131,7 +131,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* O,
int QKV_seqlen);
```
- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
<div id="mma-share-qkv"></div>
Expand Down
5 changes: 3 additions & 2 deletions kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
int QKV_seqlen);
```

- 📚 Split Q + Shared KV SMEM (Faster+)
- 📚 Split Q + Shared KV SMEM (**1/2 SRAM** vs FA2)
<div id="mma-share-kv"></div>

```C++
Expand All @@ -105,7 +105,7 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* O,
int QKV_seqlen);
```
- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
- 📚 Split Q + Fully Shared QKV SMEM (**1/4 SRAM** vs FA2)
<div id="mma-share-qkv"></div>
Expand All @@ -119,6 +119,7 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
int QKV_seqlen);
```


## 📖 Prerequisites
<div id="prerequisites"></div>

Expand Down
7 changes: 4 additions & 3 deletions kernels/flash-attn/flash_attn_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ def get_args():
parser.add_argument("--N", type=int, default=None)
parser.add_argument("--D", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--sleep", type=float, default=0.05)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--verbose", '--v', action="store_true")
parser.add_argument("--warmup", type=int, default=1)
parser.add_argument("--iters", type=int, default=5)
parser.add_argument("--warmup", "--w", type=int, default=1)
parser.add_argument("--iters", "--i", type=int, default=5)
parser.add_argument("--range-k", '--gk', action="store_true")
return parser.parse_args()

Expand Down Expand Up @@ -178,7 +179,7 @@ def run_benchmark(perf_func: callable,
print(f"{out_info:>30}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
if show_all:
print(out)
time.sleep(0.05)
time.sleep(args.sleep)
torch.cuda.synchronize()
return out.clone(), mean_time

Expand Down
66 changes: 34 additions & 32 deletions kernels/flash-attn/mma/flash_attn_mma_share_kv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -213,38 +213,8 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
// TODO: process last tile_K_seqlen ? pad to multiple of 8.

// <Prefetch Q s2r>: Load Q tile from smem -> regs, before Q@K^T.
if constexpr (kCanPrefetchQs2r) {
// Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
// NOTE: we only need to load Q once from smem -> regs, and then reuse it.
if (tile_K_seqlen == 0) {
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();

#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
// Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
// By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
// Then we can load Q from smem only once and reuse it for <loop over K seqlen>
// processes. This will reduce large io-access for Q smem while N is large.
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_Q_ptr = (
smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
lane_smem_Q_d) * sizeof(half)
);
LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
}
}
} // end if tile_K_seqlen == 0
} // end if kCanPrefetchQs2r

// Load K tile from gmem -> smem, always use smem part 0.
// Load K tile from gmem -> smem, always use smem part 0, send g2s
// memory issues before Prefetch Q s2r to enable time overlap.
if constexpr (kCanPrefetchKVg2s) {
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
Expand Down Expand Up @@ -301,6 +271,38 @@ flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
__syncthreads();
}

// <Prefetch Q s2r>: Load Q tile from smem -> regs, before Q@K^T.
if constexpr (kCanPrefetchQs2r) {
// Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
// NOTE: we only need to load Q once from smem -> regs, and then reuse it.
if (tile_K_seqlen == 0) {
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();

#pragma unroll
for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
// Allocate R_Q[(kHeadDim / kMmaAtomK)][1][4], e.g R_Q[4][1][4] 16 regs.
// By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
// Then we can load Q from smem only once and reuse it for <loop over K seqlen>
// processes. This will reduce large io-access for Q smem while N is large.
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
uint32_t lane_smem_Q_ptr = (
smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
lane_smem_Q_d) * sizeof(half)
);
LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
}
}
__syncthreads(); // wait all warps ready.
} // end if tile_K_seqlen == 0
} // end if kCanPrefetchQs2r

// <loop over K d>: tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
// Matmul with NN layout, Q row major, K row major.
// S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc]
Expand Down
10 changes: 6 additions & 4 deletions kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,22 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
}
}
__syncthreads(); // wait all warps ready.
} // end if tile_K_seqlen == 0
} // end if kCanPrefetchQs2r

// Load K tile from gmem -> smem, always use smem part 0.
// Load K tile from gmem -> smem, always use smem part 0.
// must after prefetch Q s2r in order to reuse Q smem.
if constexpr (kCanPrefetchKVg2s) {
if (tile_K_seqlen == 0) {
load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen]
int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen
int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc);
uint32_t load_smem_K_ptr = (
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
load_smem_K_d * (Bc + kPad) +
load_smem_K_Bc) * sizeof(half));
#pragma unroll
for (int i = 0; i < (Bc / (kNumThreads / kHeadDim)); i += 8) {
CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
Expand Down

0 comments on commit 9324ddf

Please sign in to comment.