Skip to content

Commit

Permalink
#0: Remove pop for scalar in BN
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Feb 24, 2025
1 parent 42adc10 commit 78a8b3f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
12 changes: 3 additions & 9 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3], [1, 2, 3, 4])),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -171,9 +169,7 @@ def test_BN_fp32_full_value(device, channel_size, eps, weight, bias):
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3], [1, 2, 3, 4])),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -248,9 +244,7 @@ def test_batch_norm_fp32(
[
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])),
torch.Size([3, 1, 64, 120]),
torch.Size([3, 2, 64, 120]),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3], [1, 2, 3, 4])),
],
)
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ ALWI void batchnorm_bcast_tiles(
// 1/(sqrt(batch_var + eps))
cb_reserve_back(cb_den, onetile);
cb_wait_front(cb_batch_var, 1);
cb_wait_front(cb_eps, 1);

tile_regs_acquire();
add_tiles_init_with_dt(cb_batch_var, cb_eps);
Expand All @@ -67,7 +66,6 @@ ALWI void batchnorm_bcast_tiles(
tile_regs_release();

cb_pop_front(cb_batch_var, 1);
cb_pop_front(cb_eps, 1);
cb_push_back(cb_den, onetile);

// (input - batch_mean)/(sqrt(batch_var + eps)) = result
Expand Down Expand Up @@ -164,6 +162,9 @@ void MAIN {
sub_tiles_init(cb_other, cb_bcast);
uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq;
uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq;

cb_wait_front(cb_eps, 1);

for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) {
batchnorm_bcast_tiles(
cb_bcast,
Expand Down Expand Up @@ -198,8 +199,5 @@ void MAIN {
weight_has_value,
bias_has_value);
}

constexpr uint32_t onetile = 1;
constexpr int dst0 = 0;
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ ALWI void batchnorm_bcast_tiles(
// 1/(sqrt(batch_var + eps))
cb_reserve_back(cb_den, onetile);
cb_wait_front(cb_batch_var, onetile);
cb_wait_front(cb_eps, onetile);

add_binary_tile_init();
rsqrt_tile_init();
Expand All @@ -86,7 +85,6 @@ ALWI void batchnorm_bcast_tiles(

cb_push_back(cb_den, onetile);
cb_pop_front(cb_batch_var, onetile);
cb_pop_front(cb_eps, onetile);

// (input - batch_mean)/(sqrt(batch_var + eps)) = result
cb_wait_front(cb_den, onetile);
Expand Down Expand Up @@ -202,6 +200,10 @@ void MAIN {

uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq;
uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq;

constexpr uint32_t onetile = 1;
cb_wait_front(cb_eps, onetile);

for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) {
batchnorm_bcast_tiles(
cb_bcast,
Expand Down Expand Up @@ -236,8 +238,5 @@ void MAIN {
weight_has_value,
bias_has_value);
}

constexpr uint32_t onetile = 1;
constexpr int dst0 = 0;
}
} // namespace NAMESPACE

0 comments on commit 78a8b3f

Please sign in to comment.