Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UNet cannot be quantized with assertion error #136

Open
4a16dick opened this issue Feb 19, 2025 · 0 comments
Open

UNet cannot be quantized with assertion error #136

4a16dick opened this issue Feb 19, 2025 · 0 comments
Assignees

Comments

@4a16dick
Copy link

Error received: AssertionError, stack trace:

Traceback (most recent call last):
  File "segmentation_model_quantization.py", line 47, in <module>
    moq.quantize("unet_batchsize2-preprocessed.onnx", config["quantize_mode"], X_numpy.numpy(), config["calibration_method"], output_path="quantized_model.onnx")
  File "/home/ubuntu/anaconda3/envs/new_fyp/lib/python3.12/site-packages/modelopt/onnx/quantization/quantize.py", line 349, in quantize
    onnx_model = qdq_to_dq(onnx_model, verbose=verbose)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/envs/new_fyp/lib/python3.12/site-packages/modelopt/onnx/quantization/qdq_utils.py", line 579, in qdq_to_dq
    scale_init_idx, zp_init_idx = _convert(node)
                                  ^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/envs/new_fyp/lib/python3.12/site-packages/modelopt/onnx/quantization/qdq_utils.py", line 541, in _convert
    assert not np_y_scale.shape or w32.shape[-1] == np_y_scale.shape[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Full code:

import torch
from torch.utils.data import DataLoader, RandomSampler, random_split
import modelopt.torch.quantization as mtq
import modelopt.onnx.quantization as moq
import modelopt.torch.opt as mto
import json, os
from unet import UNet
from carvana_dataset import CarvanaDataset
import numpy as np
from tqdm import tqdm

device = ("cuda" if torch.cuda.is_available()
          else "mps" if torch.backends.mps.is_available()
          else "cpu")

config = {'batch_size': 6, 
'sample_size': 512,
'quantize_mode': 'int8',
'calibration_method': 'max'}

#set up calibration dataloader
DATA_PATH = "../TensorRT_container_data/data"
generator = torch.Generator().manual_seed(42)
train_dataset = CarvanaDataset(DATA_PATH)
train_dataset, val_dataset, test_dataset = random_split(train_dataset, [0.7, 0.2, 0.1], generator=generator)
sampler = RandomSampler(train_dataset, num_samples=config["sample_size"])
train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], sampler=sampler)

quantize_format = "onnx"
if quantize_format == "onnx": #quantize with onnx
   #prepare dataset
   X_numpy = torch.Tensor()
   for _, (X, _) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Preparing calibration data: "):
      X_numpy = torch.cat((X_numpy, X))
   np.save("X.npy", X_numpy.numpy())

   #quantize onnx
   moq.quantize("unet_batchsize2-preprocessed.onnx", config["quantize_mode"], X_numpy.numpy(), config["calibration_method"], output_path="quantized_model.onnx")

Steps to reproduce this issue:

  1. Create and train a U-Net per the instructions on https://github.com/milesial/Pytorch-UNet.
  2. Convert the PyTorch model to onnx with the following script:
from unet import UNet
import torch

model = UNet(in_channels=3, num_classes=1)
model.load_state_dict(torch.load("unet_batchsize2.pth", weights_only=False))
torch.onnx.export(model, torch.zeros(2, 3, 512, 512), "unet_batchsize2.onnx",
                  input_names=["img"], output_names=["mask"], dynamic_axes={
                    "img": {
                      2: "img_height",
                      3: "img_width"
                    }
                  })
  1. Preprocess the converted model with python3 -m onnxruntime.quantization.preprocess --input unet_batchsize2.onnx --output unet_batchsize2-preprocessed.onnx --skip_symbolic_shape img .
  2. Run the full code above. Observe the AssertionError raised.

Environment information:
OS: Ubuntu 22.04 LTS
Python version: 3.12.8
PyTorch version: 2.5.1
PyTorch-TensorRT version: 2.5.0
ModelOpt version: 0.23.0
ONNX version: 1.17.0
onnxruntime-gpu version: 1.20.1
CUDA version: 12.5
CUDA driver version: 555.42.02
GPU: NVIDIA GeForce RTX 3060 Ti

Any help is appreciated, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants