From 7564aaa8f1ed50ca29cb2ac14cb9f0602e30b8d9 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Tue, 12 Nov 2024 20:08:34 +0000 Subject: [PATCH] Added remainder op tests --- tests/torch/test_basic.py | 44 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/torch/test_basic.py b/tests/torch/test_basic.py index 26c6027b..2d086a4d 100644 --- a/tests/torch/test_basic.py +++ b/tests/torch/test_basic.py @@ -370,3 +370,47 @@ def forward(self, x): verify_module( Basic(), input_shapes=[(256, 256)], compiler_config=cc, do_assert=False ) + + +@pytest.mark.parametrize( + ("input_range", "input_shapes"), + [ + ((1, 10), [(32, 32), (32, 32)]), + pytest.param( + (1, 10), + [(3, 3), (3, 3)], + marks=pytest.mark.skip( + reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15131" + ), + ), + pytest.param( + (1, 100), + [(32, 32), (32, 32)], + marks=pytest.mark.skip( + reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15130" + ), + ), + pytest.param( + (-100, 100), + [(32, 32), (32, 32)], + marks=pytest.mark.skip( + reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15130" + ), + ), + ], +) +def test_remainder_op(input_range, input_shapes): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.remainder(x, y) + + verify_module( + Basic(), + input_shapes=input_shapes, + input_data_types=[torch.float32, torch.float32], + input_range=input_range, + required_atol=1, + )