Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add proper support for Binarized and binary neural nets in the QKeras converter #141

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dad9869
Added test, docs/, and updated resolve_rounding_mode function to retu…
Mar 22, 2024
47a88e4
Fix table visualization.
Mar 22, 2024
e2c1504
Fix table visualization again.
Mar 22, 2024
baa0df3
Fixed converter to allow alpha/scale to be a tensor
Aug 9, 2024
de9f731
Added a check to see if tensor is representable by the quantization p…
Aug 9, 2024
72b994a
Extra Mul node inserted only when neccessary
Aug 12, 2024
75b40ab
Added parameterized test for tensor style alpha.
Aug 12, 2024
7b5bf4a
Updated QKeras converter to support auto_po2
Aug 14, 2024
9fd5f6a
Added check for none.
Aug 14, 2024
d8a66a7
Fixed pre-commit issues.
Aug 14, 2024
cd45320
Added check if tensor is repsentable with quant setting.s
Aug 14, 2024
3eddb38
Added Identity node to the removal list
Aug 22, 2024
e95042b
Added an input quantization node in the qkeras converter (if applicable)
Aug 22, 2024
3e793f2
Addded del_initializer to modelwrapper.
Aug 23, 2024
7e98eb3
reformated with pre-commit hooks.
Aug 26, 2024
fafc4d7
Removed the _TO_ to make it consitant with others
Aug 26, 2024
608e3c2
Merge branch 'rounding_mode_new' into update_qkeras_converter_for_arr…
Aug 26, 2024
7e710ea
Merge branch 'update_qkeras_converter_for_array_and_auto_po2_scale' i…
Aug 26, 2024
65114d2
Merge branch 'is_tensor_representable' into add_identity_node_to_remo…
Aug 26, 2024
acc0604
Merge branch 'add_identity_node_to_removal_list' into add_input_quant…
Aug 26, 2024
518c9ae
Merge branch 'add_input_quantization_qkeras_converter' into add_del_i…
Aug 26, 2024
76a791d
Removed quantized relu from input quantizaiton fn, because it is pick…
Aug 27, 2024
5b840a9
Merge branch 'add_input_quantization_qkeras_converter' into add_del_i…
Aug 27, 2024
ac83d47
Simplified qlayer converter.
Aug 28, 2024
f6f95b2
Fixed bug in _add_quant_node_on_input func.
Aug 28, 2024
67d5977
Outputing bipolarQuant node for binary inputs.
Aug 28, 2024
bc215c3
Binarized networks now seem to work (with BinaryQuant)
Aug 29, 2024
12a5ee8
Further cleaned up converter.
Aug 29, 2024
15b9479
Input quantizer logic now working also for qkeras.binary quantizer.
Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docs/qonnx-custom-ops/quant_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
<dt><tt>narrow</tt> : int (default is 0)</dt>
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
<dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt>
<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
<dd>Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.</dd>
</dl>

#### Inputs
Expand All @@ -46,6 +46,24 @@ This operator is not part of the ONNX standard and is not currently versioned.
</dl>


#### Rounding modes
<details>
<summary>rounding modes</summary>

| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN |
|----------------------------|-----------------|------|-------|----|------|---------|-----------|
| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 |
| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 |
| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 |
| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 |
| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 |
| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 |
| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 |
| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 |
| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 |
</details>

#### Examples
<details>
<summary>Quant</summary>
Expand Down
50 changes: 50 additions & 0 deletions src/qonnx/converters/keras.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import onnx
import tensorflow as tf
import tf2onnx
from collections import OrderedDict
from qkeras.qlayers import QActivation
from qkeras.quantizers import binary, quantized_bits
from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS

from qonnx.core.modelwrapper import ModelWrapper
Expand Down Expand Up @@ -164,6 +167,40 @@ def _convert_quantizers_to_nodes(onnx_model, quantizers_dict):
return onnx_model.model


