Skip to content

Commit

Permalink
#15602: ttnn-padding padding size enhancement. (#15758)
Browse files Browse the repository at this point in the history
### Ticket
#15602
[Link to Github
Issue](#15602)

### Problem description
When use padding size that's not multiples of 16, the result is
erroneous

### What's changed
Re-design the kernel to allow arbitrary front pad/back pad length
Still need to address the alignment issue in L1 since now the total size
of a stick is no longer 16B aligned.

### Checklist
- [x] Post commit CI passes 
https://github.com/tenstorrent/tt-metal/actions/runs/12281633434
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
llongTT authored Dec 11, 2024
1 parent f37ad77 commit 38fddcd
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include <cstring>
#include "dataflow_api.h"

inline __attribute__((always_inline)) void fill_pad_cb_with_val(
Expand Down Expand Up @@ -37,8 +38,10 @@ void kernel_main() {
constexpr uint32_t stick_size_padded_end = get_compile_time_arg_val(10);
constexpr uint32_t num_zero_pad_sticks_read = get_compile_time_arg_val(11);
constexpr uint32_t last_zero_stick_size = get_compile_time_arg_val(12);
constexpr uint32_t stick_size_padded_aligned = get_compile_time_arg_val(21);

#define not_pad_by_zero get_compile_time_arg_val(13) == 1
#define front_padding get_compile_time_arg_val(9)
#if (not_pad_by_zero)
constexpr uint32_t packed_pad_value = get_compile_time_arg_val(14);
constexpr uint32_t row_major_min_bytes = get_compile_time_arg_val(15);
Expand All @@ -47,8 +50,9 @@ void kernel_main() {
constexpr uint32_t num_sticks_padded_read = get_compile_time_arg_val(18);
#endif

constexpr auto cb_in0 = tt::CBIndex::c_0;
constexpr auto cb_pad = tt::CBIndex::c_1;
constexpr uint32_t cb_in0 = tt::CBIndex::c_0;
constexpr uint32_t cb_pad = tt::CBIndex::c_1;
constexpr uint32_t cb_pad_align = tt::CBIndex::c_2;

#define stick_size_is_pow2 get_compile_time_arg_val(19) == 1
#if (stick_size_is_pow2)
Expand All @@ -68,8 +72,14 @@ void kernel_main() {
uint64_t pad_val_addr = get_read_ptr(cb_pad);
uint64_t pad_val_noc_addr = get_noc_addr(pad_val_addr);

uint64_t pad_align_addr = get_read_ptr(cb_pad_align);
uint64_t pad_align_write_addr = get_write_ptr(cb_pad_align);
uint64_t pad_align_noc_addr = get_noc_addr(pad_align_addr);

#if (not_pad_by_zero)
fill_pad_cb_with_val(cb_pad, row_major_min_bytes, packed_pad_value);
fill_pad_cb_with_val(cb_pad, stick_size_padded, packed_pad_value);
#else
fill_pad_cb_with_val(cb_pad, stick_size_padded, 0);
#endif

uint32_t i_stick = start_id;
Expand All @@ -82,55 +92,23 @@ void kernel_main() {
bool read_stick = (curr_h >= front_pad_h and curr_h < H) and (curr_c >= front_pad_c and curr_c < C) and
(curr_n >= front_pad_n and curr_n < N);
uint64_t read_noc_addr = get_noc_addr(i_stick, s);
noc_async_read(pad_val_noc_addr, l1_write_addr, stick_size_padded);

if (read_stick) {
#if (not_pad_by_zero)
if constexpr (stick_size_padded_front != 0) {
for (uint32_t j = 0; j < num_front_pad_sticks_read; ++j) {
noc_async_read(pad_val_noc_addr, l1_write_addr, row_major_min_bytes);
l1_write_addr += row_major_min_bytes;
}
}
#if (front_padding)
// Read noc into cb_pad_align l1
noc_async_read(read_noc_addr, get_write_ptr(cb_pad_align), stick_size_bytes);
noc_async_read_barrier();
memmove(
(void*)(l1_write_addr + stick_size_padded_front),
(void*)(get_read_ptr(cb_pad_align)),
(size_t)(stick_size_bytes));
#else
if constexpr (stick_size_padded_front != 0) {
noc_async_read(zeros_noc_addr, l1_write_addr, stick_size_padded_front);
l1_write_addr += stick_size_padded_front;
}
#endif

noc_async_read(read_noc_addr, l1_write_addr, stick_size_bytes);
l1_write_addr += stick_size_bytes;
i_stick++;

#if (not_pad_by_zero)
if constexpr (stick_size_padded_end != 0) {
for (uint32_t j = 0; j < num_end_pad_sticks_read; ++j) {
noc_async_read(pad_val_noc_addr, l1_write_addr, row_major_min_bytes);
l1_write_addr += row_major_min_bytes;
}
}
#else
if constexpr (stick_size_padded_end != 0) {
noc_async_read(zeros_noc_addr, l1_write_addr, stick_size_padded_end);
l1_write_addr += stick_size_padded_end;
}
#endif

} else {
#if (not_pad_by_zero)
for (uint32_t j = 0; j < num_sticks_padded_read; ++j) {
noc_async_read(pad_val_noc_addr, l1_write_addr, row_major_min_bytes);
l1_write_addr += row_major_min_bytes;
}
#else
for (uint32_t j = 0; j < num_zero_pad_sticks_read; ++j) {
auto read_bytes = j == num_zero_pad_sticks_read - 1 ? last_zero_stick_size : 512;
noc_async_read(zeros_noc_addr, l1_write_addr, read_bytes);
l1_write_addr += read_bytes;
}
#endif
i_stick++;
}

l1_write_addr += stick_size_padded_aligned;
curr_h++;
if (curr_h == H_padded) {
curr_c++;
Expand All @@ -142,7 +120,6 @@ void kernel_main() {
}
}
noc_async_read_barrier();

cb_push_back(cb_in0, num_read_per_barrier);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void kernel_main() {
constexpr uint32_t cb_out0 = get_compile_time_arg_val(0);
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t W_size_bytes = get_compile_time_arg_val(2);
constexpr uint32_t stick_size_padded_aligned = get_compile_time_arg_val(5);

const uint32_t stick_size_bytes = W_size_bytes;

Expand All @@ -38,7 +39,7 @@ void kernel_main() {
for (uint32_t i = 0; i < num_read_per_barrier; ++i) {
uint64_t write_noc_addr = get_noc_addr(i_stick, s);
noc_async_write(l1_read_addr, write_noc_addr, stick_size_bytes);
l1_read_addr += stick_size_bytes;
l1_read_addr += stick_size_padded_aligned;
i_stick += 1;
}
noc_async_write_barrier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2(
auto stick_size_padded = W_padded * a.element_size();
auto stick_size_padded_front = front_pad[-1] * a.element_size();
auto stick_size_padded_end = stick_size_padded - stick_size - stick_size_padded_front;
uint32_t stick_size_padded_aligned = align(stick_size_padded, hal.get_alignment(HalMemType::L1));
uint32_t row_major_min_bytes = 16;

tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype());
Expand All @@ -1050,24 +1051,31 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2(
num_sticks_padded_per_core_group_2] =
tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, NCH_padded);

uint32_t src0_cb_index = 0;
uint32_t src0_cb_index = tt::CBIndex::c_0;
auto num_sticks = num_sticks_padded_per_core_group_1 > num_sticks_padded_per_core_group_2
? num_sticks_padded_per_core_group_1
: num_sticks_padded_per_core_group_2;

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(num_sticks * stick_size_padded, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, stick_size_padded);
tt::tt_metal::CircularBufferConfig(num_sticks * stick_size_padded_aligned, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, stick_size_padded_aligned);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config);

// construct const buffer with the pad_value
bool not_pad_by_zero = pad_value != 0;
if (not_pad_by_zero) {
uint32_t src1_cb_index = 1;
tt::tt_metal::CircularBufferConfig cb_src1_config =
tt::tt_metal::CircularBufferConfig(row_major_min_bytes, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, row_major_min_bytes);
auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config);

uint32_t src1_cb_index = tt::CBIndex::c_1;
tt::tt_metal::CircularBufferConfig cb_src1_config =
tt::tt_metal::CircularBufferConfig(stick_size_padded_aligned, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, stick_size_padded_aligned);
auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config);

if (stick_size_padded_front != 0) {
uint32_t src2_cb_index = tt::CBIndex::c_2;
tt::tt_metal::CircularBufferConfig cb_src2_config =
tt::tt_metal::CircularBufferConfig(stick_size_padded_aligned, {{src2_cb_index, cb_data_format}})
.set_page_size(src2_cb_index, stick_size_padded_aligned);
auto cb_src2 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src2_config);
}

Buffer* src0_buffer = a.buffer();
Expand Down Expand Up @@ -1104,13 +1112,15 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2(
(std::uint32_t)(stick_size_padded_end / row_major_min_bytes),
(std::uint32_t)(stick_size_padded / row_major_min_bytes),
(std::uint32_t)src_stick_size_is_power_of_two,
(std::uint32_t)src_stick_size_is_power_of_two ? src_log2_stick_size : stick_size};
(std::uint32_t)src_stick_size_is_power_of_two ? src_log2_stick_size : stick_size,
(std::uint32_t)stick_size_padded_aligned};
std::vector<uint32_t> writer_ct_args = {
(std::uint32_t)src0_cb_index,
(std::uint32_t)dst_is_dram,
(std::uint32_t)stick_size_padded,
(std::uint32_t)dst_stick_size_is_power_of_two,
(std::uint32_t)dst_stick_size_is_power_of_two ? dst_log2_stick_size : stick_size_padded};
(std::uint32_t)dst_stick_size_is_power_of_two ? dst_log2_stick_size : stick_size_padded,
(std::uint32_t)stick_size_padded_aligned};

KernelHandle reader_kernel_id = CreateKernel(
program,
Expand Down

0 comments on commit 38fddcd

Please sign in to comment.