Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 24, 2024
1 parent 818a548 commit 01ed8ab
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 224 deletions.
2 changes: 1 addition & 1 deletion data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The output of this step is multiple JSON files, each file corresponds
to one dataset.

##### 2. Add label_dict.json and label_mapping.json
Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`.
Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`.

## SupverVoxel Generation
1. Download the segment anything repo and download the ViT-H weights
Expand Down
24 changes: 14 additions & 10 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import sys
import time
from functools import partial

import monai
Expand All @@ -32,7 +33,6 @@
from .train import CONFIG
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point
from .utils.trt_utils import ExportWrapper, TRTWrapper
import time

rearrange, _ = optional_import("einops", name="rearrange")
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
Expand Down Expand Up @@ -131,16 +131,20 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
self.prev_mask = None
self.batch_data = None

en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder,
input_names = ['x'], output_names = ['x_out'])
en_wrapper = ExportWrapper.wrap(
self.model.image_encoder.encoder, input_names=["x"], output_names=["x_out"]
)
self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper)
self.model.image_encoder.encoder.load_engine()

cls_wrapper = ExportWrapper.wrap(self.model.class_head,
input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding'])
cls_wrapper = ExportWrapper.wrap(
self.model.class_head,
input_names=["src", "class_vector"],
output_names=["masks", "class_embedding"],
)
self.model.class_head = TRTWrapper("ClassHead", cls_wrapper)
self.model.class_head.load_engine()

return

def clear_cache(self):
Expand Down Expand Up @@ -174,7 +178,7 @@ def infer(
used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save
time and avoid repeated inference. This is by default disabled.
"""
time00=time.time()
time00 = time.time()
self.model.eval()
if not isinstance(image_file, dict):
image_file = {"image": image_file}
Expand Down Expand Up @@ -277,7 +281,7 @@ def infer(

@torch.no_grad()
def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
time00=time.time()
time00 = time.time()
self.model.eval()
device = f"cuda:{rank}"
if not isinstance(image_file, dict):
Expand Down Expand Up @@ -344,8 +348,8 @@ def batch_infer_everything(self, datalist=str, basedir=str):

if __name__ == "__main__":
try:
#import torch_onnx
#torch_onnx.patch_torch(error_report=True)
# import torch_onnx
# torch_onnx.patch_torch(error_report=True)
print("patch succeeded")
except Exception:
pass
Expand Down
11 changes: 8 additions & 3 deletions scripts/utils/cast_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
Expand All @@ -26,6 +26,7 @@

import torch


def avoid_bfloat16_autocast_context():
"""
If the current autocast context is bfloat16,
Expand Down Expand Up @@ -70,7 +71,9 @@ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
return new_dict
elif isinstance(x, tuple):
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
return tuple(
cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x
)


class CastToFloat(torch.nn.Module):
Expand All @@ -92,5 +95,7 @@ def __init__(self, mod):
def forward(self, *args):
from_dtype = args[0].dtype
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
ret = self.mod.forward(
*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)
)
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
106 changes: 75 additions & 31 deletions scripts/utils/export_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
Expand All @@ -22,16 +22,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Dict, Optional, Type
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from .cast_utils import CastToFloat, CastToFloatAll
from .cast_utils import CastToFloat


class LinearWithBiasSkip(nn.Module):
def __init__(self, weight, bias, skip_bias_add):
Expand All @@ -45,7 +43,10 @@ def forward(self, x):
return F.linear(x, self.weight), self.bias
return F.linear(x, self.weight, self.bias), None

def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01):

def run_ts_and_compare(
ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01
):
# Verify the model can be read, and is valid
ts_out = ts_model(*ts_input_list, **ts_input_dict)

Expand All @@ -54,16 +55,20 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c
expected = output_example[i]

if torch.is_tensor(expected):
tout = out.to('cpu')
tout = out.to("cpu")
print(f"Checking output {i}, shape: {expected.shape}:\n")
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
if not torch.allclose(
tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance
):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
print(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
print(
f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}"
)
all_good = False
return all_good

Expand All @@ -80,12 +85,19 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
print(f"Checking output {i}, shape: {expected.shape}:\n")
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
if not torch.allclose(
tout,
expected.cpu(),
rtol=check_tolerance,
atol=100 * check_tolerance,
):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
print(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}")
print(
f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}"
)
all_good = False
return all_good

