Skip to content

Commit

Permalink
Update tm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokkumarkannan1 committed Feb 25, 2025
1 parent 6efb2c5 commit 0d6ee4a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
2 changes: 2 additions & 0 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,7 @@ def decompose(type, attr, dc, inputs):

dc.fuse(result)
return

if type == "adv_index":
dim = attr[0]
in0_shape = inputs[0].shape
Expand All @@ -1101,6 +1102,7 @@ def decompose(type, attr, dc, inputs):
)
dc.fuse(result)
return

if type == "pad":
if all([x == 0 for x in attr[0:-2]]):
# Pad size is 0
Expand Down
44 changes: 42 additions & 2 deletions forge/test/mlir/operators/indexing/test_scatter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,47 @@
torch.zeros(10, dtype=torch.float32), # 1D input tensor
torch.tensor([True, False, True, False, True, False, True, False, True, False]), # Mask
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]), # Source tensor
id="1d_masked_scatter",
id="test_masked_scatter_1",
marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"),
),
# Less Number of elements in source
pytest.param(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,)
torch.tensor([True, False, True, False, True], dtype=torch.bool), # mask shape = (5,)
torch.tensor([10, 20, 30], dtype=torch.float32), # source shape = (3,)
id="test_masked_scatter_2",
marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"),
),
# Broadcasting: 1D input and mask with 2D source tensor
pytest.param(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,)
torch.tensor([True, False, True, False, True], dtype=torch.bool), # mask shape = (5,)
torch.tensor([[10], [20], [30]], dtype=torch.float32), # source shape = (3, 1)
id="test_masked_scatter_3",
marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"),
),
# 2D tensors where mask has a different shape from input tensor but can be broadcasted
pytest.param(
torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), # input_tensor shape = (3, 2)
torch.tensor([[True, False], [False, True], [True, True]], dtype=torch.bool), # mask shape = (3, 2)
torch.tensor([10, 20, 30, 40], dtype=torch.float32), # source shape = (4,)
id="test_masked_scatter_4",
marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"),
),
# Test with a mask of all False (nothing should be replaced)
pytest.param(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,)
torch.tensor([False, False, False, False, False], dtype=torch.bool), # mask shape = (5,)
torch.tensor([10, 20, 30], dtype=torch.float32), # source shape = (3,)
id="test_masked_scatter_5",
marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"),
),
# Test with a mask of all True (everything should be replaced)
pytest.param(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32), # input_tensor shape = (5,)
torch.tensor([True, True, True, True, True], dtype=torch.bool), # mask shape = (5,)
torch.tensor([10, 20, 30, 40, 50], dtype=torch.float32), # source shape = (5,)
id="test_masked_scatter_6",
marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"),
),
pytest.param(
Expand All @@ -30,7 +70,7 @@
]
),
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), # Source tensor
id="2d_masked_scatter",
id="test_masked_scatter_7",
marks=pytest.mark.xfail(reason="RuntimeError: users.size() > 0"),
),
],
Expand Down

0 comments on commit 0d6ee4a

Please sign in to comment.