Skip to content

Commit

Permalink
further refactoring of flash decode and fixed potential hang in write…
Browse files Browse the repository at this point in the history
…r reducer
  • Loading branch information
caixunshiren committed Jan 7, 2025
1 parent d404f8e commit 135ce4a
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ template <
uint32_t DHt,
uint32_t Sq_chunk_t,
uint32_t Sk_chunk_t,
uint32_t qk_chunk_tiles,
uint32_t out_chunk_tiles,
// QK matmul block parameters
uint32_t qk_in0_block_w,
uint32_t qk_subblock_w,
Expand Down Expand Up @@ -477,9 +479,8 @@ void flash_attention_loop(
uint32_t k_chunk_start,
uint32_t k_chunk_end,
bool do_reduce,
bool apply_mask_at_last_chunk, // for causal mode, optionally apply mask at the last chunk
uint32_t qk_chunk_tiles,
uint32_t out_chunk_tiles) {
bool apply_mask_at_last_chunk // for causal mode, optionally apply mask at the last chunk
) {
for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) {
/* QK = Q_CHUNK @ K_CHUNK */
reconfig_data_format(cb_q_in, cb_k_in); // DEBUG
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ void MAIN {
DHt,
Sq_chunk_t,
Sk_chunk_t,
qk_chunk_tiles,
out_chunk_tiles,
// QK matmul block parameters
qk_in0_block_w,
qk_subblock_w,
Expand Down Expand Up @@ -171,7 +173,7 @@ void MAIN {
cb_exp_max_diff,
cb_out_o,
cb_out_m,
cb_out_l>(k_chunk_start, k_chunk_end, do_reduce, apply_mask_at_last_chunk, qk_chunk_tiles, out_chunk_tiles);
cb_out_l>(k_chunk_start, k_chunk_end, do_reduce, apply_mask_at_last_chunk);

// do reduction across intermediates from other cores if this is the reduction core
if (do_reduce) {
Expand All @@ -182,10 +184,6 @@ void MAIN {
// This indicates that there are computes done by other workers. Needs to wait for them and send to
// reducer's compute
for (uint32_t i = 0; i < num_cores_to_wait; i++) {
cb_wait_front(cb_out_o, q_chunk_tiles); // o_2
cb_wait_front(cb_m_in, Sq_chunk_t); // m_2
cb_wait_front(cb_l_in, Sq_chunk_t); // l_2

// reconfig_data_format(cb_q_in, cb_q_in); // DEBUG
// pack_reconfig_data_format(cb_out_accumulate_im_2);
copy_block(cb_out_o, cb_out_accumulate_im_2, q_chunk_tiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,10 @@ uint32_t read_mask_chunk(uint32_t PSt, uint32_t mask_start_tile_id, const Interl
return mask_start_tile_id;
}

template <uint32_t cb_mask_in, uint32_t PNHt>
void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) {
template <uint32_t cb_mask_in, uint32_t PNHt, uint32_t Sk_chunk_t>
void generate_mask(uint32_t k_num_chunks, uint32_t cur_pos) {
/*
example 1: 64 seqlen at cur_pos 40, 2 cores, 32 chunk size
PSt = 2
k_num_chunks = 2
Sk_chunk_t = 1
cur_pos = 40
Expand All @@ -174,7 +173,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) {
cur_pos_in_tile = 8
example 2: 1024 seqlen at cur_pos 990, 2 cores, 128 chunk size
PSt = 32
k_num_chunks = 8
Sk_chunk_t = 4
cur_pos = 990
Expand All @@ -183,7 +181,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) {
cur_pos_in_tile = 30
example 3: 64 seqlen at cur_pos 63, 2 cores, 32 chunk size
PSt = 2
k_num_chunks = 2
Sk_chunk_t = 1
cur_pos = 63
Expand All @@ -192,7 +189,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) {
cur_pos_in_tile = 31
example 3: 64 seqlen at cur_pos 0, 2 cores, 32 chunk size
PSt = 2
k_num_chunks = 2
Sk_chunk_t = 1
cur_pos = 0
Expand All @@ -201,7 +197,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) {
cur_pos_in_tile = 0
*/

uint32_t Sk_chunk_t = PSt / k_num_chunks;
// the cb_mask in is of size PNHt * Sk_chunk_t
uint32_t total_read_tiles = PNHt * Sk_chunk_t;
uint32_t cur_pos_in_chunk = cur_pos % (Sk_chunk_t * 32);
Expand Down Expand Up @@ -303,6 +298,67 @@ void worker_compute(
cb_pop_front(cb_out_l, PNHt);
}

template <uint32_t cb_out, uint32_t out_chunk_tiles, uint32_t barrier_threshold>
uint32_t write_tiles_to_memory(
uint32_t& out_tile_id, const InterleavedAddrGenFast<true>& out_writer, uint32_t& barrier_count) {
constexpr uint32_t tile_bytes = get_tile_size(cb_out);
uint32_t l1_read_addr = get_read_ptr(cb_out);
for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) {
noc_async_write_tile(out_tile_id, out_writer, l1_read_addr);
++out_tile_id;
l1_read_addr += tile_bytes;
if (++barrier_count == barrier_threshold) {
noc_async_writes_flushed();
barrier_count = 0;
}
}
return barrier_count;
}

template <uint32_t cb_out, uint32_t ELEMENT_SIZE, uint32_t barrier_threshold>
uint32_t write_partial_tiles_to_memory(
uint32_t& out_tile_id,
const InterleavedAddrGenFast<true>& out_writer,
uint32_t& barrier_count,
uint32_t cur_head,
uint32_t num_heads_to_write,
uint32_t out_chunk_tiles) {
constexpr uint32_t tile_bytes = get_tile_size(cb_out);
constexpr uint32_t SUBTILE_LINE_BYTES = 16 * ELEMENT_SIZE;

for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) {
uint64_t out_writer_noc_addr = get_noc_addr(out_tile_id, out_writer);
uint32_t l1_read_addr = get_read_ptr(cb_out) + tile * tile_bytes;

// write partial output for each head
for (uint32_t head = 0; head < num_heads_to_write; ++head) {
uint32_t starting_row = cur_head * num_heads_to_write + head;
uint32_t in_tile_offset_by_starting_head =
starting_row < 16 ? starting_row * SUBTILE_LINE_BYTES
: (starting_row - 16) * SUBTILE_LINE_BYTES + 512 * ELEMENT_SIZE;
uint64_t out_writer_noc_addr_head = out_writer_noc_addr + in_tile_offset_by_starting_head;
uint32_t l1_read_addr_head = l1_read_addr + in_tile_offset_by_starting_head;

// Write first phase
noc_async_write(l1_read_addr_head, out_writer_noc_addr_head, SUBTILE_LINE_BYTES);

// Write second phase
noc_async_write(
l1_read_addr_head + 256 * ELEMENT_SIZE,
out_writer_noc_addr_head + 256 * ELEMENT_SIZE,
SUBTILE_LINE_BYTES);

if (++barrier_count == barrier_threshold) {
noc_async_writes_flushed();
barrier_count = 0;
}
}

++out_tile_id;
}
return barrier_count;
}

/******************************************************************************
* Reader Kernel Specific Functions *
******************************************************************************/
Expand All @@ -325,7 +381,6 @@ void read_kv_mask_chunks(
uint32_t k_start_tile_id,
uint32_t v_start_tile_id,
uint32_t mask_start_tile_id,
uint32_t valid_seq_len_tiles,
const InterleavedAddrGenFast<true>& k_reader,
const InterleavedAddrGenFast<true>& v_reader,
const InterleavedAddrGenFast<true>& mask_reader,
Expand All @@ -341,12 +396,10 @@ void read_kv_mask_chunks(
for (uint32_t col = 0; col < DHt; ++col) {
uint32_t k_tile_id = k_start_tile_id + col;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(k_tile_id, k_reader, k_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
noc_async_read_tile(k_tile_id, k_reader, k_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
k_tile_id += DHt;
k_write_ptr += k_tile_bytes;
Expand All @@ -369,12 +422,10 @@ void read_kv_mask_chunks(
uint32_t v_tile_id = v_start_tile_id;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
for (uint32_t col = 0; col < DHt; ++col) {
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(v_tile_id, v_reader, v_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
noc_async_read_tile(v_tile_id, v_reader, v_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
v_tile_id++;
v_write_ptr += v_tile_bytes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ void kernel_main() {
k_start_tile_id,
v_start_tile_id,
mask_start_tile_id,
PSt,
k_reader,
v_reader,
mask_reader,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@ void kernel_main() {
constexpr uint32_t PNHt = get_compile_time_arg_val(1); // padded number of heads in tiles
constexpr uint32_t St = get_compile_time_arg_val(2); // full sequence length of kv cache in tiles
constexpr uint32_t DHt = get_compile_time_arg_val(3); // head dim
constexpr uint32_t identity_scalar_packed = get_compile_time_arg_val(4);
constexpr uint32_t scale_val = get_compile_time_arg_val(5);
constexpr uint32_t num_cores_per_batch = get_compile_time_arg_val(6); // num cores per batch
constexpr uint32_t num_cores = get_compile_time_arg_val(7); // num running cores in total
uint32_t reducer_semaphore_addr = get_semaphore(get_compile_time_arg_val(8)); // semaphore for reducer
uint32_t output_semaphore_addr = get_semaphore(get_compile_time_arg_val(9)); // semaphore for sender
constexpr bool is_out_sharded = get_compile_time_arg_val(10);
constexpr uint32_t k_chunk_size = get_compile_time_arg_val(11);
constexpr uint32_t num_q_heads = get_compile_time_arg_val(12);
constexpr uint32_t num_kv_heads = get_compile_time_arg_val(13);
constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(14);
constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(15);
constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(16);
constexpr uint32_t num_output_cores = get_compile_time_arg_val(17);
constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(18);
constexpr bool is_causal = get_compile_time_arg_val(19) == 1;
constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(4); // number of tiles in seqlen of a k/v/mask chunk
constexpr uint32_t identity_scalar_packed = get_compile_time_arg_val(5);
constexpr uint32_t scale_val = get_compile_time_arg_val(6);
constexpr uint32_t num_cores_per_batch = get_compile_time_arg_val(7); // num cores per batch
constexpr uint32_t num_cores = get_compile_time_arg_val(8); // num running cores in total
uint32_t reducer_semaphore_addr = get_semaphore(get_compile_time_arg_val(9)); // semaphore for reducer
uint32_t output_semaphore_addr = get_semaphore(get_compile_time_arg_val(10)); // semaphore for sender
constexpr bool is_out_sharded = get_compile_time_arg_val(11);
constexpr uint32_t k_chunk_size = get_compile_time_arg_val(12);
constexpr uint32_t num_q_heads = get_compile_time_arg_val(13);
constexpr uint32_t num_kv_heads = get_compile_time_arg_val(14);
constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(15);
constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(16);
constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(17);
constexpr uint32_t num_output_cores = get_compile_time_arg_val(18);
constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(19);
constexpr bool is_causal = get_compile_time_arg_val(20) == 1;

uint32_t arg_idx = 0;
const uint32_t out_addr = get_arg_val<uint32_t>(arg_idx++);
Expand Down Expand Up @@ -144,7 +145,7 @@ void kernel_main() {

// generate and send mask to compute if causal
if constexpr (is_causal) {
generate_mask<cb_mask_in, PNHt>(k_num_chunks, PSt, cur_pos);
generate_mask<cb_mask_in, PNHt, Sk_chunk_t>(k_num_chunks, cur_pos);
}

for (uint32_t cur_head = cur_head_group * num_heads_per_core;
Expand All @@ -163,7 +164,7 @@ void kernel_main() {
// cb_wait_front(cb_intermed_out, num_tiles_to_wait);
constexpr uint32_t q_read_size = out_chunk_tiles * tile_bytes_intermed;
constexpr uint32_t ml_read_size = PNHt * tile_bytes_intermed;
for (uint32_t block = 0; block < num_cores_to_wait + 1; ++block) {
for (uint32_t block = 0; block < num_cores_to_wait; ++block) {
cb_reserve_back(cb_out_o, out_chunk_tiles);
cb_reserve_back(cb_m_in, PNHt);
cb_reserve_back(cb_l_in, PNHt);
Expand Down Expand Up @@ -196,46 +197,17 @@ void kernel_main() {

if constexpr (num_kv_heads > 1) {
// if gqa, we will need to write partial outputs for each head
constexpr uint32_t TILE_WIDTH = 32;
// we are assuming here that num_heads_to_write = nh/nkv is a power of 2 here, so that we don't write
// partial across phase
uint32_t num_heads_to_write = num_q_heads / num_kv_heads; // each head is one row in a tile
uint32_t SUBTILE_LINE_BYTES = 16 * ELEMENT_SIZE; // size of 16 elements (in a row)
constexpr uint32_t num_heads_to_write = num_q_heads / num_kv_heads; // each head is one row in a tile

if (!is_out_sharded) {
for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) {
uint64_t out_writer_noc_addr = get_noc_addr(out_tile_id, out_writer);
uint32_t l1_read_addr = get_read_ptr(cb_out) + tile * tile_bytes;

// write partial output for each head
for (uint32_t head = 0; head < num_heads_to_write; ++head) {
uint32_t starting_row = cur_head * num_heads_to_write + head;
uint32_t in_tile_offset_by_starting_head =
starting_row < 16 ? starting_row * SUBTILE_LINE_BYTES
: (starting_row - 16) * SUBTILE_LINE_BYTES + 512 * ELEMENT_SIZE;
uint64_t out_writer_noc_addr_head = out_writer_noc_addr + in_tile_offset_by_starting_head;
uint32_t l1_read_addr_head = l1_read_addr + in_tile_offset_by_starting_head;

// Write first phase
noc_async_write(l1_read_addr_head, out_writer_noc_addr_head, SUBTILE_LINE_BYTES);

// Write second phase
noc_async_write(
l1_read_addr_head + 256 * ELEMENT_SIZE,
out_writer_noc_addr_head + 256 * ELEMENT_SIZE,
SUBTILE_LINE_BYTES);

if (++barrier_count == barrier_threshold) {
noc_async_writes_flushed();
barrier_count = 0;
}
}

++out_tile_id;
}
barrier_count = write_partial_tiles_to_memory<cb_out, ELEMENT_SIZE, barrier_threshold>(
out_tile_id, out_writer, barrier_count, cur_head, num_heads_to_write, out_chunk_tiles);
}
// sharded out case
else if (do_output) {
constexpr uint32_t SUBTILE_LINE_BYTES = 16 * ELEMENT_SIZE;
// read from reducer cores
constexpr uint32_t num_reducers_per_output = num_reducer_cores / num_output_cores;
constexpr uint32_t num_reducers_to_wait = num_reducers_per_output - 1;
Expand Down Expand Up @@ -295,16 +267,8 @@ void kernel_main() {
} else {
// if mqa, we don't need to gather outputs for other heads so we can just write entire tiles to memory
if (!is_out_sharded) {
uint32_t l1_read_addr = get_read_ptr(cb_out);
for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) {
noc_async_write_tile(out_tile_id, out_writer, l1_read_addr);
++out_tile_id;
l1_read_addr += tile_bytes;
if (++barrier_count == barrier_threshold) {
noc_async_writes_flushed();
barrier_count = 0;
}
}
barrier_count = write_tiles_to_memory<cb_out, out_chunk_tiles, barrier_threshold>(
out_tile_id, out_writer, barrier_count);
}
}
noc_async_write_barrier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core(
PNHt,
St,
DHt,
Sk_chunk_t,
packed_identity_scalar,
scale_union.u,
num_cores_per_batch,
Expand Down

0 comments on commit 135ce4a

Please sign in to comment.