diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp index 54c5496c032..c83b3c20826 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp @@ -437,6 +437,8 @@ template < uint32_t DHt, uint32_t Sq_chunk_t, uint32_t Sk_chunk_t, + uint32_t qk_chunk_tiles, + uint32_t out_chunk_tiles, // QK matmul block parameters uint32_t qk_in0_block_w, uint32_t qk_subblock_w, @@ -477,9 +479,8 @@ void flash_attention_loop( uint32_t k_chunk_start, uint32_t k_chunk_end, bool do_reduce, - bool apply_mask_at_last_chunk, // for causal mode, optionally apply mask at the last chunk - uint32_t qk_chunk_tiles, - uint32_t out_chunk_tiles) { + bool apply_mask_at_last_chunk // for causal mode, optionally apply mask at the last chunk +) { for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { /* QK = Q_CHUNK @ K_CHUNK */ reconfig_data_format(cb_q_in, cb_k_in); // DEBUG diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index e9a6489faaf..709ed36da7f 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -137,6 +137,8 @@ void MAIN { DHt, Sq_chunk_t, Sk_chunk_t, + qk_chunk_tiles, + out_chunk_tiles, // QK matmul block parameters qk_in0_block_w, qk_subblock_w, @@ -171,7 +173,7 @@ void MAIN { cb_exp_max_diff, cb_out_o, cb_out_m, - cb_out_l>(k_chunk_start, k_chunk_end, do_reduce, apply_mask_at_last_chunk, qk_chunk_tiles, out_chunk_tiles); + cb_out_l>(k_chunk_start, k_chunk_end, do_reduce, apply_mask_at_last_chunk); // do reduction across intermediates from other cores if this is the reduction core if (do_reduce) { @@ -182,10 +184,6 @@ void MAIN { // This indicates that there are computes done by other workers. Needs to wait for them and send to // reducer's compute for (uint32_t i = 0; i < num_cores_to_wait; i++) { - cb_wait_front(cb_out_o, q_chunk_tiles); // o_2 - cb_wait_front(cb_m_in, Sq_chunk_t); // m_2 - cb_wait_front(cb_l_in, Sq_chunk_t); // l_2 - // reconfig_data_format(cb_q_in, cb_q_in); // DEBUG // pack_reconfig_data_format(cb_out_accumulate_im_2); copy_block(cb_out_o, cb_out_accumulate_im_2, q_chunk_tiles); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp index 282c0f02c1a..a0ee00c6d64 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp @@ -161,11 +161,10 @@ uint32_t read_mask_chunk(uint32_t PSt, uint32_t mask_start_tile_id, const Interl return mask_start_tile_id; } -template -void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) { +template +void generate_mask(uint32_t k_num_chunks, uint32_t cur_pos) { /* example 1: 64 seqlen at cur_pos 40, 2 cores, 32 chunk size - PSt = 2 k_num_chunks = 2 Sk_chunk_t = 1 cur_pos = 40 @@ -174,7 +173,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) { cur_pos_in_tile = 8 example 2: 1024 seqlen at cur_pos 990, 2 cores, 128 chunk size - PSt = 32 k_num_chunks = 8 Sk_chunk_t = 4 cur_pos = 990 @@ -183,7 +181,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) { cur_pos_in_tile = 30 example 3: 64 seqlen at cur_pos 63, 2 cores, 32 chunk size - PSt = 2 k_num_chunks = 2 Sk_chunk_t = 1 cur_pos = 63 @@ -192,7 +189,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) { cur_pos_in_tile = 31 example 3: 64 seqlen at cur_pos 0, 2 cores, 32 chunk size - PSt = 2 k_num_chunks = 2 Sk_chunk_t = 1 cur_pos = 0 @@ -201,7 +197,6 @@ void generate_mask(uint32_t k_num_chunks, uint32_t PSt, uint32_t cur_pos) { cur_pos_in_tile = 0 */ - uint32_t Sk_chunk_t = PSt / k_num_chunks; // the cb_mask in is of size PNHt * Sk_chunk_t uint32_t total_read_tiles = PNHt * Sk_chunk_t; uint32_t cur_pos_in_chunk = cur_pos % (Sk_chunk_t * 32); @@ -303,6 +298,67 @@ void worker_compute( cb_pop_front(cb_out_l, PNHt); } +template +uint32_t write_tiles_to_memory( + uint32_t& out_tile_id, const InterleavedAddrGenFast& out_writer, uint32_t& barrier_count) { + constexpr uint32_t tile_bytes = get_tile_size(cb_out); + uint32_t l1_read_addr = get_read_ptr(cb_out); + for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { + noc_async_write_tile(out_tile_id, out_writer, l1_read_addr); + ++out_tile_id; + l1_read_addr += tile_bytes; + if (++barrier_count == barrier_threshold) { + noc_async_writes_flushed(); + barrier_count = 0; + } + } + return barrier_count; +} + +template +uint32_t write_partial_tiles_to_memory( + uint32_t& out_tile_id, + const InterleavedAddrGenFast& out_writer, + uint32_t& barrier_count, + uint32_t cur_head, + uint32_t num_heads_to_write, + uint32_t out_chunk_tiles) { + constexpr uint32_t tile_bytes = get_tile_size(cb_out); + constexpr uint32_t SUBTILE_LINE_BYTES = 16 * ELEMENT_SIZE; + + for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { + uint64_t out_writer_noc_addr = get_noc_addr(out_tile_id, out_writer); + uint32_t l1_read_addr = get_read_ptr(cb_out) + tile * tile_bytes; + + // write partial output for each head + for (uint32_t head = 0; head < num_heads_to_write; ++head) { + uint32_t starting_row = cur_head * num_heads_to_write + head; + uint32_t in_tile_offset_by_starting_head = + starting_row < 16 ? starting_row * SUBTILE_LINE_BYTES + : (starting_row - 16) * SUBTILE_LINE_BYTES + 512 * ELEMENT_SIZE; + uint64_t out_writer_noc_addr_head = out_writer_noc_addr + in_tile_offset_by_starting_head; + uint32_t l1_read_addr_head = l1_read_addr + in_tile_offset_by_starting_head; + + // Write first phase + noc_async_write(l1_read_addr_head, out_writer_noc_addr_head, SUBTILE_LINE_BYTES); + + // Write second phase + noc_async_write( + l1_read_addr_head + 256 * ELEMENT_SIZE, + out_writer_noc_addr_head + 256 * ELEMENT_SIZE, + SUBTILE_LINE_BYTES); + + if (++barrier_count == barrier_threshold) { + noc_async_writes_flushed(); + barrier_count = 0; + } + } + + ++out_tile_id; + } + return barrier_count; +} + /****************************************************************************** * Reader Kernel Specific Functions * ******************************************************************************/ @@ -325,7 +381,6 @@ void read_kv_mask_chunks( uint32_t k_start_tile_id, uint32_t v_start_tile_id, uint32_t mask_start_tile_id, - uint32_t valid_seq_len_tiles, const InterleavedAddrGenFast& k_reader, const InterleavedAddrGenFast& v_reader, const InterleavedAddrGenFast& mask_reader, @@ -341,12 +396,10 @@ void read_kv_mask_chunks( for (uint32_t col = 0; col < DHt; ++col) { uint32_t k_tile_id = k_start_tile_id + col; for (uint32_t row = 0; row < Sk_chunk_t; ++row) { - if (row <= valid_seq_len_tiles) { - noc_async_read_tile(k_tile_id, k_reader, k_write_ptr); - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; - } + noc_async_read_tile(k_tile_id, k_reader, k_write_ptr); + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; } k_tile_id += DHt; k_write_ptr += k_tile_bytes; @@ -369,12 +422,10 @@ void read_kv_mask_chunks( uint32_t v_tile_id = v_start_tile_id; for (uint32_t row = 0; row < Sk_chunk_t; ++row) { for (uint32_t col = 0; col < DHt; ++col) { - if (row <= valid_seq_len_tiles) { - noc_async_read_tile(v_tile_id, v_reader, v_write_ptr); - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; - } + noc_async_read_tile(v_tile_id, v_reader, v_write_ptr); + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; } v_tile_id++; v_write_ptr += v_tile_bytes; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp index aefd75cbc46..dc6480bdd52 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp @@ -278,7 +278,6 @@ void kernel_main() { k_start_tile_id, v_start_tile_id, mask_start_tile_id, - PSt, k_reader, v_reader, mask_reader, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp index 7188ba77d7f..59581a13ee1 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp @@ -15,22 +15,23 @@ void kernel_main() { constexpr uint32_t PNHt = get_compile_time_arg_val(1); // padded number of heads in tiles constexpr uint32_t St = get_compile_time_arg_val(2); // full sequence length of kv cache in tiles constexpr uint32_t DHt = get_compile_time_arg_val(3); // head dim - constexpr uint32_t identity_scalar_packed = get_compile_time_arg_val(4); - constexpr uint32_t scale_val = get_compile_time_arg_val(5); - constexpr uint32_t num_cores_per_batch = get_compile_time_arg_val(6); // num cores per batch - constexpr uint32_t num_cores = get_compile_time_arg_val(7); // num running cores in total - uint32_t reducer_semaphore_addr = get_semaphore(get_compile_time_arg_val(8)); // semaphore for reducer - uint32_t output_semaphore_addr = get_semaphore(get_compile_time_arg_val(9)); // semaphore for sender - constexpr bool is_out_sharded = get_compile_time_arg_val(10); - constexpr uint32_t k_chunk_size = get_compile_time_arg_val(11); - constexpr uint32_t num_q_heads = get_compile_time_arg_val(12); - constexpr uint32_t num_kv_heads = get_compile_time_arg_val(13); - constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(14); - constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(15); - constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(16); - constexpr uint32_t num_output_cores = get_compile_time_arg_val(17); - constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(18); - constexpr bool is_causal = get_compile_time_arg_val(19) == 1; + constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(4); // number of tiles in seqlen of a k/v/mask chunk + constexpr uint32_t identity_scalar_packed = get_compile_time_arg_val(5); + constexpr uint32_t scale_val = get_compile_time_arg_val(6); + constexpr uint32_t num_cores_per_batch = get_compile_time_arg_val(7); // num cores per batch + constexpr uint32_t num_cores = get_compile_time_arg_val(8); // num running cores in total + uint32_t reducer_semaphore_addr = get_semaphore(get_compile_time_arg_val(9)); // semaphore for reducer + uint32_t output_semaphore_addr = get_semaphore(get_compile_time_arg_val(10)); // semaphore for sender + constexpr bool is_out_sharded = get_compile_time_arg_val(11); + constexpr uint32_t k_chunk_size = get_compile_time_arg_val(12); + constexpr uint32_t num_q_heads = get_compile_time_arg_val(13); + constexpr uint32_t num_kv_heads = get_compile_time_arg_val(14); + constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(15); + constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(16); + constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(17); + constexpr uint32_t num_output_cores = get_compile_time_arg_val(18); + constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(19); + constexpr bool is_causal = get_compile_time_arg_val(20) == 1; uint32_t arg_idx = 0; const uint32_t out_addr = get_arg_val(arg_idx++); @@ -144,7 +145,7 @@ void kernel_main() { // generate and send mask to compute if causal if constexpr (is_causal) { - generate_mask(k_num_chunks, PSt, cur_pos); + generate_mask(k_num_chunks, cur_pos); } for (uint32_t cur_head = cur_head_group * num_heads_per_core; @@ -163,7 +164,7 @@ void kernel_main() { // cb_wait_front(cb_intermed_out, num_tiles_to_wait); constexpr uint32_t q_read_size = out_chunk_tiles * tile_bytes_intermed; constexpr uint32_t ml_read_size = PNHt * tile_bytes_intermed; - for (uint32_t block = 0; block < num_cores_to_wait + 1; ++block) { + for (uint32_t block = 0; block < num_cores_to_wait; ++block) { cb_reserve_back(cb_out_o, out_chunk_tiles); cb_reserve_back(cb_m_in, PNHt); cb_reserve_back(cb_l_in, PNHt); @@ -196,46 +197,17 @@ void kernel_main() { if constexpr (num_kv_heads > 1) { // if gqa, we will need to write partial outputs for each head - constexpr uint32_t TILE_WIDTH = 32; // we are assuming here that num_heads_to_write = nh/nkv is a power of 2 here, so that we don't write // partial across phase - uint32_t num_heads_to_write = num_q_heads / num_kv_heads; // each head is one row in a tile - uint32_t SUBTILE_LINE_BYTES = 16 * ELEMENT_SIZE; // size of 16 elements (in a row) + constexpr uint32_t num_heads_to_write = num_q_heads / num_kv_heads; // each head is one row in a tile if (!is_out_sharded) { - for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { - uint64_t out_writer_noc_addr = get_noc_addr(out_tile_id, out_writer); - uint32_t l1_read_addr = get_read_ptr(cb_out) + tile * tile_bytes; - - // write partial output for each head - for (uint32_t head = 0; head < num_heads_to_write; ++head) { - uint32_t starting_row = cur_head * num_heads_to_write + head; - uint32_t in_tile_offset_by_starting_head = - starting_row < 16 ? starting_row * SUBTILE_LINE_BYTES - : (starting_row - 16) * SUBTILE_LINE_BYTES + 512 * ELEMENT_SIZE; - uint64_t out_writer_noc_addr_head = out_writer_noc_addr + in_tile_offset_by_starting_head; - uint32_t l1_read_addr_head = l1_read_addr + in_tile_offset_by_starting_head; - - // Write first phase - noc_async_write(l1_read_addr_head, out_writer_noc_addr_head, SUBTILE_LINE_BYTES); - - // Write second phase - noc_async_write( - l1_read_addr_head + 256 * ELEMENT_SIZE, - out_writer_noc_addr_head + 256 * ELEMENT_SIZE, - SUBTILE_LINE_BYTES); - - if (++barrier_count == barrier_threshold) { - noc_async_writes_flushed(); - barrier_count = 0; - } - } - - ++out_tile_id; - } + barrier_count = write_partial_tiles_to_memory( + out_tile_id, out_writer, barrier_count, cur_head, num_heads_to_write, out_chunk_tiles); } // sharded out case else if (do_output) { + constexpr uint32_t SUBTILE_LINE_BYTES = 16 * ELEMENT_SIZE; // read from reducer cores constexpr uint32_t num_reducers_per_output = num_reducer_cores / num_output_cores; constexpr uint32_t num_reducers_to_wait = num_reducers_per_output - 1; @@ -295,16 +267,8 @@ void kernel_main() { } else { // if mqa, we don't need to gather outputs for other heads so we can just write entire tiles to memory if (!is_out_sharded) { - uint32_t l1_read_addr = get_read_ptr(cb_out); - for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { - noc_async_write_tile(out_tile_id, out_writer, l1_read_addr); - ++out_tile_id; - l1_read_addr += tile_bytes; - if (++barrier_count == barrier_threshold) { - noc_async_writes_flushed(); - barrier_count = 0; - } - } + barrier_count = write_tiles_to_memory( + out_tile_id, out_writer, barrier_count); } } noc_async_write_barrier(); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index e1f02e4c00b..7838a974c47 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -598,6 +598,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( PNHt, St, DHt, + Sk_chunk_t, packed_identity_scalar, scale_union.u, num_cores_per_batch,