def _add_input_quantizer(onnx_model, quantizer):
"Adds an input quantizer to the onnx_model"
iname = onnx_model.graph.input[0].name
scale_init_name = f"{iname}_init_scale"
zp_init_name = f"{iname}_init_zp"
bw_init_name = f"{iname}_init_bw"
onnx_model.set_initializer(scale_init_name, np.array(quantizer.scale))
onnx_model.set_initializer(zp_init_name, np.array(0.0))
onnx_model.set_initializer(bw_init_name, np.array(quantizer.bits))
if isinstance(quantizer, quantized_bits):
quant_node = onnx.helper.make_node(
op_type="Quant",
inputs=[iname, scale_init_name, zp_init_name, bw_init_name],
outputs=[f"{iname}_quantized"],
name=f"{iname}_Quant",
domain="qonnx.custom_op.general",
narrow=quantizer.symmetric,
rounding_mode="ROUND",
signed=quantizer.keep_negative,
)
else:
quant_node = onnx.helper.make_node(
op_type="BipolarQuant",
inputs=[iname, scale_init_name],
outputs=[f"{iname}_quantized"],
name=f"{iname}_Quant",
domain="qonnx.custom_op.general",
)
for node in onnx_model.graph.node:
if node.input[0] == iname:
node.input[0] = quant_node.output[0]
onnx_model.graph.node.extend([quant_node])


