Skip to content

Commit

Permalink
Cleaned up, working TRT wrapping
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Jul 24, 2024
1 parent a82ce56 commit 818a548
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
8 changes: 4 additions & 4 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def __init__(self, config_file="./configs/infer.yaml", **override):

en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder,
input_names = ['x'], output_names = ['x_out'])
self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, use_cuda_graph=False)
# self.model.image_encoder.encoder.load_engine()
self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper)
self.model.image_encoder.encoder.load_engine()

cls_wrapper = ExportWrapper.wrap(self.model.class_head,
input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding'])
self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, use_cuda_graph=False)
# self.model.class_head.load_engine()
self.model.class_head = TRTWrapper("ClassHead", cls_wrapper)
self.model.class_head.load_engine()

return

Expand Down
40 changes: 21 additions & 19 deletions vista3d/modeling/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,16 @@ def forward(
):
out, out_auto = self.image_embeddings, None
else:
# print(input_images.dtype)
self.image_encoder.encoder.build_and_save(
(input_images,),
dynamo=False,
verbose=False,
fp16=True, tf32=True,
builder_optimization_level=5,
enable_all_tactics=True
)
# Support for TRT wrappping
if hasattr(self.image_encoder.encoder, "build_and_save"):
self.image_encoder.encoder.build_and_save(
(input_images,),
dynamo=False,
verbose=False,
fp16=True, tf32=True,
builder_optimization_level=5,
enable_all_tactics=True
)

time0 = time.time()
out, out_auto = self.image_encoder(
Expand All @@ -325,19 +326,20 @@ def forward(
# force releasing memories that set to None
torch.cuda.empty_cache()
if class_vector is not None:
self.class_head.build_and_save(
(out_auto, class_vector,),
fp16=True, tf32=True,
dynamo=False,
verbose=False,
)
time2 = time.time()
if hasattr(self.class_head, "build_and_save"):
self.class_head.build_and_save(
(out_auto, class_vector,),
fp16=True, tf32=True,
dynamo=False,
verbose=False,
)
# time2 = time.time()
logits, _ = self.class_head(src=out_auto, class_vector=class_vector)
# torch.cuda.synchronize()
# print(f"Class Head Time: {time.time() - time2}")

if point_coords is not None:
time3 = time.time()
# time3 = time.time()
point_logits = self.point_head(
out, point_coords, point_labels, class_vector=prompt_class
)
Expand Down Expand Up @@ -376,8 +378,8 @@ def forward(
mapping_index,
)

torch.cuda.synchronize()
# print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}")
# torch.cuda.synchronize()
# print(f"Total time : {time.time() - time00} shape : {logits.shape}")

if kwargs.get("keep_cache", False) and class_vector is None:
self.image_embeddings = out.detach()
Expand Down

0 comments on commit 818a548

Please sign in to comment.