Skip to content

Commit

Permalink
Enable reduction tests
Browse files Browse the repository at this point in the history
* Reduction sum along 3 dims
* Reduction mean along 3 dims
  • Loading branch information
mmanzoorTT committed Feb 21, 2025
1 parent 227a18c commit f8fa59e
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions tests/torch/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,7 @@
([(4, 2, 32, 32)], [3], False, [torch.bfloat16], 0.02),
([(4, 2, 32, 32)], [0, 2], False, [torch.bfloat16], 0.035),
([(4, 2, 32, 32)], [0, 1, 2, 3], True, [torch.bfloat16], 0.13),
pytest.param(
[(4, 2, 32, 32)],
[0, 2, 3],
True,
[torch.bfloat16],
0.35,
marks=pytest.mark.xfail(
reason="Reduce on more than two dimensions is not currently supported by TTNN"
),
),
([(4, 2, 32, 32)], [0, 2, 3], True, [torch.bfloat16], 0.30),
],
)
def test_reduce_sum(input_shape, dim_arg, keep_dim, input_type, atol):
Expand Down Expand Up @@ -202,15 +193,7 @@ def forward(self, x):
([(4, 2, 32, 32)], [0, 2], False, [torch.bfloat16]),
([(4, 2, 32, 32)], [1, 2], True, [torch.bfloat16]),
([(4, 2, 32, 32)], [0, 1, 2, 3], True, [torch.bfloat16]),
pytest.param(
[(4, 2, 32, 32)],
[1, 2, 3],
True,
[torch.bfloat16],
marks=pytest.mark.xfail(
reason="Reduce on more than two dimensions is not currently supported by TTNN"
),
),
([(4, 2, 32, 32)], [1, 2, 3], True, [torch.bfloat16]),
],
)
def test_reduce_mean(input_shape, dim_arg, keep_dim, input_type):
Expand Down

0 comments on commit f8fa59e

Please sign in to comment.