From 924f0171ad986c2abcd284f69ea6c1d100d94ccb Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 7 Jan 2025 21:57:15 +0000 Subject: [PATCH] addressed pr comments and removed hard-coded constants --- .../device/kernels/dataflow/dataflow_common.hpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp index a0ee00c6d64..4cc5ef850af 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/dataflow_common.hpp @@ -323,8 +323,10 @@ uint32_t write_partial_tiles_to_memory( uint32_t cur_head, uint32_t num_heads_to_write, uint32_t out_chunk_tiles) { + constexpr uint32_t FACE_HW = 16; + constexpr uint32_t FACE_ELEMENT_CNT = FACE_HW * FACE_HW; // 256 constexpr uint32_t tile_bytes = get_tile_size(cb_out); - constexpr uint32_t SUBTILE_LINE_BYTES = 16 * ELEMENT_SIZE; + constexpr uint32_t FACE_LINE_BYTES = FACE_HW * ELEMENT_SIZE; for (uint32_t tile = 0; tile < out_chunk_tiles; ++tile) { uint64_t out_writer_noc_addr = get_noc_addr(out_tile_id, out_writer); @@ -334,19 +336,20 @@ uint32_t write_partial_tiles_to_memory( for (uint32_t head = 0; head < num_heads_to_write; ++head) { uint32_t starting_row = cur_head * num_heads_to_write + head; uint32_t in_tile_offset_by_starting_head = - starting_row < 16 ? starting_row * SUBTILE_LINE_BYTES - : (starting_row - 16) * SUBTILE_LINE_BYTES + 512 * ELEMENT_SIZE; + starting_row < FACE_HW + ? starting_row * FACE_LINE_BYTES + : (starting_row + FACE_HW) * FACE_LINE_BYTES; // Skip the second face which has FACE_HW rows uint64_t out_writer_noc_addr_head = out_writer_noc_addr + in_tile_offset_by_starting_head; uint32_t l1_read_addr_head = l1_read_addr + in_tile_offset_by_starting_head; // Write first phase - noc_async_write(l1_read_addr_head, out_writer_noc_addr_head, SUBTILE_LINE_BYTES); + noc_async_write(l1_read_addr_head, out_writer_noc_addr_head, FACE_LINE_BYTES); // Write second phase noc_async_write( - l1_read_addr_head + 256 * ELEMENT_SIZE, - out_writer_noc_addr_head + 256 * ELEMENT_SIZE, - SUBTILE_LINE_BYTES); + l1_read_addr_head + FACE_ELEMENT_CNT * ELEMENT_SIZE, + out_writer_noc_addr_head + FACE_ELEMENT_CNT * ELEMENT_SIZE, + FACE_LINE_BYTES); if (++barrier_count == barrier_threshold) { noc_async_writes_flushed();