From dad98696a28b3c2781f13c53b5c463e58f1df891 Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 22 Mar 2024 15:23:37 +0100 Subject: [PATCH 01/23] Added test, docs/, and updated resolve_rounding_mode function to return new rounding modes. --- docs/qonnx-custom-ops/quant_op.md | 19 ++++++++++++++++++- src/qonnx/custom_op/general/quant.py | 16 +++++++++++++++- tests/custom_op/test_runding_mode.py | 20 ++++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 tests/custom_op/test_runding_mode.py diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md index 02d115fb..953fdca7 100644 --- a/docs/qonnx-custom-ops/quant_op.md +++ b/docs/qonnx-custom-ops/quant_op.md @@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
narrow : int (default is 0)
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
rounding_mode : string (default is "ROUND")
-
Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.
#### Inputs @@ -46,6 +46,23 @@ This operator is not part of the ONNX standard and is not currently versioned. +#### Rounding modes +
+rounding modes +| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN | +|---------------------------- |----------------- |------ |------- |---- |------ |--------- |----------- | +| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 | +| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 | +| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 | +| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 | +| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | +| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 | +| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 | +| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 | +| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 | +| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 | +
+ #### Examples
Quant diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index f552e7a8..15afd048 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -135,12 +135,26 @@ def resolve_rounding_mode(mode_string): """Resolve the rounding mode string of Quant and Trunc ops to the corresponding numpy functions.""" normalized_mode_string = mode_string.upper() - if normalized_mode_string == "ROUND": + if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_TO_EVEN": return np.round elif normalized_mode_string == "CEIL": return np.ceil elif normalized_mode_string == "FLOOR": return np.floor + elif normalized_mode_string == "UP": + def round_up(x): + return np.sign(x) * np.ceil(np.abs(x)) + return round_up + elif normalized_mode_string == "DOWN": + return np.fix + elif normalized_mode_string == "HALF_UP": + def round_half_up(x): + return np.sign(x) * np.floor(np.abs(x) + 0.5) + return round_half_up + elif normalized_mode_string == "HALF_DOWN": + def round_half_down(x): + return np.sign(x) * np.ceil(np.abs(x) - 0.5) + return round_half_down else: raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py new file mode 100644 index 00000000..54a81f0e --- /dev/null +++ b/tests/custom_op/test_runding_mode.py @@ -0,0 +1,20 @@ +import pytest + +import numpy as np + +from qonnx.custom_op.general.quant import resolve_rounding_mode + +@pytest.mark.parametrize("rmode,exp", [ + ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])), + ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, - 5])), + ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])), + ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])), + ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])), + ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])), + ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])) + ] +) +def test_rounding_modes(rmode, exp): + test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]) + rounding_fn = resolve_rounding_mode(rmode) + assert np.array_equal(rounding_fn(test_array), exp) From 47a88e4fb4d0bc69059297bbea39e69650f95d1f Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 22 Mar 2024 15:42:21 +0100 Subject: [PATCH 02/23] Fix table visualization. --- docs/qonnx-custom-ops/quant_op.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md index 953fdca7..b9e11c79 100644 --- a/docs/qonnx-custom-ops/quant_op.md +++ b/docs/qonnx-custom-ops/quant_op.md @@ -49,18 +49,18 @@ This operator is not part of the ONNX standard and is not currently versioned. #### Rounding modes
rounding modes -| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN | -|---------------------------- |----------------- |------ |------- |---- |------ |--------- |----------- | -| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 | -| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 | -| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 | -| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 | -| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | -| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 | -| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 | -| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 | -| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 | -| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 | +| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN | +|----------------------------|-----------------|------|-------|----|------|---------|-----------| +| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 | +| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 | +| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 | +| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 | +| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | +| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 | +| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 | +| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 | +| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 | +| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 |
#### Examples From e2c15045d9ccf5f4c8162c1555c366c337616fee Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 22 Mar 2024 15:43:56 +0100 Subject: [PATCH 03/23] Fix table visualization again. --- docs/qonnx-custom-ops/quant_op.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md index b9e11c79..68029406 100644 --- a/docs/qonnx-custom-ops/quant_op.md +++ b/docs/qonnx-custom-ops/quant_op.md @@ -49,6 +49,7 @@ This operator is not part of the ONNX standard and is not currently versioned. #### Rounding modes
rounding modes + | **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN | |----------------------------|-----------------|------|-------|----|------|---------|-----------| | 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 | From baa0df36b48b69a515604a3f5f4c04a00bd0712f Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 9 Aug 2024 14:06:07 +0200 Subject: [PATCH 04/23] Fixed converter to allow alpha/scale to be a tensor Fixed rounding_mode specifier in convert_quantized_bits --- src/qonnx/converters/qkeras/onnx.py | 14 ++++++++++++-- src/qonnx/converters/qkeras/quantizers.py | 14 +++++++++----- src/qonnx/custom_op/general/quant.py | 2 +- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 1f34d653..3bb1fed9 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -55,7 +55,7 @@ def qlayer_handler(ctx, node, name, args): quantizers = all_quantizers[keras_name] if quantizers.get("kernel_quantizer"): weights = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(weights, quantizers["kernel_quantizer"]) + quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer']) attr = quant_params["attributes"] input_nodes = [node.input[1]] for key in quant_params["inputs"].keys(): @@ -63,9 +63,19 @@ def qlayer_handler(ctx, node, name, args): np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) - ctx.insert_new_node_on_input( + quant_node = ctx.insert_new_node_on_input( node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx" ) + scale_node = ctx.make_const( + name = node.name + "_kernel_scale", + np_val = quant_params['inputs']['scale'].astype(np.float32) + ) + ctx.insert_new_node_on_output( + op_type = "Mul", + output_name = quant_node.output[0], + name = node.name + "_kernel_requantizer", + inputs = [quant_node.output[0], scale_node.name] + ) if quantizers.get("bias_quantizer") and len(node.input) == 3: bias = node.inputs[2].get_tensor_value(as_list=True) diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py index 983cc997..c6a00a00 100644 --- a/src/qonnx/converters/qkeras/quantizers.py +++ b/src/qonnx/converters/qkeras/quantizers.py @@ -1,9 +1,9 @@ import qkeras import six - +import numpy as np def get_quant_params(tensor, qkeras_quantizer): - if isinstance(qkeras_quantizer, str): + if isinstance(qkeras_quantizer, (str, dict)): qkeras_quantizer = qkeras.get_quantizer(qkeras_quantizer) return handler_map[qkeras_quantizer.__class__.__name__](tensor, qkeras_quantizer) @@ -34,11 +34,15 @@ def convert_quantized_bits(tensor, quantizer): signed = int(config["keep_negative"]) narrow = int(config["symmetric"]) qscale = _get_quantizer_scale(tensor, quantizer) - assert qscale == 1, "Non-unity alpha is not yet supported" - scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) + if not isinstance(qscale, np.ndarray): + qscale = np.array(qscale) + scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) zero_point = 0 bit_width = int(config["bits"]) - rounding_mode = "ROUND" + if config['alpha'] == "auto_po2": + rounding_mode = "ROUND_UP" + else: + rounding_mode = "HALF_EVEN" settings = { "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode}, diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 15afd048..5af3f9f3 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -135,7 +135,7 @@ def resolve_rounding_mode(mode_string): """Resolve the rounding mode string of Quant and Trunc ops to the corresponding numpy functions.""" normalized_mode_string = mode_string.upper() - if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_TO_EVEN": + if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN": return np.round elif normalized_mode_string == "CEIL": return np.ceil From de9f73173705cb757daaa80a58044ed8f09a376d Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 9 Aug 2024 15:08:07 +0200 Subject: [PATCH 05/23] Added a check to see if tensor is representable by the quantization parameters. --- src/qonnx/converters/qkeras/onnx.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 3bb1fed9..cadb10c4 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -4,7 +4,7 @@ from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp from .quantizers import get_quant_params - +from qonnx.custom_op.general.quant import quant def get_qkeras_onnx_handlers(all_quantizers): """Returns the handlers for each kind of layer @@ -58,6 +58,23 @@ def qlayer_handler(ctx, node, name, args): quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer']) attr = quant_params["attributes"] input_nodes = [node.input[1]] + qweights = quant(inp_tensor=np.array(weights), + scale=np.array(quant_params['inputs']['scale']), + zeropt=np.array(quant_params['inputs']['zero_point']), + bitwidth=np.array(quant_params['inputs']['bit_width']), + signed=quant_params['attributes']['signed'], + narrow=quant_params['attributes']['narrow'], + rounding_mode=quant_params['attributes']['rounding_mode'] + ) + assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings. + The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; + scale: {np.array(quant_params['inputs']['scale'])}, + zeropt: {np.array(quant_params['inputs']['zero_point'])}, + bitwidth: {np.array(quant_params['inputs']['bit_width'])}, + signed: {quant_params['attributes']['signed']}, + narrow: {quant_params['attributes']['narrow']}, + rounding_mode: {quant_params['attributes']['rounding_mode']} + """ for key in quant_params["inputs"].keys(): name = f"{node.name}_kernel_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) From 72b994a923e7ea0e012de4f069af4b443ca60d83 Mon Sep 17 00:00:00 2001 From: jvreca Date: Mon, 12 Aug 2024 11:17:57 +0200 Subject: [PATCH 06/23] Extra Mul node inserted only when neccessary Commented out assertion on non-representability --- src/qonnx/converters/qkeras/onnx.py | 59 ++++++++++++++++------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index cadb10c4..1444865f 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -58,23 +58,23 @@ def qlayer_handler(ctx, node, name, args): quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer']) attr = quant_params["attributes"] input_nodes = [node.input[1]] - qweights = quant(inp_tensor=np.array(weights), - scale=np.array(quant_params['inputs']['scale']), - zeropt=np.array(quant_params['inputs']['zero_point']), - bitwidth=np.array(quant_params['inputs']['bit_width']), - signed=quant_params['attributes']['signed'], - narrow=quant_params['attributes']['narrow'], - rounding_mode=quant_params['attributes']['rounding_mode'] - ) - assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings. - The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; - scale: {np.array(quant_params['inputs']['scale'])}, - zeropt: {np.array(quant_params['inputs']['zero_point'])}, - bitwidth: {np.array(quant_params['inputs']['bit_width'])}, - signed: {quant_params['attributes']['signed']}, - narrow: {quant_params['attributes']['narrow']}, - rounding_mode: {quant_params['attributes']['rounding_mode']} - """ + #qweights = quant(inp_tensor=np.array(weights), + # scale=np.array(quant_params['inputs']['scale']), + # zeropt=np.array(quant_params['inputs']['zero_point']), + # bitwidth=np.array(quant_params['inputs']['bit_width']), + # signed=quant_params['attributes']['signed'], + # narrow=quant_params['attributes']['narrow'], + # rounding_mode=quant_params['attributes']['rounding_mode'] + # ) + #assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings. + # The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; + # scale: {np.array(quant_params['inputs']['scale'])}, + # zeropt: {np.array(quant_params['inputs']['zero_point'])}, + # bitwidth: {np.array(quant_params['inputs']['bit_width'])}, + # signed: {quant_params['attributes']['signed']}, + # narrow: {quant_params['attributes']['narrow']}, + # rounding_mode: {quant_params['attributes']['rounding_mode']} + # """ for key in quant_params["inputs"].keys(): name = f"{node.name}_kernel_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) @@ -83,16 +83,21 @@ def qlayer_handler(ctx, node, name, args): quant_node = ctx.insert_new_node_on_input( node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx" ) - scale_node = ctx.make_const( - name = node.name + "_kernel_scale", - np_val = quant_params['inputs']['scale'].astype(np.float32) - ) - ctx.insert_new_node_on_output( - op_type = "Mul", - output_name = quant_node.output[0], - name = node.name + "_kernel_requantizer", - inputs = [quant_node.output[0], scale_node.name] - ) + if quantizers["kernel_initializer"]['config']['quantizer']['class_name'] == 'quantized_bits': + bits = quantizers["kernel_initializer"]['config']['quantizer']['config']['bits'] + integer = quantizers["kernel_initializer"]['config']['quantizer']['config']['integer'] + keep_negative = quantizers["kernel_initializer"]['config']['quantizer']['config']['keep_negative'] + if bits == integer + keep_negative: + scale_node = ctx.make_const( + name = node.name + "_kernel_scale", + np_val = quant_params['inputs']['scale'].astype(np.float32) + ) + ctx.insert_new_node_on_output( + op_type = "Mul", + output_name = quant_node.output[0], + name = node.name + "_kernel_requantizer", + inputs = [quant_node.output[0], scale_node.name] + ) if quantizers.get("bias_quantizer") and len(node.input) == 3: bias = node.inputs[2].get_tensor_value(as_list=True) From 75b40ab81d5b544533373ba8b0654911ab3ec4e4 Mon Sep 17 00:00:00 2001 From: jvreca Date: Mon, 12 Aug 2024 15:05:29 +0200 Subject: [PATCH 07/23] Added parameterized test for tensor style alpha. --- tests/keras/test_keras_convert.py | 65 ++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py index 388f39a4..46f445ef 100644 --- a/tests/keras/test_keras_convert.py +++ b/tests/keras/test_keras_convert.py @@ -4,6 +4,9 @@ import onnx import os import tensorflow as tf +tf.config.run_functions_eagerly(True) +tf.keras.utils.set_random_seed(42) +np.random.seed(42) from qkeras import QActivation, QConv2D, QDense, binary, quantized_bits, quantized_relu, ternary from tensorflow.keras.layers import Activation, Conv2D, Dense, Flatten, Input from tensorflow.keras.models import Model @@ -323,7 +326,67 @@ def test_qkeras_qdense_4(quantizers, request): np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-4, atol=1e-4) os.remove(model_path) - +@pytest.mark.parametrize("bits,signed,alpha",[ + (8, True, [1.000, 1.000, 1.000, 1.000]), + (8, False, [1.000, 1.000, 1.000, 1.000]), + (4, True, [1.000, 1.000, 1.000, 1.000]), + (4, False, [1.000, 1.000, 1.000, 1.000]), + (8, True, [0.125, 0.250, 0.500, 1.000]), + (8, False, [0.125, 0.250, 0.500, 1.000]), + (5, True, [0.250, 0.250, 0.125, 0.125]), + (5, False, [0.250, 0.250, 0.125, 0.125]), + (4, True, [0.125, 0.250, 0.500, 1.000]), + (4, False, [0.125, 0.250, 0.500, 1.000]), + (3, True, [0.125, 0.125, 0.250, 0.125]), + (3, False, [0.125, 0.125, 0.250, 0.125]) +]) +def test_qkeras_tensor_alpha(bits, signed, alpha, request): + random_state = np.random.RandomState(seed=42) + max_val = np.array(alpha) * 2**(bits-signed) + min_val = -(max_val + 1) + w1 = random_state.randint(low=min_val, high=max_val, size=(3, 4)) + b1 = np.array([0.0, 0.0, 0.0, 0.0]) + x = x_in = tf.keras.layers.Input(shape=3) + x = QActivation( + quantized_bits(bits=4, integer=3, keep_negative=True) + )(x) + x = QDense( + 4, + kernel_quantizer=quantized_bits( + bits=bits, integer=(bits-signed), keep_negative=signed, alpha=alpha + ), + )(x) + x = QActivation(quantized_relu(bits=3, integer=3))(x) + model = tf.keras.Model(inputs=[x_in], outputs=[x]) + model.compile() + model.layers[2].set_weights([w1, b1]) + onnx_model, _ = from_keras(model) + model_path = f"model_test_qkeras_tensor_alpha_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + onnx_model = ModelWrapper(onnx_model) + + data = np.array( + [ + [[0.0, 0.0, 0.0]], + [[0.0, 1.0, 2.0]], + [[2.0, 1.0, 0.0]], + [[4.0, 4.0, 4.0]], + [[7.0, 7.0, 7.0]], + [[6.0, 0.0, 7.0]], + [[3.0, 3.0, 3.0]], + [[7.0, 0.0, 0.0]], + [[0.0, 7.0, 0.0]], + [[0.0, 0.0, 7.0]], + ] + ).astype(np.float32) + for x in data: + y_qkeras = model.predict(x, verbose=0) + idict = {onnx_model.graph.input[0].name: x} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + assert np.array_equal(y_qkeras, y_qonnx) + os.remove(model_path) + @pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids) def test_qkeras_qconv2d_1(quantizers, request): kq, bq = quantizers From 7b5bf4a750ec0a0b66f094f917aa91aa7e677826 Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 14 Aug 2024 08:16:56 +0200 Subject: [PATCH 08/23] Updated QKeras converter to support auto_po2 --- src/qonnx/converters/qkeras/onnx.py | 61 ++++++++++++----------- src/qonnx/converters/qkeras/qlayers.py | 19 +++++-- src/qonnx/converters/qkeras/quantizers.py | 10 ++-- tests/keras/test_keras_convert.py | 53 +++++++++++++++++++- 4 files changed, 104 insertions(+), 39 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 1444865f..1383bec9 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -53,28 +53,13 @@ def qlayer_handler(ctx, node, name, args): if not keras_name: return # Not found in quantizers, nothing to do quantizers = all_quantizers[keras_name] - if quantizers.get("kernel_quantizer"): + + if quantizers.get("kernel_quantizer_cfg"): weights = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(weights, quantizers["kernel_initializer"]['config']['quantizer']) + quant_params = get_quant_params(weights, quantizers['kernel_quantizer_cfg']) attr = quant_params["attributes"] input_nodes = [node.input[1]] - #qweights = quant(inp_tensor=np.array(weights), - # scale=np.array(quant_params['inputs']['scale']), - # zeropt=np.array(quant_params['inputs']['zero_point']), - # bitwidth=np.array(quant_params['inputs']['bit_width']), - # signed=quant_params['attributes']['signed'], - # narrow=quant_params['attributes']['narrow'], - # rounding_mode=quant_params['attributes']['rounding_mode'] - # ) - #assert np.array_equal(weights, qweights), f"""Weights of tensor {node.name} are not representable with the given quantization settings. - # The original weight tensor is: {np.array(weights)} and the quantized tensor is: {qweights}; - # scale: {np.array(quant_params['inputs']['scale'])}, - # zeropt: {np.array(quant_params['inputs']['zero_point'])}, - # bitwidth: {np.array(quant_params['inputs']['bit_width'])}, - # signed: {quant_params['attributes']['signed']}, - # narrow: {quant_params['attributes']['narrow']}, - # rounding_mode: {quant_params['attributes']['rounding_mode']} - # """ + for key in quant_params["inputs"].keys(): name = f"{node.name}_kernel_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) @@ -83,10 +68,10 @@ def qlayer_handler(ctx, node, name, args): quant_node = ctx.insert_new_node_on_input( node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx" ) - if quantizers["kernel_initializer"]['config']['quantizer']['class_name'] == 'quantized_bits': - bits = quantizers["kernel_initializer"]['config']['quantizer']['config']['bits'] - integer = quantizers["kernel_initializer"]['config']['quantizer']['config']['integer'] - keep_negative = quantizers["kernel_initializer"]['config']['quantizer']['config']['keep_negative'] + if quantizers['kernel_quantizer_cfg']['class_name'] == 'quantized_bits': + bits = quantizers['kernel_quantizer_cfg']['config']['bits'] + integer = quantizers['kernel_quantizer_cfg']['config']['integer'] + keep_negative = quantizers['kernel_quantizer_cfg']['config']['keep_negative'] if bits == integer + keep_negative: scale_node = ctx.make_const( name = node.name + "_kernel_scale", @@ -99,17 +84,32 @@ def qlayer_handler(ctx, node, name, args): inputs = [quant_node.output[0], scale_node.name] ) - if quantizers.get("bias_quantizer") and len(node.input) == 3: - bias = node.inputs[2].get_tensor_value(as_list=True) - quant_params = get_quant_params(bias, quantizers["bias_quantizer"]) + if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3: + bias = node.inputs[-1].get_tensor_value(as_list=True) + quant_params = get_quant_params(bias, quantizers['bias_quantizer_cfg']) attr = quant_params["attributes"] - input_nodes = [node.input[2]] + input_nodes = [node.input[-1]] for key in quant_params["inputs"].keys(): name = f"{node.name}_bias_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx") + if quantizers['bias_quantizer_cfg']['class_name'] == 'quantized_bits': + bits = quantizers['bias_quantizer_cfg']['config']['bits'] + integer = quantizers['bias_quantizer_cfg']['config']['integer'] + keep_negative = quantizers['bias_quantizer_cfg']['config']['keep_negative'] + if bits == integer + keep_negative: + scale_node = ctx.make_const( + name = node.name + "_bias_scale", + np_val = quant_params['inputs']['scale'].astype(np.float32) + ) + ctx.insert_new_node_on_output( + op_type = "Mul", + output_name = quant_node.output[0], + name = node.name + "_bias_requantizer", + inputs = [quant_node.output[0], scale_node.name] + ) if quantizers.get("activation"): dtypes = [ctx.get_dtype(node.output[0])] @@ -141,6 +141,9 @@ def qact_handler(ctx, node, name, args): quantizers = all_quantizers[keras_name] if quantizers.get("activation"): dtypes = [ctx.get_dtype(node.output[0])] + if "auto" in quantizers['activation']: + if not node.graph.get_node_by_output(node.input[0]).is_const(): + raise AttributeError(f'Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}.') quant_params = get_quant_params(None, quantizers["activation"]) attr = quant_params["attributes"] input_nodes = [node.output[0]] @@ -180,9 +183,9 @@ def bias_handler(ctx, node, name, args): return # Not found in quantizers, nothing to do quantizers = all_quantizers[keras_name] - if quantizers.get("bias_quantizer"): + if quantizers.get("bias_quantizer_cfg"): bias = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(bias, quantizers["bias_quantizer"]) + quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"]) attr = quant_params["attributes"] input_nodes = [node.input[1]] for key in quant_params["inputs"].keys(): diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py index fdbca71b..0a4f907d 100644 --- a/src/qonnx/converters/qkeras/qlayers.py +++ b/src/qonnx/converters/qkeras/qlayers.py @@ -101,13 +101,26 @@ def _replace_activation(quant_act): def extract_qlayer(layer): quantizers = layer.get_quantization_config() - + keras_config = layer.get_config() - keras_config.pop("kernel_quantizer", None) - keras_config.pop("bias_quantizer", None) + kernel_quant_cfg = keras_config.pop("kernel_quantizer", None) + bias_quant_cfg = keras_config.pop("bias_quantizer", None) keras_config.pop("kernel_range", None) keras_config.pop("bias_range", None) + + quantizers['kernel_quantizer_cfg'] = kernel_quant_cfg + quantizers['bias_quantizer_cfg'] = bias_quant_cfg + + # For some reason downstream can't handle auto_po2, so we just calculate the scale value now + if kernel_quant_cfg['config']['alpha'] == "auto_po2": + layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2) + quantizers['kernel_quantizer_cfg']['config']['alpha'] = layer.kernel_quantizer_internal.scale.numpy().flatten().tolist() + if bias_quant_cfg['config']['alpha'] == "auto_po2": + layer.bias_quantizer_internal(layer.bias) + quantizers['bias_quantizer_cfg']['config']['alpha'] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist() + quantizers.pop('kernel_quantizer', None) + quantizers.pop('bias_quantizer', None) # Check if activation is quantized if _is_keras_quantizer(keras_config["activation"]): diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py index c6a00a00..e38cf710 100644 --- a/src/qonnx/converters/qkeras/quantizers.py +++ b/src/qonnx/converters/qkeras/quantizers.py @@ -1,6 +1,8 @@ import qkeras import six import numpy as np +import tensorflow as tf + def get_quant_params(tensor, qkeras_quantizer): if isinstance(qkeras_quantizer, (str, dict)): @@ -24,7 +26,6 @@ def _get_scale_from_alpha(tensor, quantizer): def _get_quantizer_scale(tensor, quantizer): # call the quantizer on the tensor to get its scale import numpy as np - quantizer(np.array(tensor).astype(np.float32)) return quantizer.scale @@ -34,15 +35,12 @@ def convert_quantized_bits(tensor, quantizer): signed = int(config["keep_negative"]) narrow = int(config["symmetric"]) qscale = _get_quantizer_scale(tensor, quantizer) - if not isinstance(qscale, np.ndarray): + if not isinstance(qscale, (np.ndarray, tf.Tensor)): qscale = np.array(qscale) scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) zero_point = 0 bit_width = int(config["bits"]) - if config['alpha'] == "auto_po2": - rounding_mode = "ROUND_UP" - else: - rounding_mode = "HALF_EVEN" + rounding_mode = "HALF_EVEN" settings = { "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode}, diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py index 46f445ef..eb3e9b2e 100644 --- a/tests/keras/test_keras_convert.py +++ b/tests/keras/test_keras_convert.py @@ -4,7 +4,6 @@ import onnx import os import tensorflow as tf -tf.config.run_functions_eagerly(True) tf.keras.utils.set_random_seed(42) np.random.seed(42) from qkeras import QActivation, QConv2D, QDense, binary, quantized_bits, quantized_relu, ternary @@ -66,6 +65,58 @@ def test_qkeras_qactivation(quantizer, request): np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-5, atol=1e-5) os.remove(model_path) +@pytest.mark.parametrize("quantizer", [ + quantized_relu(bits=4, integer=4), + quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1), + ]) +def test_qkeras_quantizers_rounding_modes(quantizer, request): + x = x_in = Input((10,), name="input") + x = QActivation(activation=quantizer)(x) + model = Model(inputs=[x_in], outputs=[x]) + model.compile() + + onnx_model, _ = from_keras(model) + model_path = f"model_test_qkeras_quantizers_rounding_modes_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + onnx_model = ModelWrapper(onnx_model) + + x_test = np.array([[5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]]).astype(np.float32) + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + y_qkeras = model.predict(x_test) + assert np.array_equal(y_qkeras, y_qonnx) + os.remove(model_path) + +@pytest.mark.parametrize("bias", [5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]) +def test_qkeras_quantizers_autopo2_rounding_modes(bias, request): + kq = bq = quantized_bits(4, 4, 1, alpha='auto_po2') + # Initialize the kernel & bias to RandonUniform within the range of the quantizers + x = x_in = Input((10), name="input") + x = QDense( + 1, + kernel_quantizer=kq, + bias_quantizer=bq, + kernel_initializer=tf.keras.initializers.Constant(1.0), + bias_initializer=tf.keras.initializers.Constant(bias), + name="dense_0", + )(x) + model = Model(inputs=[x_in], outputs=[x]) + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 10)).astype(dtype=np.float32) + _ = model.predict(x_test, verbose=0) + + onnx_model, _ = from_keras(model) + model_path = f"model_test_qkeras_quantizers_auto_rounding_modes_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + onnx_model = ModelWrapper(onnx_model) + + x_test = np.zeros(shape=(1, 10), dtype=np.float32) + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + y_qkeras = model.predict(x_test, verbose=0) + assert np.array_equal(y_qkeras, y_qonnx) + os.remove(model_path) # pairs of quantizers for kernel and bias kb_quantizers = [ From 9fd5f6a263fd7b8bb2d11cccdb899f61b4479fb4 Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 14 Aug 2024 08:54:12 +0200 Subject: [PATCH 09/23] Added check for none. --- src/qonnx/converters/qkeras/qlayers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py index 0a4f907d..4a62b5d7 100644 --- a/src/qonnx/converters/qkeras/qlayers.py +++ b/src/qonnx/converters/qkeras/qlayers.py @@ -113,10 +113,10 @@ def extract_qlayer(layer): quantizers['bias_quantizer_cfg'] = bias_quant_cfg # For some reason downstream can't handle auto_po2, so we just calculate the scale value now - if kernel_quant_cfg['config']['alpha'] == "auto_po2": + if kernel_quant_cfg is not None and kernel_quant_cfg['config']['alpha'] == "auto_po2": layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2) quantizers['kernel_quantizer_cfg']['config']['alpha'] = layer.kernel_quantizer_internal.scale.numpy().flatten().tolist() - if bias_quant_cfg['config']['alpha'] == "auto_po2": + if bias_quant_cfg is not None and bias_quant_cfg['config']['alpha'] == "auto_po2": layer.bias_quantizer_internal(layer.bias) quantizers['bias_quantizer_cfg']['config']['alpha'] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist() quantizers.pop('kernel_quantizer', None) From d8a66a7db408772f052d3386b74083d2d1c66f51 Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 14 Aug 2024 09:36:25 +0200 Subject: [PATCH 10/23] Fixed pre-commit issues. --- src/qonnx/converters/qkeras/onnx.py | 52 ++++++++--------- src/qonnx/converters/qkeras/qlayers.py | 24 ++++---- src/qonnx/converters/qkeras/quantizers.py | 5 +- src/qonnx/custom_op/general/quant.py | 6 ++ tests/custom_op/test_runding_mode.py | 11 ++-- tests/keras/test_keras_convert.py | 71 +++++++++++++---------- 6 files changed, 95 insertions(+), 74 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 1383bec9..41b022d7 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -4,7 +4,7 @@ from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp from .quantizers import get_quant_params -from qonnx.custom_op.general.quant import quant + def get_qkeras_onnx_handlers(all_quantizers): """Returns the handlers for each kind of layer @@ -53,10 +53,10 @@ def qlayer_handler(ctx, node, name, args): if not keras_name: return # Not found in quantizers, nothing to do quantizers = all_quantizers[keras_name] - + if quantizers.get("kernel_quantizer_cfg"): weights = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(weights, quantizers['kernel_quantizer_cfg']) + quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"]) attr = quant_params["attributes"] input_nodes = [node.input[1]] @@ -68,25 +68,24 @@ def qlayer_handler(ctx, node, name, args): quant_node = ctx.insert_new_node_on_input( node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx" ) - if quantizers['kernel_quantizer_cfg']['class_name'] == 'quantized_bits': - bits = quantizers['kernel_quantizer_cfg']['config']['bits'] - integer = quantizers['kernel_quantizer_cfg']['config']['integer'] - keep_negative = quantizers['kernel_quantizer_cfg']['config']['keep_negative'] + if quantizers["kernel_quantizer_cfg"]["class_name"] == "quantized_bits": + bits = quantizers["kernel_quantizer_cfg"]["config"]["bits"] + integer = quantizers["kernel_quantizer_cfg"]["config"]["integer"] + keep_negative = quantizers["kernel_quantizer_cfg"]["config"]["keep_negative"] if bits == integer + keep_negative: scale_node = ctx.make_const( - name = node.name + "_kernel_scale", - np_val = quant_params['inputs']['scale'].astype(np.float32) + name=node.name + "_kernel_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32) ) ctx.insert_new_node_on_output( - op_type = "Mul", - output_name = quant_node.output[0], - name = node.name + "_kernel_requantizer", - inputs = [quant_node.output[0], scale_node.name] + op_type="Mul", + output_name=quant_node.output[0], + name=node.name + "_kernel_requantizer", + inputs=[quant_node.output[0], scale_node.name], ) if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3: bias = node.inputs[-1].get_tensor_value(as_list=True) - quant_params = get_quant_params(bias, quantizers['bias_quantizer_cfg']) + quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"]) attr = quant_params["attributes"] input_nodes = [node.input[-1]] for key in quant_params["inputs"].keys(): @@ -95,20 +94,19 @@ def qlayer_handler(ctx, node, name, args): ctx.make_const(name, np_val) input_nodes.append(name) ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx") - if quantizers['bias_quantizer_cfg']['class_name'] == 'quantized_bits': - bits = quantizers['bias_quantizer_cfg']['config']['bits'] - integer = quantizers['bias_quantizer_cfg']['config']['integer'] - keep_negative = quantizers['bias_quantizer_cfg']['config']['keep_negative'] + if quantizers["bias_quantizer_cfg"]["class_name"] == "quantized_bits": + bits = quantizers["bias_quantizer_cfg"]["config"]["bits"] + integer = quantizers["bias_quantizer_cfg"]["config"]["integer"] + keep_negative = quantizers["bias_quantizer_cfg"]["config"]["keep_negative"] if bits == integer + keep_negative: scale_node = ctx.make_const( - name = node.name + "_bias_scale", - np_val = quant_params['inputs']['scale'].astype(np.float32) + name=node.name + "_bias_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32) ) ctx.insert_new_node_on_output( - op_type = "Mul", - output_name = quant_node.output[0], - name = node.name + "_bias_requantizer", - inputs = [quant_node.output[0], scale_node.name] + op_type="Mul", + output_name=quant_node.output[0], + name=node.name + "_bias_requantizer", + inputs=[quant_node.output[0], scale_node.name], ) if quantizers.get("activation"): @@ -141,9 +139,11 @@ def qact_handler(ctx, node, name, args): quantizers = all_quantizers[keras_name] if quantizers.get("activation"): dtypes = [ctx.get_dtype(node.output[0])] - if "auto" in quantizers['activation']: + if "auto" in quantizers["activation"]: if not node.graph.get_node_by_output(node.input[0]).is_const(): - raise AttributeError(f'Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}.') + raise AttributeError( + f"Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}." + ) quant_params = get_quant_params(None, quantizers["activation"]) attr = quant_params["attributes"] input_nodes = [node.output[0]] diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py index 4a62b5d7..3bfc7fa7 100644 --- a/src/qonnx/converters/qkeras/qlayers.py +++ b/src/qonnx/converters/qkeras/qlayers.py @@ -101,26 +101,28 @@ def _replace_activation(quant_act): def extract_qlayer(layer): quantizers = layer.get_quantization_config() - + keras_config = layer.get_config() kernel_quant_cfg = keras_config.pop("kernel_quantizer", None) bias_quant_cfg = keras_config.pop("bias_quantizer", None) keras_config.pop("kernel_range", None) keras_config.pop("bias_range", None) - - quantizers['kernel_quantizer_cfg'] = kernel_quant_cfg - quantizers['bias_quantizer_cfg'] = bias_quant_cfg + + quantizers["kernel_quantizer_cfg"] = kernel_quant_cfg + quantizers["bias_quantizer_cfg"] = bias_quant_cfg # For some reason downstream can't handle auto_po2, so we just calculate the scale value now - if kernel_quant_cfg is not None and kernel_quant_cfg['config']['alpha'] == "auto_po2": - layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2) - quantizers['kernel_quantizer_cfg']['config']['alpha'] = layer.kernel_quantizer_internal.scale.numpy().flatten().tolist() - if bias_quant_cfg is not None and bias_quant_cfg['config']['alpha'] == "auto_po2": + if kernel_quant_cfg is not None and kernel_quant_cfg["config"]["alpha"] == "auto_po2": + layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2) + quantizers["kernel_quantizer_cfg"]["config"]["alpha"] = ( + layer.kernel_quantizer_internal.scale.numpy().flatten().tolist() + ) + if bias_quant_cfg is not None and bias_quant_cfg["config"]["alpha"] == "auto_po2": layer.bias_quantizer_internal(layer.bias) - quantizers['bias_quantizer_cfg']['config']['alpha'] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist() - quantizers.pop('kernel_quantizer', None) - quantizers.pop('bias_quantizer', None) + quantizers["bias_quantizer_cfg"]["config"]["alpha"] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist() + quantizers.pop("kernel_quantizer", None) + quantizers.pop("bias_quantizer", None) # Check if activation is quantized if _is_keras_quantizer(keras_config["activation"]): diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py index e38cf710..3d232390 100644 --- a/src/qonnx/converters/qkeras/quantizers.py +++ b/src/qonnx/converters/qkeras/quantizers.py @@ -1,6 +1,6 @@ +import numpy as np import qkeras import six -import numpy as np import tensorflow as tf @@ -26,6 +26,7 @@ def _get_scale_from_alpha(tensor, quantizer): def _get_quantizer_scale(tensor, quantizer): # call the quantizer on the tensor to get its scale import numpy as np + quantizer(np.array(tensor).astype(np.float32)) return quantizer.scale @@ -36,7 +37,7 @@ def convert_quantized_bits(tensor, quantizer): narrow = int(config["symmetric"]) qscale = _get_quantizer_scale(tensor, quantizer) if not isinstance(qscale, (np.ndarray, tf.Tensor)): - qscale = np.array(qscale) + qscale = np.array(qscale) scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) zero_point = 0 bit_width = int(config["bits"]) diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 5af3f9f3..b0b50b9a 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -142,18 +142,24 @@ def resolve_rounding_mode(mode_string): elif normalized_mode_string == "FLOOR": return np.floor elif normalized_mode_string == "UP": + def round_up(x): return np.sign(x) * np.ceil(np.abs(x)) + return round_up elif normalized_mode_string == "DOWN": return np.fix elif normalized_mode_string == "HALF_UP": + def round_half_up(x): return np.sign(x) * np.floor(np.abs(x) + 0.5) + return round_half_up elif normalized_mode_string == "HALF_DOWN": + def round_half_down(x): return np.sign(x) * np.ceil(np.abs(x) - 0.5) + return round_half_down else: raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py index 54a81f0e..eb48d644 100644 --- a/tests/custom_op/test_runding_mode.py +++ b/tests/custom_op/test_runding_mode.py @@ -4,15 +4,18 @@ from qonnx.custom_op.general.quant import resolve_rounding_mode -@pytest.mark.parametrize("rmode,exp", [ + +@pytest.mark.parametrize( + "rmode,exp", + [ ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])), - ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, - 5])), + ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])), ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])), ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])), ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])), ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])), - ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])) - ] + ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])), + ], ) def test_rounding_modes(rmode, exp): test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]) diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py index eb3e9b2e..964a2e3c 100644 --- a/tests/keras/test_keras_convert.py +++ b/tests/keras/test_keras_convert.py @@ -4,8 +4,6 @@ import onnx import os import tensorflow as tf -tf.keras.utils.set_random_seed(42) -np.random.seed(42) from qkeras import QActivation, QConv2D, QDense, binary, quantized_bits, quantized_relu, ternary from tensorflow.keras.layers import Activation, Conv2D, Dense, Flatten, Input from tensorflow.keras.models import Model @@ -15,6 +13,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.infer_shapes import InferShapes +# For reproducibility +tf.keras.utils.set_random_seed(42) +np.random.seed(42) + act_quantizers = [ quantized_bits(8, 4, 0, alpha=1), quantized_bits(8, 4, 1, alpha=1), @@ -65,16 +67,20 @@ def test_qkeras_qactivation(quantizer, request): np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-5, atol=1e-5) os.remove(model_path) -@pytest.mark.parametrize("quantizer", [ - quantized_relu(bits=4, integer=4), - quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1), - ]) -def test_qkeras_quantizers_rounding_modes(quantizer, request): + +@pytest.mark.parametrize( + "quantizer", + [ + quantized_relu(bits=4, integer=4), + quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1), + ], +) +def test_qkeras_quantizers_rounding_modes(quantizer, request): x = x_in = Input((10,), name="input") x = QActivation(activation=quantizer)(x) model = Model(inputs=[x_in], outputs=[x]) model.compile() - + onnx_model, _ = from_keras(model) model_path = f"model_test_qkeras_quantizers_rounding_modes_{request.node.callspec.id}.onnx" onnx.save(onnx_model, model_path) @@ -88,9 +94,10 @@ def test_qkeras_quantizers_rounding_modes(quantizer, request): assert np.array_equal(y_qkeras, y_qonnx) os.remove(model_path) + @pytest.mark.parametrize("bias", [5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]) def test_qkeras_quantizers_autopo2_rounding_modes(bias, request): - kq = bq = quantized_bits(4, 4, 1, alpha='auto_po2') + kq = bq = quantized_bits(4, 4, 1, alpha="auto_po2") # Initialize the kernel & bias to RandonUniform within the range of the quantizers x = x_in = Input((10), name="input") x = QDense( @@ -118,6 +125,7 @@ def test_qkeras_quantizers_autopo2_rounding_modes(bias, request): assert np.array_equal(y_qkeras, y_qonnx) os.remove(model_path) + # pairs of quantizers for kernel and bias kb_quantizers = [ (quantized_bits(8, 4, 0, alpha=1), quantized_bits(8, 4, 0, alpha=1)), @@ -377,35 +385,35 @@ def test_qkeras_qdense_4(quantizers, request): np.testing.assert_allclose(y_qkeras, y_qonnx, rtol=1e-4, atol=1e-4) os.remove(model_path) -@pytest.mark.parametrize("bits,signed,alpha",[ - (8, True, [1.000, 1.000, 1.000, 1.000]), - (8, False, [1.000, 1.000, 1.000, 1.000]), - (4, True, [1.000, 1.000, 1.000, 1.000]), - (4, False, [1.000, 1.000, 1.000, 1.000]), - (8, True, [0.125, 0.250, 0.500, 1.000]), - (8, False, [0.125, 0.250, 0.500, 1.000]), - (5, True, [0.250, 0.250, 0.125, 0.125]), - (5, False, [0.250, 0.250, 0.125, 0.125]), - (4, True, [0.125, 0.250, 0.500, 1.000]), - (4, False, [0.125, 0.250, 0.500, 1.000]), - (3, True, [0.125, 0.125, 0.250, 0.125]), - (3, False, [0.125, 0.125, 0.250, 0.125]) -]) + +@pytest.mark.parametrize( + "bits,signed,alpha", + [ + (8, True, [1.000, 1.000, 1.000, 1.000]), + (8, False, [1.000, 1.000, 1.000, 1.000]), + (4, True, [1.000, 1.000, 1.000, 1.000]), + (4, False, [1.000, 1.000, 1.000, 1.000]), + (8, True, [0.125, 0.250, 0.500, 1.000]), + (8, False, [0.125, 0.250, 0.500, 1.000]), + (5, True, [0.250, 0.250, 0.125, 0.125]), + (5, False, [0.250, 0.250, 0.125, 0.125]), + (4, True, [0.125, 0.250, 0.500, 1.000]), + (4, False, [0.125, 0.250, 0.500, 1.000]), + (3, True, [0.125, 0.125, 0.250, 0.125]), + (3, False, [0.125, 0.125, 0.250, 0.125]), + ], +) def test_qkeras_tensor_alpha(bits, signed, alpha, request): random_state = np.random.RandomState(seed=42) - max_val = np.array(alpha) * 2**(bits-signed) + max_val = np.array(alpha) * 2 ** (bits - signed) min_val = -(max_val + 1) w1 = random_state.randint(low=min_val, high=max_val, size=(3, 4)) b1 = np.array([0.0, 0.0, 0.0, 0.0]) x = x_in = tf.keras.layers.Input(shape=3) - x = QActivation( - quantized_bits(bits=4, integer=3, keep_negative=True) - )(x) + x = QActivation(quantized_bits(bits=4, integer=3, keep_negative=True))(x) x = QDense( 4, - kernel_quantizer=quantized_bits( - bits=bits, integer=(bits-signed), keep_negative=signed, alpha=alpha - ), + kernel_quantizer=quantized_bits(bits=bits, integer=(bits - signed), keep_negative=signed, alpha=alpha), )(x) x = QActivation(quantized_relu(bits=3, integer=3))(x) model = tf.keras.Model(inputs=[x_in], outputs=[x]) @@ -437,7 +445,8 @@ def test_qkeras_tensor_alpha(bits, signed, alpha, request): y_qonnx = odict[onnx_model.graph.output[0].name] assert np.array_equal(y_qkeras, y_qonnx) os.remove(model_path) - + + @pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids) def test_qkeras_qconv2d_1(quantizers, request): kq, bq = quantizers From cd453206cf3ed1c90026924b24ae7fc2af6e785e Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 14 Aug 2024 10:12:38 +0200 Subject: [PATCH 11/23] Added check if tensor is repsentable with quant setting.s --- src/qonnx/converters/qkeras/onnx.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 41b022d7..bbccde8d 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -1,10 +1,15 @@ +import logging import numpy as np from tf2onnx.late_rewriters import channel_order_rewriters from tf2onnx.onnx_opset.math import DirectOp, MatMul from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp +from qonnx.custom_op.general.quant import quant + from .quantizers import get_quant_params +logger = logging.getLogger(__name__) + def get_qkeras_onnx_handlers(all_quantizers): """Returns the handlers for each kind of layer @@ -47,6 +52,23 @@ def _extract_node_name(onnx_node, keras_quantizers): return None +def check_tensor_is_representable(tensor, quant_params, node): + "Gives a Warning iftensor is not representable with the providede quantization settings" + qtensor = quant( + inp_tensor=np.array(tensor), + scale=np.array(quant_params["inputs"]["scale"]), + zeropt=np.array(quant_params["inputs"]["zero_point"]), + bitwidth=np.array(quant_params["inputs"]["bit_width"]), + signed=quant_params["attributes"]["signed"], + narrow=quant_params["attributes"]["narrow"], + rounding_mode=quant_params["attributes"]["rounding_mode"], + ) + if not np.array_equal(tensor, qtensor): + logger.warn( + f"Tensor of node: {node.name} is not representable with the provided quantization settings: {quant_params}" + ) + + def qlayer_handler(ctx, node, name, args): all_quantizers = args[0] keras_name = _extract_node_name(node, all_quantizers) @@ -57,9 +79,9 @@ def qlayer_handler(ctx, node, name, args): if quantizers.get("kernel_quantizer_cfg"): weights = node.inputs[1].get_tensor_value(as_list=True) quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"]) + check_tensor_is_representable(weights, quant_params, node) attr = quant_params["attributes"] input_nodes = [node.input[1]] - for key in quant_params["inputs"].keys(): name = f"{node.name}_kernel_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) @@ -86,6 +108,7 @@ def qlayer_handler(ctx, node, name, args): if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3: bias = node.inputs[-1].get_tensor_value(as_list=True) quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"]) + check_tensor_is_representable(bias, quant_params, node) attr = quant_params["attributes"] input_nodes = [node.input[-1]] for key in quant_params["inputs"].keys(): From 3eddb384c52ab516a3e8ce197be89c73fd5b6f80 Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 22 Aug 2024 17:06:23 +0200 Subject: [PATCH 12/23] Added Identity node to the removal list --- src/qonnx/transformation/remove.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index e745f0f0..4aa56ee2 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -117,5 +117,9 @@ def apply(self, model): remove_node_and_rewire(model, n) graph_modified = True break + elif n.op_type == "Identity": + remove_node_and_rewire(model, n) + graph_modified = True + break model = model.transform(InferShapes()) return (model, graph_modified) From e95042b8da6911d3b941ad9937dd5136f26945a0 Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 22 Aug 2024 17:07:05 +0200 Subject: [PATCH 13/23] Added an input quantization node in the qkeras converter (if applicable) --- src/qonnx/converters/keras.py | 51 +++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/qonnx/converters/keras.py b/src/qonnx/converters/keras.py index 5b9e7e09..9332cff0 100644 --- a/src/qonnx/converters/keras.py +++ b/src/qonnx/converters/keras.py @@ -1,7 +1,10 @@ +import numpy as np import onnx import tensorflow as tf import tf2onnx from collections import OrderedDict +from qkeras.qlayers import QActivation +from qkeras.quantizers import quantized_bits, quantized_relu from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS from qonnx.core.modelwrapper import ModelWrapper @@ -164,6 +167,41 @@ def _convert_quantizers_to_nodes(onnx_model, quantizers_dict): return onnx_model.model +def _add_input_quantizer(onnx_model, quantizer): + "Adds an input quantizer to the onnx_model" + iname = onnx_model.graph.input[0].name + scale_init_name = f"{iname}_init_scale" + zp_init_name = f"{iname}_init_zp" + bw_init_name = f"{iname}_init_bw" + onnx_model.set_initializer(scale_init_name, np.array(quantizer.scale)) + onnx_model.set_initializer(zp_init_name, np.array(0.0)) + onnx_model.set_initializer(bw_init_name, np.array(quantizer.bits)) + if isinstance(quantizer, quantized_bits): + signed = quantizer.keep_negative + narrow = quantizer.symmetric + rounding_mode = "ROUND" + elif isinstance(quantizer, quantized_relu): + signed = False + narrow = False + rounding_mode = "HALF_EVEN" + else: + raise NotImplementedError + quant_node = onnx.helper.make_node( + op_type="Quant", + inputs=[iname, scale_init_name, zp_init_name, bw_init_name], + outputs=[f"{iname}_quantized"], + name=f"{iname}_Quant", + domain="qonnx.custom_op.general", + narrow=narrow, + rounding_mode=rounding_mode, + signed=signed, + ) + for node in onnx_model.graph.node: + if node.input[0] == iname: + node.input[0] = quant_node.output[0] + onnx_model.graph.node.extend([quant_node]) + + def from_keras( model, name="qkeras_to_qonnx_converted", @@ -230,6 +268,19 @@ def from_keras( ) onnx_model = ModelWrapper(model_proto) + + # checks if there is a quantizer at the input and adds it to the proto + # This is'nt handled in the "qkeras_op_handlers" + for submod in model.submodules: + if ( + isinstance(submod, (QActivation, tf.keras.layers.Activation)) + and model.input.name == submod.input.name + and isinstance(submod.submodules[0], (quantized_bits, quantized_relu)) + ): + assert len(submod.submodules) == 1 + _add_input_quantizer(onnx_model, submod.submodules[0]) + break + # Set the first value of input/output shape to 1, currently this is set to unknown, # because it is technically the batch size if not (len(onnx_model.graph.input) == 1 and len(onnx_model.graph.output) == 1): From 3e793f251f7469449b01b85489acc3492a399305 Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 23 Aug 2024 15:20:36 +0200 Subject: [PATCH 14/23] Addded del_initializer to modelwrapper. --- src/qonnx/core/modelwrapper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index db9797dc..696cb5f5 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -339,6 +339,14 @@ def get_initializer(self, tensor_name, return_dtype=False): else: return None + def del_initializer(self, initializer_name): + """Deletes an initializer from the model.""" + graph = self._model_proto.graph + for init in graph.initializer: + if init.name == initializer_name: + graph.initializer.remove(init) + break + def find_producer(self, tensor_name): """Finds and returns the node that produces the tensor with given name.""" for x in self._model_proto.graph.node: From 7e98eb3287de5c36e463084fc5724a449751c064 Mon Sep 17 00:00:00 2001 From: jvreca Date: Mon, 26 Aug 2024 09:16:02 +0200 Subject: [PATCH 15/23] reformated with pre-commit hooks. --- src/qonnx/custom_op/general/quant.py | 6 ++++++ tests/custom_op/test_runding_mode.py | 11 +++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 15afd048..5cdc1294 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -142,18 +142,24 @@ def resolve_rounding_mode(mode_string): elif normalized_mode_string == "FLOOR": return np.floor elif normalized_mode_string == "UP": + def round_up(x): return np.sign(x) * np.ceil(np.abs(x)) + return round_up elif normalized_mode_string == "DOWN": return np.fix elif normalized_mode_string == "HALF_UP": + def round_half_up(x): return np.sign(x) * np.floor(np.abs(x) + 0.5) + return round_half_up elif normalized_mode_string == "HALF_DOWN": + def round_half_down(x): return np.sign(x) * np.ceil(np.abs(x) - 0.5) + return round_half_down else: raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py index 54a81f0e..eb48d644 100644 --- a/tests/custom_op/test_runding_mode.py +++ b/tests/custom_op/test_runding_mode.py @@ -4,15 +4,18 @@ from qonnx.custom_op.general.quant import resolve_rounding_mode -@pytest.mark.parametrize("rmode,exp", [ + +@pytest.mark.parametrize( + "rmode,exp", + [ ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])), - ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, - 5])), + ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])), ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])), ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])), ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])), ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])), - ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])) - ] + ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])), + ], ) def test_rounding_modes(rmode, exp): test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]) From fafc4d7f5496fed4129bc820a2815c99cd0105b0 Mon Sep 17 00:00:00 2001 From: jvreca Date: Mon, 26 Aug 2024 09:18:49 +0200 Subject: [PATCH 16/23] Removed the _TO_ to make it consitant with others --- src/qonnx/custom_op/general/quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 5cdc1294..b0b50b9a 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -135,7 +135,7 @@ def resolve_rounding_mode(mode_string): """Resolve the rounding mode string of Quant and Trunc ops to the corresponding numpy functions.""" normalized_mode_string = mode_string.upper() - if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_TO_EVEN": + if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN": return np.round elif normalized_mode_string == "CEIL": return np.ceil From 76a791d9f6f4f44045c2849a0adcc228587835cd Mon Sep 17 00:00:00 2001 From: jvreca Date: Tue, 27 Aug 2024 15:24:36 +0200 Subject: [PATCH 17/23] Removed quantized relu from input quantizaiton fn, because it is picked up by the normal converter flow. Added an extra InferShape transform --- src/qonnx/converters/keras.py | 20 +++++--------------- src/qonnx/util/cleanup.py | 1 + 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/qonnx/converters/keras.py b/src/qonnx/converters/keras.py index 9332cff0..b3c9f0b3 100644 --- a/src/qonnx/converters/keras.py +++ b/src/qonnx/converters/keras.py @@ -4,7 +4,7 @@ import tf2onnx from collections import OrderedDict from qkeras.qlayers import QActivation -from qkeras.quantizers import quantized_bits, quantized_relu +from qkeras.quantizers import quantized_bits from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS from qonnx.core.modelwrapper import ModelWrapper @@ -176,25 +176,15 @@ def _add_input_quantizer(onnx_model, quantizer): onnx_model.set_initializer(scale_init_name, np.array(quantizer.scale)) onnx_model.set_initializer(zp_init_name, np.array(0.0)) onnx_model.set_initializer(bw_init_name, np.array(quantizer.bits)) - if isinstance(quantizer, quantized_bits): - signed = quantizer.keep_negative - narrow = quantizer.symmetric - rounding_mode = "ROUND" - elif isinstance(quantizer, quantized_relu): - signed = False - narrow = False - rounding_mode = "HALF_EVEN" - else: - raise NotImplementedError quant_node = onnx.helper.make_node( op_type="Quant", inputs=[iname, scale_init_name, zp_init_name, bw_init_name], outputs=[f"{iname}_quantized"], name=f"{iname}_Quant", domain="qonnx.custom_op.general", - narrow=narrow, - rounding_mode=rounding_mode, - signed=signed, + narrow=quantizer.symmetric, + rounding_mode="ROUND", + signed=quantizer.keep_negative, ) for node in onnx_model.graph.node: if node.input[0] == iname: @@ -275,7 +265,7 @@ def from_keras( if ( isinstance(submod, (QActivation, tf.keras.layers.Activation)) and model.input.name == submod.input.name - and isinstance(submod.submodules[0], (quantized_bits, quantized_relu)) + and isinstance(submod.submodules[0], quantized_bits) ): assert len(submod.submodules) == 1 _add_input_quantizer(onnx_model, submod.submodules[0]) diff --git a/src/qonnx/util/cleanup.py b/src/qonnx/util/cleanup.py index 933f729d..aadb0047 100644 --- a/src/qonnx/util/cleanup.py +++ b/src/qonnx/util/cleanup.py @@ -83,6 +83,7 @@ def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_c RemoveStaticGraphInputs(), GiveUniqueNodeNames(), GiveReadableTensorNames(), + InferShapes(), ] for t in cleanup_transformations: model = model.transform(t) From ac83d477cf98854b1de0cbf65d89d9ead9240c9b Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 28 Aug 2024 15:01:38 +0200 Subject: [PATCH 18/23] Simplified qlayer converter. --- src/qonnx/converters/qkeras/onnx.py | 84 +++++++++++------------------ 1 file changed, 32 insertions(+), 52 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index bbccde8d..5cf47d87 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -69,6 +69,36 @@ def check_tensor_is_representable(tensor, quant_params, node): ) +def _add_quant_node_on_input(ctx, node, quantizer_cfg, input_ind): + weights = node.inputs[input_ind].get_tensor_value(as_list=True) + quant_params = get_quant_params(weights, quantizer_cfg) + check_tensor_is_representable(weights, quant_params, node) + attr = quant_params["attributes"] + input_nodes = [node.input[1]] + for key in quant_params["inputs"].keys(): + name = f"{node.name}_{input_ind}_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + quant_node = ctx.insert_new_node_on_input( + node, "Quant", input_nodes, name=node.name + f"_{input_ind}_quantizer", **attr, domain="qonnx" + ) + if quantizer_cfg["class_name"] == "quantized_bits": + bits = quantizer_cfg["config"]["bits"] + integer = quantizer_cfg["config"]["integer"] + keep_negative = quantizer_cfg["config"]["keep_negative"] + if bits == integer + keep_negative: + scale_node = ctx.make_const( + name=node.name + f"_{input_ind}_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32) + ) + ctx.insert_new_node_on_output( + op_type="Mul", + output_name=quant_node.output[0], + name=node.name + f"_{input_ind}_requantizer", + inputs=[quant_node.output[0], scale_node.name], + ) + + def qlayer_handler(ctx, node, name, args): all_quantizers = args[0] keras_name = _extract_node_name(node, all_quantizers) @@ -77,60 +107,10 @@ def qlayer_handler(ctx, node, name, args): quantizers = all_quantizers[keras_name] if quantizers.get("kernel_quantizer_cfg"): - weights = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"]) - check_tensor_is_representable(weights, quant_params, node) - attr = quant_params["attributes"] - input_nodes = [node.input[1]] - for key in quant_params["inputs"].keys(): - name = f"{node.name}_kernel_quantizer_{key}" - np_val = np.asarray(quant_params["inputs"][key]) - ctx.make_const(name, np_val) - input_nodes.append(name) - quant_node = ctx.insert_new_node_on_input( - node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx" - ) - if quantizers["kernel_quantizer_cfg"]["class_name"] == "quantized_bits": - bits = quantizers["kernel_quantizer_cfg"]["config"]["bits"] - integer = quantizers["kernel_quantizer_cfg"]["config"]["integer"] - keep_negative = quantizers["kernel_quantizer_cfg"]["config"]["keep_negative"] - if bits == integer + keep_negative: - scale_node = ctx.make_const( - name=node.name + "_kernel_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32) - ) - ctx.insert_new_node_on_output( - op_type="Mul", - output_name=quant_node.output[0], - name=node.name + "_kernel_requantizer", - inputs=[quant_node.output[0], scale_node.name], - ) + _add_quant_node_on_input(ctx, node, quantizers["kernel_quantizer_cfg"], 1) if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3: - bias = node.inputs[-1].get_tensor_value(as_list=True) - quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"]) - check_tensor_is_representable(bias, quant_params, node) - attr = quant_params["attributes"] - input_nodes = [node.input[-1]] - for key in quant_params["inputs"].keys(): - name = f"{node.name}_bias_quantizer_{key}" - np_val = np.asarray(quant_params["inputs"][key]) - ctx.make_const(name, np_val) - input_nodes.append(name) - ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx") - if quantizers["bias_quantizer_cfg"]["class_name"] == "quantized_bits": - bits = quantizers["bias_quantizer_cfg"]["config"]["bits"] - integer = quantizers["bias_quantizer_cfg"]["config"]["integer"] - keep_negative = quantizers["bias_quantizer_cfg"]["config"]["keep_negative"] - if bits == integer + keep_negative: - scale_node = ctx.make_const( - name=node.name + "_bias_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32) - ) - ctx.insert_new_node_on_output( - op_type="Mul", - output_name=quant_node.output[0], - name=node.name + "_bias_requantizer", - inputs=[quant_node.output[0], scale_node.name], - ) + _add_quant_node_on_input(ctx, node, quantizers["bias_quantizer_cfg"], -1) if quantizers.get("activation"): dtypes = [ctx.get_dtype(node.output[0])] From f6f95b287d8401ef7dabeea70420d80b3a4d73bc Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 28 Aug 2024 16:22:10 +0200 Subject: [PATCH 19/23] Fixed bug in _add_quant_node_on_input func. --- src/qonnx/converters/qkeras/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 5cf47d87..8785268a 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -74,7 +74,7 @@ def _add_quant_node_on_input(ctx, node, quantizer_cfg, input_ind): quant_params = get_quant_params(weights, quantizer_cfg) check_tensor_is_representable(weights, quant_params, node) attr = quant_params["attributes"] - input_nodes = [node.input[1]] + input_nodes = [node.input[input_ind]] for key in quant_params["inputs"].keys(): name = f"{node.name}_{input_ind}_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) From 67d5977576cd219de7635d43ee1cc49133099340 Mon Sep 17 00:00:00 2001 From: jvreca Date: Wed, 28 Aug 2024 16:46:48 +0200 Subject: [PATCH 20/23] Outputing bipolarQuant node for binary inputs. --- src/qonnx/converters/qkeras/onnx.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 8785268a..0a680caf 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -80,9 +80,14 @@ def _add_quant_node_on_input(ctx, node, quantizer_cfg, input_ind): np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) - quant_node = ctx.insert_new_node_on_input( - node, "Quant", input_nodes, name=node.name + f"_{input_ind}_quantizer", **attr, domain="qonnx" - ) + if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: + quant_node = ctx.insert_new_node_on_input( + node, "BipolarQuant", input_nodes[0:2], name=node.name + f"_{input_ind}_quantizer", **dict(), domain="qonnx" + ) + else: + quant_node = ctx.insert_new_node_on_input( + node, "Quant", input_nodes, name=node.name + f"_{input_ind}_quantizer", **attr, domain="qonnx" + ) if quantizer_cfg["class_name"] == "quantized_bits": bits = quantizer_cfg["config"]["bits"] integer = quantizer_cfg["config"]["integer"] From bc215c3c793cd5b32ed405160a77c52da10580a8 Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 29 Aug 2024 10:04:37 +0200 Subject: [PATCH 21/23] Binarized networks now seem to work (with BinaryQuant) Added a pure binarized networka s a test for qkeras converter --- src/qonnx/converters/qkeras/onnx.py | 76 ++++++++++++++++------------- tests/keras/test_keras_convert.py | 46 +++++++++++++++++ 2 files changed, 87 insertions(+), 35 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 0a680caf..ccf0252e 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -82,7 +82,7 @@ def _add_quant_node_on_input(ctx, node, quantizer_cfg, input_ind): input_nodes.append(name) if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: quant_node = ctx.insert_new_node_on_input( - node, "BipolarQuant", input_nodes[0:2], name=node.name + f"_{input_ind}_quantizer", **dict(), domain="qonnx" + node, "BipolarQuant", input_nodes[0:2], name=node.name + f"_{input_ind}_quantizer", domain="qonnx" ) else: quant_node = ctx.insert_new_node_on_input( @@ -127,14 +127,19 @@ def qlayer_handler(ctx, node, name, args): np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) - quant_act_node = ctx.make_node( - "Quant", - input_nodes, - dtypes=dtypes, - name=node.name + "_activation_quantizer", - attr=attr, - domain="qonnx", - ) + if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: + quant_act_node = ctx.make_node( + "BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx" + ) + else: + quant_act_node = ctx.make_node( + "Quant", + input_nodes, + dtypes=dtypes, + name=node.name + "_activation_quantizer", + attr=attr, + domain="qonnx", + ) ctx.insert_node_on_output(quant_act_node, node.output[0]) ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0])) @@ -160,14 +165,19 @@ def qact_handler(ctx, node, name, args): np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) - quant_act_node = ctx.make_node( - "Quant", - input_nodes, - dtypes=dtypes, - name=node.name + "_activation_quantizer", - attr=attr, - domain="qonnx", - ) + if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: + quant_act_node = ctx.make_node( + "BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx" + ) + else: + quant_act_node = ctx.make_node( + "Quant", + input_nodes, + dtypes=dtypes, + name=node.name + "_activation_quantizer", + attr=attr, + domain="qonnx", + ) ctx.insert_node_on_output(quant_act_node, node.output[0]) ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0])) channel_order_rewriters._to_channel_first_handler(ctx, quant_act_node) @@ -192,16 +202,7 @@ def bias_handler(ctx, node, name, args): quantizers = all_quantizers[keras_name] if quantizers.get("bias_quantizer_cfg"): - bias = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"]) - attr = quant_params["attributes"] - input_nodes = [node.input[1]] - for key in quant_params["inputs"].keys(): - name = f"{node.name}_bias_quantizer_{key}" - np_val = np.asarray(quant_params["inputs"][key]) - ctx.make_const(name, np_val) - input_nodes.append(name) - ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx") + _add_quant_node_on_input(ctx, node, quantizers["bias_quantizer_cfg"], 1) if quantizers.get("activation"): # removes node if added earlier @@ -217,14 +218,19 @@ def bias_handler(ctx, node, name, args): np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) - quant_act_node = ctx.make_node( - "Quant", - input_nodes, - dtypes=dtypes, - name=node.input[0] + "_activation_quantizer", - attr=attr, - domain="qonnx", - ) + if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: + quant_act_node = ctx.make_node( + "BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx" + ) + else: + quant_act_node = ctx.make_node( + "Quant", + input_nodes, + dtypes=dtypes, + name=node.input[0] + "_activation_quantizer", + attr=attr, + domain="qonnx", + ) ctx.insert_node_on_output(quant_act_node, node.output[0]) ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0])) diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py index 964a2e3c..ac8828fd 100644 --- a/tests/keras/test_keras_convert.py +++ b/tests/keras/test_keras_convert.py @@ -748,6 +748,52 @@ def test_qkeras_qconv2d_conversion_2(quantizers, request): os.remove(model_path) +def test_qkeras_binarized_model(): + w1 = np.array([[1, -1, -1, 1], [-1, 1, 1, -1], [-1, -1, 1, 1]]) + b1 = np.array([1, 2, 0, 1]) + w2 = np.array([-1, 1, -1, -1]).reshape(4, 1) + b2 = np.array([1]) + + x = x_in = tf.keras.layers.Input(shape=3) + x = QActivation(binary(alpha=1))(x) + x = QDense(4, kernel_quantizer=binary(alpha=1), activation="binary")(x) + x = QDense(1, kernel_quantizer=binary(alpha=1), activation="binary")(x) + model = tf.keras.Model(inputs=[x_in], outputs=[x]) + model.compile() + model.layers[2].set_weights([w1, b1]) + model.layers[3].set_weights([w2, b2]) + data = np.array( + [ + [[-1.0, -1.0, -1.0]], + [[-1.0, -1.0, 1.0]], + [[-1.0, 1.0, -1.0]], + [[-1.0, 1.0, 1.0]], + [[1.0, -1.0, -1.0]], + [[1.0, -1.0, 1.0]], + [[1.0, 1.0, -1.0]], + [[1.0, 1.0, 1.0]], + ] + ).astype(np.float32) + + onnx_model, external_storage = from_keras(model, "test_qkeras_binarized_model", opset=9) + assert external_storage is None + model_path = "model_test_qkeras_binarized_model.onnx" + onnx.save(onnx_model, model_path) + onnx_model = ModelWrapper(model_path) + onnx_model = onnx_model.transform(InferShapes()) + + for x_test in data: + y_qkeras = model.predict(x_test) + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + assert np.array_equal(y_qkeras, y_qonnx) + + for node in onnx_model.graph.node: + assert node.op_type != "Quant", "A Binarized model must have only BipolarQuant quantizers!" + os.remove(model_path) + + # quantized_relu should not be used as a layer activation # def test_qkeras_broken_1(quantizers, request): # kq, bq = (quantized_bits(4, 4, 0, alpha=1), quantized_bits(8, 8, 0, alpha=1)) From 12a5ee8206bc53b9a86a5f12f11161f7b71b309b Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 29 Aug 2024 10:22:51 +0200 Subject: [PATCH 22/23] Further cleaned up converter. --- src/qonnx/converters/qkeras/onnx.py | 109 ++++++++-------------------- 1 file changed, 30 insertions(+), 79 deletions(-) diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index ccf0252e..27c99ae5 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -1,6 +1,5 @@ import logging import numpy as np -from tf2onnx.late_rewriters import channel_order_rewriters from tf2onnx.onnx_opset.math import DirectOp, MatMul from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp @@ -104,6 +103,33 @@ def _add_quant_node_on_input(ctx, node, quantizer_cfg, input_ind): ) +def _add_quant_node_on_output(ctx, node, quantizer_cfg, output_ind): + dtypes = [ctx.get_dtype(node.output[output_ind])] + quant_params = get_quant_params(None, quantizer_cfg) + attr = quant_params["attributes"] + input_nodes = [node.output[output_ind]] + for key in quant_params["inputs"].keys(): + name = f"{node.name}_{output_ind}_quantizer_{key}" + np_val = np.asarray(quant_params["inputs"][key]) + ctx.make_const(name, np_val) + input_nodes.append(name) + if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: + quant_act_node = ctx.make_node( + "BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx" + ) + else: + quant_act_node = ctx.make_node( + "Quant", + input_nodes, + dtypes=dtypes, + name=node.name + "_activation_quantizer", + attr=attr, + domain="qonnx", + ) + ctx.insert_node_on_output(quant_act_node, node.output[output_ind]) + ctx.set_shape(quant_act_node.output[output_ind], ctx.get_shape(node.output[output_ind])) + + def qlayer_handler(ctx, node, name, args): all_quantizers = args[0] keras_name = _extract_node_name(node, all_quantizers) @@ -118,30 +144,7 @@ def qlayer_handler(ctx, node, name, args): _add_quant_node_on_input(ctx, node, quantizers["bias_quantizer_cfg"], -1) if quantizers.get("activation"): - dtypes = [ctx.get_dtype(node.output[0])] - quant_params = get_quant_params(None, quantizers["activation"]) - attr = quant_params["attributes"] - input_nodes = [node.output[0]] - for key in quant_params["inputs"].keys(): - name = f"{node.name}_activation_quantizer_{key}" - np_val = np.asarray(quant_params["inputs"][key]) - ctx.make_const(name, np_val) - input_nodes.append(name) - if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: - quant_act_node = ctx.make_node( - "BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx" - ) - else: - quant_act_node = ctx.make_node( - "Quant", - input_nodes, - dtypes=dtypes, - name=node.name + "_activation_quantizer", - attr=attr, - domain="qonnx", - ) - ctx.insert_node_on_output(quant_act_node, node.output[0]) - ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0])) + _add_quant_node_on_output(ctx, node, quantizers["activation"], 0) def qact_handler(ctx, node, name, args): @@ -151,36 +154,7 @@ def qact_handler(ctx, node, name, args): return # Not found in quantizers, nothing to do quantizers = all_quantizers[keras_name] if quantizers.get("activation"): - dtypes = [ctx.get_dtype(node.output[0])] - if "auto" in quantizers["activation"]: - if not node.graph.get_node_by_output(node.input[0]).is_const(): - raise AttributeError( - f"Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}." - ) - quant_params = get_quant_params(None, quantizers["activation"]) - attr = quant_params["attributes"] - input_nodes = [node.output[0]] - for key in quant_params["inputs"].keys(): - name = f"{node.name}_activation_quantizer_{key}" - np_val = np.asarray(quant_params["inputs"][key]) - ctx.make_const(name, np_val) - input_nodes.append(name) - if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: - quant_act_node = ctx.make_node( - "BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx" - ) - else: - quant_act_node = ctx.make_node( - "Quant", - input_nodes, - dtypes=dtypes, - name=node.name + "_activation_quantizer", - attr=attr, - domain="qonnx", - ) - ctx.insert_node_on_output(quant_act_node, node.output[0]) - ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0])) - channel_order_rewriters._to_channel_first_handler(ctx, quant_act_node) + _add_quant_node_on_output(ctx, node, quantizers["activation"], 0) def conv2d_handler(ctx, node, name, args): @@ -209,30 +183,7 @@ def bias_handler(ctx, node, name, args): remove_node_id = node.input[0] remove_node = ctx.get_node_by_output(remove_node_id) ctx.replace_all_inputs(node.input[0], remove_node.input[0], ops=None) - dtypes = [ctx.get_dtype(node.output[0])] - quant_params = get_quant_params(None, quantizers["activation"]) - attr = quant_params["attributes"] - input_nodes = [node.output[0]] - for key in quant_params["inputs"].keys(): - name = f"{node.name}_activation_quantizer_{key}" - np_val = np.asarray(quant_params["inputs"][key]) - ctx.make_const(name, np_val) - input_nodes.append(name) - if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1: - quant_act_node = ctx.make_node( - "BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx" - ) - else: - quant_act_node = ctx.make_node( - "Quant", - input_nodes, - dtypes=dtypes, - name=node.input[0] + "_activation_quantizer", - attr=attr, - domain="qonnx", - ) - ctx.insert_node_on_output(quant_act_node, node.output[0]) - ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0])) + _add_quant_node_on_output(ctx, node, quantizers["activation"], 0) def relu_handler(ctx, node, name, args): From 15b9479c317320806290df3bbbc4202dbf15df54 Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 29 Aug 2024 12:52:44 +0200 Subject: [PATCH 23/23] Input quantizer logic now working also for qkeras.binary quantizer. --- src/qonnx/converters/keras.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/qonnx/converters/keras.py b/src/qonnx/converters/keras.py index b3c9f0b3..c0dc8c55 100644 --- a/src/qonnx/converters/keras.py +++ b/src/qonnx/converters/keras.py @@ -4,7 +4,7 @@ import tf2onnx from collections import OrderedDict from qkeras.qlayers import QActivation -from qkeras.quantizers import quantized_bits +from qkeras.quantizers import binary, quantized_bits from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS from qonnx.core.modelwrapper import ModelWrapper @@ -176,16 +176,25 @@ def _add_input_quantizer(onnx_model, quantizer): onnx_model.set_initializer(scale_init_name, np.array(quantizer.scale)) onnx_model.set_initializer(zp_init_name, np.array(0.0)) onnx_model.set_initializer(bw_init_name, np.array(quantizer.bits)) - quant_node = onnx.helper.make_node( - op_type="Quant", - inputs=[iname, scale_init_name, zp_init_name, bw_init_name], - outputs=[f"{iname}_quantized"], - name=f"{iname}_Quant", - domain="qonnx.custom_op.general", - narrow=quantizer.symmetric, - rounding_mode="ROUND", - signed=quantizer.keep_negative, - ) + if isinstance(quantizer, quantized_bits): + quant_node = onnx.helper.make_node( + op_type="Quant", + inputs=[iname, scale_init_name, zp_init_name, bw_init_name], + outputs=[f"{iname}_quantized"], + name=f"{iname}_Quant", + domain="qonnx.custom_op.general", + narrow=quantizer.symmetric, + rounding_mode="ROUND", + signed=quantizer.keep_negative, + ) + else: + quant_node = onnx.helper.make_node( + op_type="BipolarQuant", + inputs=[iname, scale_init_name], + outputs=[f"{iname}_quantized"], + name=f"{iname}_Quant", + domain="qonnx.custom_op.general", + ) for node in onnx_model.graph.node: if node.input[0] == iname: node.input[0] = quant_node.output[0] @@ -265,7 +274,7 @@ def from_keras( if ( isinstance(submod, (QActivation, tf.keras.layers.Activation)) and model.input.name == submod.input.name - and isinstance(submod.submodules[0], quantized_bits) + and isinstance(submod.submodules[0], (quantized_bits, binary)) ): assert len(submod.submodules) == 1 _add_input_quantizer(onnx_model, submod.submodules[0])