Skip to content

Commit

Permalink
Fixing torch_trt compile and test case
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Nov 1, 2024
1 parent 8baaa74 commit 214def9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 37 deletions.
10 changes: 2 additions & 8 deletions monai/networks/trt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch

from monai.apps.utils import get_logger
from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes
from monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes
from monai.utils.module import optional_import

polygraphy, polygraphy_imported = optional_import("polygraphy")
Expand Down Expand Up @@ -517,7 +517,6 @@ def _build_and_save(self, model, input_example):
elif self.precision == "bf16":
enabled_precisions.append(torch.bfloat16)
inputs = list(input_example.values())
ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True)

def get_torch_trt_input(input_shape, dynamic_batchsize):
min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
Expand All @@ -527,12 +526,7 @@ def get_torch_trt_input(input_shape, dynamic_batchsize):

tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs]
engine_bytes = torch_tensorrt.convert_method_to_trt_engine(
ir_model,
"forward",
arg_inputs=tt_inputs,
ir="torchscript",
enabled_precisions=enabled_precisions,
**export_args,
model, "forward", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args
)
else:
dbs = self.dynamic_batchsize
Expand Down
31 changes: 2 additions & 29 deletions tests/test_trt_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,33 +68,6 @@ def test_handler(self):
net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda"))
self.assertIsNotNone(net1._trt_compiler.engine)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_unet_value(self, precision):
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(2, 2, 4, 8, 4),
strides=(2, 2, 2, 2),
num_res_units=2,
norm="batch",
).cuda()
with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
model.eval()
input_example = torch.randn(2, 1, 96, 96, 96).cuda()
output_example = model(input_example)
args: dict = {"builder_optimization_level": 1}
trt_compile(
model,
f"{tmpdir}/test_unet_trt_compile",
args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]},
)
self.assertIsNone(model._trt_compiler.engine)
trt_output = model(input_example)
# Check that lazy TRT build succeeded
self.assertIsNotNone(model._trt_compiler.engine)
torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@unittest.skipUnless(has_sam, "Requires SAM installation")
def test_cell_sam_wrapper_value(self, precision):
Expand All @@ -106,7 +79,7 @@ def test_cell_sam_wrapper_value(self, precision):
trt_compile(
model,
f"{tmpdir}/test_cell_sam_wrapper_trt_compile",
args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
args={"precision": precision},
)
self.assertIsNone(model._trt_compiler.engine)
trt_output = model(input_example)
Expand All @@ -124,7 +97,7 @@ def test_vista3d(self, precision):
model = trt_compile(
model,
f"{tmpdir}/test_vista3d_trt_compile",
args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
args={"precision": precision, "dynamic_batchsize": [1, 2, 4]},
submodule=["image_encoder.encoder", "class_head"],
)
self.assertIsNotNone(model.image_encoder.encoder._trt_compiler)
Expand Down

0 comments on commit 214def9

Please sign in to comment.