def from_keras(
model,
name="qkeras_to_qonnx_converted",
Expand Down Expand Up @@ -230,6 +267,19 @@ def from_keras(
)

onnx_model = ModelWrapper(model_proto)

# checks if there is a quantizer at the input and adds it to the proto
# This is'nt handled in the "qkeras_op_handlers"
for submod in model.submodules:
if (
isinstance(submod, (QActivation, tf.keras.layers.Activation))
and model.input.name == submod.input.name
and isinstance(submod.submodules[0], (quantized_bits, binary))
):
assert len(submod.submodules) == 1
_add_input_quantizer(onnx_model, submod.submodules[0])
break

# Set the first value of input/output shape to 1, currently this is set to unknown,
# because it is technically the batch size
if not (len(onnx_model.graph.input) == 1 and len(onnx_model.graph.output) == 1):
Expand Down
184 changes: 92 additions & 92 deletions src/qonnx/converters/qkeras/onnx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import numpy as np
from tf2onnx.late_rewriters import channel_order_rewriters
from tf2onnx.onnx_opset.math import DirectOp, MatMul
from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp

from qonnx.custom_op.general.quant import quant

from .quantizers import get_quant_params

logger = logging.getLogger(__name__)


def get_qkeras_onnx_handlers(all_quantizers):
"""Returns the handlers for each kind of layer
Expand Down Expand Up @@ -47,48 +51,73 @@ def _extract_node_name(onnx_node, keras_quantizers):
return None


def qlayer_handler(ctx, node, name, args):
all_quantizers = args[0]
keras_name = _extract_node_name(node, all_quantizers)
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
if quantizers.get("kernel_quantizer"):
weights = node.inputs[1].get_tensor_value(as_list=True)
quant_params = get_quant_params(weights, quantizers["kernel_quantizer"])
attr = quant_params["attributes"]
input_nodes = [node.input[1]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_kernel_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
ctx.insert_new_node_on_input(
node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx"
def check_tensor_is_representable(tensor, quant_params, node):
"Gives a Warning iftensor is not representable with the providede quantization settings"
qtensor = quant(
inp_tensor=np.array(tensor),
scale=np.array(quant_params["inputs"]["scale"]),
zeropt=np.array(quant_params["inputs"]["zero_point"]),
bitwidth=np.array(quant_params["inputs"]["bit_width"]),
signed=quant_params["attributes"]["signed"],
narrow=quant_params["attributes"]["narrow"],
rounding_mode=quant_params["attributes"]["rounding_mode"],
)
if not np.array_equal(tensor, qtensor):
logger.warn(
f"Tensor of node: {node.name} is not representable with the provided quantization settings: {quant_params}"
)

if quantizers.get("bias_quantizer") and len(node.input) == 3:
bias = node.inputs[2].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
attr = quant_params["attributes"]
input_nodes = [node.input[2]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_bias_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx")

if quantizers.get("activation"):
dtypes = [ctx.get_dtype(node.output[0])]
quant_params = get_quant_params(None, quantizers["activation"])
attr = quant_params["attributes"]
input_nodes = [node.output[0]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_activation_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
def _add_quant_node_on_input(ctx, node, quantizer_cfg, input_ind):
weights = node.inputs[input_ind].get_tensor_value(as_list=True)
quant_params = get_quant_params(weights, quantizer_cfg)
check_tensor_is_representable(weights, quant_params, node)
attr = quant_params["attributes"]
input_nodes = [node.input[input_ind]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_{input_ind}_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1:
quant_node = ctx.insert_new_node_on_input(
node, "BipolarQuant", input_nodes[0:2], name=node.name + f"_{input_ind}_quantizer", domain="qonnx"
)
else:
quant_node = ctx.insert_new_node_on_input(
node, "Quant", input_nodes, name=node.name + f"_{input_ind}_quantizer", **attr, domain="qonnx"
)
if quantizer_cfg["class_name"] == "quantized_bits":
bits = quantizer_cfg["config"]["bits"]
integer = quantizer_cfg["config"]["integer"]
keep_negative = quantizer_cfg["config"]["keep_negative"]
if bits == integer + keep_negative:
scale_node = ctx.make_const(
name=node.name + f"_{input_ind}_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32)
)
ctx.insert_new_node_on_output(
op_type="Mul",
output_name=quant_node.output[0],
name=node.name + f"_{input_ind}_requantizer",
inputs=[quant_node.output[0], scale_node.name],
)


def _add_quant_node_on_output(ctx, node, quantizer_cfg, output_ind):
dtypes = [ctx.get_dtype(node.output[output_ind])]
quant_params = get_quant_params(None, quantizer_cfg)
attr = quant_params["attributes"]
input_nodes = [node.output[output_ind]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_{output_ind}_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
if quant_params["inputs"]["bit_width"] == 1 and attr["signed"] == 1:
quant_act_node = ctx.make_node(
"BipolarQuant", input_nodes[0:2], name=node.name + "_activation_quantizer", domain="qonnx"
)
else:
quant_act_node = ctx.make_node(
"Quant",
input_nodes,
Expand All @@ -97,8 +126,25 @@ def qlayer_handler(ctx, node, name, args):
attr=attr,
domain="qonnx",
)
ctx.insert_node_on_output(quant_act_node, node.output[0])
ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0]))
ctx.insert_node_on_output(quant_act_node, node.output[output_ind])
ctx.set_shape(quant_act_node.output[output_ind], ctx.get_shape(node.output[output_ind]))


def qlayer_handler(ctx, node, name, args):
all_quantizers = args[0]
keras_name = _extract_node_name(node, all_quantizers)
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]

if quantizers.get("kernel_quantizer_cfg"):
_add_quant_node_on_input(ctx, node, quantizers["kernel_quantizer_cfg"], 1)

if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3:
_add_quant_node_on_input(ctx, node, quantizers["bias_quantizer_cfg"], -1)

if quantizers.get("activation"):
_add_quant_node_on_output(ctx, node, quantizers["activation"], 0)


def qact_handler(ctx, node, name, args):
Expand All @@ -108,26 +154,7 @@ def qact_handler(ctx, node, name, args):
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
if quantizers.get("activation"):
dtypes = [ctx.get_dtype(node.output[0])]
quant_params = get_quant_params(None, quantizers["activation"])
attr = quant_params["attributes"]
input_nodes = [node.output[0]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_activation_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
quant_act_node = ctx.make_node(
"Quant",
input_nodes,
dtypes=dtypes,
name=node.name + "_activation_quantizer",
attr=attr,
domain="qonnx",
)
ctx.insert_node_on_output(quant_act_node, node.output[0])
ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0]))
channel_order_rewriters._to_channel_first_handler(ctx, quant_act_node)
_add_quant_node_on_output(ctx, node, quantizers["activation"], 0)


def conv2d_handler(ctx, node, name, args):
Expand All @@ -148,42 +175,15 @@ def bias_handler(ctx, node, name, args):
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]

if quantizers.get("bias_quantizer"):
bias = node.inputs[1].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
attr = quant_params["attributes"]
input_nodes = [node.input[1]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_bias_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx")
if quantizers.get("bias_quantizer_cfg"):
_add_quant_node_on_input(ctx, node, quantizers["bias_quantizer_cfg"], 1)

if quantizers.get("activation"):
# removes node if added earlier
remove_node_id = node.input[0]
remove_node = ctx.get_node_by_output(remove_node_id)
ctx.replace_all_inputs(node.input[0], remove_node.input[0], ops=None)
dtypes = [ctx.get_dtype(node.output[0])]
quant_params = get_quant_params(None, quantizers["activation"])
attr = quant_params["attributes"]
input_nodes = [node.output[0]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_activation_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
quant_act_node = ctx.make_node(
"Quant",
input_nodes,
dtypes=dtypes,
name=node.input[0] + "_activation_quantizer",
attr=attr,
domain="qonnx",
)
ctx.insert_node_on_output(quant_act_node, node.output[0])
ctx.set_shape(quant_act_node.output[0], ctx.get_shape(node.output[0]))
_add_quant_node_on_output(ctx, node, quantizers["activation"], 0)


def relu_handler(ctx, node, name, args):
Expand Down
19 changes: 17 additions & 2 deletions src/qonnx/converters/qkeras/qlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,26 @@ def extract_qlayer(layer):

keras_config = layer.get_config()

keras_config.pop("kernel_quantizer", None)
keras_config.pop("bias_quantizer", None)
kernel_quant_cfg = keras_config.pop("kernel_quantizer", None)
bias_quant_cfg = keras_config.pop("bias_quantizer", None)
keras_config.pop("kernel_range", None)
keras_config.pop("bias_range", None)

quantizers["kernel_quantizer_cfg"] = kernel_quant_cfg
quantizers["bias_quantizer_cfg"] = bias_quant_cfg

# For some reason downstream can't handle auto_po2, so we just calculate the scale value now
if kernel_quant_cfg is not None and kernel_quant_cfg["config"]["alpha"] == "auto_po2":
layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2)
quantizers["kernel_quantizer_cfg"]["config"]["alpha"] = (
layer.kernel_quantizer_internal.scale.numpy().flatten().tolist()
)
if bias_quant_cfg is not None and bias_quant_cfg["config"]["alpha"] == "auto_po2":
layer.bias_quantizer_internal(layer.bias)
quantizers["bias_quantizer_cfg"]["config"]["alpha"] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist()
quantizers.pop("kernel_quantizer", None)
quantizers.pop("bias_quantizer", None)

# Check if activation is quantized
if _is_keras_quantizer(keras_config["activation"]):
keras_config["activation"] = _replace_activation(quantizers["activation"])
Expand Down
11 changes: 7 additions & 4 deletions src/qonnx/converters/qkeras/quantizers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import qkeras
import six
import tensorflow as tf


def get_quant_params(tensor, qkeras_quantizer):
if isinstance(qkeras_quantizer, str):
if isinstance(qkeras_quantizer, (str, dict)):
qkeras_quantizer = qkeras.get_quantizer(qkeras_quantizer)

return handler_map[qkeras_quantizer.__class__.__name__](tensor, qkeras_quantizer)
Expand Down Expand Up @@ -34,11 +36,12 @@ def convert_quantized_bits(tensor, quantizer):
signed = int(config["keep_negative"])
narrow = int(config["symmetric"])
qscale = _get_quantizer_scale(tensor, quantizer)
assert qscale == 1, "Non-unity alpha is not yet supported"
scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
if not isinstance(qscale, (np.ndarray, tf.Tensor)):
qscale = np.array(qscale)
scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
zero_point = 0
bit_width = int(config["bits"])
rounding_mode = "ROUND"
rounding_mode = "HALF_EVEN"

settings = {
"attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode},
Expand Down
Loading