diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md index 02d115fb..68029406 100644 --- a/docs/qonnx-custom-ops/quant_op.md +++ b/docs/qonnx-custom-ops/quant_op.md @@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
narrow : int (default is 0)
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
rounding_mode : string (default is "ROUND")
-
Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.
#### Inputs @@ -46,6 +46,24 @@ This operator is not part of the ONNX standard and is not currently versioned. +#### Rounding modes +
+rounding modes + +| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN | +|----------------------------|-----------------|------|-------|----|------|---------|-----------| +| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 | +| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 | +| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 | +| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 | +| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | +| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 | +| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 | +| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 | +| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 | +| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 | +
+ #### Examples
Quant diff --git a/src/qonnx/converters/qkeras/onnx.py b/src/qonnx/converters/qkeras/onnx.py index 1f34d653..41b022d7 100644 --- a/src/qonnx/converters/qkeras/onnx.py +++ b/src/qonnx/converters/qkeras/onnx.py @@ -53,31 +53,61 @@ def qlayer_handler(ctx, node, name, args): if not keras_name: return # Not found in quantizers, nothing to do quantizers = all_quantizers[keras_name] - if quantizers.get("kernel_quantizer"): + + if quantizers.get("kernel_quantizer_cfg"): weights = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(weights, quantizers["kernel_quantizer"]) + quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"]) attr = quant_params["attributes"] input_nodes = [node.input[1]] + for key in quant_params["inputs"].keys(): name = f"{node.name}_kernel_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) - ctx.insert_new_node_on_input( + quant_node = ctx.insert_new_node_on_input( node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx" ) - - if quantizers.get("bias_quantizer") and len(node.input) == 3: - bias = node.inputs[2].get_tensor_value(as_list=True) - quant_params = get_quant_params(bias, quantizers["bias_quantizer"]) + if quantizers["kernel_quantizer_cfg"]["class_name"] == "quantized_bits": + bits = quantizers["kernel_quantizer_cfg"]["config"]["bits"] + integer = quantizers["kernel_quantizer_cfg"]["config"]["integer"] + keep_negative = quantizers["kernel_quantizer_cfg"]["config"]["keep_negative"] + if bits == integer + keep_negative: + scale_node = ctx.make_const( + name=node.name + "_kernel_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32) + ) + ctx.insert_new_node_on_output( + op_type="Mul", + output_name=quant_node.output[0], + name=node.name + "_kernel_requantizer", + inputs=[quant_node.output[0], scale_node.name], + ) + + if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3: + bias = node.inputs[-1].get_tensor_value(as_list=True) + quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"]) attr = quant_params["attributes"] - input_nodes = [node.input[2]] + input_nodes = [node.input[-1]] for key in quant_params["inputs"].keys(): name = f"{node.name}_bias_quantizer_{key}" np_val = np.asarray(quant_params["inputs"][key]) ctx.make_const(name, np_val) input_nodes.append(name) ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx") + if quantizers["bias_quantizer_cfg"]["class_name"] == "quantized_bits": + bits = quantizers["bias_quantizer_cfg"]["config"]["bits"] + integer = quantizers["bias_quantizer_cfg"]["config"]["integer"] + keep_negative = quantizers["bias_quantizer_cfg"]["config"]["keep_negative"] + if bits == integer + keep_negative: + scale_node = ctx.make_const( + name=node.name + "_bias_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32) + ) + ctx.insert_new_node_on_output( + op_type="Mul", + output_name=quant_node.output[0], + name=node.name + "_bias_requantizer", + inputs=[quant_node.output[0], scale_node.name], + ) if quantizers.get("activation"): dtypes = [ctx.get_dtype(node.output[0])] @@ -109,6 +139,11 @@ def qact_handler(ctx, node, name, args): quantizers = all_quantizers[keras_name] if quantizers.get("activation"): dtypes = [ctx.get_dtype(node.output[0])] + if "auto" in quantizers["activation"]: + if not node.graph.get_node_by_output(node.input[0]).is_const(): + raise AttributeError( + f"Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}." + ) quant_params = get_quant_params(None, quantizers["activation"]) attr = quant_params["attributes"] input_nodes = [node.output[0]] @@ -148,9 +183,9 @@ def bias_handler(ctx, node, name, args): return # Not found in quantizers, nothing to do quantizers = all_quantizers[keras_name] - if quantizers.get("bias_quantizer"): + if quantizers.get("bias_quantizer_cfg"): bias = node.inputs[1].get_tensor_value(as_list=True) - quant_params = get_quant_params(bias, quantizers["bias_quantizer"]) + quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"]) attr = quant_params["attributes"] input_nodes = [node.input[1]] for key in quant_params["inputs"].keys(): diff --git a/src/qonnx/converters/qkeras/qlayers.py b/src/qonnx/converters/qkeras/qlayers.py index fdbca71b..3bfc7fa7 100644 --- a/src/qonnx/converters/qkeras/qlayers.py +++ b/src/qonnx/converters/qkeras/qlayers.py @@ -104,11 +104,26 @@ def extract_qlayer(layer): keras_config = layer.get_config() - keras_config.pop("kernel_quantizer", None) - keras_config.pop("bias_quantizer", None) + kernel_quant_cfg = keras_config.pop("kernel_quantizer", None) + bias_quant_cfg = keras_config.pop("bias_quantizer", None) keras_config.pop("kernel_range", None) keras_config.pop("bias_range", None) + quantizers["kernel_quantizer_cfg"] = kernel_quant_cfg + quantizers["bias_quantizer_cfg"] = bias_quant_cfg + + # For some reason downstream can't handle auto_po2, so we just calculate the scale value now + if kernel_quant_cfg is not None and kernel_quant_cfg["config"]["alpha"] == "auto_po2": + layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2) + quantizers["kernel_quantizer_cfg"]["config"]["alpha"] = ( + layer.kernel_quantizer_internal.scale.numpy().flatten().tolist() + ) + if bias_quant_cfg is not None and bias_quant_cfg["config"]["alpha"] == "auto_po2": + layer.bias_quantizer_internal(layer.bias) + quantizers["bias_quantizer_cfg"]["config"]["alpha"] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist() + quantizers.pop("kernel_quantizer", None) + quantizers.pop("bias_quantizer", None) + # Check if activation is quantized if _is_keras_quantizer(keras_config["activation"]): keras_config["activation"] = _replace_activation(quantizers["activation"]) diff --git a/src/qonnx/converters/qkeras/quantizers.py b/src/qonnx/converters/qkeras/quantizers.py index 983cc997..3d232390 100644 --- a/src/qonnx/converters/qkeras/quantizers.py +++ b/src/qonnx/converters/qkeras/quantizers.py @@ -1,9 +1,11 @@ +import numpy as np import qkeras import six +import tensorflow as tf def get_quant_params(tensor, qkeras_quantizer): - if isinstance(qkeras_quantizer, str): + if isinstance(qkeras_quantizer, (str, dict)): qkeras_quantizer = qkeras.get_quantizer(qkeras_quantizer) return handler_map[qkeras_quantizer.__class__.__name__](tensor, qkeras_quantizer) @@ -34,11 +36,12 @@ def convert_quantized_bits(tensor, quantizer): signed = int(config["keep_negative"]) narrow = int(config["symmetric"]) qscale = _get_quantizer_scale(tensor, quantizer) - assert qscale == 1, "Non-unity alpha is not yet supported" - scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) + if not isinstance(qscale, (np.ndarray, tf.Tensor)): + qscale = np.array(qscale) + scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed)) zero_point = 0 bit_width = int(config["bits"]) - rounding_mode = "ROUND" + rounding_mode = "HALF_EVEN" settings = { "attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode}, diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index f552e7a8..b0b50b9a 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -135,12 +135,32 @@ def resolve_rounding_mode(mode_string): """Resolve the rounding mode string of Quant and Trunc ops to the corresponding numpy functions.""" normalized_mode_string = mode_string.upper() - if normalized_mode_string == "ROUND": + if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN": return np.round elif normalized_mode_string == "CEIL": return np.ceil elif normalized_mode_string == "FLOOR": return np.floor + elif normalized_mode_string == "UP": + + def round_up(x): + return np.sign(x) * np.ceil(np.abs(x)) + + return round_up + elif normalized_mode_string == "DOWN": + return np.fix + elif normalized_mode_string == "HALF_UP": + + def round_half_up(x): + return np.sign(x) * np.floor(np.abs(x) + 0.5) + + return round_half_up + elif normalized_mode_string == "HALF_DOWN": + + def round_half_down(x): + return np.sign(x) * np.ceil(np.abs(x) - 0.5) + + return round_half_down else: raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py new file mode 100644 index 00000000..eb48d644 --- /dev/null +++ b/tests/custom_op/test_runding_mode.py @@ -0,0 +1,23 @@ +import pytest + +import numpy as np + +from qonnx.custom_op.general.quant import resolve_rounding_mode + + +@pytest.mark.parametrize( + "rmode,exp", + [ + ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])), + ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])), + ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])), + ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])), + ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])), + ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])), + ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])), + ], +) +def test_rounding_modes(rmode, exp): + test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]) + rounding_fn = resolve_rounding_mode(rmode) + assert np.array_equal(rounding_fn(test_array), exp) diff --git a/tests/keras/test_keras_convert.py b/tests/keras/test_keras_convert.py index 388f39a4..964a2e3c 100644 --- a/tests/keras/test_keras_convert.py +++ b/tests/keras/test_keras_convert.py @@ -13,6 +13,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.infer_shapes import InferShapes +# For reproducibility +tf.keras.utils.set_random_seed(42) +np.random.seed(42) + act_quantizers = [ quantized_bits(8, 4, 0, alpha=1), quantized_bits(8, 4, 1, alpha=1), @@ -64,6 +68,64 @@ def test_qkeras_qactivation(quantizer, request): os.remove(model_path) +@pytest.mark.parametrize( + "quantizer", + [ + quantized_relu(bits=4, integer=4), + quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1), + ], +) +def test_qkeras_quantizers_rounding_modes(quantizer, request): + x = x_in = Input((10,), name="input") + x = QActivation(activation=quantizer)(x) + model = Model(inputs=[x_in], outputs=[x]) + model.compile() + + onnx_model, _ = from_keras(model) + model_path = f"model_test_qkeras_quantizers_rounding_modes_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + onnx_model = ModelWrapper(onnx_model) + + x_test = np.array([[5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]]).astype(np.float32) + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + y_qkeras = model.predict(x_test) + assert np.array_equal(y_qkeras, y_qonnx) + os.remove(model_path) + + +@pytest.mark.parametrize("bias", [5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]) +def test_qkeras_quantizers_autopo2_rounding_modes(bias, request): + kq = bq = quantized_bits(4, 4, 1, alpha="auto_po2") + # Initialize the kernel & bias to RandonUniform within the range of the quantizers + x = x_in = Input((10), name="input") + x = QDense( + 1, + kernel_quantizer=kq, + bias_quantizer=bq, + kernel_initializer=tf.keras.initializers.Constant(1.0), + bias_initializer=tf.keras.initializers.Constant(bias), + name="dense_0", + )(x) + model = Model(inputs=[x_in], outputs=[x]) + x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 10)).astype(dtype=np.float32) + _ = model.predict(x_test, verbose=0) + + onnx_model, _ = from_keras(model) + model_path = f"model_test_qkeras_quantizers_auto_rounding_modes_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + onnx_model = ModelWrapper(onnx_model) + + x_test = np.zeros(shape=(1, 10), dtype=np.float32) + idict = {onnx_model.graph.input[0].name: x_test} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + y_qkeras = model.predict(x_test, verbose=0) + assert np.array_equal(y_qkeras, y_qonnx) + os.remove(model_path) + + # pairs of quantizers for kernel and bias kb_quantizers = [ (quantized_bits(8, 4, 0, alpha=1), quantized_bits(8, 4, 0, alpha=1)), @@ -324,6 +386,67 @@ def test_qkeras_qdense_4(quantizers, request): os.remove(model_path) +@pytest.mark.parametrize( + "bits,signed,alpha", + [ + (8, True, [1.000, 1.000, 1.000, 1.000]), + (8, False, [1.000, 1.000, 1.000, 1.000]), + (4, True, [1.000, 1.000, 1.000, 1.000]), + (4, False, [1.000, 1.000, 1.000, 1.000]), + (8, True, [0.125, 0.250, 0.500, 1.000]), + (8, False, [0.125, 0.250, 0.500, 1.000]), + (5, True, [0.250, 0.250, 0.125, 0.125]), + (5, False, [0.250, 0.250, 0.125, 0.125]), + (4, True, [0.125, 0.250, 0.500, 1.000]), + (4, False, [0.125, 0.250, 0.500, 1.000]), + (3, True, [0.125, 0.125, 0.250, 0.125]), + (3, False, [0.125, 0.125, 0.250, 0.125]), + ], +) +def test_qkeras_tensor_alpha(bits, signed, alpha, request): + random_state = np.random.RandomState(seed=42) + max_val = np.array(alpha) * 2 ** (bits - signed) + min_val = -(max_val + 1) + w1 = random_state.randint(low=min_val, high=max_val, size=(3, 4)) + b1 = np.array([0.0, 0.0, 0.0, 0.0]) + x = x_in = tf.keras.layers.Input(shape=3) + x = QActivation(quantized_bits(bits=4, integer=3, keep_negative=True))(x) + x = QDense( + 4, + kernel_quantizer=quantized_bits(bits=bits, integer=(bits - signed), keep_negative=signed, alpha=alpha), + )(x) + x = QActivation(quantized_relu(bits=3, integer=3))(x) + model = tf.keras.Model(inputs=[x_in], outputs=[x]) + model.compile() + model.layers[2].set_weights([w1, b1]) + onnx_model, _ = from_keras(model) + model_path = f"model_test_qkeras_tensor_alpha_{request.node.callspec.id}.onnx" + onnx.save(onnx_model, model_path) + onnx_model = ModelWrapper(onnx_model) + + data = np.array( + [ + [[0.0, 0.0, 0.0]], + [[0.0, 1.0, 2.0]], + [[2.0, 1.0, 0.0]], + [[4.0, 4.0, 4.0]], + [[7.0, 7.0, 7.0]], + [[6.0, 0.0, 7.0]], + [[3.0, 3.0, 3.0]], + [[7.0, 0.0, 0.0]], + [[0.0, 7.0, 0.0]], + [[0.0, 0.0, 7.0]], + ] + ).astype(np.float32) + for x in data: + y_qkeras = model.predict(x, verbose=0) + idict = {onnx_model.graph.input[0].name: x} + odict = oxe.execute_onnx(onnx_model, idict, True) + y_qonnx = odict[onnx_model.graph.output[0].name] + assert np.array_equal(y_qkeras, y_qonnx) + os.remove(model_path) + + @pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids) def test_qkeras_qconv2d_1(quantizers, request): kq, bq = quantizers