diff --git a/forge/forge/op/eval/forge/tm.py b/forge/forge/op/eval/forge/tm.py index f90efaefb..0e580d90a 100644 --- a/forge/forge/op/eval/forge/tm.py +++ b/forge/forge/op/eval/forge/tm.py @@ -1080,6 +1080,7 @@ def decompose(type, attr, dc, inputs): dc.fuse(result) return + if type == "adv_index": dim = attr[0] in0_shape = inputs[0].shape @@ -1101,6 +1102,7 @@ def decompose(type, attr, dc, inputs): ) dc.fuse(result) return + if type == "pad": if all([x == 0 for x in attr[0:-2]]): # Pad size is 0 diff --git a/forge/test/mlir/operators/indexing/test_scatter_ops.py b/forge/test/mlir/operators/indexing/test_scatter_ops.py index b461e06ad..f564c7a28 100644 --- a/forge/test/mlir/operators/indexing/test_scatter_ops.py +++ b/forge/test/mlir/operators/indexing/test_scatter_ops.py @@ -16,7 +16,47 @@ torch.zeros(10, dtype=torch.float32), # 1D input tensor torch.tensor([True, False, True, False, True, False, True, False, True, False]), # Mask torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]), # Source tensor - id="1d_masked_scatter", + id="test_masked_scatter_1", + marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"), + ), + # Less Number of elements in source + pytest.param( + torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,) + torch.tensor([True, False, True, False, True], dtype=torch.bool), # mask shape = (5,) + torch.tensor([10, 20, 30], dtype=torch.float32), # source shape = (3,) + id="test_masked_scatter_2", + marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"), + ), + # Broadcasting: 1D input and mask with 2D source tensor + pytest.param( + torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,) + torch.tensor([True, False, True, False, True], dtype=torch.bool), # mask shape = (5,) + torch.tensor([[10], [20], [30]], dtype=torch.float32), # source shape = (3, 1) + id="test_masked_scatter_3", + marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"), + ), + # 2D tensors where mask has a different shape from input tensor but can be broadcasted + pytest.param( + torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), # input_tensor shape = (3, 2) + torch.tensor([[True, False], [False, True], [True, True]], dtype=torch.bool), # mask shape = (3, 2) + torch.tensor([10, 20, 30, 40], dtype=torch.float32), # source shape = (4,) + id="test_masked_scatter_4", + marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"), + ), + # Test with a mask of all False (nothing should be replaced) + pytest.param( + torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,) + torch.tensor([False, False, False, False, False], dtype=torch.bool), # mask shape = (5,) + torch.tensor([10, 20, 30], dtype=torch.float32), # source shape = (3,) + id="test_masked_scatter_5", + marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"), + ), + # Test with a mask of all True (everything should be replaced) + pytest.param( + torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,) + torch.tensor([True, True, True, True, True], dtype=torch.bool), # mask shape = (5,) + torch.tensor([10, 20, 30, 40, 50], dtype=torch.float32), # source shape = (5,) + id="test_masked_scatter_6", marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"), ), pytest.param( @@ -30,7 +70,7 @@ ] ), torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), # Source tensor - id="2d_masked_scatter", + id="test_masked_scatter_7", marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"), ), ],