Skip to content

Commit

Permalink
#0: Float32 support for Training mode in Batch Norm
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Feb 9, 2025
1 parent e1a028f commit 7a7022b
Show file tree
Hide file tree
Showing 7 changed files with 424 additions and 44 deletions.
108 changes: 108 additions & 0 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,114 @@
from models.utility_functions import skip_for_grayskull


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize(
"check_mean, check_var",
[
(False, False),
(True, False),
(False, True),
(True, True),
],
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5])
def test_batch_norm_training_fp32(
input_shapes, check_mean, check_var, weight, bias, eps, device, momentum, training=True, testing_dtype="float32"
):
in_data, input_tensor = data_gen_with_range_batch_norm(
input_shapes, 5, 10, device, is_input=True, testing_dtype=testing_dtype
)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype)
if (check_mean)
else (None, None)
)
var_data, var_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 20, device, testing_dtype=testing_dtype)
if (check_var)
else (None, None)
)
weight_data, weight_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype)
if weight
else (None, None)
)
bias_data, bias_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device, testing_dtype=testing_dtype)
if bias
else (None, None)
)

if (not training) and ((not check_mean) or (not check_var)):
pytest.xfail("running_mean and running_var must be defined in evaluation mode")

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor,
running_mean=mean_tensor,
running_var=var_tensor,
training=training,
eps=eps,
weight=weight_tensor,
bias=bias_tensor,
momentum=momentum,
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)
tt_updated_mean = None
tt_updated_var = None
if training:
if check_mean:
tt_updated_mean = ttnn.to_torch(mean_tensor)
if check_var:
tt_updated_var = ttnn.to_torch(var_tensor)

torch_result = torch.nn.functional.batch_norm(
input=in_data,
running_mean=mean_data,
running_var=var_data,
weight=weight_data,
bias=bias_data,
training=training,
eps=eps,
momentum=momentum,
)
comp_pass = compare_results_batch_norm([tt_output], [torch_result])
if training:
channels = input_shapes[1]
if check_mean:
comp_pass_1 = compare_results_batch_norm(
[tt_updated_mean], [mean_data.view(1, channels, 1, 1)], stats=True
) # Check Updated running mean
else:
if tt_updated_mean is None:
comp_pass_1 = True
else:
comp_pass_1 = False
if check_var:
comp_pass_2 = compare_results_batch_norm(
[tt_updated_var], [var_data.view(1, channels, 1, 1)], stats=True
) # Check Updated running var
else:
if tt_updated_var is None:
comp_pass_2 = True
else:
comp_pass_2 = False
comp_pass = comp_pass and comp_pass_1 and comp_pass_2
assert comp_pass


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
@pytest.mark.parametrize("channel_size", [1, 2, 3, 4])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ void MAIN {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat
mul_tiles_to_cb(cb_tmp1, cb_old_running_mean, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1); // cb_tmp2 + cb_tmp3
}
if constexpr (old_running_var_has_value) {
sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum
mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat
mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3
add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); // cb_tmp2 + cb_tmp3
}
tile_regs_commit();
tile_regs_wait();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp"
#include "compute_kernel_api/eltwise_binary_sfpu.h"
#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h"
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"

namespace NAMESPACE {
void MAIN {
uint32_t num_tiles = get_arg_val<uint32_t>(0);
constexpr uint32_t old_running_mean_has_value = get_compile_time_arg_val(0) == 1;
constexpr uint32_t old_running_var_has_value = get_compile_time_arg_val(1) == 1;

constexpr auto cb_batch_mean = tt::CBIndex::c_0; // batch mean
constexpr auto cb_batch_var = tt::CBIndex::c_1; // batch var
constexpr auto cb_out0 = tt::CBIndex::c_2;
constexpr auto cb_old_running_mean = tt::CBIndex::c_3; // old running mean tensor
constexpr auto cb_old_running_var = tt::CBIndex::c_4; // old running var tensor
constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor
constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor
constexpr auto cb_momentum = tt::CBIndex::c_5; // momentum
constexpr auto cb_one = tt::CBIndex::c_6; // stores 1
constexpr auto cb_tmp1 = tt::CBIndex::c_21; // tmp 1
constexpr auto cb_tmp2 = tt::CBIndex::c_22; // tmp 2
constexpr auto cb_tmp3 = tt::CBIndex::c_23; // tmp 3

unary_op_init_common(cb_batch_mean, cb_out0);
constexpr uint32_t onetile = 1;

// updated_running_stat = (1 − momentum) × running_stat + momentum × batch_stat
for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) {
tile_regs_acquire();
cb_wait_front(cb_one, 1);
cb_wait_front(cb_momentum, 1);

if constexpr (old_running_mean_has_value) {
// 1 - momentum
cb_reserve_back(cb_tmp1, onetile);
sub_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_one);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_one, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_one, cb_momentum);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_momentum, i, i * 2 + 1);
sub_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_tmp1);
}
tile_regs_release();
cb_push_back(cb_tmp1, onetile);

