diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index fed91edc6..1a7f494d7 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -389,3 +389,18 @@ def forward(self, x): assert len(model.graph.input) == 1 model = run_model_test(Model(), (torch.rand((1,)),), keep_initializers_as_inputs=True) assert len(model.graph.input) == (2 if persistent else 1) + + +def test_ppe_map(): + torch.manual_seed(100) + + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv = torch.nn.Conv2d(1, 1, 3) + + def forward(self, x): + y = self.conv(x) + return list(ppe.map(lambda u: u + 1, y))[0] + + run_model_test(Net(), (torch.rand(1, 1, 112, 112),), rtol=1e-03)