diff --git a/onnx/defs/tensor/defs.cc b/onnx/defs/tensor/defs.cc index 795d28a5d1a..4cde3e0d690 100644 --- a/onnx/defs/tensor/defs.cc +++ b/onnx/defs/tensor/defs.cc @@ -1227,7 +1227,10 @@ ONNX_OPERATOR_SET_SCHEMA( "Constrain indices to integer types") .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); - propagateShapeFromInputToOutput(ctx, 1, 0); + // propagate indices' shape to output if it exists + if (hasInputShape(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 1, 0); + } })); static const char* Squeeze_ver11_doc = R"DOC( @@ -1370,7 +1373,8 @@ ONNX_OPERATOR_SET_SCHEMA( } } - // sort after correcting negative axes values (if any) in the previous step + // sort after correcting negative axes values (if any) in the previous + // step std::sort(axes.begin(), axes.end()); int j = 0; @@ -2704,4 +2708,4 @@ ONNX_OPERATOR_SET_SCHEMA( return; })); -} // namespace ONNX_NAMESPACE \ No newline at end of file +} // namespace ONNX_NAMESPACE diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index 7add1089b1d..43f0b621ad8 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -2707,6 +2707,22 @@ def test_pad(self): # type: () -> None initializer=[make_tensor('pads', TensorProto.INT64, (6,), (1, 3, 1, 1, 0, 1,))]) self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (3, None, 4))]) # type: ignore + def test_gatherelements_basic(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (6,)), + ('indices', TensorProto.INT64, (2,))], + [make_node('GatherElements', ['x', 'indices'], ['y'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (2,))]) + + def test_gatherelements_indices_missing_shape(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (6,)), + ('indices', TensorProto.INT64, None)], # type: ignore + [make_node('GatherElements', ['x', 'indices'], ['y'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, None)]) # type: ignore + if __name__ == '__main__': unittest.main()