diff --git a/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py b/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py index 5499c0dc7de..02058d8f739 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py @@ -18,12 +18,15 @@ def data_gen_with_range_batch_norm( device, is_input=False, required_grad=False, + testing_dtype="bfloat16", ): assert high > low, "Incorrect range provided" torch.manual_seed(213919) channels = input_shapes[1] size = input_shapes if is_input else channels - pt_tensor = torch.rand(size, requires_grad=required_grad).bfloat16() * (high - low) + low + torch_dtype = getattr(torch, testing_dtype) + ttnn_dtype = getattr(ttnn, testing_dtype) + pt_tensor = torch.rand(size, requires_grad=required_grad, dtype=torch_dtype) * (high - low) + low reshaped_tensor = pt_tensor if not is_input: reshaped_tensor = pt_tensor.view(1, channels, 1, 1) @@ -31,7 +34,7 @@ def data_gen_with_range_batch_norm( reshaped_tensor, device=device, layout=ttnn.TILE_LAYOUT, - dtype=ttnn.bfloat16, + dtype=ttnn_dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) return pt_tensor, tt_tensor diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index 66d5d432d01..56922409d00 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -10,6 +10,129 @@ compare_results_batch_norm, ) from itertools import product +from models.utility_functions import skip_for_grayskull + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) +@pytest.mark.parametrize("channel_size", [1, 2, 3, 4]) +@pytest.mark.parametrize("weight", [True, False]) +@pytest.mark.parametrize("bias", [True, False]) +def test_BN_fp32_full_value(device, channel_size, eps, weight, bias): + input_tensor_torch = torch.full(torch.Size([3, channel_size, 64, 120]), 1, dtype=torch.float32) + batch_mean_torch = torch.full(torch.Size([channel_size]), 0.00030171126, dtype=torch.float32) + batch_var_torch = torch.full(torch.Size([channel_size]), 0.1262342343, dtype=torch.float32) + weight_torch = torch.full(torch.Size([channel_size]), 0.246943565369, dtype=torch.float32) if weight else None + bias_torch = torch.full(torch.Size([channel_size]), 0.59, dtype=torch.float32) if bias else None + + result_torch = torch.nn.functional.batch_norm( + input=input_tensor_torch, + running_mean=batch_mean_torch, + running_var=batch_var_torch, + weight=weight_torch, + bias=bias_torch, + eps=eps, + ) + + batch_mean_torch = batch_mean_torch.view(1, channel_size, 1, 1) + batch_var_torch = batch_var_torch.view(1, channel_size, 1, 1) + weight_torch = weight_torch.view(1, channel_size, 1, 1) if weight else None + bias_torch = bias_torch.view(1, channel_size, 1, 1) if bias else None + + input_tensor_tt = ttnn.from_torch(input_tensor_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + batch_mean_tt = ttnn.from_torch(batch_mean_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + batch_var_tt = ttnn.from_torch(batch_var_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + weight_tt = ( + ttnn.from_torch(weight_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) if weight else None + ) + bias_tt = ttnn.from_torch(bias_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) if bias else None + + result_tt = ttnn.batch_norm( + input_tensor_tt, running_mean=batch_mean_tt, running_var=batch_var_tt, eps=eps, weight=weight_tt, bias=bias_tt + ) + tt_out = ttnn.to_torch(result_tt) + + status_1 = torch.allclose(result_torch, tt_out, atol=1e-10, rtol=1e-5) + status_2 = compare_results_batch_norm([result_torch], [tt_out]) + assert status_2 and status_1 + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "input_shapes", + [ + *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), + *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), + *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), + torch.Size([3, 1, 64, 120]), + torch.Size([3, 2, 64, 120]), + ], +) +@pytest.mark.parametrize( + "check_mean, check_var", + [ + (False, False), # xfail case + (True, False), # xfail case + (False, True), # xfail case + (True, True), + ], +) +@pytest.mark.parametrize("weight", [True, False]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) +def test_batch_norm_fp32( + input_shapes, check_mean, check_var, weight, bias, eps, device, training=False, testing_dtype="float32" +): + in_data, input_tensor = data_gen_with_range_batch_norm( + input_shapes, 5, 10, device, is_input=True, testing_dtype=testing_dtype + ) + mean_data, mean_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype) + if (check_mean) + else (None, None) + ) + var_data, var_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 20, device, testing_dtype=testing_dtype) + if (check_var) + else (None, None) + ) + weight_data, weight_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype) + if weight + else (None, None) + ) + bias_data, bias_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype) + if bias + else (None, None) + ) + + if (not training) and ((not check_mean) or (not check_var)): + pytest.xfail("running_mean and running_var must be defined in evaluation mode") + + tt_output_tensor_on_device = ttnn.batch_norm( + input_tensor, + running_mean=mean_tensor, + running_var=var_tensor, + training=training, + eps=eps, + weight=weight_tensor, + bias=bias_tensor, + ) + tt_output = ttnn.to_torch(tt_output_tensor_on_device) + torch_result = torch.nn.functional.batch_norm( + input=in_data, + running_mean=mean_data, + running_var=var_data, + weight=weight_data, + bias=bias_data, + training=training, + eps=eps, + ) + comp_pass = compare_results_batch_norm([tt_output], [torch_result]) and torch.allclose( + torch_result, tt_output, atol=1e-6, rtol=1e-3 + ) + assert comp_pass @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp index 0ec70f7c7a2..4131612e660 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp @@ -8,42 +8,49 @@ #include "ttnn/tensor/tensor.hpp" namespace ttnn::operations::normalization { + +namespace { +inline void check_tensor_BN(const Tensor& tensor, std::string_view name, std::uint32_t input_c_dim) { + TT_FATAL( + tensor.get_layout() == Layout::TILE, "batch_norm only supports tiled layout. Got: {}", tensor.get_layout()); + TT_FATAL( + tensor.get_dtype() == DataType::BFLOAT16 || tensor.get_dtype() == DataType::FLOAT32, + "batch_norm only supports bfloat16, float32. Got: {}", + tensor.get_dtype()); + TT_FATAL( + tensor.storage_type() == StorageType::DEVICE, + "Operands to batch_norm need to be on device! Got: {}", + tensor.storage_type()); + TT_FATAL(tensor.buffer() != nullptr, "Operands to batch_norm need to be allocated in buffers on device!"); + TT_FATAL(tensor.get_logical_shape().rank() == 4, "batch_norm supports tensors of rank 4"); + TT_FATAL(tensor.get_logical_shape()[1] == input_c_dim, "{}[1] must be the same as input's channel size.", name); +} +} // namespace + void BatchNormOperation::validate_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args; - check_tensor(input, "batch_norm", "input"); - check_tensor(batch_mean, "batch_norm", "batch_mean"); - check_tensor(batch_var, "batch_norm", "batch_var"); - check_tensor(weight, "batch_norm", "weight"); - check_tensor(bias, "batch_norm", "bias"); - check_tensor(output, "batch_norm", "output"); - // input (N, C, H, W) auto C = input.get_logical_shape()[1]; + + check_tensor_BN(input, "input_shape", C); + check_tensor_BN(batch_mean, "batch_mean_shape", C); + check_tensor_BN(batch_var, "batch_mean_shape", C); + // output (N, C, H, W) if (output.has_value()) { - auto check_C = output.value().get_logical_shape()[1]; - TT_FATAL(C == check_C, "output_shape[1] must be the same as input's channel size."); + check_tensor_BN(output.value(), "output_shape", C); } - // mean (1, C, 1, 1) - TT_FATAL(batch_mean.get_logical_shape()[1] == C, "batch_mean_shape[1] must be the same as input's channel size."); - // var (1, C, 1, 1) - TT_FATAL(batch_var.get_logical_shape()[1] == C, "batch_var_shape[1] must be the same as input's channel size."); - // weight (1, C, 1, 1) if (weight.has_value()) { - TT_FATAL( - weight.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size."); - TT_FATAL( - weight.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size."); + check_tensor_BN(weight.value(), "weight_shape", C); } // bias (1, C, 1, 1) if (bias.has_value()) { - TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size."); - TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size."); + check_tensor_BN(bias.value(), "bias_shape", C); } } @@ -127,7 +134,7 @@ std::tuple bias, std::optional output, const std::optional& memory_config) { - operation_attributes_t operation_attributes{eps, memory_config.value_or(input.memory_config())}; + operation_attributes_t operation_attributes{eps, memory_config.value_or(input.memory_config()), input.get_dtype()}; tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)}; return {operation_attributes, tensor_args}; } diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index c640a45e00d..a0f062da2f8 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -73,8 +73,11 @@ void set_or_update_runtime_arguments( } uint32_t cHtWt = cHt * cWt; - class bfloat16 bfloat_scalar_eps(eps); - uint32_t packed_scalar_eps = pack_two_bfloat16_into_uint32({bfloat_scalar_eps, bfloat_scalar_eps}); + const auto scalar = eps; + const auto packed_scalar_eps = input_tensor.get_dtype() == DataType::FLOAT32 + ? std::bit_cast(scalar) + : pack_two_bfloat16_into_uint32({scalar, scalar}); + std::array reader_runtime_args = { packed_scalar_eps, input_tensor.buffer()->address(), @@ -218,38 +221,83 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch const auto e_is_dram = weight_has_value and weight_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM; const auto f_is_dram = bias_has_value and bias_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + std::map dataflow_defines; // Currently support only for fp32, bf16 + if (input_tensor.get_dtype() == DataType::FLOAT32) { + dataflow_defines["FILL_TILE_WITH_FIRST_ELEMENT"] = "fill_tile_with_first_element"; + dataflow_defines["FILL_WITH_VALUE_FLOAT"] = "fill_with_val<1024, float>"; + } else { + dataflow_defines["FILL_TILE_WITH_FIRST_ELEMENT"] = "fill_tile_with_first_element_bfloat16"; + dataflow_defines["FILL_WITH_VALUE"] = "fill_with_val_bfloat16"; + } + // READER KERNEL + auto reader_defines = dataflow_defines; auto reader_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp", all_device_cores, - tt_metal::ReaderDataMovementConfig({a_is_dram})); + tt_metal::ReaderDataMovementConfig({a_is_dram}, std::move(reader_defines))); // WRITER KERNEL + auto writer_defines = dataflow_defines; auto writer_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp", all_device_cores, - tt_metal::WriterDataMovementConfig({ - b_is_dram, - c_is_dram, - d_is_dram, - e_is_dram, - f_is_dram, - static_cast(weight_has_value), - static_cast(bias_has_value), - })); + tt_metal::WriterDataMovementConfig( + { + b_is_dram, + c_is_dram, + d_is_dram, + e_is_dram, + f_is_dram, + static_cast(weight_has_value), + static_cast(bias_has_value), + }, + std::move(writer_defines))); // COMPUTE KERNEL bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 || c_data_format == tt::DataFormat::Float32; + + uint32_t src_input_cb_index = tt::CBIndex::c_0; + uint32_t src_batch_mean_cb_index = tt::CBIndex::c_1; + uint32_t src_batch_var_cb_index = tt::CBIndex::c_3; + uint32_t src_eps_cb_index = tt::CBIndex::c_4; + uint32_t src_temp_den_cb_index = tt::CBIndex::c_5; + uint32_t src_temp_num_cb_index = tt::CBIndex::c_6; + uint32_t src_weight_cb_index = tt::CBIndex::c_16; + uint32_t src_temp_1_cb_index = tt::CBIndex::c_17; + uint32_t src_bias_cb_index = tt::CBIndex::c_18; + + std::vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); + if (fp32_dest_acc_en) { + for (const auto cb_index : + {src_input_cb_index, + src_batch_mean_cb_index, + src_batch_var_cb_index, + src_temp_num_cb_index, + src_temp_den_cb_index, + src_eps_cb_index, + src_weight_cb_index, + src_temp_1_cb_index, + src_bias_cb_index}) { + unpack_to_dest_mode[cb_index] = UnpackToDestMode::UnpackToDestFp32; + } + } + std::vector compute_kernel_args = { static_cast(weight_has_value), static_cast(bias_has_value)}; auto compute_kernel_id = tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp", + fmt::format( + "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_{}.cpp", + fp32_dest_acc_en ? "sfpu_kernel" : "kernel"), all_device_cores, - tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_kernel_args}); + tt_metal::ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, + .unpack_to_dest_mode = std::move(unpack_to_dest_mode), + .compile_args = compute_kernel_args}); auto set_runtime_args = [](Program& program, KernelHandle kernel_id, CoreCoord core, auto&& args) { tt_metal::SetRuntimeArgs(program, kernel_id, core, args); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp new file mode 100644 index 00000000000..52942da1f55 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_sfpu_kernel.cpp @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "compute_kernel_api/eltwise_binary_sfpu.h" +#include "cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" + +#include + +namespace NAMESPACE { + +ALWI void batchnorm_bcast_tiles( + uint32_t cb_bcast, + uint32_t cb_other, + uint32_t freq, + uint32_t tile_start, + uint32_t cb_batch_var, + uint32_t cb_eps, + uint32_t cb_den, + uint32_t cb_num, + uint32_t cb_weight, + uint32_t cb_bias, + uint32_t cb_tmp_1, + uint32_t cb_output_0, + uint32_t weight_has, + uint32_t bias_has) { + constexpr uint32_t onetile = 1; + constexpr int dst0 = 0; + uint32_t weight_has_value = weight_has; + uint32_t bias_has_value = bias_has; + auto cb_affine_or_out = (weight_has_value || bias_has_value) ? cb_tmp_1 : cb_output_0; + auto cb_scaled_output = (bias_has_value) ? cb_tmp_1 : cb_output_0; + + // input - batch_mean + cb_wait_front(cb_bcast, onetile); + for (uint32_t j = tile_start; j < freq; ++j) { + cb_wait_front(cb_other, onetile); + + cb_reserve_back(cb_num, onetile); + + sub_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_bcast, cb_other); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_other, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_other, cb_bcast); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_bcast, i, i * 2 + 1); + sub_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_num); + } + tile_regs_release(); + cb_push_back(cb_num, onetile); + cb_pop_front(cb_other, onetile); + } + cb_pop_front(cb_bcast, onetile); + + // 1/(sqrt(batch_var + eps)) + cb_reserve_back(cb_den, onetile); + cb_wait_front(cb_batch_var, onetile); + cb_wait_front(cb_eps, onetile); + + add_binary_tile_init(); + rsqrt_tile_init(); + copy_tile_to_dst_init_short_with_dt(cb_eps, cb_batch_var); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_batch_var, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_batch_var, cb_eps); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_eps, i, i * 2 + 1); + + add_binary_tile(i * 2, i * 2 + 1); + rsqrt_tile(i * 2); + tile_regs_commit(); + + tile_regs_wait(); + pack_tile(i * 2, cb_den); + } + tile_regs_release(); + + cb_push_back(cb_den, onetile); + cb_pop_front(cb_batch_var, onetile); + cb_pop_front(cb_eps, onetile); + + // (input - batch_mean)/(sqrt(batch_var + eps)) = result + cb_wait_front(cb_den, onetile); + for (uint32_t j = tile_start; j < freq; ++j) { + cb_wait_front(cb_num, onetile); + + cb_reserve_back(cb_affine_or_out, onetile); + + mul_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_den, cb_num); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_num, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_num, cb_den); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_den, i, i * 2 + 1); + mul_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_affine_or_out); + } + tile_regs_release(); + cb_push_back(cb_affine_or_out, onetile); + cb_pop_front(cb_num, onetile); + } + cb_pop_front(cb_den, onetile); + + if (weight_has_value) { // result = result * weight + cb_wait_front(cb_weight, onetile); + for (uint32_t j = tile_start; j < freq; ++j) { + cb_wait_front(cb_affine_or_out, onetile); + + cb_reserve_back(cb_scaled_output, onetile); + + mul_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_weight, cb_affine_or_out); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_affine_or_out, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_affine_or_out, cb_weight); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_weight, i, i * 2 + 1); + mul_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_scaled_output); + } + tile_regs_release(); + cb_push_back(cb_scaled_output, onetile); + cb_pop_front(cb_affine_or_out, onetile); + } + cb_pop_front(cb_weight, onetile); + } + + if (bias_has_value) { // result = result + bias + cb_wait_front(cb_bias, onetile); + for (uint32_t j = tile_start; j < freq; ++j) { + cb_wait_front(cb_tmp_1, onetile); + + cb_reserve_back(cb_output_0, onetile); + + add_binary_tile_init(); + tile_regs_acquire(); + tile_regs_wait(); + copy_tile_to_dst_init_short_with_dt(cb_bias, cb_tmp_1); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_tmp_1, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_tmp_1, cb_bias); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_bias, i, i * 2 + 1); + add_binary_tile(i * 2, i * 2 + 1); + tile_regs_commit(); + pack_tile(i * 2, cb_output_0); + } + tile_regs_release(); + cb_push_back(cb_output_0, onetile); + cb_pop_front(cb_tmp_1, onetile); + } + cb_pop_front(cb_bias, onetile); + } +} + +void MAIN { + uint32_t num_tiles = get_arg_val(0); + uint32_t tile_freq = get_arg_val(1); + uint32_t tile_start = get_arg_val(2); + constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1; + constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1; + + if (num_tiles == 0) { + return; + } + + constexpr auto cb_input = tt::CBIndex::c_0; // input + constexpr auto cb_batch_mean = tt::CBIndex::c_1; // batch_mean + constexpr auto cb_output_0 = + tt::CBIndex::c_2; // output -- > [(input - batch_mean)/(sqrt(batch_var + eps))] * weight + constexpr auto cb_batch_var = tt::CBIndex::c_3; // batch_var + constexpr auto cb_eps = tt::CBIndex::c_4; // eps + constexpr auto cb_den = tt::CBIndex::c_5; // 1/(sqrt(batch_var + eps)) + constexpr auto cb_num = tt::CBIndex::c_6; // input - batch_mean + constexpr auto cb_weight = tt::CBIndex::c_16; // weight tensor + constexpr auto cb_tmp_1 = tt::CBIndex::c_17; // (input - batch_mean)/(sqrt(batch_var + eps)) + constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor + + auto cb_bcast = cb_batch_mean; + auto cb_other = cb_input; + + unary_op_init_common(cb_other, cb_output_0); + + uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq; + uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq; + for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) { + batchnorm_bcast_tiles( + cb_bcast, + cb_other, + tile_freq, + tile_start, + cb_batch_var, + cb_eps, + cb_den, + cb_num, + cb_weight, + cb_bias, + cb_tmp_1, + cb_output_0, + weight_has_value, + bias_has_value); + } + if (remaining_iterations > 0) { + batchnorm_bcast_tiles( + cb_bcast, + cb_other, + remaining_iterations, + tile_start, + cb_batch_var, + cb_eps, + cb_den, + cb_num, + cb_weight, + cb_bias, + cb_tmp_1, + cb_output_0, + weight_has_value, + bias_has_value); + } + + constexpr uint32_t onetile = 1; + constexpr int dst0 = 0; +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp index a5f9c86787a..ebf287dce1f 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp @@ -37,8 +37,18 @@ void kernel_main() { constexpr auto cb_id_eps = tt::CBIndex::c_4; + union { + float f; + uint32_t u; + } scalar; + scalar.u = eps; cb_reserve_back(cb_id_eps, onetile); - fill_with_val_bfloat16(cb_id_eps, eps); +#ifdef FILL_WITH_VALUE_FLOAT + FILL_WITH_VALUE_FLOAT(cb_id_eps, scalar.f); +#endif +#ifdef FILL_WITH_VALUE + FILL_WITH_VALUE(cb_id_eps, eps); +#endif cb_push_back(cb_id_eps, onetile); // Input tile offset diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp index 0143fbec042..0c80abbc870 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp @@ -89,7 +89,7 @@ void kernel_main() { uint32_t l1_write_addr = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset, src, l1_write_addr); noc_async_read_barrier(); - fill_tile_with_first_element_bfloat16(cb_id_src); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_src); cb_push_back(cb_id_src, onetile); // read a tile from batch variance @@ -97,7 +97,7 @@ void kernel_main() { uint32_t l1_batch_var_write_addr = get_write_ptr(cb_id_batch_var); noc_async_read_tile(tile_offset, batch_var, l1_batch_var_write_addr); noc_async_read_barrier(); - fill_tile_with_first_element_bfloat16(cb_id_batch_var); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_batch_var); cb_push_back(cb_id_batch_var, onetile); if constexpr (weight_has_value) { // read a tile from weight tensor @@ -105,7 +105,7 @@ void kernel_main() { uint32_t l1_weight_write_addr = get_write_ptr(cb_id_weight); noc_async_read_tile(tile_offset, weight, l1_weight_write_addr); noc_async_read_barrier(); - fill_tile_with_first_element_bfloat16(cb_id_weight); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_weight); cb_push_back(cb_id_weight, onetile); } @@ -114,7 +114,7 @@ void kernel_main() { uint32_t l1_bias_write_addr = get_write_ptr(cb_id_bias); noc_async_read_tile(tile_offset, bias, l1_bias_write_addr); noc_async_read_barrier(); - fill_tile_with_first_element_bfloat16(cb_id_bias); + FILL_TILE_WITH_FIRST_ELEMENT(cb_id_bias); cb_push_back(cb_id_bias, onetile); }