From 01ed8ab4e28af88f2a43657d1d66d6425a6eed41 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jul 2024 04:12:20 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- data/README.md | 2 +- scripts/export.py | 24 ++- scripts/utils/cast_utils.py | 11 +- scripts/utils/export_utils.py | 106 ++++++--- scripts/utils/trt_utils.py | 367 ++++++++++++++++++-------------- vista3d/modeling/segresnetds.py | 8 +- vista3d/modeling/vista3d.py | 28 +-- 7 files changed, 322 insertions(+), 224 deletions(-) diff --git a/data/README.md b/data/README.md index 3fbdde0..03354c3 100644 --- a/data/README.md +++ b/data/README.md @@ -81,7 +81,7 @@ The output of this step is multiple JSON files, each file corresponds to one dataset. ##### 2. Add label_dict.json and label_mapping.json -Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`. +Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`. ## SupverVoxel Generation 1. Download the segment anything repo and download the ViT-H weights diff --git a/scripts/export.py b/scripts/export.py index 7f1a3c2..a2020e1 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -12,6 +12,7 @@ import logging import os import sys +import time from functools import partial import monai @@ -32,7 +33,6 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point from .utils.trt_utils import ExportWrapper, TRTWrapper -import time rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -131,16 +131,20 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.prev_mask = None self.batch_data = None - en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder, - input_names = ['x'], output_names = ['x_out']) + en_wrapper = ExportWrapper.wrap( + self.model.image_encoder.encoder, input_names=["x"], output_names=["x_out"] + ) self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper) self.model.image_encoder.encoder.load_engine() - cls_wrapper = ExportWrapper.wrap(self.model.class_head, - input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding']) + cls_wrapper = ExportWrapper.wrap( + self.model.class_head, + input_names=["src", "class_vector"], + output_names=["masks", "class_embedding"], + ) self.model.class_head = TRTWrapper("ClassHead", cls_wrapper) self.model.class_head.load_engine() - + return def clear_cache(self): @@ -174,7 +178,7 @@ def infer( used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save time and avoid repeated inference. This is by default disabled. """ - time00=time.time() + time00 = time.time() self.model.eval() if not isinstance(image_file, dict): image_file = {"image": image_file} @@ -277,7 +281,7 @@ def infer( @torch.no_grad() def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): - time00=time.time() + time00 = time.time() self.model.eval() device = f"cuda:{rank}" if not isinstance(image_file, dict): @@ -344,8 +348,8 @@ def batch_infer_everything(self, datalist=str, basedir=str): if __name__ == "__main__": try: - #import torch_onnx - #torch_onnx.patch_torch(error_report=True) + # import torch_onnx + # torch_onnx.patch_torch(error_report=True) print("patch succeeded") except Exception: pass diff --git a/scripts/utils/cast_utils.py b/scripts/utils/cast_utils.py index ff58dde..329033e 100644 --- a/scripts/utils/cast_utils.py +++ b/scripts/utils/cast_utils.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# +# # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, @@ -26,6 +26,7 @@ import torch + def avoid_bfloat16_autocast_context(): """ If the current autocast context is bfloat16, @@ -70,7 +71,9 @@ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) return new_dict elif isinstance(x, tuple): - return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + return tuple( + cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x + ) class CastToFloat(torch.nn.Module): @@ -92,5 +95,7 @@ def __init__(self, mod): def forward(self, *args): from_dtype = args[0].dtype with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + ret = self.mod.forward( + *cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32) + ) return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) diff --git a/scripts/utils/export_utils.py b/scripts/utils/export_utils.py index d09cfce..b7eecbd 100644 --- a/scripts/utils/export_utils.py +++ b/scripts/utils/export_utils.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# +# # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, @@ -22,16 +22,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from contextlib import nullcontext -from enum import Enum from typing import Callable, Dict, Optional, Type -import logging + import torch import torch.nn as nn import torch.nn.functional as F -from .cast_utils import CastToFloat, CastToFloatAll +from .cast_utils import CastToFloat + class LinearWithBiasSkip(nn.Module): def __init__(self, weight, bias, skip_bias_add): @@ -45,7 +43,10 @@ def forward(self, x): return F.linear(x, self.weight), self.bias return F.linear(x, self.weight, self.bias), None -def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01): + +def run_ts_and_compare( + ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01 +): # Verify the model can be read, and is valid ts_out = ts_model(*ts_input_list, **ts_input_dict) @@ -54,16 +55,20 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c expected = output_example[i] if torch.is_tensor(expected): - tout = out.to('cpu') + tout = out.to("cpu") print(f"Checking output {i}, shape: {expected.shape}:\n") this_good = True try: - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + if not torch.allclose( + tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance + ): this_good = False except Exception: # there may ne size mismatch and it may be OK this_good = False if not this_good: - print(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") + print( + f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}" + ) all_good = False return all_good @@ -80,12 +85,19 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): print(f"Checking output {i}, shape: {expected.shape}:\n") this_good = True try: - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): + if not torch.allclose( + tout, + expected.cpu(), + rtol=check_tolerance, + atol=100 * check_tolerance, + ): this_good = False except Exception: # there may ne size mismatch and it may be OK this_good = False if not this_good: - print(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + print( + f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}" + ) all_good = False return all_good @@ -96,7 +108,10 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.contrib.layer_norm.layer_norm import FastLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear + from apex.transformer.tensor_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ) def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: """ @@ -115,7 +130,9 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: else: return None - mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) + mod = nn.LayerNorm( + shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype + ) n_state = n.state_dict() mod.load_state_dict(n_state) return mod @@ -129,7 +146,9 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: Equivalent LayerNorm module """ if not isinstance(n, RowParallelLinear): - raise ValueError("This function can only change the RowParallelLinear module.") + raise ValueError( + "This function can only change the RowParallelLinear module." + ) dev = next(n.parameters()).device mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) @@ -146,8 +165,12 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: Returns: Equivalent Linear module """ - if not (isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)): - raise ValueError("This function can only change the ColumnParallelLinear or RowParallelLinear module.") + if not ( + isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear) + ): + raise ValueError( + "This function can only change the ColumnParallelLinear or RowParallelLinear module." + ) dev = next(n.parameters()).device mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) @@ -165,11 +188,19 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: Equivalent LayerNorm module """ if not isinstance(n, FusedScaleMaskSoftmax): - raise ValueError("This function can only change the FusedScaleMaskSoftmax module.") + raise ValueError( + "This function can only change the FusedScaleMaskSoftmax module." + ) # disable the fusion only mod = FusedScaleMaskSoftmax( - n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale + n.input_in_fp16, + n.input_in_bf16, + n.attn_mask_type, + False, + n.mask_func, + n.softmax_in_fp32, + n.scale, ) return mod @@ -178,18 +209,20 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: "FusedLayerNorm": replace_FusedLayerNorm, "MixedFusedLayerNorm": replace_FusedLayerNorm, "FastLayerNorm": replace_FusedLayerNorm, - "ESM1bLayerNorm" : replace_FusedLayerNorm, + "ESM1bLayerNorm": replace_FusedLayerNorm, "RowParallelLinear": replace_ParallelLinear, "ColumnParallelLinear": replace_ParallelLinear, "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, } -except Exception as e: +except Exception: default_Apex_replacements = {} apex_available = False -def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def simple_replace( + BaseT: Type[nn.Module], DestT: Type[nn.Module] +) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. Args: @@ -218,18 +251,28 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: exportable module """ # including the import here to avoid circular imports - from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax + from nemo.collections.nlp.modules.common.megatron.fused_softmax import ( + MatchedScaleMaskSoftmax, + ) # disabling fusion for the MatchedScaleMaskSoftmax mod = MatchedScaleMaskSoftmax( - n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale + n.input_in_fp16, + n.input_in_bf16, + n.attn_mask_type, + False, + n.mask_func, + n.softmax_in_fp32, + n.scale, ) return mod -def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def wrap_module( + BaseT: Type[nn.Module], DestT: Type[nn.Module] +) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT wrapper. + Generic function generator to replace BaseT module with DestT wrapper. Args: BaseT : module type to replace DestT : destination module type @@ -256,14 +299,15 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): expanded_path = path.split(".") parent_mod = model for sub_path in expanded_path[:-1]: - parent_mod = parent_mod._modules[sub_path] # noqa - parent_mod._modules[expanded_path[-1]] = new_mod # noqa + parent_mod = parent_mod._modules[sub_path] + parent_mod._modules[expanded_path[-1]] = new_mod return model def replace_modules( - model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None + model: nn.Module, + expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None, ) -> nn.Module: """ Top-level function to replace modules in model, specified by class name with a desired replacement. @@ -308,7 +352,7 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: if apex_available: print("Replacing Apex layers ...") replace_modules(model, default_Apex_replacements) - + if do_cast: print("Adding casts around norms...") cast_replacements = { @@ -319,6 +363,6 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), } replace_modules(model, cast_replacements) - + # This one has to be the last replace_modules(model, script_replacements) diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index 6275a0e..515d4a0 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# +# # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, @@ -26,65 +26,69 @@ # limitations under the License. # -from collections import OrderedDict -from typing import List -from copy import copy -import numpy as np import os import pickle -from PIL import Image -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx -from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx -from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -from polygraphy.logger import G_LOGGER as L_ +import threading +from collections import OrderedDict -import random -from scipy import integrate import tensorrt as trt import torch -import traceback - -from io import BytesIO from cuda import cudart -from enum import Enum, auto - -import threading +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx +from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx +from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + TrtRunner, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from polygraphy.logger import G_LOGGER as L_ # TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) # trt.init_libnvinfer_plugins(TRT_LOGGER, '') lock_sm = threading.Lock() + @torch.jit.script def check_m(m): t = torch.isnan(m) return not torch.any(t) + # Map of torch dtype -> numpy dtype trt_to_torch_dtype_dict = { - trt.int32 : torch.int32, + trt.int32: torch.int32, trt.float32: torch.float32, trt.float16: torch.float16, - trt.bfloat16 : torch.float16, - trt.int64 : torch.int64, - trt.int8 : torch.int8, - trt.bool : torch.bool, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, } + def CUASSERT(cuda_ret): err = cuda_ret[0] if err != 0: - raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) if len(cuda_ret) > 1: return cuda_ret[1] return None + class ShapeException(Exception): pass -class Engine(): + +class Engine: def __init__( self, engine_path, @@ -93,24 +97,32 @@ def __init__( self.engine = None self.context = None self.tensors = OrderedDict() - self.cuda_graph_instance = None # cuda graph - - def build(self, onnx_path, - profiles=[], fp16=False, bf16=False, tf32=True, - builder_optimization_level=3, - enable_all_tactics=True, - direct_io=False, - timing_cache=None, - update_output_names=None): + self.cuda_graph_instance = None # cuda graph + + def build( + self, + onnx_path, + profiles=[], + fp16=False, + bf16=False, + tf32=True, + builder_optimization_level=3, + enable_all_tactics=True, + direct_io=False, + timing_cache=None, + update_output_names=None, + ): L_.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") config_kwargs = { - 'builder_optimization_level' : builder_optimization_level, - 'direct_io' : direct_io, + "builder_optimization_level": builder_optimization_level, + "direct_io": direct_io, } if not enable_all_tactics: - config_kwargs['tactic_sources'] = [] + config_kwargs["tactic_sources"] = [] - network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + network = network_from_onnx_path( + onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] + ) if update_output_names: L_.info(f"Updating network outputs to {update_output_names}") network = ModifyNetworkOutputs(network, update_output_names) @@ -118,22 +130,22 @@ def build(self, onnx_path, L_.info("Calling engine_from_network...") engine = engine_from_network( - network, - config=CreateConfig( - fp16=fp16, - bf16=bf16, - tf32=tf32, - profiles=profiles, - load_timing_cache=timing_cache, - **config_kwargs - ), - save_timing_cache=timing_cache + network, + config=CreateConfig( + fp16=fp16, + bf16=bf16, + tf32=tf32, + profiles=profiles, + load_timing_cache=timing_cache, + **config_kwargs, + ), + save_timing_cache=timing_cache, ) self.engine = engine - + def save(self): save_engine(self.engine, path=self.engine_path) - + def load(self): L_.info(f"Loading TensorRT engine: {self.engine_path}") self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) @@ -155,18 +167,20 @@ def activate(self, profile_num=0, reuse_device_memory=None): self.output_names.append(binding) dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) - self.cur_profile = profile_num + self.cur_profile = profile_num # L_.info(self.input_names) # L_.info(self.output_names) - - def allocate_buffers(self, device): + + def allocate_buffers(self, device): # allocate outputs e = self.engine ctx = self.context - + for i, binding in enumerate(self.output_names): - shape=ctx.get_tensor_shape(binding) - t = torch.empty(list(shape), dtype=self.dtypes[i], device=device).contiguous() + shape = ctx.get_tensor_shape(binding) + t = torch.empty( + list(shape), dtype=self.dtypes[i], device=device + ).contiguous() self.tensors[binding] = t ctx.set_tensor_address(binding, t.data_ptr()) @@ -180,41 +194,39 @@ def check_shape(shape, profile): if s < minlist[i] or s > maxlist[i]: good = False return good - - def set_inputs(self, feed_dict, stream): + + def set_inputs(self, feed_dict, stream): e = self.engine ctx = self.context last_profile = self.cur_profile - + def try_set_inputs(): - for binding, t in feed_dict.items(): + for binding, t in feed_dict.items(): if t is not None: t = t.contiguous() shape = t.shape # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) # if not self.check_shape(shape, mincurmax): - # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") + # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") ctx.set_input_shape(binding, shape) ctx.set_tensor_address(binding, t.data_ptr()) while True: try: try_set_inputs() - break; + break except ShapeException: - next_profile = (self.cur_profile+1) % e.num_optimization_profiles + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles if next_profile == last_profile: raise self.cur_profile = next_profile ctx.set_optimization_profile_async(self.cur_profile, stream) # torch.cuda.synchronize() - - left = ctx.infer_shapes() - assert len(left)==0 + left = ctx.infer_shapes() + assert len(left) == 0 - - def infer(self, stream, use_cuda_graph=False): + def infer(self, stream, use_cuda_graph=False): e = self.engine ctx = self.context if use_cuda_graph: @@ -225,19 +237,26 @@ def infer(self, stream, use_cuda_graph=False): # do inference before CUDA graph capture noerror = self.context.execute_async_v3(stream) if not noerror: - raise ValueError(f"ERROR: inference failed.") + raise ValueError("ERROR: inference failed.") # capture cuda graph - CUASSERT(cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal)) + CUASSERT( + cudart.cudaStreamBeginCapture( + stream, + cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal, + ) + ) self.context.execute_async_v3(stream) graph = CUASSERT(cudart.cudaStreamEndCapture(stream)) - self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(graph, 0)) + self.cuda_graph_instance = CUASSERT( + cudart.cudaGraphInstantiate(graph, 0) + ) print("CUDA Graph captured!") else: noerror = self.context.execute_async_v3(stream) CUASSERT(cudart.cudaStreamSynchronize(stream)) if not noerror: - raise ValueError(f"ERROR: inference failed.") - + raise ValueError("ERROR: inference failed.") + return self.tensors @@ -245,10 +264,8 @@ class ExportWrapper(torch.nn.Module): """ An auxiliary class to facilitate ONNX->TRT export of a module """ - def __init__(self, model, - input_names=None, - output_names=None, - precision="fp32"): + + def __init__(self, model, input_names=None, output_names=None, precision="fp32"): super().__init__() self.input_names = input_names self.output_names = output_names @@ -256,13 +273,13 @@ def __init__(self, model, self.model = model self.precision = precision - + def get_export_obj(self): return self.model def sample_profile(self, min_len=None, max_len=None): return None - + def can_handle(self, **args): return True @@ -271,17 +288,19 @@ def wrap(cls, model, **args): wrapper = cls(model, **args) return wrapper - + @torch.jit.script def no_nans(m): t = torch.isnan(m) return not torch.any(t) + class TRTWrapper(torch.nn.Module): """ An auxiliary class to implement running of TRT optimized engines - + """ + def __init__(self, path, exp, use_cuda_graph=False): super().__init__() self.exp_wrapper = None @@ -291,26 +310,29 @@ def __init__(self, path, exp, use_cuda_graph=False): self.jit_model = None self.onnx_runner = None self.path = path - self.use_cuda_graph=use_cuda_graph + self.use_cuda_graph = use_cuda_graph if exp is not None: self.attach(exp) @property def engine_path(self): - return self.path + '.plan' + return self.path + ".plan" + @property def jit_path(self): - return self.path + '.ts' + return self.path + ".ts" + @property def onnx_path(self): - return self.path + '.onnx' + return self.path + ".onnx" + @property def profiles_path(self): - return self.path + '.profiles.pkl' + return self.path + ".profiles.pkl" def has_engine(self): return self.engine is not None - + def has_onnx(self): return os.path.exists(self.onnx_path) @@ -322,12 +344,12 @@ def has_profiles(self): def load_engine(self): try: - engine=Engine(self.engine_path) + engine = Engine(self.engine_path) engine.load() engine.activate() self.engine = engine except Exception as e: - print (f"Exception while loading the engine:\n{e}") + print(f"Exception while loading the engine:\n{e}") pass def load_jit(self): @@ -344,18 +366,18 @@ def load_onnx(self, providers=["CUDAExecutionProvider"]): onnx_runner.activate() self.onnx_runner = onnx_runner except Exception: - pass + pass def load_profiles(self): with open(self.profiles_path, "rb") as fp: profiles = pickle.load(fp) self.profiles = profiles return profiles - + def save_profiles(self): with open(self.profiles_path, "wb") as fp: pickle.dump(self.profiles, fp) - + def attach(self, exp): self.exp_wrapper = exp self.input_names = exp.input_names @@ -367,10 +389,10 @@ def can_handle(self, **args): def inputs_to_dict(self, input_example): trt_inputs = {} for i, inp in enumerate(input_example): - input_name=self.engine.input_names[i] + input_name = self.engine.input_names[i] trt_inputs[input_name] = inp return trt_inputs - + def forward(self, **args): try: if self.engine is not None: @@ -386,7 +408,7 @@ def forward(self, **args): ret = self.onnx_runner.infer(args) ret = list(ret.values()) ret = [r.cuda() for r in ret] - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret except Exception as e: @@ -402,9 +424,9 @@ def forward_trt(self, trt_inputs): stream.wait_stream(torch.cuda.current_stream()) ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) ret = list(ret.values()) - #for r in ret: + # for r in ret: # assert no_nans(r), "NaNs in TRT output!" - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret @@ -414,18 +436,22 @@ def forward_trt_runner(self, trt_inputs): ret = list(ret.values()) ret = [r.cuda() for r in ret] check = [check_m(r) for r in ret] - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret - - def build_engine(self, input_profiles=[], - fp16=False, bf16=False, tf32=False, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True): + def build_engine( + self, + input_profiles=[], + fp16=False, + bf16=False, + tf32=False, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True, + ): profiles = [] - if len(input_profiles) > 0: + if len(input_profiles) > 0: for input_profile in input_profiles: if isinstance(input_profile, Profile): profiles.append(input_profile) @@ -438,61 +464,71 @@ def build_engine(self, input_profiles=[], self.profiles = profiles self.save_profiles() - engine = Engine(self.path+'.plan') - engine.build(self.onnx_path, profiles, - fp16=fp16, - bf16=bf16, - tf32=tf32, - direct_io=direct_io, - builder_optimization_level=builder_optimization_level, - enable_all_tactics=enable_all_tactics - ) + engine = Engine(self.path + ".plan") + engine.build( + self.onnx_path, + profiles, + fp16=fp16, + bf16=bf16, + tf32=tf32, + direct_io=direct_io, + builder_optimization_level=builder_optimization_level, + enable_all_tactics=enable_all_tactics, + ) engine.activate() self.engine = engine - def jit_export(self, input_example, - verbose=False, ): + def jit_export( + self, + input_example, + verbose=False, + ): self.jit_model = torch.jit.trace( self.exp_wrapper, input_example, ).eval() self.jit_model = torch.jit.freeze(self.jit_model) torch.jit.save(self.jit_model, self.jit_path) - - def onnx_export(self, input_example, - dynamo=False, - onnx_registry=None, - dynamic_shapes=None, - verbose=False, - opset_version=18, - ): + + def onnx_export( + self, + input_example, + dynamo=False, + onnx_registry=None, + dynamic_shapes=None, + verbose=False, + opset_version=18, + ): L_.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") model = self.exp_wrapper.get_export_obj() from .export_utils import replace_for_export + replace_for_export(model, do_cast=True) if dynamo: - torch.onnx.export(model, - input_example, - self.onnx_path, - dynamo=dynamo, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_shapes=dynamic_shapes + torch.onnx.export( + model, + input_example, + self.onnx_path, + dynamo=dynamo, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_shapes=dynamic_shapes, ) else: - torch.onnx.export(model, - input_example, - self.onnx_path, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_axes=dynamic_shapes + torch.onnx.export( + model, + input_example, + self.onnx_path, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_axes=dynamic_shapes, ) L_.info("Folding constants...") model_onnx = onnx_from_path(self.onnx_path) @@ -506,29 +542,32 @@ def onnx_export(self, input_example, ) L_.info("Done saving model.") - def build_and_save(self, - input_example, - dynamo=False, - verbose=False, - input_profiles=[], - fp16=False, bf16=False, tf32=True, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True): + def build_and_save( + self, + input_example, + dynamo=False, + verbose=False, + input_profiles=[], + fp16=False, + bf16=False, + tf32=True, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True, + ): return - if not self.has_engine(): + if not self.has_engine(): if not self.has_onnx(): self.onnx_export( - input_example, - dynamo=dynamo, - verbose=verbose, - ) + input_example, + dynamo=dynamo, + verbose=verbose, + ) self.build_engine( - fp16=fp16, tf32=tf32, + fp16=fp16, + tf32=tf32, direct_io=direct_io, builder_optimization_level=5, - enable_all_tactics=enable_all_tactics) + enable_all_tactics=enable_all_tactics, + ) self.engine.save() - - - diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index 6fabe2a..513d8f7 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -12,7 +12,6 @@ from __future__ import annotations from collections.abc import Callable -from typing import Union import numpy as np import torch @@ -497,7 +496,7 @@ def _forward( outputs.reverse() x = x_ - + if with_label: i = 0 for level in self.up_layers_auto: @@ -522,7 +521,10 @@ def _forward( return outputs, outputs_auto def forward( - self, x: torch.Tensor, with_point=True, with_label=True, # **kwargs + self, + x: torch.Tensor, + with_point=True, + with_label=True, # **kwargs ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: return self._forward(x, with_point, with_label) diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index accfbde..3e71efb 100755 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -11,12 +11,13 @@ from __future__ import annotations +import time + import monai import numpy as np import torch import torch.nn as nn from monai.utils import optional_import -import time from scripts.utils.trans_utils import convert_points_to_disc from scripts.utils.trans_utils import get_largest_connected_component_mask as lcc @@ -43,7 +44,7 @@ def __init__(self, image_encoder, class_head, point_head, feature_size): self.auto_freeze = False self.point_freeze = False self.engine = None - + def precompute_embedding(self, input_images): """precompute image embedding, require sliding window inference""" raise NotImplementedError @@ -205,8 +206,6 @@ def set_auto_grad(self, auto_freeze=False, point_freeze=False): param.requires_grad = not point_freeze self.point_freeze = point_freeze - - def forward( self, input_images, @@ -308,11 +307,12 @@ def forward( (input_images,), dynamo=False, verbose=False, - fp16=True, tf32=True, + fp16=True, + tf32=True, builder_optimization_level=5, - enable_all_tactics=True + enable_all_tactics=True, ) - + time0 = time.time() out, out_auto = self.image_encoder( x=input_images, @@ -322,22 +322,26 @@ def forward( # torch.cuda.synchronize() # time1 = time.time() # print(f"Encoder Time: {time.time() - time0}, shape : {input_images.shape}, point: {point_coords is not None}") - input_images = None + input_images = None # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: if hasattr(self.class_head, "build_and_save"): self.class_head.build_and_save( - (out_auto, class_vector,), - fp16=True, tf32=True, + ( + out_auto, + class_vector, + ), + fp16=True, + tf32=True, dynamo=False, verbose=False, ) # time2 = time.time() logits, _ = self.class_head(src=out_auto, class_vector=class_vector) # torch.cuda.synchronize() - # print(f"Class Head Time: {time.time() - time2}") - + # print(f"Class Head Time: {time.time() - time2}") + if point_coords is not None: # time3 = time.time() point_logits = self.point_head(