Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update qkeras converter for array and auto po2 scale #135

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
20 changes: 19 additions & 1 deletion docs/qonnx-custom-ops/quant_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
<dt><tt>narrow</tt> : int (default is 0)</dt>
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
<dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt>
<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
<dd>Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.</dd>
</dl>

#### Inputs
Expand All @@ -46,6 +46,24 @@ This operator is not part of the ONNX standard and is not currently versioned.
</dl>


#### Rounding modes
<details>
<summary>rounding modes</summary>

| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN |
|----------------------------|-----------------|------|-------|----|------|---------|-----------|
| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 |
| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 |
| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 |
| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 |
| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 |
| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 |
| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 |
| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 |
| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 |
</details>

#### Examples
<details>
<summary>Quant</summary>
Expand Down
55 changes: 45 additions & 10 deletions src/qonnx/converters/qkeras/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,61 @@ def qlayer_handler(ctx, node, name, args):
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
if quantizers.get("kernel_quantizer"):

if quantizers.get("kernel_quantizer_cfg"):
weights = node.inputs[1].get_tensor_value(as_list=True)
quant_params = get_quant_params(weights, quantizers["kernel_quantizer"])
quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"])
attr = quant_params["attributes"]
input_nodes = [node.input[1]]

for key in quant_params["inputs"].keys():
name = f"{node.name}_kernel_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
ctx.insert_new_node_on_input(
quant_node = ctx.insert_new_node_on_input(
node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx"
)

if quantizers.get("bias_quantizer") and len(node.input) == 3:
bias = node.inputs[2].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
if quantizers["kernel_quantizer_cfg"]["class_name"] == "quantized_bits":
bits = quantizers["kernel_quantizer_cfg"]["config"]["bits"]
integer = quantizers["kernel_quantizer_cfg"]["config"]["integer"]
keep_negative = quantizers["kernel_quantizer_cfg"]["config"]["keep_negative"]
if bits == integer + keep_negative:
scale_node = ctx.make_const(
name=node.name + "_kernel_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32)
)
ctx.insert_new_node_on_output(
op_type="Mul",
output_name=quant_node.output[0],
name=node.name + "_kernel_requantizer",
inputs=[quant_node.output[0], scale_node.name],
)

if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3:
bias = node.inputs[-1].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"])
attr = quant_params["attributes"]
input_nodes = [node.input[2]]
input_nodes = [node.input[-1]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_bias_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx")
if quantizers["bias_quantizer_cfg"]["class_name"] == "quantized_bits":
bits = quantizers["bias_quantizer_cfg"]["config"]["bits"]
integer = quantizers["bias_quantizer_cfg"]["config"]["integer"]
keep_negative = quantizers["bias_quantizer_cfg"]["config"]["keep_negative"]
if bits == integer + keep_negative:
scale_node = ctx.make_const(
name=node.name + "_bias_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32)
)
ctx.insert_new_node_on_output(
op_type="Mul",
output_name=quant_node.output[0],
name=node.name + "_bias_requantizer",
inputs=[quant_node.output[0], scale_node.name],
)

if quantizers.get("activation"):
dtypes = [ctx.get_dtype(node.output[0])]
Expand Down Expand Up @@ -109,6 +139,11 @@ def qact_handler(ctx, node, name, args):
quantizers = all_quantizers[keras_name]
if quantizers.get("activation"):
dtypes = [ctx.get_dtype(node.output[0])]
if "auto" in quantizers["activation"]:
if not node.graph.get_node_by_output(node.input[0]).is_const():
raise AttributeError(
f"Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}."
)
quant_params = get_quant_params(None, quantizers["activation"])
attr = quant_params["attributes"]
input_nodes = [node.output[0]]
Expand Down Expand Up @@ -148,9 +183,9 @@ def bias_handler(ctx, node, name, args):
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]

if quantizers.get("bias_quantizer"):
if quantizers.get("bias_quantizer_cfg"):
bias = node.inputs[1].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"])
attr = quant_params["attributes"]
input_nodes = [node.input[1]]
for key in quant_params["inputs"].keys():
Expand Down
19 changes: 17 additions & 2 deletions src/qonnx/converters/qkeras/qlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,26 @@ def extract_qlayer(layer):

keras_config = layer.get_config()

keras_config.pop("kernel_quantizer", None)
keras_config.pop("bias_quantizer", None)
kernel_quant_cfg = keras_config.pop("kernel_quantizer", None)
bias_quant_cfg = keras_config.pop("bias_quantizer", None)
keras_config.pop("kernel_range", None)
keras_config.pop("bias_range", None)

quantizers["kernel_quantizer_cfg"] = kernel_quant_cfg
quantizers["bias_quantizer_cfg"] = bias_quant_cfg

# For some reason downstream can't handle auto_po2, so we just calculate the scale value now
if kernel_quant_cfg is not None and kernel_quant_cfg["config"]["alpha"] == "auto_po2":
layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2)
quantizers["kernel_quantizer_cfg"]["config"]["alpha"] = (
layer.kernel_quantizer_internal.scale.numpy().flatten().tolist()
)
if bias_quant_cfg is not None and bias_quant_cfg["config"]["alpha"] == "auto_po2":
layer.bias_quantizer_internal(layer.bias)
quantizers["bias_quantizer_cfg"]["config"]["alpha"] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist()
quantizers.pop("kernel_quantizer", None)
quantizers.pop("bias_quantizer", None)

