From 127be69264932cdc4db4aa4c3f902263c91f6d26 Mon Sep 17 00:00:00 2001 From: thanhnguyen-moreh Date: Tue, 22 Oct 2024 17:58:23 +0000 Subject: [PATCH] #13931: Move index_size to compile time --- .../device/index_fill_multi_core_factory.cpp | 15 ++++++++------- .../device/kernels/reader_index_fill.cpp | 18 ++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_multi_core_factory.cpp b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_multi_core_factory.cpp index 1ed2f3433ec..7327d13178f 100644 --- a/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_multi_core_factory.cpp +++ b/ttnn/cpp/ttnn/operations/index_fill/device/index_fill_multi_core_factory.cpp @@ -50,6 +50,7 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c Device* device = input.device(); auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; uint32_t num_cores_y = compute_with_storage_grid_size.y; auto [num_cores, all_cores, core_group_1, core_group_2, num_rows_per_core_group_1, num_rows_per_core_group_2] = @@ -107,7 +108,8 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c (std::uint32_t)index_is_dram, (std::uint32_t)src_cb_index, (std::uint32_t)index_cb_index, - (std::uint32_t)(dim == n - 1)}; + (std::uint32_t)(dim == n - 1), + (std::uint32_t)index.volume()}; auto reader_kernel_id = CreateKernel( program, @@ -124,10 +126,11 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c WriterDataMovementConfig(writer_compile_time_args)); uint32_t unit_offset = 0; - for (uint32_t i = 0; i < num_cores; i++) { - const CoreCoord core(i / num_cores_y, i % num_cores_y); - - uint32_t num_rows_per_core = 0; + uint32_t num_cores_group_1 = core_group_1.num_cores(); + auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y); + for (uint32_t i = 0; i < cores.size(); i++) { + const auto& core = cores[i]; + uint32_t num_rows_per_core = i < num_cores_group_1 ? num_rows_per_core_group_1 : num_rows_per_core_group_2; if (core_group_1.core_coord_in_core_ranges(core)) { num_rows_per_core = num_rows_per_core_group_1; } else if (core_group_2.core_coord_in_core_ranges(core)) { @@ -146,8 +149,6 @@ IndexFillOperation::MultiCore::cached_program_t IndexFillOperation::MultiCore::c index_unit_size, unit_offset, num_rows_per_core, - input_shape[-1], - index.volume(), num_rows_to_fill_per_index, input_shape[dim]}); SetRuntimeArgs( diff --git a/ttnn/cpp/ttnn/operations/index_fill/device/kernels/reader_index_fill.cpp b/ttnn/cpp/ttnn/operations/index_fill/device/kernels/reader_index_fill.cpp index aeca15b1948..1da1cba100a 100644 --- a/ttnn/cpp/ttnn/operations/index_fill/device/kernels/reader_index_fill.cpp +++ b/ttnn/cpp/ttnn/operations/index_fill/device/kernels/reader_index_fill.cpp @@ -28,16 +28,15 @@ void kernel_main() { uint32_t index_page_size = get_arg_val(4); uint32_t start_row_id = get_arg_val(5); uint32_t num_rows_per_core = get_arg_val(6); - uint32_t row_size = get_arg_val(7); - uint32_t index_size = get_arg_val(8); - uint32_t num_rows_to_fill_per_index = get_arg_val(9); - uint32_t dim = get_arg_val(10); + uint32_t num_rows_to_fill_per_index = get_arg_val(7); + uint32_t dim = get_arg_val(8); constexpr bool input_is_dram = get_compile_time_arg_val(0) == 1; constexpr bool index_is_dram = get_compile_time_arg_val(1) == 1; constexpr uint32_t src_cb_id = get_compile_time_arg_val(2); constexpr uint32_t index_cb_id = get_compile_time_arg_val(3); constexpr bool is_last_dim = get_compile_time_arg_val(4) == 1; + constexpr uint32_t index_size = get_compile_time_arg_val(5); constexpr uint32_t onetile = 1; @@ -55,7 +54,6 @@ void kernel_main() { noc_async_read(index_noc_addr, index_cb_reader, index_page_size); noc_async_read_barrier(); uint32_t *index_ptr = reinterpret_cast(index_cb_reader); - if (is_last_dim) { for (uint32_t row_id = start_row_id; row_id < start_row_id + num_rows_per_core; row_id++) { cb_reserve_back(src_cb_id, onetile); @@ -66,7 +64,7 @@ void kernel_main() { uint32_t *input_ptr = reinterpret_cast(src_cb_reader); - for (uint32_t i = 0; i < index_page_size / 4; i++) { + for (uint32_t i = 0; i < index_size; i++) { uint32_t current_index = index_ptr[i]; input_ptr[current_index] = fill_value; } @@ -81,22 +79,22 @@ void kernel_main() { noc_async_read(input_noc_addr, src_cb_reader, input_page_size); noc_async_read_barrier(); - if (is_in_indices(index_ptr, index_page_size / 4, row_id / num_rows_to_fill_per_index % dim)) { + if (is_in_indices(index_ptr, index_size, row_id / num_rows_to_fill_per_index % dim)) { #ifdef OUTPUT_DTYPE_BFLOAT16 auto ptr = reinterpret_cast(write_addr); - for (uint32_t i = 0; i < input_page_size / 4; ++i) { + for (uint32_t i = 0; i < index_size; ++i) { ptr[i] = val.u >> 16; } #endif #ifdef OUTPUT_DTYPE_INT32 auto ptr = reinterpret_cast(write_addr); - for (uint32_t i = 0; i < input_page_size / 4; ++i) { + for (uint32_t i = 0; i < index_size; ++i) { ptr[i] = fill_value; } #endif #ifdef OUTPUT_DTYPE_FLOAT32 auto ptr = reinterpret_cast(write_addr); - for (uint32_t i = 0; i < input_page_size / 4; ++i) { + for (uint32_t i = 0; i < index_size; ++i) { ptr[i] = val.f; } #endif