Skip to content

Commit

Permalink
#0: refactored flash decode round 2
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Jan 7, 2025
1 parent 26a4797 commit d404f8e
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ void flash_attention_loop(
uint32_t k_chunk_start,
uint32_t k_chunk_end,
bool do_reduce,
bool apply_mask_at_last_chunk, // for causal mode, optionally apply mask at the last chunk
uint32_t qk_chunk_tiles,
uint32_t out_chunk_tiles) {
for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) {
Expand All @@ -502,8 +503,8 @@ void flash_attention_loop(
mul_block_bcast_scalar_inplace(cb_qk_im, cb_scale_in, qk_chunk_tiles);

if constexpr (is_causal) {
// For decode, we only apply mask at the last chunk on reducer core for causal mode
if (k_chunk == k_chunk_end - 1 && do_reduce) {
// For decode, we only apply mask at the last chunk for causal mode
if (k_chunk == k_chunk_end - 1 && apply_mask_at_last_chunk) {
/* QK += MASK */
reconfig_data_format(cb_qk_im, cb_mask_in);
add_block_inplace<false>(cb_qk_im, cb_mask_in, qk_chunk_tiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void MAIN {

uint32_t arg_idx = 0;
const bool do_reduce = get_arg_val<uint32_t>(arg_idx++) == 1;
const bool apply_mask_at_last_chunk = do_reduce && is_causal;
const bool do_output = get_arg_val<uint32_t>(arg_idx++) == 1;
const uint32_t cur_head = get_arg_val<uint32_t>(arg_idx++);
const uint32_t cur_batch = get_arg_val<uint32_t>(arg_idx++);
Expand Down Expand Up @@ -170,7 +171,7 @@ void MAIN {
cb_exp_max_diff,
cb_out_o,
cb_out_m,
cb_out_l>(k_chunk_start, k_chunk_end, do_reduce, qk_chunk_tiles, out_chunk_tiles);
cb_out_l>(k_chunk_start, k_chunk_end, do_reduce, apply_mask_at_last_chunk, qk_chunk_tiles, out_chunk_tiles);

// do reduction across intermediates from other cores if this is the reduction core
if (do_reduce) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,86 @@ void worker_compute(
cb_pop_front(cb_out_m, PNHt);
cb_pop_front(cb_out_l, PNHt);
}

/******************************************************************************
* Reader Kernel Specific Functions *
******************************************************************************/

template <
uint32_t DHt,
uint32_t Sk_chunk_t,
uint32_t barrier_threshold,
uint32_t k_chunk_tiles,
uint32_t mask_chunk_tiles,
uint32_t mask_tile_bytes,
uint32_t PNHt,
bool use_attention_mask,
uint32_t cb_k_in,
uint32_t cb_v_in,
uint32_t cb_mask_in>
void read_kv_mask_chunks(
uint32_t k_chunk_start,
uint32_t k_chunk_end,
uint32_t k_start_tile_id,
uint32_t v_start_tile_id,
uint32_t mask_start_tile_id,
uint32_t valid_seq_len_tiles,
const InterleavedAddrGenFast<true>& k_reader,
const InterleavedAddrGenFast<true>& v_reader,
const InterleavedAddrGenFast<true>& mask_reader,
uint32_t k_tile_bytes,
uint32_t v_tile_bytes,
uint32_t PSt) {
uint32_t barrier_count = 0;
for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) {
// Read K chunk transposed
cb_reserve_back(cb_k_in, k_chunk_tiles);
uint32_t k_write_ptr = get_write_ptr(cb_k_in);
barrier_count = 0;
for (uint32_t col = 0; col < DHt; ++col) {
uint32_t k_tile_id = k_start_tile_id + col;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(k_tile_id, k_reader, k_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
k_tile_id += DHt;
k_write_ptr += k_tile_bytes;
}
}
noc_async_read_barrier();
cb_push_back(cb_k_in, k_chunk_tiles);
k_start_tile_id += k_chunk_tiles;

if constexpr (use_attention_mask) {
mask_start_tile_id =
read_mask_chunk<cb_mask_in, mask_chunk_tiles, mask_tile_bytes, barrier_threshold, PNHt, Sk_chunk_t>(
PSt, mask_start_tile_id, mask_reader);
}

// Read V chunk
cb_reserve_back(cb_v_in, k_chunk_tiles);
uint32_t v_write_ptr = get_write_ptr(cb_v_in);
barrier_count = 0;
uint32_t v_tile_id = v_start_tile_id;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
for (uint32_t col = 0; col < DHt; ++col) {
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(v_tile_id, v_reader, v_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
v_tile_id++;
v_write_ptr += v_tile_bytes;
}
}
noc_async_read_barrier();
cb_push_back(cb_v_in, k_chunk_tiles);
v_start_tile_id += k_chunk_tiles;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ void kernel_main() {
return;
}
}
const uint32_t valid_seq_len_tiles = (cur_pos + 1 + 32 - 1) / 32;

volatile tt_l1_ptr uint32_t* page_table_ptr;
if constexpr (is_paged_attention) {
Expand Down Expand Up @@ -262,61 +261,30 @@ void kernel_main() {
uint32_t k_start_tile_id = k_batch_offset + k_head_offset + k_chunk_offset;
uint32_t v_start_tile_id = v_batch_offset + v_head_offset + v_chunk_offset;

for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) {
// Read K chunk transposed
cb_reserve_back(cb_k_in, k_chunk_tiles);
uint32_t k_write_ptr = get_write_ptr(cb_k_in);
barrier_count = 0;
for (uint32_t col = 0; col < DHt; ++col) {
uint32_t k_tile_id = k_start_tile_id + col;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(k_tile_id, k_reader, k_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
k_tile_id += DHt;
k_write_ptr += k_tile_bytes;
}
}
noc_async_read_barrier();
cb_push_back(cb_k_in, k_chunk_tiles);
k_start_tile_id += k_chunk_tiles;

if constexpr (use_attention_mask) {
mask_start_tile_id = read_mask_chunk<
cb_mask_in,
mask_chunk_tiles,
mask_tile_bytes,
barrier_threshold,
PNHt,
Sk_chunk_t>(PSt, mask_start_tile_id, mask_reader);
}

// Read V chunk
cb_reserve_back(cb_v_in, k_chunk_tiles);
uint32_t v_write_ptr = get_write_ptr(cb_v_in);
barrier_count = 0;
uint32_t v_tile_id = v_start_tile_id;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
for (uint32_t col = 0; col < DHt; ++col) {
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(v_tile_id, v_reader, v_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
v_tile_id++;
v_write_ptr += v_tile_bytes;
}
}
noc_async_read_barrier();
cb_push_back(cb_v_in, k_chunk_tiles);
v_start_tile_id += k_chunk_tiles;
}
read_kv_mask_chunks<
DHt,
Sk_chunk_t,
barrier_threshold,
k_chunk_tiles,
mask_chunk_tiles,
mask_tile_bytes,
PNHt,
use_attention_mask,
cb_k_in,
cb_v_in,
cb_mask_in>(
k_chunk_start,
k_chunk_end,
k_start_tile_id,
v_start_tile_id,
mask_start_tile_id,
PSt,
k_reader,
v_reader,
mask_reader,
k_tile_bytes,
v_tile_bytes,
PSt);
}
}
}

0 comments on commit d404f8e

Please sign in to comment.