From 818a548eed32a3b9dc8c4aef410359078bd16a71 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 23 Jul 2024 21:07:09 -0700 Subject: [PATCH] Cleaned up, working TRT wrapping Signed-off-by: Boris Fomitchev --- scripts/export.py | 8 ++++---- vista3d/modeling/vista3d.py | 40 +++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/scripts/export.py b/scripts/export.py index 59784d3..7f1a3c2 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -133,13 +133,13 @@ def __init__(self, config_file="./configs/infer.yaml", **override): en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder, input_names = ['x'], output_names = ['x_out']) - self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, use_cuda_graph=False) - # self.model.image_encoder.encoder.load_engine() + self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper) + self.model.image_encoder.encoder.load_engine() cls_wrapper = ExportWrapper.wrap(self.model.class_head, input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding']) - self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, use_cuda_graph=False) - # self.model.class_head.load_engine() + self.model.class_head = TRTWrapper("ClassHead", cls_wrapper) + self.model.class_head.load_engine() return diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index 2cf9663..accfbde 100755 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -302,15 +302,16 @@ def forward( ): out, out_auto = self.image_embeddings, None else: - # print(input_images.dtype) - self.image_encoder.encoder.build_and_save( - (input_images,), - dynamo=False, - verbose=False, - fp16=True, tf32=True, - builder_optimization_level=5, - enable_all_tactics=True - ) + # Support for TRT wrappping + if hasattr(self.image_encoder.encoder, "build_and_save"): + self.image_encoder.encoder.build_and_save( + (input_images,), + dynamo=False, + verbose=False, + fp16=True, tf32=True, + builder_optimization_level=5, + enable_all_tactics=True + ) time0 = time.time() out, out_auto = self.image_encoder( @@ -325,19 +326,20 @@ def forward( # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: - self.class_head.build_and_save( - (out_auto, class_vector,), - fp16=True, tf32=True, - dynamo=False, - verbose=False, - ) - time2 = time.time() + if hasattr(self.class_head, "build_and_save"): + self.class_head.build_and_save( + (out_auto, class_vector,), + fp16=True, tf32=True, + dynamo=False, + verbose=False, + ) + # time2 = time.time() logits, _ = self.class_head(src=out_auto, class_vector=class_vector) # torch.cuda.synchronize() # print(f"Class Head Time: {time.time() - time2}") if point_coords is not None: - time3 = time.time() + # time3 = time.time() point_logits = self.point_head( out, point_coords, point_labels, class_vector=prompt_class ) @@ -376,8 +378,8 @@ def forward( mapping_index, ) - torch.cuda.synchronize() - # print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}") + # torch.cuda.synchronize() + # print(f"Total time : {time.time() - time00} shape : {logits.shape}") if kwargs.get("keep_cache", False) and class_vector is None: self.image_embeddings = out.detach()