// momentum * batch stat
cb_wait_front(cb_batch_mean, onetile);
cb_reserve_back(cb_tmp2, onetile);
mul_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_batch_mean);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_batch_mean, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_batch_mean, cb_momentum);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_momentum, i, i * 2 + 1);
mul_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_tmp2);
}
tile_regs_release();
cb_push_back(cb_tmp2, onetile);
cb_pop_front(cb_batch_mean, onetile);

// cb_tmp1 * running stats --> (1 - momentum) * running stats
cb_wait_front(cb_tmp1, onetile);
cb_wait_front(cb_old_running_mean, onetile);
cb_reserve_back(cb_tmp3, onetile);
mul_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_tmp1, cb_old_running_mean);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_old_running_mean, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_old_running_mean, cb_tmp1);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_tmp1, i, i * 2 + 1);
mul_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_tmp3);
}
tile_regs_release();
cb_push_back(cb_tmp3, onetile);
cb_pop_front(cb_old_running_mean, onetile);
cb_pop_front(cb_tmp1, onetile);

// cb_tmp2 + cb_tmp3 --> (momentum * batch stat) + ((1 - momentum) * running stats)
cb_wait_front(cb_tmp2, onetile);
cb_wait_front(cb_tmp3, onetile);

cb_reserve_back(cb_updated_running_mean, onetile);

add_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_tmp2, cb_tmp3);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_tmp3, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_tmp3, cb_tmp2);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_tmp2, i, i * 2 + 1);
add_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_updated_running_mean);
}
tile_regs_release();
cb_push_back(cb_updated_running_mean, onetile);
cb_pop_front(cb_tmp3, onetile);
cb_pop_front(cb_tmp2, onetile);
}
if constexpr (old_running_var_has_value) {
// 1 - momentum
cb_reserve_back(cb_tmp1, onetile);
sub_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_one);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_one, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_one, cb_momentum);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_momentum, i, i * 2 + 1);
sub_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_tmp1);
}
tile_regs_release();
cb_push_back(cb_tmp1, onetile);

// momentum * batch stat
cb_wait_front(cb_batch_var, onetile);
cb_reserve_back(cb_tmp2, onetile);
mul_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_momentum, cb_batch_var);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_batch_var, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_batch_var, cb_momentum);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_momentum, i, i * 2 + 1);
mul_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_tmp2);
}
tile_regs_release();
cb_push_back(cb_tmp2, onetile);
cb_pop_front(cb_batch_var, onetile);

// cb_tmp1 * running stats --> (1 - momentum) * running stats
cb_wait_front(cb_tmp1, onetile);
cb_wait_front(cb_old_running_var, onetile);
cb_reserve_back(cb_tmp3, onetile);
mul_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_tmp1, cb_old_running_var);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_old_running_var, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_old_running_var, cb_tmp1);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_tmp1, i, i * 2 + 1);
mul_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_tmp3);
}
tile_regs_release();
cb_push_back(cb_tmp3, onetile);
cb_pop_front(cb_old_running_var, onetile);
cb_pop_front(cb_tmp1, onetile);

// cb_tmp2 + cb_tmp3 --> (momentum * batch stat) + ((1 - momentum) * running stats)
cb_wait_front(cb_tmp2, onetile);
cb_wait_front(cb_tmp3, onetile);

cb_reserve_back(cb_updated_running_var, onetile);

add_binary_tile_init();
tile_regs_acquire();
tile_regs_wait();
copy_tile_to_dst_init_short_with_dt(cb_tmp2, cb_tmp3);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_tmp3, i, i * 2);
}
copy_tile_to_dst_init_short_with_dt(cb_tmp3, cb_tmp2);
for (uint32_t i = 0; i < onetile; ++i) {
copy_tile(cb_tmp2, i, i * 2 + 1);
add_binary_tile(i * 2, i * 2 + 1);
tile_regs_commit();
pack_tile(i * 2, cb_updated_running_var);
}
tile_regs_release();
cb_push_back(cb_updated_running_var, onetile);
cb_pop_front(cb_tmp3, onetile);
cb_pop_front(cb_tmp2, onetile);
}
}
tile_regs_commit();
tile_regs_wait();
pack_tile(0, cb_out0);
tile_regs_release();
cb_pop_front(cb_momentum, 1);
cb_pop_front(cb_one, 1);
cb_push_back(cb_out0, 1);
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,19 @@ void kernel_main() {
union {
float f;
uint32_t u;
} scalar;
scalar.f = 1.0f;
fill_cb_with_value(cb_id_one, scalar.u);
} scalar_one, scalar_momentum;
scalar_one.f = 1.0f;
fill_cb_with_value(cb_id_one, scalar_one.u);

// momentum
scalar_momentum.u = momentum;
cb_reserve_back(cb_id_momentum, onetile);
fill_with_val_bfloat16(cb_id_momentum, momentum);
#ifdef FILL_WITH_VALUE_FLOAT
FILL_WITH_VALUE_FLOAT(cb_id_momentum, scalar_momentum.f);
#endif
#ifdef FILL_WITH_VALUE
FILL_WITH_VALUE(cb_id_momentum, momentum);
#endif
cb_push_back(cb_id_momentum, onetile);

uint32_t num_tiles_read = 0;
Expand Down
Loading

0 comments on commit 7a7022b

Please sign in to comment.