This folder contains the Quantization-Aware Training (QAT) workflow for standard networks.
The QAT end-to-end workflow (TF2-to-ONNX) consists of the following steps:
- Model quantization using the
quantize_model
function withNVIDIA
quantization scheme. - QAT model fine-tuning (saves checkpoints).
- Baseline vs QAT models accuracy comparison.
- QAT model conversion to SavedModel format.
- Conversion of SavedModel to ONNX.
- TensorRT engine building via ONNX file and inference.
- Install
tensorflow-quantization
toolkit. - Install additional requirements:
pip install -r requirements.txt
. - (Optional) Install TensorRT for full workflow support (needed for
infer_engine.py
).
Note: For CLI run, please go to the cloned repository's root directory and run export PYTHONPATH=$PWD
, so that the examples
folder is available for import.
We are using the ImageNet 2012 dataset (task 1 - image classification), which requires manual downloads due to terms of access agreements. Please login/sign-up on the ImageNet website and download the "train/validation data". This is needed for the QAT model fine-tuning, and it is also used to evaluate the Baseline and QAT models.
Our workflow supports tfrecord
format, so please follow the following instructions (modified from TensorFlow's instructions) to convert the downloaded .tar
ImageNet files to the required format:
- Set
IMAGENET_HOME=/path/to/imagenet/tar/files
indata/imagenet_data_setup.sh
. - Download
imagenet_to_gcs.py
to$IMAGENET_HOME
. - Run
./data/imagenet_data_setup.sh
.
Model quantization, fine-tuning, and conversion to ONNX.
Example models:
Model | Task | Script - QAT Workflow |
---|---|---|
ResNet | Classification | resnet |
EfficientNet | Classification | efficientnet |
MobileNet | Classification | mobilenet |
Inception | Classification | inception |
For each model's performance results, please refer to the toolkit's User Guide ("Model Zoo").
Build the TensorRT engine and evaluate its latency and accuracy performances.
Convert the ONNX model into a TensorRT engine (also obtains latency measurements):
trtexec --onnx=model_qat.onnx --int8 --saveEngine=model_qat.engine --verbose
Arguments:
--onnx
: Path to QAT onnx graph.--saveEngine
: Output filename of TensorRT engine.--verbose
: Flag to enable verbose logging.
Obtain accuracy results on the validation dataset:
python infer_engine.py --engine=<path_to_trt_engine> --data_dir=<path_to_tfrecord_val_data> -b=<batch_size>
Arguments:
-e, --engine
: TensorRT engine filename (to load).-m, --model_name
: Name of the model, needed to choose the appropriate input pre-processing. Options={resnet_v1
(default),resnet_v2
,efficientnet_b0
,efficientnet_b3
,mobilenet_v1
,mobilenet_v2
}.-d, --data_dir
: Path to directory of input images in tfrecord format (data["validation"]
).-k, --top_k_value
(default=1): Value ofK
for the top-K predictions used in the accuracy calculation.-b, --batch_size
(default=1): Number of inputs to send in parallel (up to max batch size of engine).--log_file
: Filename to save logs.
Outputs:
.log
file: contains the engine's performance accuracy.
The following resources provide a deeper understanding about Quantization aware training, TF2ONNX and importing a model into TensorRT using Python.
Quantization Aware Training
- Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
- Quantization Aware Training guide
- Deep Residual Learning for Image Recognition
Parsers
Documentation