Skip to content

Commit

Permalink
[Test] add fork cases to RemoveIdentityOps test
Browse files Browse the repository at this point in the history
  • Loading branch information
maltanar committed Sep 12, 2024
1 parent 0a4d5c5 commit 2d09341
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions tests/transformation/test_remove_identity_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def insert_identity_op(model, op, as_first_node, approx):
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
@pytest.mark.parametrize("approx", [False, True])
@pytest.mark.parametrize("as_first_node", [False, True])
def test_remove_identity_ops(op, as_first_node, approx):
@pytest.mark.parametrize("fork_before_id", [False, True])
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
# set up onnx model
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [])
Expand Down Expand Up @@ -114,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
idict = {"inp": inp_values}
odict = oxe.execute_onnx(model, idict)
out_before = odict["outp"]
odict_before = oxe.execute_onnx(model, idict)
num_of_nodes_before = len(model.graph.node)

if fork_before_id and not as_first_node:
divout_vi = model.get_tensor_valueinfo("div_out")
model.graph.output.append(divout_vi)
model.graph.value_info.remove(divout_vi)
model = model.transform(RemoveIdentityOps())
num_of_nodes_after = len(model.graph.node)
assert num_of_nodes_before - 1 == num_of_nodes_after

odict = oxe.execute_onnx(model, idict)
out_after = odict["outp"]
assert np.isclose(out_before, out_after, atol=1e-3).all()
odict_after = oxe.execute_onnx(model, idict)
outputs_same = [np.isclose(odict_before[tname], odict_after[tname], atol=1e-3).all() for tname in odict_before.keys()]
assert all(outputs_same)

0 comments on commit 2d09341

Please sign in to comment.