Skip to content

Commit

Permalink
#12253: Add test for optional output tensor in BN
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Feb 14, 2025
1 parent 75429dc commit 5be82c8
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,23 @@ def test_batch_norm_qid(input_shapes, device):
torch_result = torch.nn.functional.batch_norm(input=in_data, running_mean=mean_data, running_var=var_data)
comp_BN_Output = compare_results_batch_norm([tt_output], [torch_result])
assert comp_BN_Output


@pytest.mark.parametrize(
"input_shapes",
[
torch.Size([2, 3, 120, 120]),
],
)
def test_batch_norm_output_Default(input_shapes, device):
N, H, W, C = input_shapes
_, tt_output_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device)

ttnn.batch_norm(input_tensor, running_mean=mean_tensor, running_var=var_tensor, queue_id=0, output=tt_output_tensor)
tt_output = ttnn.to_torch(tt_output_tensor)
torch_result = torch.nn.functional.batch_norm(input=in_data, running_mean=mean_data, running_var=var_data)
comp_BN_Output = compare_results_batch_norm([tt_output], [torch_result])
assert comp_BN_Output

0 comments on commit 5be82c8

Please sign in to comment.