From 922ddf2c445fb0f5e3a0bdd67cbec420e236dcc3 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Wed, 5 Feb 2025 17:05:27 +0000 Subject: [PATCH] #0: Float32 support for Training mode in Batch Norm --- .../unit_tests/operations/test_batch_norm.py | 108 +++++++++ .../compute/running_statistics_kernel.cpp | 4 +- .../running_statistics_sfpu_kernel.cpp | 228 ++++++++++++++++++ .../dataflow/reader_running_statistics.cpp | 15 +- .../dataflow/writer_running_statistics.cpp | 4 +- .../running_statistics_device_operation.cpp | 27 +-- .../running_statistics_program_factory.cpp | 81 +++++-- 7 files changed, 423 insertions(+), 44 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_sfpu_kernel.cpp diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index 56922409d001..1305fc330053 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -13,6 +13,114 @@ from models.utility_functions import skip_for_grayskull +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "input_shapes", + [ + *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), + *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), + *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), + torch.Size([3, 1, 64, 120]), + torch.Size([3, 2, 64, 120]), + ], +) +@pytest.mark.parametrize( + "check_mean, check_var", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +@pytest.mark.parametrize("weight", [True, False]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) +@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5]) +def test_batch_norm_training_fp32( + input_shapes, check_mean, check_var, weight, bias, eps, device, momentum, training=True, testing_dtype="float32" +): + in_data, input_tensor = data_gen_with_range_batch_norm( + input_shapes, 5, 10, device, is_input=True, testing_dtype=testing_dtype + ) + mean_data, mean_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype) + if (check_mean) + else (None, None) + ) + var_data, var_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 20, device, testing_dtype=testing_dtype) + if (check_var) + else (None, None) + ) + weight_data, weight_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype) + if weight + else (None, None) + ) + bias_data, bias_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype) + if bias + else (None, None) + ) + + if (not training) and ((not check_mean) or (not check_var)): + pytest.xfail("running_mean and running_var must be defined in evaluation mode") + + tt_output_tensor_on_device = ttnn.batch_norm( + input_tensor, + running_mean=mean_tensor, + running_var=var_tensor, + training=training, + eps=eps, + weight=weight_tensor, + bias=bias_tensor, + momentum=momentum, + ) + tt_output = ttnn.to_torch(tt_output_tensor_on_device) + tt_updated_mean = None + tt_updated_var = None + if training: + if check_mean: + tt_updated_mean = ttnn.to_torch(mean_tensor) + if check_var: + tt_updated_var = ttnn.to_torch(var_tensor) + + torch_result = torch.nn.functional.batch_norm( + input=in_data, + running_mean=mean_data, + running_var=var_data, + weight=weight_data, + bias=bias_data, + training=training, + eps=eps, + momentum=momentum, + ) + comp_pass = compare_results_batch_norm([tt_output], [torch_result]) + if training: + channels = input_shapes[1] + if check_mean: + comp_pass_1 = compare_results_batch_norm( + [tt_updated_mean], [mean_data.view(1, channels, 1, 1)], stats=True + ) # Check Updated running mean + else: + if tt_updated_mean is None: + comp_pass_1 = True + else: + comp_pass_1 = False + if check_var: + comp_pass_2 = compare_results_batch_norm( + [tt_updated_var], [var_data.view(1, channels, 1, 1)], stats=True + ) # Check Updated running var + else: + if tt_updated_var is None: + comp_pass_2 = True + else: + comp_pass_2 = False + comp_pass = comp_pass and comp_pass_1 and comp_pass_2 + assert comp_pass + + @skip_for_grayskull("Unsupported dtype for Grayskull") @pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) @pytest.mark.parametrize("channel_size", [1, 2, 3, 4]) diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp index f7955a6f81d6..642a1c6f807c 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp @@ -39,13 +39,13 @@ void MAIN { sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat mul_tiles_to_cb(cb_tmp1, cb_old_running_mean, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats - add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3 + add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1); // cb_tmp2 + cb_tmp3 } if constexpr (old_running_var_has_value) { sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats - add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3 + add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); // cb_tmp2 + cb_tmp3 } tile_regs_commit(); tile_regs_wait(); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_sfpu_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_sfpu_kernel.cpp new file mode 100644 index 000000000000..47256317ee82 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_sfpu_kernel.cpp @@ -0,0 +1,228 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" +#include "compute_kernel_api/eltwise_binary_sfpu.h" +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" + +namespace NAMESPACE { +void MAIN { + uint32_t num_tiles = get_arg_val(0); + constexpr uint32_t old_running_mean_has_value = get_compile_time_arg_val(0) == 1; + constexpr uint32_t old_running_var_has_value = get_compile_time_arg_val(1) == 1; + + constexpr auto cb_batch_mean = tt::CBIndex::c_0; // batch mean + constexpr auto cb_batch_var = tt::CBIndex::c_1; // batch var + constexpr auto cb_out0 = tt::CBIndex::c_2; + constexpr auto cb_old_running_mean = tt::CBIndex::c_3; // old running mean tensor + constexpr auto cb_old_running_var = tt::CBIndex::c_4; // old running var tensor + constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor + constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor + constexpr auto cb_momentum = tt::CBIndex::c_5; // momentum + constexpr auto cb_one = tt::CBIndex::c_6; // stores 1 + constexpr auto cb_tmp1 = tt::CBIndex::c_21; // tmp 1 + constexpr auto cb_tmp2 = tt::CBIndex::c_22; // tmp 2 + constexpr auto cb_tmp3 = tt::CBIndex::c_23; // tmp 3 + + unary_op_init_common(cb_batch_mean, cb_out0); + constexpr uint32_t onetile = 1; + + // updated_running_stat = (1 − momentum) × running_stat + momentum × batch_stat + for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { + tile_regs_acquire(); + cb_wait_front(cb_one, 1); + cb_wait_front(cb_momentum, 1); + + if constexpr (old_running_mean_has_value) { + // 1 - momentum + cb_reserve_back(cb_tmp1, onetile); + sub_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_one); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_one, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_one, cb_momentum); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_momentum, i, i * 2 + 1); + sub_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_tmp1); + } + tile_regs_release(); + cb_push_back(cb_tmp1, onetile); + + // momentum * batch stat + cb_wait_front(cb_batch_mean, onetile); + cb_reserve_back(cb_tmp2, onetile); + mul_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_batch_mean); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_batch_mean, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_batch_mean, cb_momentum); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_momentum, i, i * 2 + 1); + mul_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_tmp2); + } + tile_regs_release(); + cb_push_back(cb_tmp2, onetile); + cb_pop_front(cb_batch_mean, onetile); + + // cb_tmp1 * running stats --> (1 - momentum) * running stats + cb_wait_front(cb_tmp1, onetile); + cb_wait_front(cb_old_running_mean, onetile); + cb_reserve_back(cb_tmp3, onetile); + mul_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_tmp1, cb_old_running_mean); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_old_running_mean, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_old_running_mean, cb_tmp1); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_tmp1, i, i * 2 + 1); + mul_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_tmp3); + } + tile_regs_release(); + cb_push_back(cb_tmp3, onetile); + cb_pop_front(cb_old_running_mean, onetile); + cb_pop_front(cb_tmp1, onetile); + + // cb_tmp2 + cb_tmp3 --> (momentum * batch stat) + ((1 - momentum) * running stats) + cb_wait_front(cb_tmp2, onetile); + cb_wait_front(cb_tmp3, onetile); + + cb_reserve_back(cb_updated_running_mean, onetile); + + add_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_tmp2, cb_tmp3); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_tmp3, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_tmp3, cb_tmp2); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_tmp2, i, i * 2 + 1); + add_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_updated_running_mean); + } + tile_regs_release(); + cb_push_back(cb_updated_running_mean, onetile); + cb_pop_front(cb_tmp3, onetile); + cb_pop_front(cb_tmp2, onetile); + } + if constexpr (old_running_var_has_value) { + // 1 - momentum + cb_reserve_back(cb_tmp1, onetile); + sub_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_one); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_one, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_one, cb_momentum); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_momentum, i, i * 2 + 1); + sub_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_tmp1); + } + tile_regs_release(); + cb_push_back(cb_tmp1, onetile); + + // momentum * batch stat + cb_wait_front(cb_batch_var, onetile); + cb_reserve_back(cb_tmp2, onetile); + mul_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_batch_var); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_batch_var, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_batch_var, cb_momentum); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_momentum, i, i * 2 + 1); + mul_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_tmp2); + } + tile_regs_release(); + cb_push_back(cb_tmp2, onetile); + cb_pop_front(cb_batch_var, onetile); + + // cb_tmp1 * running stats --> (1 - momentum) * running stats + cb_wait_front(cb_tmp1, onetile); + cb_wait_front(cb_old_running_var, onetile); + cb_reserve_back(cb_tmp3, onetile); + mul_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_tmp1, cb_old_running_var); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_old_running_var, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_old_running_var, cb_tmp1); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_tmp1, i, i * 2 + 1); + mul_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_tmp3); + } + tile_regs_release(); + cb_push_back(cb_tmp3, onetile); + cb_pop_front(cb_old_running_var, onetile); + cb_pop_front(cb_tmp1, onetile); + + // cb_tmp2 + cb_tmp3 --> (momentum * batch stat) + ((1 - momentum) * running stats) + cb_wait_front(cb_tmp2, onetile); + cb_wait_front(cb_tmp3, onetile); + + cb_reserve_back(cb_updated_running_var, onetile); + + add_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_tmp2, cb_tmp3); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_tmp3, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_tmp3, cb_tmp2); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_tmp2, i, i * 2 + 1); + add_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_updated_running_var); + } + tile_regs_release(); + cb_push_back(cb_updated_running_var, onetile); + cb_pop_front(cb_tmp3, onetile); + cb_pop_front(cb_tmp2, onetile); + } + } + tile_regs_commit(); + tile_regs_wait(); + pack_tile(0, cb_out0); + tile_regs_release(); + cb_pop_front(cb_momentum, 1); + cb_pop_front(cb_one, 1); + cb_push_back(cb_out0, 1); +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp index e27719d5b5e0..e3c457c13c6d 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp @@ -46,12 +46,19 @@ void kernel_main() { union { float f; uint32_t u; - } scalar; - scalar.f = 1.0f; - fill_cb_with_value(cb_id_one, scalar.u); + } scalar_one, scalar_momentum; + scalar_one.f = 1.0f; + fill_cb_with_value(cb_id_one, scalar_one.u); + // momentum + scalar_momentum.u = momentum; cb_reserve_back(cb_id_momentum, onetile); - fill_with_val_bfloat16(cb_id_momentum, momentum); +#ifdef FILL_WITH_VALUE_FLOAT + FILL_WITH_VALUE_FLOAT(cb_id_momentum, scalar_momentum.f); +#endif +#ifdef FILL_WITH_VALUE + FILL_WITH_VALUE(cb_id_momentum, momentum); +#endif cb_push_back(cb_id_momentum, onetile); uint32_t num_tiles_read = 0; diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp index dec7420448b1..6924193e6f67 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp @@ -93,7 +93,7 @@ void kernel_main() { uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean); noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr); noc_async_read_barrier(); - fill_tile_with_first_element_bfloat16(cb_id_old_running_mean); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_old_running_mean); cb_push_back(cb_id_old_running_mean, onetile); // write data @@ -110,7 +110,7 @@ void kernel_main() { uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var); noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr); noc_async_read_barrier(); - fill_tile_with_first_element_bfloat16(cb_id_old_running_var); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_old_running_var); cb_push_back(cb_id_old_running_var, onetile); // write data diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp index 30341012f2ec..c688fa17520f 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp @@ -12,34 +12,20 @@ void RunningStatistics::validate_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { const auto& [batch_mean, batch_var, running_mean, running_var] = tensor_args; - check_tensor(batch_mean, "running_statistics", "batch_mean"); - check_tensor(batch_var, "running_statistics", "batch_var"); - check_tensor(running_mean, "running_statistics", "running_mean"); - check_tensor(running_var, "running_statistics", "running_var"); - // mean (1, C, 1, 1) auto C = batch_mean.get_logical_shape()[1]; - // var (1, C, 1, 1) - TT_FATAL(batch_var.get_logical_shape()[1] == C, "batch_var_shape[1] must be the same as input's channel size."); + + check_tensor_BN(batch_mean, "batch_mean_shape", C); + check_tensor_BN(batch_var, "batch_var_shape", C); // running_mean (1, C, 1, 1) if (running_mean.has_value()) { - TT_FATAL( - running_mean.value().get_logical_shape()[1] == C, - "running_mean_shape[1] must be the same as input's channel size."); - TT_FATAL( - running_mean.value().get_logical_shape()[1] == C, - "running_mean_shape[1] must be the same as input's channel size."); + check_tensor_BN(running_mean.value(), "running_mean_shape", C); } // running_var (1, C, 1, 1) if (running_var.has_value()) { - TT_FATAL( - running_var.value().get_logical_shape()[1] == C, - "running_var_shape[1] must be the same as input's channel size."); - TT_FATAL( - running_var.value().get_logical_shape()[1] == C, - "running_var_shape[1] must be the same as input's channel size."); + check_tensor_BN(running_var.value(), "running_var_shape", C); } } @@ -110,7 +96,8 @@ std::tuple running_mean, std::optional running_var, const std::optional& memory_config) { - operation_attributes_t operation_attributes{momentum, memory_config.value_or(batch_mean.memory_config())}; + operation_attributes_t operation_attributes{ + momentum, memory_config.value_or(batch_mean.memory_config()), batch_mean.get_dtype()}; tensor_args_t tensor_args{batch_mean, batch_var, std::move(running_mean), std::move(running_var)}; return {operation_attributes, tensor_args}; } diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp index 7f476e8f2ea8..05ea322dc21e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp @@ -74,9 +74,10 @@ void set_or_update_runtime_arguments( } uint32_t cHtWt = cHt * cWt; - class bfloat16 bfloat_scalar_momentum(momentum); - uint32_t packed_scalar_momentum = - pack_two_bfloat16_into_uint32({bfloat_scalar_momentum, bfloat_scalar_momentum}); + const auto scalar = momentum; + const auto packed_scalar_momentum = batch_mean_tensor.get_dtype() == DataType::FLOAT32 + ? std::bit_cast(scalar) + : pack_two_bfloat16_into_uint32({scalar, scalar}); std::array reader_runtime_args = { packed_scalar_momentum, batch_mean_tensor.buffer()->address(), @@ -227,8 +228,7 @@ RunningStatistics::RunningStatisticsProgramFactory::create( b_num_tiles_per_cb, e_data_format); // updated running var - // Intermediate buffers required for uodation of running stats - + // Intermediate buffers required for updation of running stats auto [tmp1_cb, tmp1_cb_handle] = create_cb(tt::CBIndex::c_21, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); @@ -246,37 +246,86 @@ RunningStatistics::RunningStatisticsProgramFactory::create( const auto e_is_dram = running_var_has_value and running_var_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + std::map dataflow_defines; // Currently support only for fp32, bf16 + if (batch_mean_tensor.get_dtype() == DataType::FLOAT32) { + dataflow_defines["FILL_TILE_WITH_FIRST_ELEMENT"] = "fill_tile_with_first_element"; + dataflow_defines["FILL_WITH_VALUE_FLOAT"] = "fill_with_val<1024, float>"; + } else { + dataflow_defines["FILL_TILE_WITH_FIRST_ELEMENT"] = "fill_tile_with_first_element_bfloat16"; + dataflow_defines["FILL_WITH_VALUE"] = "fill_with_val_bfloat16"; + } + // READER KERNEL + auto reader_defines = dataflow_defines; auto reader_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp", all_device_cores, - tt_metal::ReaderDataMovementConfig({a_is_dram})); + tt_metal::ReaderDataMovementConfig({a_is_dram}, std::move(reader_defines))); // WRITER KERNEL + auto writer_defines = dataflow_defines; auto writer_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp", all_device_cores, - tt_metal::WriterDataMovementConfig({ - b_is_dram, - c_is_dram, - d_is_dram, - e_is_dram, - static_cast(running_mean_has_value), - static_cast(running_var_has_value), - })); + tt_metal::WriterDataMovementConfig( + { + b_is_dram, + c_is_dram, + d_is_dram, + e_is_dram, + static_cast(running_mean_has_value), + static_cast(running_var_has_value), + }, + std::move(writer_defines))); // COMPUTE KERNEL bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 || c_data_format == tt::DataFormat::Float32; + + uint32_t src_batch_mean_cb_index = tt::CBIndex::c_0; + uint32_t src_batch_var_cb_index = tt::CBIndex::c_1; + uint32_t src_momentum_cb_index = tt::CBIndex::c_5; + uint32_t src_one_cb_index = tt::CBIndex::c_6; + uint32_t src_temp_1_cb_index = tt::CBIndex::c_21; + uint32_t src_temp_2_cb_index = tt::CBIndex::c_22; + uint32_t src_temp_3_cb_index = tt::CBIndex::c_23; + uint32_t src_updated_running_mean_cb_index = tt::CBIndex::c_27; + uint32_t src_old_running_mean_cb_index = tt::CBIndex::c_3; + uint32_t src_updated_running_var_cb_index = tt::CBIndex::c_28; + uint32_t src_old_running_var_cb_index = tt::CBIndex::c_4; + + std::vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); + if (fp32_dest_acc_en) { + for (const auto cb_index : + {src_batch_mean_cb_index, + src_batch_var_cb_index, + src_momentum_cb_index, + src_one_cb_index, + src_temp_1_cb_index, + src_temp_2_cb_index, + src_temp_3_cb_index, + src_updated_running_mean_cb_index, + src_old_running_mean_cb_index, + src_updated_running_var_cb_index, + src_old_running_var_cb_index}) { + unpack_to_dest_mode[cb_index] = UnpackToDestMode::UnpackToDestFp32; + } + } + std::vector compute_kernel_args = { static_cast(running_mean_has_value), static_cast(running_var_has_value)}; auto compute_kernel_id = tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp", + fmt::format( + "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_{}.cpp", + fp32_dest_acc_en ? "sfpu_kernel" : "kernel"), all_device_cores, - tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_kernel_args}); + tt_metal::ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, + .unpack_to_dest_mode = std::move(unpack_to_dest_mode), + .compile_args = compute_kernel_args}); auto set_runtime_args = [](Program& program, KernelHandle kernel_id, CoreCoord core, auto&& args) { tt_metal::SetRuntimeArgs(program, kernel_id, core, args);