From cebf525e9ba0b242a396a1ae6cce8de1a557731c Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:10:51 +0530 Subject: [PATCH] #18332: Move input stats to reader file (#18335) Continuation of another PR. Will be merged once CI passes Used for testing --- .../device/batch_norm_program_factory.cpp | 50 ++++++---- .../kernels/dataflow/reader_batch_norm.cpp | 88 +++++++++++++++++ .../kernels/dataflow/writer_batch_norm.cpp | 97 ++----------------- 3 files changed, 127 insertions(+), 108 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index 1b6f8984a63..2b876923619 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -54,8 +54,8 @@ void set_or_update_runtime_arguments( tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_tiles, row_major); auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); - constexpr size_t num_reader_args = 9; - constexpr size_t num_writer_args = 12; + constexpr size_t num_reader_args = 15; + constexpr size_t num_writer_args = 8; constexpr size_t num_kernel_args = 3; for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) { const auto& core = cores[i]; @@ -78,6 +78,9 @@ void set_or_update_runtime_arguments( ? std::bit_cast(scalar) : pack_two_bfloat16_into_uint32({scalar, scalar}); + const auto weight_addr = weight_has_value ? weight_tensor->buffer()->address() : 0; + const auto bias_addr = bias_has_value ? bias_tensor->buffer()->address() : 0; + std::array reader_runtime_args = { packed_scalar_eps, input_tensor.buffer()->address(), @@ -87,17 +90,18 @@ void set_or_update_runtime_arguments( aHt * aWt * aC * (aN > 1), aHt * aWt * (aC > 1), cN, - cC}; + cC, + bHt * bWt * bC * (bN > 1), + bHt * bWt * (bC > 1), + batch_var_tensor.buffer()->address(), // batch var + weight_addr, // weight + bias_addr, // bias + batch_mean_tensor.buffer()->address() // batch mean + }; handle_args(program, reader_kernel_id, core, reader_runtime_args); - const auto weight_addr = weight_has_value ? weight_tensor->buffer()->address() : 0; - const auto bias_addr = bias_has_value ? bias_tensor->buffer()->address() : 0; std::array writer_runtime_args = { - batch_mean_tensor.buffer()->address(), // batch mean - batch_var_tensor.buffer()->address(), // batch var - weight_addr, // weight - bias_addr, // bias - c.buffer()->address(), // output + c.buffer()->address(), // output start_tile_id, num_tiles_per_core, cHtWt, @@ -232,7 +236,21 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp", all_device_cores, - tt_metal::ReaderDataMovementConfig({a_is_dram, input_tensor_cb, eps_cb}, std::move(reader_defines))); + tt_metal::ReaderDataMovementConfig( + {a_is_dram, + input_tensor_cb, + eps_cb, + d_is_dram, + batch_var_tensor_cb, + e_is_dram, + weight_tensor_cb, + static_cast(weight_has_value), + bias_tensor_cb, + f_is_dram, + static_cast(bias_has_value), + b_is_dram, + batch_mean_tensor_cb}, + std::move(reader_defines))); // WRITER KERNEL auto writer_defines = dataflow_defines; @@ -242,18 +260,8 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch all_device_cores, tt_metal::WriterDataMovementConfig( { - b_is_dram, c_is_dram, - d_is_dram, - e_is_dram, - f_is_dram, - static_cast(weight_has_value), - static_cast(bias_has_value), - batch_mean_tensor_cb, output_tensor_cb, - batch_var_tensor_cb, - weight_tensor_cb, - bias_tensor_cb, }, std::move(writer_defines))); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp index e0c453eb786..e8049f2702c 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp @@ -18,6 +18,12 @@ void kernel_main() { uint32_t c_stride = get_arg_val(6); uint32_t N = get_arg_val(7); uint32_t C = get_arg_val(8); + uint32_t n_stride_stat = get_arg_val(9); + uint32_t c_stride_stat = get_arg_val(10); + uint32_t batch_var_addr = get_arg_val(11); // batch_var + uint32_t weight_addr = get_arg_val(12); // weight + uint32_t bias_addr = get_arg_val(13); // bias + uint32_t batch_mean_addr = get_arg_val(14); // batch_mean constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; @@ -54,12 +60,92 @@ void kernel_main() { // Input tile offset uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t; + // Inputs stats offset + uint32_t tile_offset_stat = start_n * n_stride_stat + start_c * c_stride_stat; + uint32_t next_batch_shift_stat = n_stride_stat - c_stride_stat * C; + uint32_t next_channel_shift = c_stride - HtWt; uint32_t next_batch_shift = n_stride - c_stride * C; + // batch_mean + constexpr auto cb_id_batch_mean = get_compile_time_arg_val(12); + constexpr bool batch_mean_is_dram = get_compile_time_arg_val(11) == 1; + const uint32_t batch_mean_tile_bytes = get_tile_size(cb_id_batch_mean); + const DataFormat batch_mean_data_format = get_dataformat(cb_id_batch_mean); + + const InterleavedAddrGenFast batch_mean = { + .bank_base_address = batch_mean_addr, + .page_size = batch_mean_tile_bytes, + .data_format = batch_mean_data_format}; + + // batch_var + constexpr auto cb_id_batch_var = get_compile_time_arg_val(4); + constexpr bool batch_var_is_dram = get_compile_time_arg_val(3) == 1; + const uint32_t batch_var_tile_bytes = get_tile_size(cb_id_batch_var); + const DataFormat batch_var_data_format = get_dataformat(cb_id_batch_var); + + const InterleavedAddrGenFast batch_var = { + .bank_base_address = batch_var_addr, .page_size = batch_var_tile_bytes, .data_format = batch_var_data_format}; + + // weight + constexpr auto cb_id_weight = get_compile_time_arg_val(6); + constexpr bool weight_is_dram = get_compile_time_arg_val(5) == 1; + const uint32_t weight_tile_bytes = get_tile_size(cb_id_weight); + const DataFormat weight_data_format = get_dataformat(cb_id_weight); + + const InterleavedAddrGenFast weight = { + .bank_base_address = weight_addr, .page_size = weight_tile_bytes, .data_format = weight_data_format}; + + constexpr bool weight_has_value = get_compile_time_arg_val(7) == 1; + + // bias + constexpr auto cb_id_bias = get_compile_time_arg_val(8); + constexpr bool bias_is_dram = get_compile_time_arg_val(9) == 1; + const uint32_t bias_tile_bytes = get_tile_size(cb_id_bias); + const DataFormat bias_data_format = get_dataformat(cb_id_bias); + + const InterleavedAddrGenFast bias = { + .bank_base_address = bias_addr, .page_size = bias_tile_bytes, .data_format = bias_data_format}; + + constexpr bool bias_has_value = get_compile_time_arg_val(10) == 1; + uint32_t num_tiles_read = 0; for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) { for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_t = 0) { + // read a tile from batch_mean + cb_reserve_back(cb_id_batch_mean, onetile); + uint32_t l1_write_addr = get_write_ptr(cb_id_batch_mean); + noc_async_read_tile(tile_offset_stat, batch_mean, l1_write_addr); + noc_async_read_barrier(); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_batch_mean); + cb_push_back(cb_id_batch_mean, onetile); + + // read a tile from batch variance + cb_reserve_back(cb_id_batch_var, onetile); + uint32_t l1_batch_var_write_addr = get_write_ptr(cb_id_batch_var); + noc_async_read_tile(tile_offset_stat, batch_var, l1_batch_var_write_addr); + noc_async_read_barrier(); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_batch_var); + cb_push_back(cb_id_batch_var, onetile); + + if constexpr (weight_has_value) { // read a tile from weight tensor + cb_reserve_back(cb_id_weight, onetile); + uint32_t l1_weight_write_addr = get_write_ptr(cb_id_weight); + noc_async_read_tile(tile_offset_stat, weight, l1_weight_write_addr); + noc_async_read_barrier(); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_weight); + cb_push_back(cb_id_weight, onetile); + } + + if constexpr (bias_has_value) { // read a tile from bias tensor + cb_reserve_back(cb_id_bias, onetile); + uint32_t l1_bias_write_addr = get_write_ptr(cb_id_bias); + noc_async_read_tile(tile_offset_stat, bias, l1_bias_write_addr); + noc_async_read_barrier(); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_bias); + cb_push_back(cb_id_bias, onetile); + } + for (uint32_t t = start_t; t < HtWt && num_tiles_read < num_tiles; ++t, ++num_tiles_read, ++tile_offset) { cb_reserve_back(cb_id_src, onetile); uint32_t l1_write_addr_src = get_write_ptr(cb_id_src); @@ -68,7 +154,9 @@ void kernel_main() { cb_push_back(cb_id_src, onetile); } tile_offset += next_channel_shift; + tile_offset_stat += c_stride_stat; } tile_offset += next_batch_shift; + tile_offset_stat += next_batch_shift_stat; } } diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp index f95965ca242..56a493cc624 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp @@ -8,69 +8,26 @@ #include "cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp" void kernel_main() { - uint32_t src_addr = get_arg_val(0); // batch_mean - uint32_t batch_var_addr = get_arg_val(1); // batch_var - uint32_t weight_addr = get_arg_val(2); // weight - uint32_t bias_addr = get_arg_val(3); // bias - uint32_t dst_addr = get_arg_val(4); // output - uint32_t start_tile_id = get_arg_val(5); - uint32_t num_tiles = get_arg_val(6); - uint32_t HtWt = get_arg_val(7); - uint32_t n_stride = get_arg_val(8); - uint32_t c_stride = get_arg_val(9); - uint32_t N = get_arg_val(10); - uint32_t C = get_arg_val(11); + uint32_t dst_addr = get_arg_val(0); // output + uint32_t start_tile_id = get_arg_val(1); + uint32_t num_tiles = get_arg_val(2); + uint32_t HtWt = get_arg_val(3); + uint32_t n_stride = get_arg_val(4); + uint32_t c_stride = get_arg_val(5); + uint32_t N = get_arg_val(6); + uint32_t C = get_arg_val(7); constexpr uint32_t onetile = 1; - // batch_mean - constexpr auto cb_id_src = get_compile_time_arg_val(7); - constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; - const uint32_t src_tile_bytes = get_tile_size(cb_id_src); - const DataFormat src_data_format = get_dataformat(cb_id_src); - - const InterleavedAddrGenFast src = { - .bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format}; - // output - constexpr auto cb_id_dst = get_compile_time_arg_val(8); - constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; + constexpr auto cb_id_dst = get_compile_time_arg_val(1); + constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst); const DataFormat dst_data_format = get_dataformat(cb_id_dst); const InterleavedAddrGenFast dst = { .bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format}; - // batch_var - constexpr auto cb_id_batch_var = get_compile_time_arg_val(9); - constexpr bool batch_var_is_dram = get_compile_time_arg_val(2) == 1; - const uint32_t batch_var_tile_bytes = get_tile_size(cb_id_batch_var); - const DataFormat batch_var_data_format = get_dataformat(cb_id_batch_var); - - const InterleavedAddrGenFast batch_var = { - .bank_base_address = batch_var_addr, .page_size = batch_var_tile_bytes, .data_format = batch_var_data_format}; - - // weight - constexpr auto cb_id_weight = get_compile_time_arg_val(10); - constexpr bool weight_is_dram = get_compile_time_arg_val(3) == 1; - const uint32_t weight_tile_bytes = get_tile_size(cb_id_weight); - const DataFormat weight_data_format = get_dataformat(cb_id_weight); - - const InterleavedAddrGenFast weight = { - .bank_base_address = weight_addr, .page_size = weight_tile_bytes, .data_format = weight_data_format}; - - // bias - constexpr auto cb_id_bias = get_compile_time_arg_val(11); - constexpr bool bias_is_dram = get_compile_time_arg_val(4) == 1; - const uint32_t bias_tile_bytes = get_tile_size(cb_id_bias); - const DataFormat bias_data_format = get_dataformat(cb_id_bias); - - const InterleavedAddrGenFast bias = { - .bank_base_address = bias_addr, .page_size = bias_tile_bytes, .data_format = bias_data_format}; - - constexpr bool weight_has_value = get_compile_time_arg_val(5) == 1; - constexpr bool bias_has_value = get_compile_time_arg_val(6) == 1; - uint32_t tiles_per_batch = HtWt * C; uint32_t start_n = start_tile_id / tiles_per_batch; uint32_t start_remaining = start_tile_id % tiles_per_batch; @@ -84,40 +41,6 @@ void kernel_main() { uint32_t num_tiles_written = 0; for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) { for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_t = 0) { - // read a tile from src - cb_reserve_back(cb_id_src, onetile); - uint32_t l1_write_addr = get_write_ptr(cb_id_src); - noc_async_read_tile(tile_offset, src, l1_write_addr); - noc_async_read_barrier(); - FILL_TILE_WITH_FIRST_ELEMENT(cb_id_src); - cb_push_back(cb_id_src, onetile); - - // read a tile from batch variance - cb_reserve_back(cb_id_batch_var, onetile); - uint32_t l1_batch_var_write_addr = get_write_ptr(cb_id_batch_var); - noc_async_read_tile(tile_offset, batch_var, l1_batch_var_write_addr); - noc_async_read_barrier(); - FILL_TILE_WITH_FIRST_ELEMENT(cb_id_batch_var); - cb_push_back(cb_id_batch_var, onetile); - - if constexpr (weight_has_value) { // read a tile from weight tensor - cb_reserve_back(cb_id_weight, onetile); - uint32_t l1_weight_write_addr = get_write_ptr(cb_id_weight); - noc_async_read_tile(tile_offset, weight, l1_weight_write_addr); - noc_async_read_barrier(); - FILL_TILE_WITH_FIRST_ELEMENT(cb_id_weight); - cb_push_back(cb_id_weight, onetile); - } - - if constexpr (bias_has_value) { // read a tile from bias tensor - cb_reserve_back(cb_id_bias, onetile); - uint32_t l1_bias_write_addr = get_write_ptr(cb_id_bias); - noc_async_read_tile(tile_offset, bias, l1_bias_write_addr); - noc_async_read_barrier(); - FILL_TILE_WITH_FIRST_ELEMENT(cb_id_bias); - cb_push_back(cb_id_bias, onetile); - } - for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) { // write a tile to dst cb_wait_front(cb_id_dst, onetile);