Skip to content

Commit

Permalink
Merge pull request #142 from fastmachinelearning/feature/test_a2q_nets
Browse files Browse the repository at this point in the history
Add easy fetch of accumulator-aware quantized (A2Q) CIFAR-10 models for testing
  • Loading branch information
maltanar authored Sep 9, 2024
2 parents c9ad9e5 + 8694a6d commit 1f8938a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
76 changes: 76 additions & 0 deletions src/qonnx/util/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,76 @@
# utility functions to fetch models and data for
# testing various qonnx transformations

a2q_rn18_preproc_mean = np.asarray([0.491, 0.482, 0.447], dtype=np.float32)
a2q_rn18_preproc_std = np.asarray([0.247, 0.243, 0.262], dtype=np.float32)
a2q_rn18_int_range = (0, 255)
a2q_rn18_iscale = 1 / 255
a2q_rn18_rmin = (a2q_rn18_int_range[0] * a2q_rn18_iscale - a2q_rn18_preproc_mean) / a2q_rn18_preproc_std
a2q_rn18_rmax = (a2q_rn18_int_range[1] * a2q_rn18_iscale - a2q_rn18_preproc_mean) / a2q_rn18_preproc_std
a2q_rn18_scale = (1 / a2q_rn18_preproc_std) * a2q_rn18_iscale
a2q_rn18_bias = -a2q_rn18_preproc_mean * a2q_rn18_preproc_std
a2q_rn18_common = {
"input_shape": (1, 3, 32, 32),
"input_range": (a2q_rn18_rmin, a2q_rn18_rmax),
"int_range": a2q_rn18_int_range,
"scale": a2q_rn18_scale,
"bias": a2q_rn18_bias,
}
a2q_rn18_urlbase = "https://github.com/fastmachinelearning/qonnx_model_zoo/releases/download/a2q-20240905/"

a2q_model_details = {
"rn18_w4a4_a2q_16b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 16-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_16b-d4bfa990.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_15b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 15-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_15b-eeca8ac2.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_14b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 14-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_14b-563cf426.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_13b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 13-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_13b-d3cae293.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_12b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q 12-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_12b-fb3a0f8a.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_plus_16b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 16-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_16b-09e47feb.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_plus_15b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 15-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_15b-10e7bc83.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_plus_14b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 14-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_14b-8db8c78c.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_plus_13b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 13-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_13b-f57b05ce.onnx",
**a2q_rn18_common,
},
"rn18_w4a4_a2q_plus_12b": {
"description": "4-bit ResNet-18 on CIFAR-10, A2Q+ 12-bit accumulators",
"url": a2q_rn18_urlbase + "quant_resnet18_w4a4_a2q_plus_12b-1e2aca29.onnx",
**a2q_rn18_common,
},
}

test_model_details = {
"FINN-CNV_W2A2": {
"description": "2-bit VGG-10-like CNN on CIFAR-10",
Expand Down Expand Up @@ -116,6 +186,7 @@
"input_shape": (1, 3, 224, 224),
"input_range": (0, 1),
},
**a2q_model_details,
}


Expand Down Expand Up @@ -149,6 +220,11 @@ def get_random_input(test_model, seed=42):
rng = np.random.RandomState(seed)
input_shape = test_model_details[test_model]["input_shape"]
(low, high) = test_model_details[test_model]["input_range"]
# some models spec per-channel ranges, be conservative for those
if isinstance(low, np.ndarray):
low = low.max()
if isinstance(high, np.ndarray):
high = high.min()
size = np.prod(np.asarray(input_shape))
input_tensor = rng.uniform(low=low, high=high, size=size)
input_tensor = input_tensor.astype(np.float32)
Expand Down
5 changes: 5 additions & 0 deletions tests/transformation/test_change_batchsize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def test_change_batchsize(test_model):
batch_size = 10
old_ishape = test_details["input_shape"]
imin, imax = test_details["input_range"]
# some models spec per-channel ranges, be conservative for those
if isinstance(imin, np.ndarray):
imin = imin.max()
if isinstance(imax, np.ndarray):
imax = imax.min()
model = download_model(test_model=test_model, do_cleanup=True, return_modelwrapper=True)
iname = model.graph.input[0].name
oname = model.graph.output[0].name
Expand Down

0 comments on commit 1f8938a

Please sign in to comment.