Expand All @@ -96,7 +108,10 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from apex.transformer.tensor_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
)

def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
"""
Expand All @@ -115,7 +130,9 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
else:
return None

mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)
mod = nn.LayerNorm(
shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype
)
n_state = n.state_dict()
mod.load_state_dict(n_state)
return mod
Expand All @@ -129,7 +146,9 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
Equivalent LayerNorm module
"""
if not isinstance(n, RowParallelLinear):
raise ValueError("This function can only change the RowParallelLinear module.")
raise ValueError(
"This function can only change the RowParallelLinear module."
)

dev = next(n.parameters()).device
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev)
Expand All @@ -146,8 +165,12 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
Returns:
Equivalent Linear module
"""
if not (isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)):
raise ValueError("This function can only change the ColumnParallelLinear or RowParallelLinear module.")
if not (
isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)
):
raise ValueError(
"This function can only change the ColumnParallelLinear or RowParallelLinear module."
)

dev = next(n.parameters()).device
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)
Expand All @@ -165,11 +188,19 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Equivalent LayerNorm module
"""
if not isinstance(n, FusedScaleMaskSoftmax):
raise ValueError("This function can only change the FusedScaleMaskSoftmax module.")
raise ValueError(
"This function can only change the FusedScaleMaskSoftmax module."
)

# disable the fusion only
mod = FusedScaleMaskSoftmax(
n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
n.input_in_fp16,
n.input_in_bf16,
n.attn_mask_type,
False,
n.mask_func,
n.softmax_in_fp32,
n.scale,
)

return mod
Expand All @@ -178,18 +209,20 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
"FusedLayerNorm": replace_FusedLayerNorm,
"MixedFusedLayerNorm": replace_FusedLayerNorm,
"FastLayerNorm": replace_FusedLayerNorm,
"ESM1bLayerNorm" : replace_FusedLayerNorm,
"ESM1bLayerNorm": replace_FusedLayerNorm,
"RowParallelLinear": replace_ParallelLinear,
"ColumnParallelLinear": replace_ParallelLinear,
"FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax,
}

except Exception as e:
except Exception:
default_Apex_replacements = {}
apex_available = False


def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
def simple_replace(
BaseT: Type[nn.Module], DestT: Type[nn.Module]
) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied.
Args:
Expand Down Expand Up @@ -218,18 +251,28 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
exportable module
"""
# including the import here to avoid circular imports
from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax
from nemo.collections.nlp.modules.common.megatron.fused_softmax import (
MatchedScaleMaskSoftmax,
)

# disabling fusion for the MatchedScaleMaskSoftmax
mod = MatchedScaleMaskSoftmax(
n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
n.input_in_fp16,
n.input_in_bf16,
n.attn_mask_type,
False,
n.mask_func,
n.softmax_in_fp32,
n.scale,
)
return mod


def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
def wrap_module(
BaseT: Type[nn.Module], DestT: Type[nn.Module]
) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
Generic function generator to replace BaseT module with DestT wrapper.
Generic function generator to replace BaseT module with DestT wrapper.
Args:
BaseT : module type to replace
DestT : destination module type
Expand All @@ -256,14 +299,15 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]):
expanded_path = path.split(".")
parent_mod = model
for sub_path in expanded_path[:-1]:
parent_mod = parent_mod._modules[sub_path] # noqa
parent_mod._modules[expanded_path[-1]] = new_mod # noqa
parent_mod = parent_mod._modules[sub_path]
parent_mod._modules[expanded_path[-1]] = new_mod

return model


def replace_modules(
model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None
model: nn.Module,
expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None,
) -> nn.Module:
"""
Top-level function to replace modules in model, specified by class name with a desired replacement.
Expand Down Expand Up @@ -308,7 +352,7 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module:
if apex_available:
print("Replacing Apex layers ...")
replace_modules(model, default_Apex_replacements)

if do_cast:
print("Adding casts around norms...")
cast_replacements = {
Expand All @@ -319,6 +363,6 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module:
"InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat),
}
replace_modules(model, cast_replacements)

# This one has to be the last
replace_modules(model, script_replacements)
Loading

0 comments on commit 01ed8ab

Please sign in to comment.