Skip to content

Commit

Permalink
Added remainder op tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Nov 16, 2024
1 parent 4d55886 commit fd3d2f5
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/torch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,47 @@ def forward(self, x):
verify_module(
Basic(), input_shapes=[(256, 256)], compiler_config=cc, do_assert=False
)


@pytest.mark.parametrize(
("input_range", "input_shapes"),
[
((1, 10), [(32, 32), (32, 32)]),
pytest.param(
(1, 10),
[(3, 3), (3, 3)],
marks=pytest.mark.xfail(
reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15131"
),
),
pytest.param(
(1, 100),
[(32, 32), (32, 32)],
marks=pytest.mark.xfail(
reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15130"
),
),
pytest.param(
(-100, 100),
[(32, 32), (32, 32)],
marks=pytest.mark.xfail(
reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15130"
),
),
],
)
def test_remainder_op(input_range, input_shapes):
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.remainder(x, y)

verify_module(
Basic(),
input_shapes=input_shapes,
input_data_types=[torch.float32, torch.float32],
input_range=input_range,
required_atol=1,
)

0 comments on commit fd3d2f5

Please sign in to comment.