forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pytorch_onnx_shape_inference.py
78 lines (64 loc) · 3.33 KB
/
test_pytorch_onnx_shape_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import torch
import copy
import test_pytorch_onnx_onnxruntime
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
from torch.onnx import utils, OperatorExportTypes, TrainingMode
from torch.onnx.utils import _validate_dynamic_axes
from torch.onnx.symbolic_helper import (_set_opset_version, _set_operator_export_type,
_set_onnx_shape_inference, _set_training_mode,
_is_tensor_list, _is_tensor, _is_none)
def verify_inferred_shape(graph):
# Check every node in graph has type properly assigned.
for n in graph.nodes():
for out in n.outputs():
if not _is_tensor_list(out) and not _is_tensor(out) and not _is_none(out):
raise RuntimeError("Output of node is neither type Tensor nor type list of Tensor: ", out)
if _is_tensor(out) and out.type().scalarType() is None:
raise RuntimeError("Output of node does not have type assigned", out)
if _is_tensor(out) and out.type().dim() is None:
raise RuntimeError("Output of node does not have shape assigned", out)
def run_model_test(self, model, batch_size=2, state_dict=None,
input=None, use_gpu=True, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, test_with_inputs=None,
input_names=None, output_names=None,
fixed_batch_size=False):
model.eval()
if input is None:
input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
# In-place operators will update input tensor data as well.
# Thus inputs are replicated before every forward call.
input_copy = copy.deepcopy(input)
output = model(*input_copy)
if isinstance(output, torch.Tensor):
output = (output,)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
_set_onnx_shape_inference(True)
_set_training_mode(False)
if dynamic_axes is None:
dynamic_axes = {}
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
input_copy = copy.deepcopy(input)
graph, _, _ = utils._model_to_graph(model, input_copy,
input_names=input_names,
output_names=output_names,
operator_export_type=OperatorExportTypes.ONNX,
example_outputs=output,
do_constant_folding=do_constant_folding,
training=TrainingMode.EVAL,
use_new_jit_passes=self.use_new_jit_passes,
dynamic_axes=dynamic_axes)
verify_inferred_shape(graph)
if __name__ == '__main__':
TestONNXRuntime.opset_version = 12
test_pytorch_onnx_onnxruntime.run_model_test = run_model_test
unittest.main()