diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index fb9b9dcdd1..c5cc43df58 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -23,7 +23,7 @@ import torch from monai.apps.utils import get_logger -from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes from monai.utils.module import optional_import polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -517,7 +517,6 @@ def _build_and_save(self, model, input_example): elif self.precision == "bf16": enabled_precisions.append(torch.bfloat16) inputs = list(input_example.values()) - ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) def get_torch_trt_input(input_shape, dynamic_batchsize): min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) @@ -527,12 +526,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize): tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] engine_bytes = torch_tensorrt.convert_method_to_trt_engine( - ir_model, - "forward", - arg_inputs=tt_inputs, - ir="torchscript", - enabled_precisions=enabled_precisions, - **export_args, + model, "forward", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args ) else: dbs = self.dynamic_batchsize diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 8007148888..2a15d5e697 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -68,33 +68,6 @@ def test_handler(self): net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) self.assertIsNotNone(net1._trt_compiler.engine) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_unet_value(self, precision): - model = UNet( - spatial_dims=3, - in_channels=1, - out_channels=2, - channels=(2, 2, 4, 8, 4), - strides=(2, 2, 2, 2), - num_res_units=2, - norm="batch", - ).cuda() - with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: - model.eval() - input_example = torch.randn(2, 1, 96, 96, 96).cuda() - output_example = model(input_example) - args: dict = {"builder_optimization_level": 1} - trt_compile( - model, - f"{tmpdir}/test_unet_trt_compile", - args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, - ) - self.assertIsNone(model._trt_compiler.engine) - trt_output = model(input_example) - # Check that lazy TRT build succeeded - self.assertIsNotNone(model._trt_compiler.engine) - torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @unittest.skipUnless(has_sam, "Requires SAM installation") def test_cell_sam_wrapper_value(self, precision): @@ -106,7 +79,7 @@ def test_cell_sam_wrapper_value(self, precision): trt_compile( model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + args={"precision": precision}, ) self.assertIsNone(model._trt_compiler.engine) trt_output = model(input_example) @@ -124,7 +97,7 @@ def test_vista3d(self, precision): model = trt_compile( model, f"{tmpdir}/test_vista3d_trt_compile", - args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + args={"precision": precision, "dynamic_batchsize": [1, 2, 4]}, submodule=["image_encoder.encoder", "class_head"], ) self.assertIsNotNone(model.image_encoder.encoder._trt_compiler)