From 075f2241439b496511667b95d9b8df4b8116a22d Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Mon, 10 Feb 2025 10:47:17 +0000 Subject: [PATCH] #17758: Update Running stats Writer kernel --- .../kernels/dataflow/writer_running_statistics.cpp | 12 ++++++------ .../device/running_statistics_program_factory.cpp | 6 ++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp index 6924193e6f67..03b2b474b364 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp @@ -22,7 +22,7 @@ void kernel_main() { constexpr uint32_t onetile = 1; - constexpr auto cb_id_src = tt::CBIndex::c_1; + constexpr auto cb_id_src = get_compile_time_arg_val(6); constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; const uint32_t src_tile_bytes = get_tile_size(cb_id_src); const DataFormat src_data_format = get_dataformat(cb_id_src); @@ -30,7 +30,7 @@ void kernel_main() { const InterleavedAddrGenFast src = { .bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format}; - constexpr auto cb_id_dst = tt::CBIndex::c_2; + constexpr auto cb_id_dst = get_compile_time_arg_val(7); constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst); const DataFormat dst_data_format = get_dataformat(cb_id_dst); @@ -39,7 +39,7 @@ void kernel_main() { .bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format}; // old running mean - constexpr auto cb_id_old_running_mean = tt::CBIndex::c_3; + constexpr auto cb_id_old_running_mean = get_compile_time_arg_val(8); constexpr bool old_running_mean_is_dram = get_compile_time_arg_val(2) == 1; const uint32_t old_running_mean_tile_bytes = get_tile_size(cb_id_old_running_mean); const DataFormat old_running_mean_data_format = get_dataformat(cb_id_old_running_mean); @@ -50,7 +50,7 @@ void kernel_main() { .data_format = old_running_mean_data_format}; // old running var - constexpr auto cb_id_old_running_var = tt::CBIndex::c_4; + constexpr auto cb_id_old_running_var = get_compile_time_arg_val(9); constexpr bool old_running_var_is_dram = get_compile_time_arg_val(3) == 1; const uint32_t old_running_var_tile_bytes = get_tile_size(cb_id_old_running_var); const DataFormat old_running_var_data_format = get_dataformat(cb_id_old_running_var); @@ -62,8 +62,8 @@ void kernel_main() { constexpr bool old_running_mean_has_value = get_compile_time_arg_val(4) == 1; constexpr bool old_running_var_has_value = get_compile_time_arg_val(5) == 1; - constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27; - constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28; + constexpr auto cb_id_updated_running_mean = get_compile_time_arg_val(10); + constexpr auto cb_id_updated_running_var = get_compile_time_arg_val(11); uint32_t tiles_per_batch = HtWt * C; uint32_t start_n = start_tile_id / tiles_per_batch; diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp index a4d6ee3f27c2..0dfa6b218b03 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp @@ -278,6 +278,12 @@ RunningStatistics::RunningStatisticsProgramFactory::create( e_is_dram, static_cast(running_mean_has_value), static_cast(running_var_has_value), + batch_var_tensor_cb, + output_tensor_cb, + old_running_mean_tensor_cb, + old_running_var_tensor_cb, + updated_m_cb, + updated_v_cb, }, std::move(writer_defines)));