# Check if activation is quantized
if _is_keras_quantizer(keras_config["activation"]):
keras_config["activation"] = _replace_activation(quantizers["activation"])
Expand Down
11 changes: 7 additions & 4 deletions src/qonnx/converters/qkeras/quantizers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import qkeras
import six
import tensorflow as tf


def get_quant_params(tensor, qkeras_quantizer):
if isinstance(qkeras_quantizer, str):
if isinstance(qkeras_quantizer, (str, dict)):
qkeras_quantizer = qkeras.get_quantizer(qkeras_quantizer)

return handler_map[qkeras_quantizer.__class__.__name__](tensor, qkeras_quantizer)
Expand Down Expand Up @@ -34,11 +36,12 @@ def convert_quantized_bits(tensor, quantizer):
signed = int(config["keep_negative"])
narrow = int(config["symmetric"])
qscale = _get_quantizer_scale(tensor, quantizer)
assert qscale == 1, "Non-unity alpha is not yet supported"
scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
if not isinstance(qscale, (np.ndarray, tf.Tensor)):
qscale = np.array(qscale)
scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
zero_point = 0
bit_width = int(config["bits"])
rounding_mode = "ROUND"
rounding_mode = "HALF_EVEN"

settings = {
"attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode},
Expand Down
22 changes: 21 additions & 1 deletion src/qonnx/custom_op/general/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,32 @@ def resolve_rounding_mode(mode_string):
"""Resolve the rounding mode string of Quant and Trunc ops
to the corresponding numpy functions."""
normalized_mode_string = mode_string.upper()
if normalized_mode_string == "ROUND":
if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
return np.round
elif normalized_mode_string == "CEIL":
return np.ceil
elif normalized_mode_string == "FLOOR":
return np.floor
elif normalized_mode_string == "UP":

def round_up(x):
return np.sign(x) * np.ceil(np.abs(x))

return round_up
elif normalized_mode_string == "DOWN":
return np.fix
elif normalized_mode_string == "HALF_UP":

def round_half_up(x):
return np.sign(x) * np.floor(np.abs(x) + 0.5)

return round_half_up
elif normalized_mode_string == "HALF_DOWN":

def round_half_down(x):
return np.sign(x) * np.ceil(np.abs(x) - 0.5)

return round_half_down
else:
raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")

Expand Down
23 changes: 23 additions & 0 deletions tests/custom_op/test_runding_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

import numpy as np

from qonnx.custom_op.general.quant import resolve_rounding_mode


@pytest.mark.parametrize(
"rmode,exp",
[
("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])),
("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])),
],
)
def test_rounding_modes(rmode, exp):
test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
rounding_fn = resolve_rounding_mode(rmode)
assert np.array_equal(rounding_fn(test_array), exp)
123 changes: 123 additions & 0 deletions tests/keras/test_keras_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes

# For reproducibility
tf.keras.utils.set_random_seed(42)
np.random.seed(42)

act_quantizers = [
quantized_bits(8, 4, 0, alpha=1),
quantized_bits(8, 4, 1, alpha=1),
Expand Down Expand Up @@ -64,6 +68,64 @@ def test_qkeras_qactivation(quantizer, request):
os.remove(model_path)


@pytest.mark.parametrize(
"quantizer",
[
quantized_relu(bits=4, integer=4),
quantized_bits(bits=4, integer=4, keep_negative=False, alpha=1),
],
)
def test_qkeras_quantizers_rounding_modes(quantizer, request):
x = x_in = Input((10,), name="input")
x = QActivation(activation=quantizer)(x)
model = Model(inputs=[x_in], outputs=[x])
model.compile()

onnx_model, _ = from_keras(model)
model_path = f"model_test_qkeras_quantizers_rounding_modes_{request.node.callspec.id}.onnx"
onnx.save(onnx_model, model_path)
onnx_model = ModelWrapper(onnx_model)

x_test = np.array([[5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5]]).astype(np.float32)
idict = {onnx_model.graph.input[0].name: x_test}
odict = oxe.execute_onnx(onnx_model, idict, True)
y_qonnx = odict[onnx_model.graph.output[0].name]
y_qkeras = model.predict(x_test)
assert np.array_equal(y_qkeras, y_qonnx)
os.remove(model_path)


@pytest.mark.parametrize("bias", [5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
def test_qkeras_quantizers_autopo2_rounding_modes(bias, request):
kq = bq = quantized_bits(4, 4, 1, alpha="auto_po2")
# Initialize the kernel & bias to RandonUniform within the range of the quantizers
x = x_in = Input((10), name="input")
x = QDense(
1,
kernel_quantizer=kq,
bias_quantizer=bq,
kernel_initializer=tf.keras.initializers.Constant(1.0),
bias_initializer=tf.keras.initializers.Constant(bias),
name="dense_0",
)(x)
model = Model(inputs=[x_in], outputs=[x])
x_test = np.random.uniform(low=-1.0, high=1.0, size=(1, 10)).astype(dtype=np.float32)
_ = model.predict(x_test, verbose=0)

onnx_model, _ = from_keras(model)
model_path = f"model_test_qkeras_quantizers_auto_rounding_modes_{request.node.callspec.id}.onnx"
onnx.save(onnx_model, model_path)
onnx_model = ModelWrapper(onnx_model)

x_test = np.zeros(shape=(1, 10), dtype=np.float32)
idict = {onnx_model.graph.input[0].name: x_test}
odict = oxe.execute_onnx(onnx_model, idict, True)
y_qonnx = odict[onnx_model.graph.output[0].name]
y_qkeras = model.predict(x_test, verbose=0)
assert np.array_equal(y_qkeras, y_qonnx)
os.remove(model_path)


# pairs of quantizers for kernel and bias
kb_quantizers = [
(quantized_bits(8, 4, 0, alpha=1), quantized_bits(8, 4, 0, alpha=1)),
Expand Down Expand Up @@ -324,6 +386,67 @@ def test_qkeras_qdense_4(quantizers, request):
os.remove(model_path)


@pytest.mark.parametrize(
"bits,signed,alpha",
[
(8, True, [1.000, 1.000, 1.000, 1.000]),
(8, False, [1.000, 1.000, 1.000, 1.000]),
(4, True, [1.000, 1.000, 1.000, 1.000]),
(4, False, [1.000, 1.000, 1.000, 1.000]),
(8, True, [0.125, 0.250, 0.500, 1.000]),
(8, False, [0.125, 0.250, 0.500, 1.000]),
(5, True, [0.250, 0.250, 0.125, 0.125]),
(5, False, [0.250, 0.250, 0.125, 0.125]),
(4, True, [0.125, 0.250, 0.500, 1.000]),
(4, False, [0.125, 0.250, 0.500, 1.000]),
(3, True, [0.125, 0.125, 0.250, 0.125]),
(3, False, [0.125, 0.125, 0.250, 0.125]),
],
)
def test_qkeras_tensor_alpha(bits, signed, alpha, request):
random_state = np.random.RandomState(seed=42)
max_val = np.array(alpha) * 2 ** (bits - signed)
min_val = -(max_val + 1)
w1 = random_state.randint(low=min_val, high=max_val, size=(3, 4))
b1 = np.array([0.0, 0.0, 0.0, 0.0])
x = x_in = tf.keras.layers.Input(shape=3)
x = QActivation(quantized_bits(bits=4, integer=3, keep_negative=True))(x)
x = QDense(
4,
kernel_quantizer=quantized_bits(bits=bits, integer=(bits - signed), keep_negative=signed, alpha=alpha),
)(x)
x = QActivation(quantized_relu(bits=3, integer=3))(x)
model = tf.keras.Model(inputs=[x_in], outputs=[x])
model.compile()
model.layers[2].set_weights([w1, b1])
onnx_model, _ = from_keras(model)
model_path = f"model_test_qkeras_tensor_alpha_{request.node.callspec.id}.onnx"
onnx.save(onnx_model, model_path)
onnx_model = ModelWrapper(onnx_model)

data = np.array(
[
[[0.0, 0.0, 0.0]],
[[0.0, 1.0, 2.0]],
[[2.0, 1.0, 0.0]],
[[4.0, 4.0, 4.0]],
[[7.0, 7.0, 7.0]],
[[6.0, 0.0, 7.0]],
[[3.0, 3.0, 3.0]],
[[7.0, 0.0, 0.0]],
[[0.0, 7.0, 0.0]],
[[0.0, 0.0, 7.0]],
]
).astype(np.float32)
for x in data:
y_qkeras = model.predict(x, verbose=0)
idict = {onnx_model.graph.input[0].name: x}
odict = oxe.execute_onnx(onnx_model, idict, True)
y_qonnx = odict[onnx_model.graph.output[0].name]
assert np.array_equal(y_qkeras, y_qonnx)
os.remove(model_path)


@pytest.mark.parametrize("quantizers", kb_quantizers, ids=kb_quantizers_ids)
def test_qkeras_qconv2d_1(quantizers, request):
kq, bq = quantizers
Expand Down