Skip to content

Commit

Permalink
Add shape existence check in GatherElements shape inference logic (on…
Browse files Browse the repository at this point in the history
…nx#2402)

* Add shape existence check guard in GatherElements shape inference logic

* PR comments

* PR comments
  • Loading branch information
hariharans29 authored and gramalingam committed Oct 22, 2019
1 parent 192ad8c commit 3ea3b0e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
10 changes: 7 additions & 3 deletions onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,10 @@ ONNX_OPERATOR_SET_SCHEMA(
"Constrain indices to integer types")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 1, 0);
// propagate indices' shape to output if it exists
if (hasInputShape(ctx, 1)) {
propagateShapeFromInputToOutput(ctx, 1, 0);
}
}));

static const char* Squeeze_ver11_doc = R"DOC(
Expand Down Expand Up @@ -1370,7 +1373,8 @@ ONNX_OPERATOR_SET_SCHEMA(
}
}

// sort after correcting negative axes values (if any) in the previous step
// sort after correcting negative axes values (if any) in the previous
// step
std::sort(axes.begin(), axes.end());

int j = 0;
Expand Down Expand Up @@ -2704,4 +2708,4 @@ ONNX_OPERATOR_SET_SCHEMA(
return;
}));

} // namespace ONNX_NAMESPACE
} // namespace ONNX_NAMESPACE
16 changes: 16 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2707,6 +2707,22 @@ def test_pad(self): # type: () -> None
initializer=[make_tensor('pads', TensorProto.INT64, (6,), (1, 3, 1, 1, 0, 1,))])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (3, None, 4))]) # type: ignore

def test_gatherelements_basic(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (6,)),
('indices', TensorProto.INT64, (2,))],
[make_node('GatherElements', ['x', 'indices'], ['y'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (2,))])

def test_gatherelements_indices_missing_shape(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (6,)),
('indices', TensorProto.INT64, None)], # type: ignore
[make_node('GatherElements', ['x', 'indices'], ['y'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, None)]) # type: ignore


if __name__ == '__main__':
unittest.main()

0 comments on commit 3ea3b0e

Please sign in to comment.