diff --git a/docs/advanced/qonnx.rst b/docs/advanced/qonnx.rst new file mode 100644 index 0000000000..09b0074a0b --- /dev/null +++ b/docs/advanced/qonnx.rst @@ -0,0 +1,56 @@ +============== +ONNX and QONNX +============== + +Parsing of ONNX and QONNX models is made in conjunction with the `qonnx `_ package, even if it no quantization is used. This is a common initial parser shared with the AMD/Xilinx FINN project. The first step is to do constant folding, shape inference, etc., on the ONNX graph, commonly known as `cleaning`. If a model has convolution layers, the model also needs to be converted to a channels-last format, since that is what hls4ml mainly supports. The ``qonnx`` package also provides a number of additional transforms that may need to be used. For example, ``Gemm`` nodes need to converted to ``MatMul`` and ``Add`` nodes. + +There are command-line based versions of cleaning and channels-last conversion: + +.. code-block:: bash + + $ qonnx_clean filename.onnx + $ qonnx_to_channels_last filename_clean.onnx + $ qonnx_clean filename_clean_channels_last.onnx # good to do a clean again as a last step + +Things can similarly be done in python. This method is usually easier if you additionally need to call other transforms. An example is given below which also calls the ``GemmToMatMul`` converter: + +.. code-block:: python + + model = ModelWrapper('filename.onnx') + model = qonnx.util.cleanup.cleanup_model(model) + model = model.transform(ConvertToChannelsLastAndClean()) + model = model.transform(GemmToMatMul()) + model = qonnx.util.cleanup.cleanup_model(model) + +``ModelWrapper`` is defined in ``qonnx.core.modelwrapper``. More information on the ``qonnx`` package can be found at the `QONNX documentation page `_. + + +The next steps are very similar to if you are using a Keras model: + +.. code-block:: python + + config = hls4ml.utils.config.config_from_onnx_model( + model, granularity='name', backend='Vitis', default_precision='fixed<16,6>' + ) + # modify the config as desired + hls_model = hls4ml.converters.convert_from_onnx_model( + model, + output_dir='my-hls-test', + io_type='io_stream', + backend='Vitis', + hls_config=config, + ) + hls_model.compile() + +Note, unlike the Keras version, "name" granularity is the default for ``config_from_onnx_model``, and it must be used for QONNX models. Unquantized ONNX models can use "model" if so desired, but generally there is no benefit. + +One can subsequently call the ``predict`` function to check the performance or build the project. + +Note that ``execute_onnx`` in ``qonnx.core.onnx_exec`` can be use to run the QONNX graphs directly, and it also provides the values at intermediate layers for validating the model (tracing). + +Quant nodes +=========== + +Documentation for quant nodes is provided in the `qonnx package `_. Note that currently hls4ml only supports the `Quant operator `_. Also, not all legal ``Quant`` configurations are parsable by hls4ml or synthesizable. The ``scale``, ``zeropt``, and ``bitwidth`` values must be constant (though not necessarily scalar for the ``scale`` and ``zeropt``). + +Generally if the ``zeropt`` is 0 and the ``scale`` is a scalar power of 2, hls4ml uses ``ap_fixed`` or ``ac_fixed`` types (depending on the backend) to represent the quantizations. In other cases, the ``scale`` and ``zeropt`` need to be explicitly handled by hls4ml, and there is more of a chance of hls4ml not being able to process the input. (Please report any issues that you find.) diff --git a/docs/index.rst b/docs/index.rst index 07fcd217db..339c4cfd42 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,6 +22,7 @@ :hidden: :caption: Advanced Features + advanced/qonnx advanced/fifo_depth advanced/extension advanced/oneapi diff --git a/example-models b/example-models index 3cfbcfd062..d40894b03f 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit 3cfbcfd062f60492507d21ff0e91559b3bdd6550 +Subproject commit d40894b03f840a32da43a5adea0531ffc1db216e diff --git a/hls4ml/backends/catapult/passes/pointwise.py b/hls4ml/backends/catapult/passes/pointwise.py index 0141d7f108..fd464ef172 100755 --- a/hls4ml/backends/catapult/passes/pointwise.py +++ b/hls4ml/backends/catapult/passes/pointwise.py @@ -1,5 +1,3 @@ -from copy import copy - from hls4ml.backends.catapult.passes.convolution_templates import ( Conv1DConfigTemplate, Conv1DFunctionTemplate, @@ -75,8 +73,10 @@ def match(self, node): def transform(self, model, node): dim = node.__class__.__name__[-2:] # '1D' or '2D' - pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy()) - pw_node.weights['bias'].data = node.weights['bias'].data + new_attrs = {k: v for k, v in node.attributes.items() if k not in ('trace', 'precision', 'reuse_factor')} + pw_node = model.make_node( + 'PointwiseConv' + dim, node.name, new_attrs, node.inputs.copy(), outputs=node.outputs.copy() + ) # Set strategy to ensure lowercase string is passed to the template if model.config.is_resource_strategy(pw_node): pw_node.set_attr('strategy', 'resource') diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 5c85682354..a9fc09b7aa 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -13,6 +13,8 @@ LSTM, Activation, BatchNormalization, + BatchNormOnnx, + Conv, Conv1D, Conv2D, Dense, @@ -22,8 +24,11 @@ GarNetStack, GlobalPooling1D, GlobalPooling2D, + MatMul, + Merge, Pooling1D, Pooling2D, + Quant, SeparableConv1D, SeparableConv2D, SimpleRNN, @@ -63,6 +68,8 @@ def __init__(self, name): LSTM, GRU, Dot, + Conv, + MatMul, ] for layer in accum_layers: @@ -70,7 +77,16 @@ def __init__(self, name): attrs.append(TypeAttribute('accum')) self.attribute_map[layer] = attrs - rf_layers = accum_layers + [BatchNormalization, Activation, Embedding, GarNet, GarNetStack] + rf_layers = accum_layers + [ + BatchNormalization, + Activation, + Embedding, + GarNet, + GarNetStack, + Quant, + BatchNormOnnx, + Merge, + ] for layer in rf_layers: attrs = self.attribute_map.get(layer, []) diff --git a/hls4ml/backends/quartus/passes/pointwise.py b/hls4ml/backends/quartus/passes/pointwise.py index 0f7f6821ae..d65ab22569 100644 --- a/hls4ml/backends/quartus/passes/pointwise.py +++ b/hls4ml/backends/quartus/passes/pointwise.py @@ -1,5 +1,3 @@ -from copy import copy - from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D from hls4ml.backends.quartus.passes.convolution_templates import ( Conv1DConfigTemplate, @@ -81,10 +79,10 @@ def match(self, node): def transform(self, model, node): dim = node.__class__.__name__[-2:] # '1D' or '2D' + new_attrs = {k: v for k, v in node.attributes.items() if k not in ('trace', 'precision', 'reuse_factor')} pw_node = model.make_node( - 'PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy(), outputs=node.outputs.copy() + 'PointwiseConv' + dim, node.name, new_attrs, node.inputs.copy(), outputs=node.outputs.copy() ) - pw_node.weights['bias'].data = node.weights['bias'].data model.replace_node(node, pw_node) return True diff --git a/hls4ml/backends/vivado/passes/pointwise.py b/hls4ml/backends/vivado/passes/pointwise.py index 85d2635cb8..34568b09f7 100644 --- a/hls4ml/backends/vivado/passes/pointwise.py +++ b/hls4ml/backends/vivado/passes/pointwise.py @@ -1,5 +1,3 @@ -from copy import copy - from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D from hls4ml.backends.vivado.passes.convolution_templates import ( Conv1DConfigTemplate, @@ -75,8 +73,11 @@ def match(self, node): def transform(self, model, node): dim = node.__class__.__name__[-2:] # '1D' or '2D' - pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy()) - pw_node.weights['bias'].data = node.weights['bias'].data + # to remove warning, since these get set again + new_attrs = {k: v for k, v in node.attributes.items() if k not in ('trace', 'precision', 'reuse_factor')} + pw_node = model.make_node( + 'PointwiseConv' + dim, node.name, new_attrs, node.inputs.copy(), outputs=node.outputs.copy() + ) # Set strategy to ensure lowercase string is passed to the template if model.config.is_resource_strategy(pw_node): pw_node.set_attr('strategy', 'resource') diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index 092e53b3d3..13e90df687 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -10,8 +10,7 @@ from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401 from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401 from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler - -# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401 +from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401 from hls4ml.model import ModelGraph from hls4ml.utils.config import create_config from hls4ml.utils.symbolic_utils import LUTFunction diff --git a/hls4ml/converters/keras/reshape.py b/hls4ml/converters/keras/reshape.py index bd9d519a2a..1f6dc2a759 100644 --- a/hls4ml/converters/keras/reshape.py +++ b/hls4ml/converters/keras/reshape.py @@ -11,8 +11,8 @@ def parse_flatten_layer(keras_layer, input_names, input_shapes, data_reader): layer = parse_default_keras_layer(keras_layer, input_names) layer['class_name'] = 'Reshape' - layer['target_shape'] = [input_shapes[0][0], np.prod(input_shapes[0][1:])] - output_shape = layer['target_shape'] + layer['target_shape'] = [np.prod(input_shapes[0][1:])] # target shape has no batch dimension + output_shape = input_shapes[0][:1] + layer['target_shape'] return layer, output_shape diff --git a/hls4ml/converters/onnx/convolution.py b/hls4ml/converters/onnx/convolution.py index 39b2232169..d84fb855a8 100644 --- a/hls4ml/converters/onnx/convolution.py +++ b/hls4ml/converters/onnx/convolution.py @@ -1,85 +1,77 @@ -from hls4ml.converters.onnx_to_hls import ( - compute_pads_1d, - compute_pads_2d, - get_onnx_attribute, - get_onnx_input_name, - onnx_handler, -) -from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d +import numpy as np + +from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler @onnx_handler('Conv') -def parse_conv_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_conv_layer(node, input_names, input_shapes, graph): layer = {} layer['name'] = node.name - layer['data_format'] = 'channels_first' # ONNX's default is channel first - layer['inputs'] = get_onnx_input_name(node, graph) - reader.add_input(layer['name'], node.input) + if node.domain != 'qonnx.custom_op.channels_last': + raise RuntimeError("Please convert the model to channels-last format with qonnx-to-channels-last") + layer['data_format'] = 'channels_last' # QONNX needs to be channels-last. + layer['inputs'] = input_names + layer['outputs'] = node.output strides = get_onnx_attribute(node, 'strides') kernel_shape = get_onnx_attribute(node, 'kernel_shape') - - if len(input_shapes[0]) == 3: # Conv1D - layer['class_name'] = 'Conv1D' - - layer['in_width'] = input_shapes[0][2] - layer['n_chan'] = input_shapes[0][1] - layer['filt_width'] = kernel_shape[0] - layer['n_filt'] = reader.get_weights_data(layer['name'], 'kernel').shape[2] - layer['stride_width'] = strides[0] - pads = compute_pads_1d(node, layer) - + # Note: currently don't have support for auto_pad. + pads = get_onnx_attribute(node, 'pads') + dilations = get_onnx_attribute(node, 'dilations') + if dilations is None: + dilations = [1] * len(layer['kernel_shape']) + + layer['in_width'] = input_shapes[0][-2] + layer['n_chan'] = input_shapes[0][-1] + layer['n_filt'] = input_shapes[1][0] + + layer['group'] = int(get_onnx_attribute(node, 'group')) + if layer['group'] != 1: + layer['depth_multiplier'] = get_onnx_attribute(node, 'group') / layer['n_chan'] + if not layer['depth_multiplier'].is_integer(): + raise ValueError('Depth multiplier must be an integer') + else: + layer['depth_multiplier'] = int(layer['depth_multiplier']) + + layer['n_dim'] = len(input_shapes[0]) - 2 # 2 comes from channels and batch dimentions + if layer['n_dim'] not in (1, 2): + raise ValueError("Only 1D and 2D convolutions are supported") + layer['class_name'] = 'Conv' + + # set some values needed later + if layer['n_dim'] == 1: + # this is 1D convolution + full_width = layer['in_width'] + pads[0] + pads[1] + eff_kernel_width = kernel_shape[0] * dilations[0] + layer['out_width'] = int(np.ceil((full_width - eff_kernel_width + 1) / strides[0])) + # for compatibility interpret some variables layer['pad_left'] = pads[0] layer['pad_right'] = pads[1] - - if all(x == 0 for x in pads): # No padding, i.e., 'VALID' padding - layer['padding'] = 'valid' - else: - layer['padding'] = 'same' - - (layer['out_width'], _, _) = compute_padding_1d( - layer['padding'], layer['in_width'], layer['stride_width'], layer['filt_width'] - ) - - output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_width']] - - elif len(input_shapes[0]) == 4: # Conv2D - layer['class_name'] = 'Conv2D' - - layer['in_height'] = input_shapes[0][2] - layer['in_width'] = input_shapes[0][3] - layer['n_chan'] = input_shapes[0][1] - + layer['filt_width'] = kernel_shape[0] + layer['stride_width'] = strides[0] + layer['dilation_width'] = dilations[0] + else: + # 2d + layer['in_height'] = input_shapes[0][-3] + full_height = layer['in_height'] + pads[0] + pads[2] + eff_kernel_height = kernel_shape[0] * dilations[0] + out_height = int(np.ceil((full_height - eff_kernel_height + 1) / strides[0])) + layer['out_height'] = out_height + + full_width = input_shapes[0][-2] + pads[1] + pads[3] + eff_kernel_width = kernel_shape[1] * dilations[1] + out_width = int(np.ceil((full_width - eff_kernel_width + 1) / strides[1])) + layer['out_width'] = out_width + # for compatibility interpret some variables + layer['pad_top'] = pads[0] + layer['pad_left'] = pads[1] + layer['pad_bottom'] = pads[2] + layer['pad_right'] = pads[3] layer['filt_height'] = kernel_shape[0] layer['filt_width'] = kernel_shape[1] - - layer['n_filt'] = next( - (x.type.tensor_type.shape.dim[1].dim_value for x in graph.value_info if x.name == node.output[0]), None - ) layer['stride_height'] = strides[0] layer['stride_width'] = strides[1] - pads = compute_pads_2d(node, layer) - - layer['pad_top'] = pads[0] - layer['pad_bottom'] = pads[2] - layer['pad_left'] = pads[1] - layer['pad_right'] = pads[3] - - if all(x == 0 for x in pads): # No padding, i.e., 'VALID' padding in Keras/Tensorflow - layer['padding'] = 'valid' - else: # Only 'valid' and 'same' padding are available in Keras - layer['padding'] = 'same' - - (layer['out_height'], layer['out_width'], _, _, _, _) = compute_padding_2d( - layer['padding'], - layer['in_height'], - layer['in_width'], - layer['stride_height'], - layer['stride_width'], - layer['filt_height'], - layer['filt_width'], - ) - - output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']] + layer['dilation_height'] = dilations[0] + layer['dilation_width'] = dilations[1] - return layer, output_shape + return layer diff --git a/hls4ml/converters/onnx/core.py b/hls4ml/converters/onnx/core.py index 940b860870..8ad851426d 100644 --- a/hls4ml/converters/onnx/core.py +++ b/hls4ml/converters/onnx/core.py @@ -1,28 +1,20 @@ -from hls4ml.converters.onnx_to_hls import get_onnx_attribute, get_onnx_input_name, onnx_handler +import numpy as np +from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler -@onnx_handler(*['Gemm', 'MatMul']) -def parse_gemm_layer(reader, node, inputs_map, input_shapes, graph, config): + +@onnx_handler('MatMul') +def parse_matmul_layer(node, input_names, input_shapes, graph): layer = {} - layer['class_name'] = 'Dense' + layer['class_name'] = 'MatMul' layer['name'] = node.name - layer['inputs'] = get_onnx_input_name(node, graph) - - tran_weight = get_onnx_attribute(node, 'transB', 0) - reader.add_input(layer['name'], node.input, tran_weight) - - weights_shape = reader.get_weights_data(layer['name'], 'kernel').shape - layer['n_in'] = weights_shape[0] - layer['n_out'] = weights_shape[1] - - output_shape = input_shapes[0][:] - output_shape[-1] = layer['n_out'] + layer['inputs'] = input_names + layer['outputs'] = list(node.output) - return layer, output_shape + return layer -# ------------------Global paras for activations # TODO: repair HardSigmoid support # https://github.com/fastmachinelearning/hls4ml/issues/409 activation_layers = [ @@ -37,7 +29,6 @@ def parse_gemm_layer(reader, node, inputs_map, input_shapes, graph, config): 'Softmax', 'Softsign', 'Softplus', - 'Clip', ] activation_map = { @@ -53,70 +44,79 @@ def parse_gemm_layer(reader, node, inputs_map, input_shapes, graph, config): 'Softmax': 'Softmax', 'Softsign': 'Activation', 'Softplus': 'Activation', - 'Clip': 'Clip', } # --------- @onnx_handler(*activation_layers) -def parse_activation_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_activation_layer(node, input_names, input_shapes, graph): layer = {} layer['name'] = node.name layer['class_name'] = activation_map[node.op_type] layer['activation'] = node.op_type.lower() - layer['inputs'] = get_onnx_input_name(node, graph) + layer['inputs'] = input_names + layer['outputs'] = list(node.output) if layer['class_name'] != 'Activation': if layer['class_name'] == 'Softmax': layer['activation'] = 'softmax' + layer['axis'] = get_onnx_attribute(node, 'axis', -1) + # because -1 is better supported than an explicit index, check if it's the same + if layer['axis'] == len(input_shapes[0]) - 1: + layer['axis'] = -1 elif layer['class_name'] in ['ELU', 'LeakyReLU', 'ThresholdedReLU']: layer['activation'] = layer['class_name'] layer['activ_param'] = get_onnx_attribute(node, 'alpha', 0.01) - elif layer['class_name'] == 'Clip': - clip_min_node = [x for x in graph.initializer if x.name in node.input] - clip_min = clip_min_node[0].float_data[0] - - # Check if it's relu or not - if clip_min == 0.0: - layer['class_name'] = 'Activation' - layer['activation'] = 'ReLU' - else: - raise Exception('Clip with min != 0 is not supported yet!') - else: layer['activation'] = layer['class_name'] layer['class_name'] = 'Activation' - return layer, [shape for shape in input_shapes[0]] + return layer @onnx_handler('BatchNormalization') -def parse_batchnorm_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_batchnorm_layer(node, input_names, input_shapes, graph): layer = {} - layer['class_name'] = 'BatchNormalization' - layer['data_format'] = 'channels_first' + layer['class_name'] = 'BatchNormOnnx' layer['name'] = node.name - layer['inputs'] = get_onnx_input_name(node, graph) + layer['inputs'] = input_names + layer['outputs'] = list(node.output) # Other attributes - layer['epsilon'] = get_onnx_attribute(node, 'epsilon') - layer['momentum'] = get_onnx_attribute(node, 'momentum') - - reader.add_input(layer['name'], node.input) - - in_size = 1 - for dim in input_shapes[0][1:]: - in_size *= dim + layer['epsilon'] = get_onnx_attribute(node, 'epsilon', 1e-05) + # layer['momentum'] = get_onnx_attribute(node, 'momentum', 0.9) # not used - layer['n_in'] = layer['n_out'] = in_size + layer['n_in'] = layer['n_out'] = np.prod(input_shapes[0][1:]) if len(input_shapes[0]) == 2: layer['n_filt'] = -1 elif len(input_shapes[0]) > 2: - layer['n_filt'] = input_shapes[0][1] # Always channel first for onnx + if node.domain != 'qonnx.custom_op.channels_last': + raise RuntimeError("Please convert the model to channels-last format with qonnx-to-channels-last") + layer['data_format'] = 'channels_last' # QONNX needs to be channels-last. + layer['n_filt'] = input_shapes[0][-1] + else: + raise RuntimeError(f"Unexpected input shape: {input_shapes[0]}") + + return layer + + +@onnx_handler('Quant') +def parse_quant_layer(node, input_names, input_shapes, graph): + layer = {} + + layer['class_name'] = 'Quant' + layer['name'] = node.name + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + + # Other attributes + layer['narrow'] = bool(get_onnx_attribute(node, 'narrow')) + layer['rounding_mode'] = get_onnx_attribute(node, 'rounding_mode') + layer['signed'] = bool(get_onnx_attribute(node, 'signed')) - return layer, [shape for shape in input_shapes[0]] + return layer diff --git a/hls4ml/converters/onnx/merge.py b/hls4ml/converters/onnx/merge.py index 9ccd432d18..420f077ec2 100644 --- a/hls4ml/converters/onnx/merge.py +++ b/hls4ml/converters/onnx/merge.py @@ -1,16 +1,28 @@ -from hls4ml.converters.onnx_to_hls import get_onnx_attribute, get_onnx_input_name, onnx_handler +from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler -merge_layers = ['Add', 'Sub', 'Mul', 'Average', 'Max', 'Min', 'Concat', 'Sum'] +merge_layers = ['Add', 'Sub', 'Mul', 'Div', 'Average', 'Max', 'Min', 'Concat', 'Sum'] + +op_map = { + 'Add': 'add', + 'Sub': 'subtract', + 'Mul': 'multiply', + 'Div': 'divide', + 'Average': 'average', + 'Max': 'maximum', + 'Min': 'minimum', + 'Sum': 'add', + 'Concat': 'concat', +} @onnx_handler(*merge_layers) -def parse_merge_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_merge_layer(node, input_names, input_shapes, graph): layer = {} layer['class_name'] = node.op_type layer['name'] = node.name - layer['op'] = layer['class_name'].lower() - layer['inputs'] = get_onnx_input_name(node, graph) - output_shape = input_shapes[0] + layer['op'] = op_map[node.op_type] + layer['inputs'] = input_names + layer['outputs'] = list(node.output) if layer['class_name'] == 'Concat': rank = len(input_shapes[0][1:]) @@ -21,22 +33,10 @@ def parse_merge_layer(reader, node, inputs_map, input_shapes, graph, config): layer['op'] = layer['class_name'].lower() + f'{rank}d' layer['axis'] = get_onnx_attribute(node, 'axis') - # Calculate output shape - new_dim = sum( - [x.type.tensor_type.shape.dim[layer['axis']].dim_value for x in graph.value_info if x.name in node.input] - ) - output_shape[layer['axis']] = new_dim - - elif layer['class_name'] == 'Add': - # Check if the layer is an AddBias - for input in node.input: - if "bias" in input: - layer['class_name'] = 'BiasAdd' - reader.add_input(layer['name'], node.input) else: layer['class_name'] = 'Merge' if len(layer['inputs']) > 2: raise Exception('ERROR: Merging more than two tensors is not yet supported.') - return layer, output_shape + return layer diff --git a/hls4ml/converters/onnx/pooling.py b/hls4ml/converters/onnx/pooling.py index 67fa76c7c7..1f5c431004 100644 --- a/hls4ml/converters/onnx/pooling.py +++ b/hls4ml/converters/onnx/pooling.py @@ -1,26 +1,30 @@ -from hls4ml.converters.onnx_to_hls import ( - compute_pads_1d, - compute_pads_2d, - get_onnx_attribute, - get_onnx_input_name, - onnx_handler, -) -from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d +import numpy as np + +from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler pool_operations = ['AveragePool', 'MaxPool'] @onnx_handler(*pool_operations) -def parse_pool_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_pool_layer(node, input_names, input_shapes, graph): layer = {} layer['name'] = node.name - layer['inputs'] = get_onnx_input_name(node, graph) + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + if node.domain != 'qonnx.custom_op.channels_last': + raise RuntimeError("Please convert the model to channels-last format with qonnx-to-channels-last") layer['class_name'] = node.op_type - layer['data_format'] = 'channels_first' # Default ONNX + layer['data_format'] = 'channels_last' # Default QONNX info = layer['class_name'].replace('Pool', '') strides = get_onnx_attribute(node, 'strides') kernel_shape = get_onnx_attribute(node, 'kernel_shape') + pads = get_onnx_attribute(node, 'pads') + layer['pads'] = pads + dilations = get_onnx_attribute(node, 'dilations') + if dilations is None: + dilations = [1] * len(kernel_shape) + layer['dilations'] = dilations if len(input_shapes[0]) == 3: # 1D layer['class_name'] = info + 'Pooling1D' @@ -31,70 +35,50 @@ def parse_pool_layer(reader, node, inputs_map, input_shapes, graph, config): layer['pool_width'] = kernel_shape[0] layer['stride_width'] = strides[0] - # Padding - pads = compute_pads_1d(node, layer) - layer['pad_left'] = pads[0] - layer['pad_right'] = pads[1] - - if all(x == 0 for x in pads): # No padding, i.e., 'VALID' padding - layer['padding'] = 'valid' - else: - layer['padding'] = 'same' - - (layer['n_out'], _, _) = compute_padding_1d( - layer['padding'], layer['n_in'], layer['stride_width'], layer['pool_width'] + # formula from ONNX Operators.md documentation + layer['n_out'] = int( + np.floor((layer['n_in'] + np.sum(pads) - ((kernel_shape[0] - 1) * dilations[0] + 1)) / strides[0] + 1) ) - output_shape = [input_shapes[0][0], layer['n_filt'], layer['n_out']] - elif len(input_shapes[0]) == 4: # 2D layer['class_name'] = info + 'Pooling2D' - layer['n_filt'] = input_shapes[0][1] - layer['in_height'] = input_shapes[0][2] - layer['in_width'] = input_shapes[0][3] + layer['n_filt'] = input_shapes[0][3] + layer['in_height'] = input_shapes[0][1] + layer['in_width'] = input_shapes[0][2] layer['stride_height'] = strides[0] layer['stride_width'] = strides[1] layer['pool_height'] = layer['filt_height'] = kernel_shape[0] layer['pool_width'] = layer['filt_width'] = kernel_shape[1] - pads = compute_pads_2d(node, layer) layer['pad_top'] = pads[0] layer['pad_bottom'] = pads[2] layer['pad_left'] = pads[1] layer['pad_right'] = pads[3] - if all(x == 0 for x in pads): # No padding, i.e., 'VALID' padding in Keras/Tensorflow - layer['padding'] = 'valid' - else: # Only 'valid' and 'same' padding are available in Keras - layer['padding'] = 'same' - - (layer['out_height'], layer['out_width'], _, _, _, _) = compute_padding_2d( - layer['padding'], - layer['in_height'], - layer['in_width'], - layer['stride_height'], - layer['stride_width'], - layer['filt_height'], - layer['filt_width'], + # formula from ONNX Operators.md documentation + layer['out_height'] = int( + np.floor((layer['in_height'] + pads[0] + pads[2] - ((kernel_shape[0] - 1) * dilations[0] + 1)) / strides[0] + 1) + ) + layer['out_width'] = int( + np.floor((layer['in_width'] + pads[1] + pads[3] - ((kernel_shape[1] - 1) * dilations[1] + 1)) / strides[1] + 1) ) - output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']] - - return layer, output_shape + return layer global_pooling_layers = ['GlobalMaxPool', 'GlobalAveragePool'] @onnx_handler(*global_pooling_layers) -def parse_global_pooling_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_global_pooling_layer(node, input_names, input_shapes, graph): layer = {} layer['name'] = node.name - layer['inputs'] = get_onnx_input_name(node, graph) + layer['inputs'] = input_names + layer['outputs'] = list(node.output) layer['class_name'] = node.op_type - layer['data_format'] = 'channels_first' + layer['data_format'] = 'channels_last' # default QONNX # Sonme default parameters for global pooling layer['n_out'] = 1 @@ -116,6 +100,4 @@ def parse_global_pooling_layer(reader, node, inputs_map, input_shapes, graph, co layer['in_height'] = input_shapes[0][2] layer['in_width'] = input_shapes[0][3] - output_shape = [input_shapes[0][0], layer['n_filt']] + [1] * (len(input_shapes[0]) - 2) - - return layer, output_shape + return layer diff --git a/hls4ml/converters/onnx/reshape.py b/hls4ml/converters/onnx/reshape.py index 5bbf58b079..9ef20f03d7 100644 --- a/hls4ml/converters/onnx/reshape.py +++ b/hls4ml/converters/onnx/reshape.py @@ -1,39 +1,38 @@ -import numpy as np - -from hls4ml.converters.onnx_to_hls import get_onnx_input_name, onnx_handler +from hls4ml.converters.onnx_to_hls import onnx_handler @onnx_handler('Transpose') -def parse_transpose_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_transpose_layer(node, input_names, input_shapes, graph): layer = {} layer['name'] = node.name layer['class_name'] = 'Transpose' - layer['inputs'] = get_onnx_input_name(node, graph) + layer['inputs'] = input_names + layer['outputs'] = list(node.output) perm = [list(i.ints) for i in node.attribute][0] # This will get something like [[a,b,c]][0] = [a,b,c] layer['perm'] = [x - 1 for x in perm[1:]] # Ignore the batch dimension in ONNX, and adjust the perm indexing - output_shape = [input_shapes[0][i] for i in perm] - - return layer, output_shape + return layer @onnx_handler('Reshape') -def parse_reshape_layer(reader, node, inputs_map, input_shapes, graph, config): +def parse_reshape_layer(node, input_names, input_shapes, graph): layer = {} layer['name'] = node.name layer['class_name'] = 'Reshape' - layer['inputs'] = get_onnx_input_name(node, graph) + layer['inputs'] = input_names + layer['outputs'] = list(node.output) - target_shape = list([x for x in graph.initializer if x.name == node.input[1]][0].int64_data)[1:] + return layer - if -1 in target_shape: # Need to infer shape for -1 - print("WARNING: Inferring -1 shape ... ") - dummy_x = np.ones(input_shapes[0][1:]) - dummy_y = np.reshape(dummy_x, target_shape) - target_shape = list(dummy_y.shape) - layer['target_shape'] = target_shape - output_shape = input_shapes[0][:1] + layer['target_shape'] +@onnx_handler('Flatten') +def parse_flatten_layer(node, input_names, input_shapes, graph): + layer = {} + layer['name'] = node.name + layer['class_name'] = 'Reshape' + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + layer['target_shape'] = [-1] # does not contain batch dimension - return layer, output_shape + return layer diff --git a/hls4ml/converters/onnx_to_hls.py b/hls4ml/converters/onnx_to_hls.py index 106daf62da..75850fa93e 100644 --- a/hls4ml/converters/onnx_to_hls.py +++ b/hls4ml/converters/onnx_to_hls.py @@ -1,78 +1,10 @@ -import numpy as np import onnx -from onnx import helper, numpy_helper, shape_inference +from onnx import helper, numpy_helper from hls4ml.model import ModelGraph -MAXMULT = 4096 - -class ONNXDataReader: - """ - ONNX data reader to be used for extracting relevant information during conversion. - """ - - def __init__(self, model): - self.model = model - self.input_map = {} - self.index_map = { - # Dense - 'kernel': 1, - 'bias': 2, - # BatchNormalization - 'gamma': 1, - 'beta': 2, - 'moving_mean': 3, - 'moving_variance': 4, - } - - def get_weights_data(self, layer_name, var_name): - """Extract weights data from ONNX model. - - Args: - layer_name (str): Layer's name in the ONNX model. - var_name (str): Variable to be extracted. - - Returns: - ndarray: Extracted weights data. - """ - # Get the node associated with the layer name - node = next(node for node in self.model.graph.node if node.name == layer_name) - - inputs = self.input_map[layer_name] - inp_idx = self.index_map[var_name] - - if inp_idx >= len(inputs['inputs']): - # Check if the layer is an AddBias layer - if (node.op_type == 'Add') and (var_name == 'bias'): - inp_idx = 1 - else: - # Input not found, likely a bias tensor is not available - return None - - tensor = next((x for x in self.model.graph.initializer if x.name == inputs['inputs'][inp_idx]), None) - - if tensor is not None: - data = numpy_helper.to_array(tensor) - - if inputs['transpose']: - if inputs['perm'] is not None and len(data.shape) == len(inputs['perm']): - data = data.transpose(inputs['perm']) - else: - data = data.transpose() - - # Check for transB in Gemm - if node.op_type == 'Gemm': - if not get_onnx_attribute(node, 'transB'): - data = data.transpose() - - return data - - def add_input(self, layer_name, inputs, transpose=True, perm=None): - self.input_map[layer_name] = {'inputs': inputs, 'transpose': transpose, 'perm': perm} - - -# ----------------------Helpers--------------------- # +# ----------------------Helpers--------------------- def sanitize_layer_name(layer): new_name = layer['name'] if new_name[0].isdigit(): @@ -99,9 +31,52 @@ def get_onnx_attribute(operation, name, default=None): return value -def get_input_shape(model, operation, input_idx=0): - value_info_idx = next((i for i, x in enumerate(model.graph.value_info) if x.name == operation.input[input_idx]), 0) - return [d.dim_value for d in model.graph.value_info[value_info_idx].type.tensor_type.shape.dim] +def get_global_input_shape(graph, inp): + """Return the global input shape of the graph with name inp + + Arguments: + graph: the onnx graph + inp (str): the global input name + + Returns: + list: The shape + + Raises: + StopIteration: If the global input name is not found + """ + inp_shape = next(x.type.tensor_type.shape.dim for x in graph.input if x.name == inp) + return list(x.dim_value for x in inp_shape) + + +def get_input_shape(graph, node): + """Return the input shapes of the node in the model + + Arguments: + graph: the onnx graph + node: the onnx node for which the input is desired + + Returns: + list of lists: The shapes of all the inputs + + Raises: + StopIteration: If the an input name is not found in the graph + """ + rv = [] + for inp in node.input: + try: + value_info_idx = next((i for i, x in enumerate(graph.value_info) if x.name == inp)) + dim = list(d.dim_value for d in graph.value_info[value_info_idx].type.tensor_type.shape.dim) + except StopIteration: + # The input is not in the graph, likely it's the input + dim = get_global_input_shape(graph, inp) + if dim: + rv.append(dim) + return rv + + +def get_constant_value(graph, constant_name): + tensor = next((x for x in graph.initializer if x.name == constant_name), None) + return numpy_helper.to_array(tensor) def compute_pads_1d(operation, layer): @@ -155,7 +130,7 @@ def compute_pads_2d(operation, layer): return pads -# ----------------------Layer handling--------------------- # +# ----------------------Layer handling--------------------- layer_handlers = {} @@ -178,27 +153,6 @@ def decorator(function): return decorator -# --->> A set of functions to address the naming convetion in ONNx's graph -def get_onnx_input_name(node, graph): - """ - In ONNX, when calling node.input, it returns the node input's index in the graph instead of the input's name. - However, the input's name is used for indexing in ModelGraph's graph. This function return the input node's name instead. - """ - - in_node = [in_node for in_node in graph.node if (in_node.output[0] in node.input)] - - if in_node: - if in_node[0].op_type != 'Flatten': - input_node_name = [x.name for x in in_node] - else: # IF it's a flatten - input_node_name = [x.name for x in graph.node if (x.output[0] in in_node[0].input)] - - return input_node_name - - else: # If there is no input name it's actually the first layer - return [replace_char_inconsitency(node.input[0])] - - def get_out_layer_name(graph): """ Get the output layer's name for the model. @@ -208,36 +162,31 @@ def get_out_layer_name(graph): return [node.name for node in graph.node if node.output[0] in output_index_list] -def onnx_to_hls(config): - """Convert onnx model to hls model from configuration. +def parse_onnx_model(onnx_model): + """Parses the onnx model, both for configuration building and general processing. Args: - config (dict): ONNX configuration from yaml file or passed through API. + onnx_model: an ONNX model object. Raises: Exception: Raised if an unsupported operation is found in the ONNX model. Returns: - ModelGraph: hls4ml model object + layer_list (list): The onnx layers + input_layers (list): The input layers + output_layers (list): The output layers """ # This is a list of dictionaries to hold all the layer info we need to generate HLS layer_list = [] - # Extract model architecture - print('Interpreting Model ...') - - model = onnx.load(config['OnnxModel']) if isinstance(config['OnnxModel'], str) else config['OnnxModel'] - - model = shape_inference.infer_shapes(model) - graph = model.graph - - reader = ONNXDataReader(model) + # We don't infer the shapes because the qonnx package preprocessing does it. # Obtain list of input/ouput layers - all_inputs = [x.name for x in model.graph.input] - all_initializers = [x.name for x in model.graph.initializer] + all_inputs = [x.name for x in onnx_model.graph.input] + all_initializers = [x.name for x in onnx_model.graph.initializer] input_layers = [x for x in all_inputs if x not in all_initializers] - output_layers = get_out_layer_name(graph) + constant_layers = all_initializers # no need to copy it even though we change it + output_layers = get_out_layer_name(onnx_model.graph) print("Output layers: ", output_layers) @@ -245,75 +194,93 @@ def onnx_to_hls(config): input_layer = {} input_layer['name'] = replace_char_inconsitency(inp) input_layer['class_name'] = 'InputLayer' - inp_shape = next((x.type.tensor_type.shape.dim for x in model.graph.input if x.name == inp), None) - input_layer['input_shape'] = [x.dim_value for x in inp_shape] - - if len(input_layer['input_shape']) > 1: - input_layer['input_shape'][0] = None # Firt dim is batch + inp_shape = get_global_input_shape(onnx_model.graph, inp) + # We only support ONNX where the first dimension is the batch dimension. + # Remove the batch dimension in all subsequnt use + input_layer['input_shape'] = inp_shape[1:] + print('Input shape:', input_layer['input_shape']) # Clean the layer name for specific models sanitize_layer_name(input_layer) input_layers[i] = input_layer['name'] layer_list.append(input_layer) + for i, constant in enumerate(constant_layers): + constant_layer = {} + constant_layer['name'] = replace_char_inconsitency(constant) + constant_layer['class_name'] = 'Constant' + constant_layer['value'] = get_constant_value(onnx_model.graph, constant) + + # Clean the layer name for specific models + sanitize_layer_name(constant_layer) + constant_layers[i] = constant_layer['name'] + + layer_list.append(constant_layer) + # Defined supported layers and check for unsupported layer type - skip_layers = ['Dropout', 'Identity', 'Flatten'] + skip_layers = ['Dropout', 'Identity'] # Map inputs of skipped layers inputs_map = {} supported_layers = get_supported_onnx_layers() + skip_layers - # Get input shape - current_shape = [input_layer['input_shape']] - print('Input shape:', current_shape[0]) - - # Loop through layers - layer_counter = 0 - - # Output shape tracking - output_shape = None - print('Topology:') - for node in graph.node: + for node in onnx_model.graph.node: if node.op_type not in supported_layers: raise Exception(f'ERROR: Unsupported operation type: {node.op_type}') - # If not the first layer then input shape is taken from last layer's output - if layer_counter != 0: - current_shape = [output_shape] + # Note that at this point, input shape still contains batch dimension + # in cases where it appears. That is not filtered out till later. + input_shapes = get_input_shape(onnx_model.graph, node) if node.op_type in skip_layers: - if node.op_type == 'Flatten': - output_shape = [current_shape[0][0], np.prod(current_shape[0][1:])] - - else: - # Currently supported skipped layers have only one input and output - # Skipped layers can follow each other (e.g., Dropout -> Flatten) - - # Mapping inputs - input_name = inputs_map.get(node.input[0], node.input[0]) - output_name = node.output[0] - inputs_map[output_name] = input_name + # Currently supported skipped layers have only one input and output + # Skipped layers can follow each other - output_shape = current_shape[0] + # Mapping inputs + input_name = inputs_map.get(node.input[0], node.input[0]) + output_name = node.output[0] + inputs_map[output_name] = input_name continue - if node.op_type in supported_layers: - layer_counter = layer_counter + 1 + input_names = [inputs_map.get(x, x) for x in node.input] # Process the layer - layer, output_shape = layer_handlers[node.op_type](reader, node, inputs_map, current_shape, graph, config) + layer = layer_handlers[node.op_type](node, input_names, input_shapes, onnx_model.graph) sanitize_layer_name(layer) - print('Layer name: {}, layer type: {}, current shape: {}'.format(layer['name'], layer['class_name'], current_shape)) + print(f"Layer name: {layer['name']}, layer type: {layer['class_name']}, current shape: {input_shapes}") layer_list.append(layer) + return layer_list, input_layers, output_layers + + +def onnx_to_hls(config): + """Convert onnx model to hls model from configuration. + + Args: + config (dict): ONNX configuration from yaml file or passed through API. + + Raises: + Exception: Raised if an unsupported operation is found in the ONNX model. + + Returns: + ModelGraph: hls4ml model object + """ + + # Extract model architecture + print('Interpreting Model ...') + + onnx_model = onnx.load(config['OnnxModel']) if isinstance(config['OnnxModel'], str) else config['OnnxModel'] + + layer_list, input_layers, output_layers = parse_onnx_model(onnx_model) + ################# # Generate HLS ################# print('Creating HLS model') - hls_model = ModelGraph(config, reader, layer_list, input_layers, output_layers) + hls_model = ModelGraph(config, layer_list, input_layers, output_layers) return hls_model diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 709c3db3ff..cf715fd767 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -121,7 +121,8 @@ def get_precision(self, layer, var='default'): type_name = layer.name.lower() + '_' + var + '_t' if precision is None: precision = self.layer_name_precision.get(layer.name.lower() + '_default') - type_name = layer.name.lower() + '_default_t' + # I think it is better to keep these unique still to avoid inadvertent updates + # type_name = layer.name.lower() + '_default_t' if precision is None: precision = self.layer_type_precision.get(layer.class_name.lower() + '_' + var) diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 8054f41ee6..fb548aa164 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -22,6 +22,7 @@ IntegerPrecisionType, NamedType, TensorVariable, + UnspecifiedPrecisionType, WeightVariable, find_minimum_width, ) @@ -344,7 +345,7 @@ class Input(Layer): def initialize(self): shape = self.attributes['input_shape'] if shape[0] is None: - shape = shape[1:] + raise RuntimeError(f"Unexpectedly have a None in {shape=} of Input layer") dims = [f'N_INPUT_{i}_{self.index}' for i in range(1, len(shape) + 1)] if self.index == 1: default_type_name = 'input_t' @@ -355,6 +356,46 @@ def initialize(self): self.add_output_variable(shape, dims, var_name=self.name, type_name=type_name, precision=precision) +class Constant(Layer): + # one could consider making this a weight attribute, but given it's transient nature, I am not sure it helps + _expected_attributes = [ + Attribute('value', value_type=np.ndarray), + ] + + def initialize(self): + value = self.attributes['value'] + shape = list(value.shape) + if not shape: + shape = (1,) + self.set_attr('value', np.array([value])) + dims = [f'{self.name}_{i}' for i in range(len(shape))] + quantizer = self.get_attr('quantizer') + + # Should the else clause below be None or UnspecifiedPrecisionType + precision = quantizer.hls_type if quantizer is not None else UnspecifiedPrecisionType() + + self.add_output_variable(shape, dims, var_name=self.name, precision=precision) + + +class Quant(Layer): # The QONNX quantization layer + """ + This is a QONNX quantization layer. Optimizations should convert it + before HLS is produced. + """ + + _expected_attributes = [ + Attribute('narrow', value_type=bool), + Attribute('rounding_mode', value_type=str), + Attribute('signed', value_type=bool), + ] + + def initialize(self): + inp = self.get_input_variable(self.inputs[0]) + shape = inp.shape + dims = inp.dim_names + self.add_output_variable(shape, dims) + + class Reshape(Layer): _expected_attributes = [ Attribute('target_shape', value_type=typing.Sequence), @@ -362,17 +403,18 @@ class Reshape(Layer): def initialize(self): input_shape = self.get_input_variable(self.inputs[0]).shape - target_shape = self.get_attr('target_shape') + target_shape = self.get_attr('target_shape') # this should not have a batch dimension if target_shape is None: # need to get it from the input shape_node = self.get_input_node(self.inputs[1]) # for QONNX, remove batch dimension - if shape_node: - target_shape = shape_node.value[1:] + # (onnx cleaning should have removed reshapes not on data path) + if isinstance(shape_node, Constant): + target_shape = shape_node.attributes['value'][1:] else: raise RuntimeError("Reshape for ONNX requires the target shape to be a second input.") - # remove Nones -- is this ever triggered? + # remove Nones -- Seems to be used by pytorch parser if target_shape[0] is None: target_shape = target_shape[1:] @@ -406,7 +448,7 @@ class Dense(Layer): ] def initialize(self): - shape = self.get_input_variable().shape[:] + shape = list(self.get_input_variable().shape) shape[-1] = self.attributes['n_out'] if len(shape) > 1: dims = [f'N_LAYER_{i}_{self.index}' for i in range(1, len(shape) + 1)] @@ -417,6 +459,26 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class Conv(Layer): + """ + This is for the ONNX Conv node. Currently, it is only supported as an intermediate + form that gets converted to an explicit ConvXD. + + Note: these are always channels-last. + """ + + def initialize(self): + if self.attributes['n_dim'] == 1: + # this is 1D convolution + shape = [self.attributes['out_width'], self.attributes['n_filt']] + dims = [f'N_OUTPUTS_{self.index}', f'N_FILT_{self.index}'] + else: + shape = [self.attributes['out_height'], self.attributes['out_width'], self.attributes['n_filt']] + dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_FILT_{self.index}'] + + self.add_output_variable(shape, dims) + + class Conv1D(Layer): _expected_attributes = [ Attribute('in_width'), @@ -915,10 +977,24 @@ def initialize(self): super().initialize() +class BatchNormOnnx(Layer): + ''' + A transient layer formed from ONNX BatchNormalization that gets converted to + BatchNormalization after the scale and bias are determined + ''' + + def initialize(self): + inp = self.get_input_variable() + shape = inp.shape + dims = inp.dim_names + self.add_output_variable(shape, dims) + + +# TODO: We currently seem to ignore the quantizers to mean, variance, etc. class BatchNormalization(Layer): _expected_attributes = [ Attribute('n_in'), - Attribute('n_filt', default=0), + Attribute('n_filt', default=-1), WeightAttribute('scale'), WeightAttribute('bias'), TypeAttribute('scale'), @@ -945,6 +1021,36 @@ def initialize(self): self.add_weights_variable(name='bias', var_name='b{index}', data=bias) +# TODO: discuss whether this should be renamed to soemthing more descriptive, and whether the class hierarchy makes sense +class ApplyAlpha(BatchNormalization): + '''A custom layer to scale the output of a QDense layer which used 'alpha != 1' + Inference computation uses BatchNormalization methods''' + + def initialize(self): + inp = self.get_input_variable() + shape = inp.shape + dims = inp.dim_names + self.add_output_variable(shape, dims) + self.set_attr('n_in', inp.size()) + + # precision values are ignored if quantizer is not None + scale = self.get_attr('scale_data') + scale_quantizer = self.get_attr('scale_quantizer') + scale_precision = self.get_attr('scale_precision') + bias = self.get_attr('bias_data') + bias_quantizer = self.get_attr('bias_quantizer') + bias_precision = self.get_attr('bias_precision') + + self.add_weights(scale, quantizer=scale_quantizer, precision=scale_precision) + self.add_bias(bias, quantizer=bias_quantizer, precision=bias_precision) + + def add_weights(self, scale, quantizer=None, precision=None): + self.add_weights_variable(name='scale', var_name='s{index}', data=scale, quantizer=quantizer, precision=precision) + + def add_bias(self, bias, quantizer=None, precision=None): + self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer, precision=precision) + + class Merge(Layer): def initialize(self): assert len(self.inputs) == 2 @@ -959,6 +1065,31 @@ def initialize(self): self.add_output_variable(shape, dims) +class MatMul(Layer): + """ + This is a matrix multiply. Currently, it is only supported as an intermediate + form that gets converted to a Dense layer. + """ + + def initialize(self): + assert len(self.inputs) == 2 + inp1 = self.get_input_variable(self.inputs[0]) + inp2 = self.get_input_variable(self.inputs[1]) + if len(inp2.shape) == 1: + # mat vec multiply + assert inp1.shape[-1] == inp2.shape[0] + shape = list(inp1.shape[:-1]) + [inp2.shape[0]] + else: + assert inp1.shape[-1] == inp2.shape[-2] + shape = list(inp1.shape[:-1]) + [inp2.shape[-1]] + if len(shape) > 1: + dims = [f'N_LAYER_{i}_{self.index}' for i in range(1, len(shape) + 1)] + else: + dims = [f'N_LAYER_{self.index}'] + + self.add_output_variable(shape, dims) + + class Dot(Merge): def initialize(self): assert len(self.inputs) == 2 @@ -1434,6 +1565,7 @@ def initialize(self): layer_map = { 'Input': Input, 'InputLayer': Input, + 'Constant': Constant, 'Activation': Activation, 'QActivation': Activation, 'LeakyReLU': ParametrizedActivation, @@ -1448,6 +1580,7 @@ def initialize(self): 'BinaryDense': Dense, 'TernaryDense': Dense, 'QDense': Dense, + 'Conv': Conv, 'Conv1D': Conv1D, 'QConv1D': Conv1D, 'Conv2D': Conv2D, @@ -1474,6 +1607,7 @@ def initialize(self): 'ZeroPadding1D': ZeroPadding1D, 'ZeroPadding2D': ZeroPadding2D, 'Merge': Merge, + 'MatMul': MatMul, 'Dot': Dot, 'Concatenate': Concatenate, 'Resize': Resize, @@ -1489,6 +1623,9 @@ def initialize(self): 'QGRU': GRU, 'GarNet': GarNet, 'GarNetStack': GarNetStack, + 'Quant': Quant, + 'ApplyAlpha': ApplyAlpha, + 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, 'SymbolicExpression': SymbolicExpression, # TensorFlow-specific layers: diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 77e38b0c5b..0edd549b29 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -30,10 +30,35 @@ del module_path del optimizers +register_flow( + 'parse_qonnx', + [ + 'reshape_constant', + 'quant_constant_parameters', + 'quant_to_activation', + 'fuse_quant_with_constant', + 'const_quant_to_const_alpha', + 'quant_to_alpha_activation_alpha', + 'batch_norm_onnx_constant_parameters', + 'constant_batch_norm_fusion', + 'merge_two_constants', + 'scale_down_add', + 'bias_down_add', + 'scale_down_mat_mul', + 'scale_down_conv', + 'merge_to_apply_alpha', + 'merge_to_apply_alpha_div', + 'matmul_const_to_dense', + 'conv_to_conv_x_d', + 'conv_to_depthwise_conv_x_d', + ], +) + register_flow( 'convert', [ 'channels_last_converter', + 'merge_linear_activation', 'seperable_to_depthwise_and_conv', 'remove_transpose_before_flatten', 'remove_nop_transpose', @@ -51,10 +76,13 @@ # many of the above optimzers need to be done before this 'infer_precision_types', ], + requires=['parse_qonnx'], ) # TODO Maybe not all QKeras optmizers belong here? register_flow( 'optimize', - [], + [ + 'remove_nop_batch_normalization', + ], requires=['convert'], ) diff --git a/hls4ml/model/optimizer/passes/batchnorm_opt.py b/hls4ml/model/optimizer/passes/batchnorm_opt.py new file mode 100644 index 0000000000..e18d79ff4a --- /dev/null +++ b/hls4ml/model/optimizer/passes/batchnorm_opt.py @@ -0,0 +1,274 @@ +import warnings + +import numpy as np + +from hls4ml.model.layers import BatchNormalization, BatchNormOnnx, Constant +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.quantizers import QuantNodeQuantizer +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, UnspecifiedPrecisionType + +_base_attributes = ('epsilon', 'n_in', 'n_filt') + + +class BatchNormOnnxConstantParameters(OptimizerPass): + """Remove Constant from the BatchNormalization node parameters (but not input[0])""" + + def match(self, node): + is_match = isinstance(node, BatchNormOnnx) and any(node.inputs[1:]) + + return is_match + + def transform(self, model, node): + """ + Remove Constant from the BatchNormalization node parameters (but not input[0]) + + TODO: Currently the quantizers are not actually used by the underlying layer. + """ + + if not (len(node.inputs) == 5 and all(node.inputs)): + raise ValueError('All 5 BatchNormOnnnx inputs need to be defined') + + attributes = {k: node.attributes[k] for k in _base_attributes if k in node.attributes} + + gamma_node = node.get_input_node(node.inputs[1]) + if not isinstance(gamma_node, Constant): + raise TypeError('Only constant gammas supported') + gamma = gamma_node.attributes['value'] + attributes['gamma_data'] = gamma + attributes['gamma_quantizer'] = gamma_node.get_attr('quantizer') + + node.inputs[1] = '' + model.remove_node(gamma_node, rewire=False) + + beta_node = node.get_input_node(node.inputs[2]) + if not isinstance(beta_node, Constant): + raise TypeError('Only constant betas supported') + beta = beta_node.attributes['value'] + attributes['beta_data'] = beta + attributes['beta_quantizer'] = beta_node.get_attr('quantizer') + node.inputs[2] = '' + model.remove_node(beta_node, rewire=False) + + moving_mean_node = node.get_input_node(node.inputs[3]) + if not isinstance(moving_mean_node, Constant): + raise TypeError('Only constant moving_means supported') + moving_mean = moving_mean_node.attributes['value'] + attributes['mean_data'] = moving_mean + attributes['mean_quantizer'] = moving_mean_node.get_attr('quantizer') + node.inputs[3] = '' + model.remove_node(moving_mean_node, rewire=False) + + moving_variance_node = node.get_input_node(node.inputs[4]) + if not isinstance(moving_variance_node, Constant): + raise TypeError('Only constant moving_variances supported') + moving_variance = moving_variance_node.attributes['value'] + attributes['variance_data'] = moving_variance + attributes['variance_quantizer'] = moving_variance_node.get_attr('quantizer') + node.inputs[4] = '' + model.remove_node(moving_variance_node, rewire=False) + + node.inputs = [inp for inp in node.inputs if inp] + if len(node.inputs) != 1: + raise RuntimeError('The QONNX batchnorm had unexpected inputs.') + + new_node = model.make_node(BatchNormalization, node.name, attributes, [node.inputs[0]], [x for x in node.outputs]) + + model.replace_node(node, new_node) + + return True + + +# Most likely this case is removed by qonnx cleaning +class ConstantBatchNormFusion(OptimizerPass): + """ + Merge BatchNorm into Const (after parameters have already been merged in BatchNormalization) + """ + + def match(self, node): + is_match = ( + isinstance(node, BatchNormalization) + and not any(node.inputs[1:]) + and isinstance(node.get_input_node(node.inputs[0]), Constant) + and isinstance( + node.get_input_node(node.inputs[0]).get_output_variable().type.precision, UnspecifiedPrecisionType + ) + ) + return is_match + + def transform(self, model, node): + """ + Remove the batch norm + """ + warnings.warn('ConstantBatchNormFusion should probably not be triggered. Check the optimizer order.', stacklevel=2) + const_node = node.get_input_node(node.inputs[0]) + + const_prec = const_node.get_output_variable().type.precision + + new_val = ( + const_node.attributes['value'] * node.weights['scale'].data_unquantized + node.weights['bias'].data_unquantized + ) + + const_node.set_attr('value', new_val) + const_node.set_attr('quantizer', node.get_attr('quantizer')) # None if not defined + + if isinstance(node.get_output_variable().type.precision, UnspecifiedPrecisionType): + if isinstance(const_prec, UnspecifiedPrecisionType): + pass # leave it as is + else: + const_node.get_output_variable().type.precision = UnspecifiedPrecisionType() # default + # propagate precision + scale_q = node.get_attr('scale_quantizer') + bias_q = node.get_attr('bias_quantizer') + if scale_q and bias_q: + # propagate precsion + scale_prec = scale_q.hls_type + bias_prec = bias_q.hls_type + if scale_prec not in (IntegerPrecisionType, FixedPrecisionType) or bias_prec not in ( + IntegerPrecisionType, + FixedPrecisionType, + ): + print("Warning: output type not propagated for constant merge") + else: + signed_prod = const_prec.signed or scale_prec.signed + w_prod = const_prec.width + scale_prec.width + i_prod = const_prec.integer + scale_prec.integer + signed = signed_prod or bias_prec.signed + i_tot = ( + max( + i_prod + (bias_prec.signed and not signed_prod), + bias_prec.ingeter + (signed_prod and not bias_prec.signed), + ) + + 1 + ) + w_tot = i_tot + max(w_prod - i_prod, bias_prec.width - bias_prec.integer) + new_prec = FixedPrecisionType(w_tot, i_tot, signed) + const_node.set_attr('quantizer', QuantNodeQuantizer(new_prec)) + const_node.get_output_variable().type.precision = new_prec + else: + const_node.get_output_variable().type.precision = node.get_output_variable().type.precision + + # remove the batch norm node + model.remove_node(node, rewire=True) + + return True + + +class FuseConsecutiveBatchNormalization(OptimizerPass): + """ + OptimizerPass to merge consecutive BatchNormalization layers, only if the earlier one does not have the output type + specified. There is a further check on the compatibility to merge: except in cases when merging a scale of 1 or a + bias of 0, this does not merge when both scales or both biases are quantized. + + Note: Consider restricting this to ApplyAlpha. Batch Normalization-style quantization seems to be ignored. + + Note: This optimizer may not be safe if weights are updateable, in particular if a scale can go from ones to other + values or if a bias can go from zeros to other values. + """ + + def match(self, node): + prev_node = node.get_input_node(node.inputs[0]) + basic_match = ( + isinstance(node, BatchNormalization) + and isinstance(prev_node, BatchNormalization) + and isinstance(prev_node.get_output_variable().type.precision, UnspecifiedPrecisionType) + ) + + # check for compatibility to merge + if basic_match: + s0 = prev_node.weights['scale'].data_unquantized + b0 = prev_node.weights['bias'].data_unquantized + s1 = node.weights['scale'].data_unquantized + b1 = node.weights['bias'].data_unquantized + scale_compatible = ( + (prev_node.get_attr('scale_quantizer') is None or node.get_attr('scale_quantizer') is None) + or (s0 == np.ones_like(s0)).all() + or (s1 == np.ones_like(s1)).all() + ) + bias_compatible = ( + (prev_node.get_attr('bias_quantizer') is None or node.get_attr('bias_quantizer') is None) + or (b0 == np.zeros_like(b0)).all() + or (b1 == np.zeros_like(b1)).all() + ) + return scale_compatible and bias_compatible + else: + return False + + def transform(self, model, node): + prev_node = node.get_input_node(node.inputs[0]) + + prev_map = prev_node.get_output_use_map() + if len(prev_map[prev_node.outputs[0]]) > 1: + return False + + s0 = prev_node.weights['scale'].data_unquantized + b0 = prev_node.weights['bias'].data_unquantized + s1 = node.weights['scale'].data_unquantized + b1 = node.weights['bias'].data_unquantized + + if (s0 == np.ones_like(s0)).all(): + s_quantizer = node.get_attr('scale_quantizer') + elif (s1 == np.ones_like(s1)).all(): + s_quantizer = prev_node.get_attr('scale_quantizer') + else: + s_quantizer = None + + if (b0 == np.ones_like(b0)).all(): + b_quantizer = node.get_attr('bias_quantizer') + elif (b1 == np.ones_like(b1)).all(): + b_quantizer = prev_node.get_attr('bias_quantizer') + else: + b_quantizer = None + + node.set_attr('scale_quantizer', s_quantizer) + node.set_attr('bias_quantizer', b_quantizer) + + scale_new = s0 * s1 + bias_new = s1 * b0 + b1 + + # Not sure if this setting of this is useful + s_prec = None + if s_quantizer is None and (scale_new == np.ones_like(scale_new)).all(): + if ( + isinstance(prev_node.weights['scale'].type, IntegerPrecisionType) + and isinstance(node.weights['scale'].type, IntegerPrecisionType) + and prev_node.weights['scale'].type.width == 1 + and node.weights['scale'].type.width == 1 + ): + s_prec = node.weights['scale'].type + + b_prec = None + if b_quantizer is None and (bias_new == np.zeros_like(bias_new)).all(): + if ( + isinstance(prev_node.weights['bias'].type, IntegerPrecisionType) + and isinstance(node.weights['bias'].type, IntegerPrecisionType) + and prev_node.weights['bias'].type.width == 1 + and node.weights['bias'].type.width == 1 + ): + b_prec = node.weights['bias'].type + + # call function so that quantizer would be called if needed + node.add_weights_variable(name='scale', var_name='s{index}', data=scale_new, quantizer=s_quantizer, precision=s_prec) + node.add_weights_variable(name='bias', var_name='b{index}', data=bias_new, quantizer=b_quantizer, precision=b_prec) + + model.remove_node(prev_node, rewire=True) + return True + + +class RemoveNopBatchNormalization(OptimizerPass): + """ + OptimizerPass to remove batch normalizations that do nothing (scale 1, bias 0) + + Note: This optimizer may not be safe if weights are updateable. + """ + + def match(self, node): + if isinstance(node, BatchNormalization): + s0 = node.weights['scale'].data_unquantized + b0 = node.weights['bias'].data_unquantized + return (s0 == np.ones_like(s0)).all() and (b0 == np.zeros_like(b0)).all() + else: + return False + + def transform(self, model, node): + model.remove_node(node, rewire=True) + return True diff --git a/hls4ml/model/optimizer/passes/bn_fuse.py b/hls4ml/model/optimizer/passes/bn_fuse.py index 02d9b849ed..000d8380ce 100644 --- a/hls4ml/model/optimizer/passes/bn_fuse.py +++ b/hls4ml/model/optimizer/passes/bn_fuse.py @@ -1,23 +1,54 @@ +import numpy as np + from hls4ml.model.layers import BatchNormalization, Conv1D, Conv2D, Dense from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, UnspecifiedPrecisionType class FuseBatchNormalization(OptimizerPass): + """ + OptimizerPass to merge a BatchNormalization layer with Dense or Conv layer, only if the Dense or Conv layer does not + have the output type specified. There is a further check on the compatibility to merge: except in cases when merging a + weight/scale of 1 or a bias of 0, this optimizer does not merge nodes when both the weight and scale or both biases + are quantized. + + Note: Consider restricting this to ApplyAlpha. Batch Normalization quantization seems to be ignored. + + Note: This optimizer may not be safe if weights are updateable. May need to turn off. + """ + def match(self, node): - is_match = ( + prev_node = node.get_input_node(node.inputs[0]) + basic_match = ( isinstance(node, BatchNormalization) - and isinstance(node.get_input_node(), (Dense, Conv1D, Conv2D)) - and node.get_input_node().get_attr('weight_quantizer') is None - and node.get_input_node().get_attr('bias_quantizer') is None + and isinstance(prev_node, (Dense, Conv1D, Conv2D)) + and isinstance(prev_node.get_output_variable().type.precision, UnspecifiedPrecisionType) ) - return is_match + if basic_match: + s0 = prev_node.weights['weight'].data_unquantized + b0 = prev_node.weights['bias'].data_unquantized + s1 = node.weights['scale'].data_unquantized + b1 = node.weights['bias'].data_unquantized + scale_compatible = ( + (prev_node.get_attr('weight_quantizer') is None and node.get_attr('scale_quantizer') is None) + or ((s0 == np.ones_like(s0)).all() and prev_node.get_attr('weight_quantizer') is None) + or ((s1 == np.ones_like(s1)).all() and node.get_attr('scale_quantizer') is None) + ) + bias_compatible = ( + (prev_node.get_attr('bias_quantizer') is None and node.get_attr('bias_quantizer') is None) + or ((b0 == np.zeros_like(b0)).all() and prev_node.get_attr('bias_quantizer') is None) + or ((b1 == np.zeros_like(b1)).all() and node.get_attr('bias_quantizer') is None) + ) + return scale_compatible and bias_compatible + + else: + return False def transform(self, model, node): - # Fuse weight and bias of Dense/Conv1D/Conv2D layer with BN values + """Fuse weight and bias of Dense/Conv1D/Conv2D layer with BN values.""" parent_node = node.get_input_node() parent_map = parent_node.get_output_use_map() - node_map = node.get_output_use_map() - if len(parent_map[parent_node.name]) > 1 or len(node_map[node.name]) > 1: + if len(parent_map[parent_node.outputs[0]]) > 1: return False parent_weight = parent_node.weights['weight'] @@ -26,13 +57,38 @@ def transform(self, model, node): bn_scale = node.weights['scale'] bn_bias = node.weights['bias'] + allowed_precisions = (IntegerPrecisionType, FixedPrecisionType, UnspecifiedPrecisionType) + + # only merge if the types are integer or fixed + if ( + not isinstance(parent_weight.type.precision, allowed_precisions) + or not isinstance(parent_bias.type.precision, allowed_precisions) + or not isinstance(bn_scale.type.precision, allowed_precisions) + or not isinstance(bn_bias.type.precision, allowed_precisions) + ): + return False + fused_weight = bn_scale.data * parent_weight.data fused_bias = bn_scale.data * parent_bias.data + bn_bias.data + w_quantizer = ( + node.get_attr('scale_quantizer') + if node.get_attr('scale_quantizer') is not None + else parent_node.get_attr('weight_quantizer') + ) + b_quantizer = ( + node.get_attr('bias_quantizer') + if node.get_attr('bias_quantizer') is not None + else parent_node.get_attr('bias_quantizer') + ) + + node.set_attr('weight_quantizer', w_quantizer) + node.set_attr('bias_quantizer', b_quantizer) + + # call function so that quantizer would be called if needed + parent_node.add_weights_variable(name='weight', var_name='w{index}', data=fused_weight, quantizer=w_quantizer) + parent_node.add_weights_variable(name='bias', var_name='b{index}', data=fused_bias, quantizer=b_quantizer) + model.remove_node(node, rewire=True) - parent_weight.data = fused_weight - parent_bias.data = fused_bias - if not parent_node.get_attr('use_bias', True): - parent_bias.update_precision(bn_bias.type.precision) return True diff --git a/hls4ml/model/optimizer/passes/conv_to_convxd.py b/hls4ml/model/optimizer/passes/conv_to_convxd.py new file mode 100644 index 0000000000..3e870e43a6 --- /dev/null +++ b/hls4ml/model/optimizer/passes/conv_to_convxd.py @@ -0,0 +1,93 @@ +import numpy as np + +from hls4ml.model.layers import Constant, Conv, Conv1D, Conv2D +from hls4ml.model.optimizer import OptimizerPass + +# these are attributes to copy +_base_attributes = ( + 'in_width', + 'out_width', + 'n_chan', + 'n_filt', + 'pad_left', + 'pad_right', + 'filt_width', + 'stride_width', + 'dilation_width', + 'in_height', + 'out_height', + 'pad_top', + 'pad_bottom', + 'filt_height', + 'stride_height', + 'dilation_height', + 'data_format', +) + + +class ConvToConvXD(OptimizerPass): + """Convert Conv with constant to a Conv1D or Conv2D layer""" + + def match(self, node): + is_match = ( + isinstance(node, Conv) + and node.get_attr('group') == 1 + and ( + (len(node.inputs) == 2 and isinstance(node.get_input_node(node.inputs[1]), Constant)) + or ( + len(node.inputs) == 3 + and isinstance(node.get_input_node(node.inputs[1]), Constant) + and isinstance(node.get_input_node(node.inputs[2]), Constant) + ) + ) + ) + + return is_match + + def transform(self, model, node): + """Convert Conv with constant to a Conv1D or Conv2D layer""" + + weight_node = node.get_input_node(node.inputs[1]) + weight_data = weight_node.attributes['value'] + bias_node = None + if len(node.inputs) == 3: + bias_node = node.get_input_node(node.inputs[2]) + + # creating the attributes + attributes = {k: node.attributes[k] for k in _base_attributes if k in node.attributes} + + # The ConvxD nodes expect the weight data to be in a different format, not (M, k1.., C) + if node.attributes['n_dim'] == 1: + newtype = Conv1D + attributes['weight_data'] = np.transpose(weight_data, (1, 2, 0)) + else: + newtype = Conv2D + attributes['weight_data'] = np.transpose(weight_data, (1, 2, 3, 0)) + attributes['weight_quantizer'] = weight_node.get_attr('quantizer') + + if bias_node: + attributes['bias_data'] = bias_node.attributes['value'] + attributes['bias_quantizer'] = bias_node.get_attr('quantizer') + attributes['use_bias'] = True + else: + attributes['bias_data'] = np.zeros(attributes['n_filt']) + attributes['use_bias'] = False + + # get the configuration name + config = model.config.get_layer_config(node) + new_name = f'{newtype.__name__}_{node.name}' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + # making new node + new_node = model.make_node(newtype, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + + # removing and replacing old nodes + if bias_node: + model.remove_node(bias_node, rewire=False) + del node.inputs[2] + model.remove_node(weight_node, rewire=False) + del node.inputs[1] + model.replace_node(node, new_node) + + return True diff --git a/hls4ml/model/optimizer/passes/conv_to_depthwiseconvxd.py b/hls4ml/model/optimizer/passes/conv_to_depthwiseconvxd.py new file mode 100644 index 0000000000..b1271b5784 --- /dev/null +++ b/hls4ml/model/optimizer/passes/conv_to_depthwiseconvxd.py @@ -0,0 +1,94 @@ +import numpy as np + +from hls4ml.model.layers import Constant, Conv, DepthwiseConv1D, DepthwiseConv2D +from hls4ml.model.optimizer import OptimizerPass + +# these are attributes to copy +_base_attributes = ( + 'in_width', + 'out_width', + 'n_chan', + 'n_filt', + 'pad_left', + 'pad_right', + 'filt_width', + 'stride_width', + 'dilation_width', + 'in_height', + 'out_height', + 'pad_top', + 'pad_bottom', + 'filt_height', + 'stride_height', + 'dilation_height', + 'data_format', +) + + +class ConvToDepthwiseConvXD(OptimizerPass): + """Convert Conv with constant to a DepthwiseConv1D or DepthwiseConv2D layer""" + + def match(self, node): + is_match = ( + isinstance(node, Conv) + and node.get_attr('group') == node.get_attr('n_chan') + and (node.get_attr('group') != 1) + and ( + (len(node.inputs) == 2 and isinstance(node.get_input_node(node.inputs[1]), Constant)) + or ( + len(node.inputs) == 3 + and isinstance(node.get_input_node(node.inputs[1]), Constant) + and isinstance(node.get_input_node(node.inputs[2]), Constant) + ) + ) + ) + + return is_match + + def transform(self, model, node): + """Convert Conv with constant to a DepthwiseConv1D or DepthwiseConv2D layer""" + + weight_node = node.get_input_node(node.inputs[1]) + weight_data = weight_node.attributes['value'] + bias_node = None + if len(node.inputs) == 3: + bias_node = node.get_input_node(node.inputs[2]) + + # creating the attributes + attributes = {k: node.attributes[k] for k in _base_attributes if k in node.attributes} + + # The ConvxD nodes expect the weight data to be in a different format, not (M, k1.., C) + if node.attributes['n_dim'] == 1: + newtype = DepthwiseConv1D + attributes['depthwise_data'] = np.transpose(weight_data, (1, 2, 0)) + else: + newtype = DepthwiseConv2D + attributes['depthwise_data'] = np.transpose(weight_data, (1, 2, 3, 0)) + attributes['depthwise_quantizer'] = weight_node.get_attr('quantizer') + + if bias_node: + attributes['bias_data'] = bias_node.attributes['value'] + attributes['bias_quantizer'] = bias_node.get_attr('quantizer') + attributes['use_bias'] = True + else: + attributes['bias_data'] = np.zeros(attributes['n_filt']) + attributes['use_bias'] = False + + # get the configuration name + config = model.config.get_layer_config(node) + new_name = f'{newtype.__name__}_{node.name}' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + # making new node + new_node = model.make_node(newtype, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + + # removing and replacing old nodes + if bias_node: + model.remove_node(bias_node, rewire=False) + del node.inputs[2] + model.remove_node(weight_node, rewire=False) + del node.inputs[1] + model.replace_node(node, new_node) + + return True diff --git a/hls4ml/model/optimizer/passes/linear.py b/hls4ml/model/optimizer/passes/linear.py new file mode 100644 index 0000000000..b1aee7adc7 --- /dev/null +++ b/hls4ml/model/optimizer/passes/linear.py @@ -0,0 +1,46 @@ +from hls4ml.model.layers import Activation, BatchNormalization, Conv1D, Conv2D, Dense +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.types import UnspecifiedPrecisionType + + +class EliminateLinearActivation(OptimizerPass): + def match(self, node): + cast = False + if isinstance(node, Activation): + cast = node.get_input_variable().type.precision != node.get_output_variable().type.precision + return isinstance(node, Activation) and node.get_attr('activation') == 'linear' and not cast + + def transform(self, model, node): + model.remove_node(node) + return True + + +_safe_parents = (Dense, Conv1D, Conv2D, BatchNormalization, Activation) + + +class MergeLinearActivation(OptimizerPass): + ''' + For many objects it's safe to change the output precision independently of the calculation. + ''' + + def match(self, node): + ''' + Only match if the parent is safe and the precision is not explicitly set. + ''' + if isinstance(node, Activation) and node.get_attr('activation') == 'linear': + parent = node.get_input_node(node.inputs[0]) + safe_parent = isinstance(parent, _safe_parents) + return safe_parent and isinstance(parent.get_output_variable().type.precision, UnspecifiedPrecisionType) + else: + return False + + def transform(self, model, node): + prev_node = node.get_input_node(node.inputs[0]) + quantizer = node.get_attr("quantizer") + # if the activation has a quantizer (usually from a QONNX Quant node), set the previous node's output precision + if quantizer is not None: + prev_node.set_attr("quantizer", quantizer) + prev_node.types['result_t'] = quantizer.hls_type + prev_node.get_output_variable().type.precision = quantizer.hls_type + model.remove_node(node) + return True diff --git a/hls4ml/model/optimizer/passes/matmul_const_to_dense.py b/hls4ml/model/optimizer/passes/matmul_const_to_dense.py new file mode 100644 index 0000000000..4c48944eb3 --- /dev/null +++ b/hls4ml/model/optimizer/passes/matmul_const_to_dense.py @@ -0,0 +1,58 @@ +import numpy as np + +from hls4ml.model.layers import Constant, Dense, MatMul +from hls4ml.model.optimizer import OptimizerPass + + +class MatmulConstToDense(OptimizerPass): + """ + Convert MatMul with constant to a dense layer. Note, this only supports the second input + being the constant. If needed, one could add transposes to make that be the case in + other yet to be written optimizers. + """ + + def match(self, node): + is_match = ( + isinstance(node, MatMul) and len(node.inputs) == 2 and isinstance(node.get_input_node(node.inputs[1]), Constant) + ) + return is_match + + def transform(self, model, node): + """Substitute Matmul + Constant for a single dense""" + # determining Constant layer input + const_node = node.get_input_node(node.inputs[1]) + other_var = node.get_input_variable(node.inputs[0]) + + weight_data = const_node.attributes['value'] + weight_quantizer = const_node.get_attr('quantizer') + + # get the configuration name + config = model.config.get_layer_config(node) + new_name = f'Dense_{node.name}' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + in_shape = other_var.shape + n_in = np.prod(in_shape) + out_shape = list(in_shape[:-1]) + [weight_data.shape[-1]] + n_out = np.prod(out_shape) + + # creating the attributes + attributes = { + 'weight_data': weight_data, + 'weight_quantizer': weight_quantizer, + 'bias_data': np.zeros(out_shape), + 'use_bias': False, + 'n_in': n_in, + 'n_out': n_out, + } + + # making new node + new_dense = model.make_node(Dense, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + + # removing and replacing old nodes + model.remove_node(const_node, rewire=False) + del node.inputs[1] + model.replace_node(node, new_dense) + + return True diff --git a/hls4ml/model/optimizer/passes/merge_const.py b/hls4ml/model/optimizer/passes/merge_const.py new file mode 100644 index 0000000000..a75ed27aca --- /dev/null +++ b/hls4ml/model/optimizer/passes/merge_const.py @@ -0,0 +1,246 @@ +import numpy as np + +from hls4ml.model.layers import ApplyAlpha, Constant, Merge +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.quantizers import QuantNodeQuantizer +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType + + +# This should generally not happen because of qonnx cleaning +class MergeTwoConstants(OptimizerPass): + """Merge of two constants makes another constant""" + + def match(self, node): + is_match = ( + isinstance(node, Merge) + and isinstance(node.get_input_node(node.inputs[0]), Constant) + and isinstance(node.get_input_node(node.inputs[1]), Constant) + ) + + return is_match + + def transform(self, model, node): + """ + Merge of two constants makes another constant. + + Note: full precision is used in the calculation, and precision is not propagated. + The precision + """ + const_node0 = node.get_input_node(node.inputs[0]) + const_node1 = node.get_input_node(node.inputs[1]) + + val0 = const_node0.attributes['value'] + val1 = const_node1.attributes['value'] + + op = node.attributes['op'] + if op == 'add': + new_val = val0 + val1 + elif op == 'subtract': + new_val = val0 - val1 + elif op == 'multiply': + new_val = val0 * val1 + elif op == 'divide': + new_val = val0 / val1 + elif op == 'average': + new_val = np.mean(np.array([val0, val1]), axis=0) + elif op == 'maximum': + new_val = np.maximum(val0, val1) + elif op == 'minimum': + new_val = np.minimum(val0, val1) + else: + raise RuntimeError(f'Unexpected op_type: {op}') + + quantizer = node.get_attr('quantizer') # None if not defined + const_node0.set_attr('quantizer', quantizer) # overwrite the quantizer + if quantizer: + const_node0.set_attr('quantizer', quantizer) + const_node0.types['result_t'] = quantizer.hls_type + const_node0.get_output_variable().type.precision = quantizer.hls_type + const_node0.set_attr('value', new_val) + + model.remove_node(const_node1, rewire=False) + + # remove the batch norm node + model.remove_node(node, rewire=True) + + return True + + +class MergeToApplyAlpha(OptimizerPass): + """Convert Add, Sub, Mul, or Div Merges with constant to ApplyAlpha""" + + def match(self, node): + is_match = ( + isinstance(node, Merge) + and node.attributes['op'] in ('add', 'subtract', 'multiply') # Div is separate + and ( + isinstance(node.get_input_node(node.inputs[0]), Constant) + != isinstance(node.get_input_node(node.inputs[1]), Constant) + ) + ) + # note: != for booleans is xor. + return is_match + + def transform(self, model, node): + node1 = node.get_input_node(node.inputs[1]) + + node1const = isinstance(node1, Constant) + if node1const: + const_node = node1 + input_node_idx = 0 + const_node_idx = 1 + else: + const_node = node.get_input_node(node.inputs[0]) + input_node_idx = 1 + const_node_idx = 0 + + input_shape = node.get_input_variable(node.inputs[input_node_idx]).shape + n_in = np.prod(input_shape) + + # Note: precision is ignored if quantizer is not None + scale_precision = None + scale_quantizer = None + bias_precision = None + bias_quantizer = None + + op = node.attributes['op'] + if op == 'add': + scale = np.array(1) + scale_precision = IntegerPrecisionType(1, False) + bias = const_node.attributes['value'] + bias_quantizer = const_node.get_attr('quantizer') + elif op == 'subtract': + bias_quantizer = const_node.get_attr('quantizer') + if node1const: + scale = np.array(1) + scale_precision = IntegerPrecisionType(1, False) + bias = -const_node.attributes['value'] + if ( + bias_quantizer is not None + and isinstance(bias_quantizer.hls_type, (IntegerPrecisionType, FixedPrecisionType)) + and not bias_quantizer.hls_type.signed + ): + # need to make signed and increas the bit, if unsigned + bias_precision = FixedPrecisionType( + bias_quantizer.hls_type.width + 1, + bias_quantizer.hls_type.integer + 1, + True, + bias_quantizer.hls_type.rounding_mode, + bias_quantizer.hls_type.saturation_mode, + bias_quantizer.hls_type.saturation_bits, + ) + bias_quantizer = QuantNodeQuantizer(bias_precision) + else: + scale = np.array(-1) + scale_precision = IntegerPrecisionType(2, True) + bias = const_node.attributes['value'] + + elif op == 'multiply': + scale = const_node.attributes['value'] + scale_quantizer = const_node.get_attr('quantizer') + bias = np.array(0) + bias_precision = IntegerPrecisionType(1, False) + + # because C++ doesn't do broadcasting, we may have to change the shapes of the scale and bias + if scale.shape != tuple(input_shape) and np.squeeze(scale).shape != tuple(input_shape): + scale = np.broadcast_to(scale, input_shape) + if bias.shape != tuple(input_shape) and np.squeeze(bias).shape != tuple(input_shape): + bias = np.broadcast_to(bias, input_shape) + + attributes = { + 'scale_data': scale, + 'bias_data': bias, + 'n_in': n_in, + 'n_out': n_in, + 'n_filt': -1, + 'scale_precision': scale_precision, + 'scale_quantizer': scale_quantizer, + 'bias_precision': bias_precision, + 'bias_quantizer': bias_quantizer, + } + + # get the configuration name + config = model.config.get_layer_config(node) + new_name = f'bn_{node.name}' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + aa_layer = model.make_node( + ApplyAlpha, new_name, attributes, [node.inputs[input_node_idx]], [x for x in node.outputs] + ) + + model.remove_node(const_node, rewire=False) + del node.inputs[const_node_idx] + model.replace_node(node, aa_layer) + + return True + + +class MergeToApplyAlphaDiv(OptimizerPass): + """ + Convert Div Merges with constant to ApplyAlpha + + TODO: propagate precision + """ + + def match(self, node): + is_match = ( + isinstance(node, Merge) + and node.attributes['op'] == 'divide' + and isinstance(node.get_input_node(node.inputs[1]), Constant) + ) # only second can be const + + return is_match + + def transform(self, model, node): + input_shape = node.get_input_variable().shape + n_in = np.prod(input_shape) + const_node = node.get_input_node(node.inputs[1]) + scale = 1 / const_node.attributes['value'] + scale_quantizer = const_node.get_attr('quantizer') + if scale_quantizer: + scale_precision = scale_quantizer.hls_type + i_new = 1 + int(scale_precision.signed) + scale_precision.fractional + w_new = 1 + int(scale_precision.signed) + max(scale_precision.fractional, 0) + new_scale_precision = FixedPrecisionType( + w_new, + i_new, + scale_precision.signed, + rounding_mode=scale_precision.rounding_mode, + saturation_mode=scale_precision.saturation_mode, + saturation_bits=scale_precision.saturation_bits, + ) + scale_quantizer = QuantNodeQuantizer(new_scale_precision) + + bias = np.array(0) + bias_precision = IntegerPrecisionType(1, False) + + # because C++ doesn't do broadcasting, we may have to change the shapes of the scale and bias + if scale.shape != tuple(input_shape) and np.squeeze(scale).shape != tuple(input_shape): + scale = np.broadcast_to(scale, input_shape) + if bias.shape != tuple(input_shape) and np.squeeze(bias).shape != tuple(input_shape): + bias = np.broadcast_to(bias, input_shape) + + attributes = { + 'scale_data': scale, + 'bias_data': bias, + 'scale_quantizer': scale_quantizer, + 'bias_precision': bias_precision, + 'n_in': n_in, + 'n_out': n_in, + 'n_filt': -1, + } + + # get the configuration name + config = model.config.get_layer_config(node) + new_name = f'bn_{node.name}' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + bn_layer = model.make_node(ApplyAlpha, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + + model.remove_node(const_node, rewire=False) + del node.inputs[1] + model.replace_node(node, bn_layer) + + return True diff --git a/hls4ml/model/optimizer/passes/move_scales.py b/hls4ml/model/optimizer/passes/move_scales.py new file mode 100644 index 0000000000..43fcaa0da7 --- /dev/null +++ b/hls4ml/model/optimizer/passes/move_scales.py @@ -0,0 +1,489 @@ +''' +This file includes optimizations related to moving the ApplyAphas across MatMul and Conv nodes. + +TODO: Check that biases are properly handled. (Attempt to do it via Merge) + +''' + +import numpy as np + +from hls4ml.model.layers import ApplyAlpha, Constant, Conv, MatMul, Merge +from hls4ml.model.optimizer import OptimizerPass + + +class ScaleDownMatMul(OptimizerPass): + '''Shift an ApplyAlpha below a MatMul''' + + def match(self, node): + ''' + Check to see if we have a MatMul with at least one input ApplyAlpha. + Note, if both are this optimizer runs twice. + ''' + is_match = ( + isinstance(node, MatMul) + and len(node.inputs) == 2 + and ( + isinstance(node.get_input_node(node.inputs[0]), ApplyAlpha) + or isinstance(node.get_input_node(node.inputs[1]), ApplyAlpha) + ) + ) + return is_match + + def transform(self, model, node): + # determine input with ApplyAlpha. If both, first propagate apply alpha associated with a constant + is_aa = [False, False] + from_const = [False, False] + inp = [node.get_input_node(node.inputs[0]), node.get_input_node(node.inputs[1])] + for i in range(2): + if isinstance(inp[i], ApplyAlpha): + is_aa[i] = True + from_const[i] = isinstance(inp[i].get_input_node(inp[i].inputs[0]), Constant) + + # prefer alpha from constant + if from_const[0]: + alpha_idx = 0 + elif from_const[1]: + alpha_idx = 1 + elif is_aa[0]: + alpha_idx = 0 + else: + alpha_idx = 1 # is_aa[1] must be true + + apply_alpha = inp[alpha_idx] + other_idx = 0 if alpha_idx else 1 + + # Check if we can move + scale = apply_alpha.weights['scale'].data_unquantized + bias = apply_alpha.weights['bias'].data_unquantized + + scale, bias = _make_scalar(scale, bias) + + output = node.get_output_variable() + # to remove warning, since these get set again + new_attrs = {k: v for k, v in apply_alpha.attributes.items() if k not in ('trace', 'precision')} + + can_propagate = False + if not bias.shape and bias == 0: + # zero bias, propagate through, if possible + # (always possible if scale is scalar) + try: + newscale = np.broadcast_to(scale, output.shape) # check size compatibility + newbias = np.zeros(output.shape) + can_propagate = True + except ValueError: + can_propagate = False + + # if did not succeed in propagating, try again + if not can_propagate and isinstance(inp[other_idx], Constant): + # can handle nonzero bias in some cases if other value is a Constant + try: + newscale = np.broadcast_to(scale, output.shape) # check size compatibility + newbias = np.broadcast_to(inp[other_idx].attributes['value'] * bias, output.shape) + new_attrs.pop('bias_precision', None) # remove special bias precision settings + can_propagate = True + except ValueError: + can_propagate = False + + if not can_propagate: + return False + + model.remove_node(apply_alpha) + + new_attrs['scale_data'] = newscale + new_attrs['bias_data'] = newbias + + new_node = model.make_node('ApplyAlpha', apply_alpha.name, new_attrs, [x for x in node.outputs]) + model.insert_node(new_node) + return True + + +class ScaleDownAdd(OptimizerPass): + '''Shift an identical ApplyAlpha below a Merge (Add)''' + + def match(self, node): + '''Check to see if we have an add with two ApplyAlphas with identical scale''' + is_match = isinstance(node, Merge) and len(node.inputs) == 2 and node.attributes["op"] == "add" + if is_match: + in0 = node.get_input_node(node.inputs[0]) + in1 = node.get_input_node(node.inputs[1]) + is_match = ( + isinstance(in0, ApplyAlpha) + and isinstance(in1, ApplyAlpha) + and (in0.weights['scale'].data_unquantized == in1.weights['scale'].data_unquantized).all() + ) + return is_match + + def transform(self, model, node): + in0 = node.get_input_node(node.inputs[0]) + in1 = node.get_input_node(node.inputs[1]) + + # Check if we can move + scale = in0.weights['scale'].data_unquantized + bias0 = in0.weights['bias'].data_unquantized + bias1 = in1.weights['bias'].data_unquantized + try: + bias = bias0 + bias1 + except ValueError: + return False + + model.remove_node(in0) + model.remove_node(in1) + + new_attrs = in0.attributes + new_attrs['scale_data'] = scale + new_attrs['bias_data'] = bias + + new_node = model.make_node('ApplyAlpha', in0.name, new_attrs, [x for x in node.outputs]) + model.insert_node(new_node) + return True + + +class BiasDownAdd(OptimizerPass): + '''Shift a ApplyAlpha with only bias below a Merge (Add)''' + + def match(self, node): + '''Match if there is only one ApplyAlpha. If there are two, if the scale of both is 0, they would + match the ScaleDownAdd, so this optimizer does not need to handle that case. + ''' + is_match = isinstance(node, Merge) and len(node.inputs) == 2 and node.attributes["op"] == "add" + if is_match: + in0 = node.get_input_node(node.inputs[0]) + in1 = node.get_input_node(node.inputs[1]) + is_match = (isinstance(in0, ApplyAlpha) or isinstance(in1, ApplyAlpha)) and not ( + isinstance(in0, ApplyAlpha) and isinstance(in1, ApplyAlpha) + ) # only one ApplyAlpha + return is_match + + def transform(self, model, node): + in0 = node.get_input_node(node.inputs[0]) + in1 = node.get_input_node(node.inputs[1]) + + alpha_node = in0 if isinstance(in0, ApplyAlpha) else in1 + + # Check if we can move + scale = alpha_node.weights['scale'].data_unquantized + + if (scale == 0).all(): + model.remove_node(alpha_node) + new_node = model.make_node('ApplyAlpha', alpha_node.name, alpha_node.attributes, [x for x in node.outputs]) + model.insert_node(new_node) + return True + else: + return False + + +class ScaleDownConv(OptimizerPass): + '''Shift an ApplyAlpha on a Conv with 2-3 inputs''' + + def match(self, node): + '''Shift an ApplyAlpha from the Weight''' + is_match = ( + isinstance(node, Conv) + and len(node.inputs) > 1 + and ( + isinstance(node.get_input_node(node.inputs[0]), ApplyAlpha) + or isinstance(node.get_input_node(node.inputs[1]), ApplyAlpha) + or (len(node.inputs) == 3 and isinstance(node.get_input_node(node.inputs[2]), ApplyAlpha)) + ) + ) + return is_match + + def transform(self, model, node): + in0 = node.get_input_node(node.inputs[0]) + in1 = node.get_input_node(node.inputs[1]) + in2 = node.get_input_node(node.inputs[2]) if len(node.inputs) == 3 else None + + aa0 = isinstance(in0, ApplyAlpha) + aa1 = isinstance(in1, ApplyAlpha) + aa2 = isinstance(in2, ApplyAlpha) if len(node.inputs) == 3 else False + + if not isinstance(in1, (Constant, ApplyAlpha)): + raise RuntimeError("The weight node needs to be ApplyAlpha or Constant") + if len(node.inputs) == 3 and not isinstance(in2, (Constant, ApplyAlpha)): + raise RuntimeError("The bias node needs to be ApplyAlpha or Constant") + + scale0 = in0.weights['scale'].data_unquantized if aa0 else None + bias0 = in0.weights['bias'].data_unquantized if aa0 else None + scale1 = in1.weights['scale'].data_unquantized if aa1 else None + bias1 = in1.weights['bias'].data_unquantized if aa1 else None + scale2 = in2.weights['scale'].data_unquantized if aa2 else None + bias2 = in2.weights['bias'].data_unquantized if aa2 else None + + # If possible, make scale and bias have scalar values + if aa0: + scale0, bias0 = _make_scalar(scale0, bias0) + if aa1: + scale1, bias1 = _make_scalar(scale1, bias1) + if aa2: + scale2, bias2 = _make_scalar(scale2, bias2) + + output = node.get_output_variable() + if aa0 and not aa1 and not aa2: + # only datapath has a scale + bias = in2.attributes['value'] if len(node.inputs) == 3 else 0 + conv_nobias = np.all(bias == 0) + + can_propagate = False + if not bias0.shape and bias0 == 0: + # No zero offset, propagate through, if possible + # (always possible if scale is scalar) + if conv_nobias: + try: + newscale = np.broadcast_to(_remove_redundant_dims(scale0), output.shape) # check broadcastable + newbias = np.zeros(output.shape) + can_propagate = True + except ValueError: + can_propagate = False + elif not scale0.shape: + # scalar scale0 + try: + newscale = np.broadcast_to(scale0, output.shape) # check broadcastable + newbias = np.broadcast_to(bias * (1 - scale0), output.shape) + can_propagate = True + except ValueError: + can_propagate = False + if not can_propagate: + return False + + # to remove warning, since these get set again + new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_name = in0.name + model.remove_node(in0) + + elif not aa0 and aa1 and not aa2: + # only weights have an ApplyAlpha + bias = in2.attributes['value'] if len(node.inputs) == 3 else 0 + conv_nobias = np.all(bias == 0) + + can_propagate = False + if not bias1.shape and bias1 == 0: + # No zero offset, propagate through, if possible + # (always possible if scale is scalar) + if conv_nobias: + try: + if scale1.ndim > 1: + # undo any broadcast_to + reduced_scale = _remove_redundant_dims(scale1) + if reduced_scale.shape[-1] == 1: + reduced_scale = reduced_scale[..., 0] + if node.attributes['n_dim'] == 1: + scale_trans = np.transpose(reduced_scale, (1, 0)) + else: + scale_trans = np.transpose(reduced_scale, (1, 2, 0)) + newscale = np.broadcast_to(scale_trans, output.shape) # make sure broadcastable + can_propagate = True + else: + newscale = np.broadcast_to(scale1, output.shape) # make sure broadcastable + can_propagate = True + newbias = np.zeros(output.shape) + except ValueError: + can_propagate = False + elif not scale1.shape: + # scalar scale1 + try: + newscale = np.broadcast_to(scale1, output.shape) # check broadcastable + newbias = np.broadcast_to(bias * (1 - scale1), output.shape) + can_propagate = True + except ValueError: + can_propagate = False + if not can_propagate: + return False + + # to remove warning, since these get set again + new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_name = in1.name + model.remove_node(in1) + + elif not aa0 and not aa1 and aa2: + # only bias has a scale + + can_propagate = False + if not scale2.shape and scale2 == 1: + # No scale, just additional bias + try: + newscale = np.ones(output.shape) + newbias = np.broadcast_to(bias2, output.shape) + can_propagate = True + except ValueError: + can_propagate = False + + if not can_propagate: + return False + + # to remove warning, since these get set again + new_attrs = {k: v for k, v in in2.attributes.items() if k not in ('trace', 'precision')} + new_name = in2.name + model.remove_node(in2) + + elif aa0 and aa1 and not aa2: + # dataflow and weights have an ApplyAlpha + bias = in2.attributes['value'] if len(node.inputs) == 3 else 0 + conv_nobias = np.all(bias == 0) + + can_propagate = False + if not bias0.shape and bias0 == 0 and not bias1.shape and bias1 == 0: + # No zero offset, propagate through, if possible + # (always possible if scale is scalar) + if conv_nobias: + try: + if scale1.ndim > 1: + # undo any broadcast_to + reduced_scale0 = _remove_redundant_dims(scale0) if scale0.ndim > 1 else scale0 + reduced_scale1 = _remove_redundant_dims(scale1) + reduced_scale = reduced_scale0 @ reduced_scale1 + if reduced_scale.shape[-1] == 1: + reduced_scale = reduced_scale[..., 0] + if node.attributes['n_dim'] == 1: + scale_trans = np.transpose(reduced_scale, (1, 0)) + else: + scale_trans = np.transpose(reduced_scale, (1, 2, 0)) + newscale = np.broadcast_to(scale_trans, output.shape) # make sure broadcastable + can_propagate = True + elif scale0.ndim > 1: + # scale1 is scalar + # undo any broadcast_to + reduced_scale0 = _remove_redundant_dims(scale0) + reduced_scale = scale1 * reduced_scale0 + if reduced_scale.shape[-1] == 1: + reduced_scale = reduced_scale[..., 0] + if node.attributes['n_dim'] == 1: + scale_trans = np.transpose(reduced_scale, (1, 0)) + else: + scale_trans = np.transpose(reduced_scale, (1, 2, 0)) + newscale = np.broadcast_to(scale_trans, output.shape) # make sure broadcastable + can_propagate = True + else: + newscale = np.broadcast_to(scale0 * scale1, output.shape) # make sure broadcastable + can_propagate = True + newbias = np.zeros(output.shape) + except ValueError: + can_propagate = False + elif not scale0.shape and not scale1.shape: + # scalar scale1 + try: + newscale = np.broadcast_to(scale0 * scale1, output.shape) # check broadcastable + newbias = np.broadcast_to(bias * (1 - scale0 * scale1), output.shape) + can_propagate = True + except ValueError: + can_propagate = False + if not can_propagate: + return False + + # to remove warning, since these get set again + new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_name = in1.name + model.remove_node(in0) + model.remove_node(in1) + + elif aa0 and not aa1 and aa2: + # datapath and bias have a scale + + can_propagate = False + if not bias0.shape and bias0 == 0 and not scale2.shape and not scale0.shape and scale2 == scale0: + # scalar scale0, no bais0 and scale2. + try: + newscale = np.broadcast_to(scale0, output.shape) # check broadcastable + newbias = np.broadcast_to(bias2, output.shape) + can_propagate = True + except ValueError: + can_propagate = False + if not can_propagate: + return False + + # to remove warning, since these get set again + new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_name = in0.name + model.remove_node(in0) + model.remove_node(in2) + + elif not aa0 and aa1 and aa2: + # only weights and bias have an ApplyAlpha + + can_propagate = False + if not bias1.shape and bias1 == 0 and not scale2.shape and not scale1.shape and scale2 == scale1: + # No zero offset, propagate through, if possible + # (always possible if scale is scalar) + if not scale1.shape: + # scalar scale1 + try: + newscale = np.broadcast_to(scale1, output.shape) # check broadcastable + newbias = np.broadcast_to(bias2, output.shape) + can_propagate = True + except ValueError: + can_propagate = False + if not can_propagate: + return False + + # to remove warning, since these get set again + new_attrs = {k: v for k, v in in1.attributes.items() if k not in ('trace', 'precision')} + new_name = in1.name + model.remove_node(in1) + model.remove_node(in2) + + elif aa0 and aa1 and aa2: + # have all + + can_propagate = False + if ( + not bias0.shape + and bias0 == 0 + and not bias1.shape + and bias1 == 0 + and not scale2.shape + and not scale1.shape + and not scale0.shape + and scale2 == scale1 * scale0 + ): + # No zero offset, propagate through, if possible + # (always possible if scale is scalar) + if not scale1.shape: + # scalar scale1 + try: + newscale = np.broadcast_to(scale0 * scale1, output.shape) # check broadcastable + newbias = np.broadcast_to(bias2, output.shape) + can_propagate = True + except ValueError: + can_propagate = False + if not can_propagate: + return False + + # to remove warning, since these get set again + new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_name = in0.name + model.remove_node(in0) + model.remove_node(in1) + model.remove_node(in2) + + # after the big if-else above + new_attrs['scale_data'] = newscale + new_attrs['bias_data'] = newbias + + new_node = model.make_node('ApplyAlpha', new_name, new_attrs, [x for x in node.outputs]) + model.insert_node(new_node) + return True + + +def _remove_redundant_dims(X): + """This is somewhat of the inverse of broadcast-to. It sets the dimension size to 1 if all values are identical""" + + shape = X.shape + for i in range(len(shape)): + reduced = np.expand_dims(np.take(X, 0, axis=i), axis=i) + if np.all(reduced == X): + X = reduced + return X + + +def _make_scalar(scale, bias): + """Make the scale and bias scalar if possible""" + scale1d = np.ravel(scale) + if (scale1d[0] == scale).all(): + # scalar scale + scale = np.array(scale1d[0]) + + bias1d = np.ravel(bias) + if (bias1d[0] == bias).all(): + # scalar bias + bias = np.array(bias1d[0]) + + return scale, bias diff --git a/hls4ml/model/optimizer/passes/nop.py b/hls4ml/model/optimizer/passes/nop.py deleted file mode 100644 index 55fcf16e93..0000000000 --- a/hls4ml/model/optimizer/passes/nop.py +++ /dev/null @@ -1,14 +0,0 @@ -from hls4ml.model.layers import Activation -from hls4ml.model.optimizer import OptimizerPass - - -class EliminateLinearActivation(OptimizerPass): - def match(self, node): - cast = False - if isinstance(node, Activation): - cast = node.get_input_variable().type.precision != node.get_output_variable().type.precision - return isinstance(node, Activation) and node.get_attr('activation') == 'linear' and not cast - - def transform(self, model, node): - model.remove_node(node) - return True diff --git a/hls4ml/model/optimizer/passes/qkeras.py b/hls4ml/model/optimizer/passes/qkeras.py index ebc66fe59e..03690bed0d 100644 --- a/hls4ml/model/optimizer/passes/qkeras.py +++ b/hls4ml/model/optimizer/passes/qkeras.py @@ -1,7 +1,7 @@ import numpy as np import tensorflow as tf -from hls4ml.model.layers import BatchNormalization, register_layer +from hls4ml.model.layers import ApplyAlpha from hls4ml.model.optimizer import ConfigurableOptimizerPass, OptimizerPass, register_pass from hls4ml.model.quantizers import QKerasPO2Quantizer from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType @@ -77,40 +77,11 @@ def precision_string_modify(self, pstr): return pstr -class ApplyAlpha(BatchNormalization): - '''A custom layer to scale the output of a QDense layer which used 'alpha != 1' - Inference computation uses BatchNormalization methods''' - - def initialize(self): - inp = self.get_input_variable() - shape = inp.shape - dims = inp.dim_names - self.add_output_variable(shape, dims) - - scale = self.get_attr('scale_data') - scale_quantizer = self.get_attr('scale_quantizer') - bias = self.get_attr('bias_data') - bias_quantizer = self.get_attr('bias_quantizer') - - self.add_weights(scale, quantizer=scale_quantizer) - self.add_bias(bias, quantizer=bias_quantizer) - - def add_weights(self, scale, quantizer=None): - self.add_weights_variable(name='scale', var_name='s{index}', data=scale, quantizer=quantizer) - - def add_bias(self, bias, quantizer=None): - self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer) - - def register_qkeras(): - # Register the layer types to the layer map - register_layer('ApplyAlpha', ApplyAlpha) - # Register the optimization passes register_pass('output_rounding_saturation_mode', OutputRoundingSaturationMode) register_pass('qkeras_factorize_alpha', QKerasFactorizeAlpha) register_pass('extract_ternary_threshold', ExtractTernaryThreshold) - register_pass('fuse_consecutive_batch_normalization', FuseConsecutiveBatchNormalization) class QKerasFactorizeAlpha(OptimizerPass): @@ -192,8 +163,16 @@ def transform(self, model, node): else: n_in = node.get_attr('n_out') + # the name of the new ApplyAlpha node + alpha_name = node.get_attr('name') + '_alpha' + + # make the precision auto + alpha_precision = {'Precision': 'auto'} + model.config.set_name_config(alpha_name, alpha_precision) + model.config.parse_name_config(alpha_name, alpha_precision) + attrs = { - 'name': node.get_attr('name') + '_alpha', + 'name': alpha_name, 'class_name': 'Alpha', 'inputs': node.outputs, 'n_in': n_in, @@ -210,38 +189,6 @@ def transform(self, model, node): return True -class FuseConsecutiveBatchNormalization(OptimizerPass): - '''OptimizerPass to merge consecutive BatchNormalization layers. - These may exist in a model after QKerasFactorizeAlpha layer. - Scale and Bias of each layer are combined into scale and bias of a single layer. - ''' - - def match(self, node): - return isinstance(node, BatchNormalization) and isinstance(node.get_input_node(), BatchNormalization) - - def transform(self, model, node): - bn0 = node.get_input_node() - bn1 = node - bn0_map = bn0.get_output_use_map() - bn1_map = bn1.get_output_use_map() - if len(bn0_map[bn0.name]) > 1 or len(bn1_map[bn1.name]) > 1: - return False - - s0 = bn0.weights['scale'].data - b0 = bn0.weights['bias'].data - s1 = bn1.weights['scale'].data - b1 = bn1.weights['bias'].data - - s2 = s0 * s1 - b2 = s1 * b0 + b1 - - bn0.weights['scale'].data = s2 - bn0.weights['bias'].data = b2 - - model.remove_node(node, rewire=True) - return True - - class ExtractTernaryThreshold(OptimizerPass): '''The input value (threshold) at which the output of a a ternary activation changes is configurable. This pass extracts that threshold point, inserting diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py new file mode 100644 index 0000000000..cac29b5040 --- /dev/null +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -0,0 +1,385 @@ +""" +This file includes optimizations related to quant nodes. + +As a first step, QuantConstantParameters converts the extra inputs to attributes. + +The next step differs between the case of (1) (positive) power-of-2 scale and zero offset, or (2) other cases. In the first +case no explicit scaling is required, so a Quant node logically becomes a linear activation. (Cases when the scale is a +power of 2 not equal to one are implicitly scaled with fixed precision types.) When the activation is applied to a constant +weight, the activation is immediately merged with the weight, quantizing the weights. In case (2), we need to explicitly +scale and unscale, so the Quant node becomes 3 nodes, an ApplyAlpha node to apply a scale/shift, a Linear node to apply the +quantization, and another ApplyAlpha to unscale/shift. We depend on optimization steps to move the unscaling ApplyAlpha +down as needed so that we can do integer or fixed-point calculations. When the Quant is a applied to a weight, the scaling +and Linear nodes are immediately merged into the Constant. + +""" + +import copy +import math # prefer to use math.ceil for scalar values + +import numpy as np + +from hls4ml.model.layers import Activation, ApplyAlpha, Constant, Quant +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.quantizers import QuantNodeQuantizer +from hls4ml.model.types import FixedPrecisionType + +_ALSO_MATCH_PO2 = True + + +class QuantConstantParameters(OptimizerPass): + """Remove Constant from the Qaunt node parameters (but not input[0])""" + + def match(self, node): + is_match = ( + isinstance(node, Quant) + and len(node.inputs) == 4 + and ( + (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) + or (node.get_input_node(node.inputs[2]) and isinstance(node.get_input_node(node.inputs[2]), Constant)) + or (node.get_input_node(node.inputs[3]) and isinstance(node.get_input_node(node.inputs[3]), Constant)) + ) + ) + + return is_match + + def transform(self, model, node): + """ + Remove Constant from the Quant node parameters (but not input[0]) + """ + if node.get_input_node(node.inputs[1]): + scale_node = node.get_input_node(node.inputs[1]) + if isinstance(scale_node, Constant): + node.set_attr('scale', scale_node.get_attr('value')) + node.inputs[1] = '' + model.remove_node(scale_node, rewire=False) + + if node.get_input_node(node.inputs[2]): + zeropt_node = node.get_input_node(node.inputs[2]) + if isinstance(zeropt_node, Constant): + node.set_attr('zeropt', zeropt_node.get_attr('value')) + node.inputs[2] = '' + model.remove_node(zeropt_node, rewire=False) + + if node.get_input_node(node.inputs[3]): + bitwidth_node = node.get_input_node(node.inputs[3]) + if isinstance(bitwidth_node, Constant): + bitwidth = bitwidth_node.get_attr('value') + if bitwidth.size != 1: + raise RuntimeError('Only scalar bitwidth values are supporeted by the Quant node') + node.set_attr('bitwidth', bitwidth[0]) + node.inputs[3] = '' + model.remove_node(bitwidth_node, rewire=False) + + node.inputs = [inp for inp in node.inputs if inp] + if len(node.inputs) != 1: + raise RuntimeError("hls4ml only supports constant scale, zeropt, and bitwidth values") + + return True + + +class QuantToActivation(OptimizerPass): + """ + This is for the case when scale is a (positive) power of 2 and zeropt is 0. It is a a 1:1 transformation of + a Quant to an Activation. + + As an optimization, this is not called when the input is constant. + """ + + def match(self, node): + # only matches after the other inputs are already folded + + is_match = ( + isinstance(node, Quant) + and len(node.inputs) == 1 + and not isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is power of 2 and the zero-point is 0s + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + bias = node.get_attr('zeropt') + is_match = is_match and (bias == np.zeros_like(bias)).all() + + # check if scale is ones-like or a power of two + scale_unit_or_po2 = (scale == np.ones_like(scale)).all() + if not scale_unit_or_po2 and _ALSO_MATCH_PO2: + # This optimization only works if all scales are the same + if np.all(scale[0] == scale): + mantissa, _ = np.frexp(scale[0]) + scale_unit_or_po2 = mantissa == 0.5 + + is_match = scale_unit_or_po2 + + return is_match + + def transform(self, model, node): + """ + Change quant node to Activation + """ + + rounding_mode = node.get_attr('rounding_mode') + narrow = node.get_attr('narrow') + signed = node.get_attr('signed') + bitwidth = node.get_attr('bitwidth') + integer = bitwidth + scale = node.get_attr('scale') + if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): + _, exp = np.frexp(scale[0]) + integer = bitwidth + exp - 1 + + precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode) + + attributes = {'activation': 'linear', 'quantizer': quantizer} + + # update the configuration + config = model.config.get_layer_config(node) + prec_config = config.setdefault('Precision', {}) + prec_config['result'] = str(precision) + new_name = f'{node.name}_act' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + model.replace_node(node, new_node) + + return True + + +class FuseQuantWithConstant(OptimizerPass): + """ + This is for the case when scale is a positive power of 2 and zeropt is 0. + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, Quant) and len(node.inputs) == 1 and isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is power of 2 and the zero-point is 0s + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + bias = node.get_attr('zeropt') + is_match = is_match and (bias == np.zeros_like(bias)).all() + + # check if scale is ones-like or a power of two + scale_unit_or_po2 = (scale == np.ones_like(scale)).all() + if not scale_unit_or_po2 and _ALSO_MATCH_PO2: + # This optimization only works if all scales are the same + if np.all(scale[0] == scale): + mantissa, _ = np.frexp(scale[0]) + scale_unit_or_po2 = mantissa == 0.5 + + is_match = scale_unit_or_po2 + + return is_match + + def transform(self, model, node): + """ + Fuse Quant with Constant. + """ + + rounding_mode = node.get_attr('rounding_mode') + narrow = node.get_attr('narrow') + signed = node.get_attr('signed') + bitwidth = node.get_attr('bitwidth') + integer = bitwidth + scale = node.get_attr('scale') + if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): + _, exp = np.frexp(scale[0]) # know that np.all(scale[0] == scale) must be true + integer = bitwidth + exp - 1 + + precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode) + + const_node = node.get_input_node(node.inputs[0]) + const_node.set_attr('quantizer', quantizer) + const_node.set_attr('result_t', precision) + const_node.get_output_variable().type.precision = precision + + # Should we update the configuration to reflect the new precision? I don't think it's necessary + + # remove the Quant node + model.remove_node(node, rewire=True) + + return True + + +class QuantToAlphaActivationAlpha(OptimizerPass): + """ + This is for the case when scale is not power-of-2 or zeropt is not 0. It is a a 1:3 transformation of + a Quant to an ApplyAlpha (to scale), Activatio, ApplyAlpho (to rescale). + + NOTE: It needs to be scheduled after QuantToActivation (or we need to make the match criteria stricter) + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, Quant) + and len(node.inputs) == 1 + and not isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + return is_match + + def transform(self, model, node): + """ + Change quant node to ApplyAlhpa, Activation, ApplyAlpha + """ + + # Do the Activation as in the simple case + + rounding_mode = node.get_attr('rounding_mode') + narrow = node.get_attr('narrow') + signed = node.get_attr('signed') + bitwidth = node.get_attr('bitwidth') + + precision, quantizer = _calculate_precision_quantizer(bitwidth, bitwidth, signed, narrow, rounding_mode) + + activation_attributes = {'activation': 'linear', 'quantizer': quantizer} + + # update the configuration + config = model.config.get_layer_config(node) + act_config = copy.deepcopy(config) + prec_config = act_config.setdefault('Precision', {}) + prec_config['result'] = str(precision) + act_name = f'{node.name}_act' + model.config.set_name_config(act_name, act_config) + model.config.parse_name_config(act_name, act_config) + + new_node = model.make_node(Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs]) + model.replace_node(node, new_node) + + # but now add the ApplyAlhpas before and after + + inshape = node.get_input_variable().shape + + scale = node.get_attr('scale') + bias = node.get_attr('zeropt') + + attributes_scale = {'n_filt': -1} + attributes_rescale = {'n_filt': -1} + + scale_config = copy.deepcopy(config) + scale_name = f'{node.name}_scale' + model.config.set_name_config(scale_name, scale_config) + model.config.parse_name_config(scale_name, scale_config) + + rescale_config = config # no need to deep copy the last + rescale_name = f'{node.name}_rescale' + model.config.set_name_config(rescale_name, rescale_config) + model.config.parse_name_config(rescale_name, rescale_config) + + firstscale = 1 / scale + firstbias = bias + attributes_scale['scale_data'] = np.broadcast_to(firstscale, inshape) + attributes_scale['bias_data'] = np.broadcast_to(firstbias, inshape) + + scale_node = model.make_node(ApplyAlpha, scale_name, attributes_scale, [node.inputs[0]]) + model.insert_node(scale_node) + + rescale = scale + rebias = -bias * scale + attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape) + attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape) + + rescale_node = model.make_node(ApplyAlpha, rescale_name, attributes_rescale, [new_node.outputs[0]]) + model.insert_node(rescale_node) + + return True + + +class ConstQuantToConstAlpha(OptimizerPass): + """ + This is for the case when scale is not power-of-2 or zeropt is not 0. It is a a 1:3 transformation of + a Quant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to unscale), but an input + consts allows for optimization, so the ApplyAlpha (to scale), Activation are + optimized away right away. + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, Quant) and len(node.inputs) == 1 and isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + bias = node.get_attr('zeropt') + is_match = is_match and ((scale != np.ones_like(scale)).any() or (bias != np.zeros_like(bias)).any()) + return is_match + + def transform(self, model, node): + """ + Change Constant + Quant node to Constant, ApplyAlpha + """ + + rounding_mode = node.get_attr('rounding_mode') + narrow = node.get_attr('narrow') + signed = node.get_attr('signed') + bitwidth = node.get_attr('bitwidth') + + precision, quantizer = _calculate_precision_quantizer(bitwidth, bitwidth, signed, narrow, rounding_mode) + + const_node = node.get_input_node(node.inputs[0]) + + scale = node.get_attr('scale') + bias = node.get_attr('zeropt') + + # caclucate the new value + new_val = const_node.get_attr('value') / scale + bias + const_node.set_attr('value', new_val) + const_node.set_attr('quantizer', quantizer) + + const_node.types['result_t'].precision = precision + const_node.get_output_variable().type.precision = precision + + inshape = node.get_input_variable().shape + + attributes_rescale = {'n_filt': -1} + + rescale_config = copy.deepcopy(model.config.get_layer_config(node)) + rescale_name = f'{node.name}_rescale' + model.config.set_name_config(rescale_name, rescale_config) + model.config.parse_name_config(rescale_name, rescale_config) + + rescale = scale + rebias = -bias * scale + attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape) + attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape) + + rescale_node = model.make_node( + ApplyAlpha, rescale_name, attributes_rescale, [x for x in node.inputs], [x for x in node.outputs] + ) + model.replace_node(node, rescale_node) + + return True + + +def _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode): + """ + A function to determine the precision and quantizer + """ + if rounding_mode == 'ROUND': + bn_round = 'AP_RND_CONV' + elif rounding_mode == 'FLOOR': + bn_round = 'AP_TRN' + else: + raise NotImplementedError( + f'Rounding mode {rounding_mode} not supported in Quant node. Only ROUND and FLOOR supported.' + ) + + if narrow and not signed: + raise NotImplementedError('Narrow mode is only supported for singed numbers.') + + if narrow: + bn_sat = 'AP_SAT_SYM' + else: + bn_sat = 'AP_SAT' + + bitwidth = math.ceil(bitwidth) + integer = math.ceil(integer) + + precision = FixedPrecisionType(bitwidth, integer, signed, bn_round, bn_sat) + quantizer = QuantNodeQuantizer(precision) + return (precision, quantizer) diff --git a/hls4ml/model/optimizer/passes/reshape_const.py b/hls4ml/model/optimizer/passes/reshape_const.py new file mode 100644 index 0000000000..0012b2761e --- /dev/null +++ b/hls4ml/model/optimizer/passes/reshape_const.py @@ -0,0 +1,27 @@ +from hls4ml.model.layers import Constant, Reshape +from hls4ml.model.optimizer import OptimizerPass + + +class ReshapeConstant(OptimizerPass): + """ + ONNX has the target shape come as an input, not a parameter. This removes + the Constant input from new shape input. (Non-constant inputs are not supported.) + The constant value was already used; this is just a cleanup uptimization. + """ + + def match(self, node): + is_match = isinstance(node, Reshape) and len(node.inputs) > 1 and node.get_input_node(node.inputs[1]) + + return is_match + + def transform(self, model, node): + """ + Remove Constant from new shape input. Note, input shape node is already used on initialize + """ + shape_node = node.get_input_node(node.inputs[1]) + node.inputs[1] = '' + if not isinstance(shape_node, Constant): + raise RuntimeError("Nonconstant shape inputs are not currently supported") + model.remove_node(shape_node, rewire=False) + + return True diff --git a/hls4ml/model/quantizers.py b/hls4ml/model/quantizers.py index c857ef51ac..a5b9ceb8c4 100644 --- a/hls4ml/model/quantizers.py +++ b/hls4ml/model/quantizers.py @@ -8,7 +8,14 @@ import tensorflow as tf from qkeras.quantizers import get_quantizer -from hls4ml.model.types import ExponentPrecisionType, FixedPrecisionType, IntegerPrecisionType, XnorPrecisionType +from hls4ml.model.types import ( + ExponentPrecisionType, + FixedPrecisionType, + IntegerPrecisionType, + RoundingMode, + SaturationMode, + XnorPrecisionType, +) class Quantizer: @@ -158,3 +165,98 @@ def __call__(self, data): if hasattr(y, 'numpy'): y = y.numpy() return y + + +class QuantNodeQuantizer(Quantizer): + """ + This implements a quantizer for a FixedPrecisionType with width==integer + + This is based on the sample implementation in finn-base + """ + + def __init__(self, precision): + super().__init__(precision.width, precision) + if not isinstance(precision, (FixedPrecisionType, IntegerPrecisionType)): + raise TypeError('QuantNodeQuantizer is only defined for FixedPrecisionType and IntegerPrecisionType') + + def __call__(self, data): + """Apply the quantization on the data""" + + scale = 2 ** (self.hls_type.width - self.hls_type.integer) + + data = data * scale # (not using *= to avoid modifying data) + # Clamping + min_int_val = self._min_int(self.hls_type.signed, self.hls_type.saturation_mode, self.bits) + max_int_val = self._max_int(self.hls_type.signed, self.bits) + data = np.where(data > max_int_val, max_int_val, data) + data = np.where(data < min_int_val, min_int_val, data) + # Rounding + rounding_fx = self._resolve_rounding_mode(self.hls_type.rounding_mode) + return rounding_fx(data) / scale + + @staticmethod + def _min_int(signed: bool, saturation_mode: str, bit_width: int) -> int: + """Compute the minimum integer representable by a given number of bits. + Args: + signed (bool): Indicates whether the represented integer is signed or not. + saturation_mode (bool): Indicates the saturation mode used (AP_SAT_SYM or AP_SAT) + bit_width (int): Number of bits available for the representation. + Returns: + int: Maximum unsigned integer that can be represented according to + the input arguments. + Examples: + >>> min_int(signed=True, saturation_mode='AP_SAT_SYM', bit_width=8) + int(-127) + >>> min_int(signed=False, saturation_mode='AP_SAT_SYM', bit_width=8) + int(0) + >>> min_int(signed=True, saturation_mode='AP_SAT', bit_width=8) + int(-128) + >>> min_int(signed=False, saturation_mode='AP_SAT_SYM', bit_width=8) + int(0) + """ + if saturation_mode not in (SaturationMode.SAT_SYM, SaturationMode.SAT, SaturationMode.WRAP): + raise ValueError( + f'Saturation mode {saturation_mode} not supported. Only AP_SAT_SYM, AP_SAT supported, WRAP partially' + ) + if signed and saturation_mode == SaturationMode.SAT_SYM: + value = -(2 ** (bit_width - 1)) + 1 + elif signed: + value = -(2 ** (bit_width - 1)) + else: + value = 0 + return value + + @staticmethod + def _max_int(signed: bool, bit_width: int) -> int: + """Compute the maximum integer representable by a given number of bits. + (Note, narrow and unsigned is not supported by the implementation, so saturation mode is not used) + Args: + signed (bool): Indicates whether the represented integer is signed or not. + bit_width (int): Number of bits available for the representation. + Returns: + Tensor: Maximum integer that can be represented according to + the input arguments. + Examples: + >>> max_int(signed=True, bit_width=8) + int(127) + >>> max_int(signed=False, bit_width=8) + int(255) + """ + if not signed: + value = (2**bit_width) - 1 + else: + value = (2 ** (bit_width - 1)) - 1 + return value + + @staticmethod + def _resolve_rounding_mode(mode): + """Resolve the rounding mode of Quant and Trunc ops + to the corresponding numpy functions.""" + if mode == RoundingMode.RND_CONV: + return np.round + # elif mode_string == 'CEIL': # not supported + # return np.ceil + elif mode == RoundingMode.TRN: + return np.floor + else: + raise ValueError(f'Rounding mode {mode} not supported.') diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index fb5cde3863..9fb257a1ef 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -88,6 +88,7 @@ def __str__(self): typestring = '{signed}int<{width}>'.format(signed='u' if not self.signed else '', width=self.width) return typestring + # Does this need to make sure other is also an IntegerPrecisionType? I could see a match between Fixed and Integer def __eq__(self, other): if isinstance(other, IntegerPrecisionType): return super().__eq__(other) @@ -136,6 +137,8 @@ def __init__(self, width=16, integer=6, signed=True, rounding_mode=None, saturat self.saturation_mode = saturation_mode self.saturation_bits = saturation_bits + # make this a property to avoid inconsistencies + @property def fractional(self): return self.width - self.integer @@ -204,6 +207,7 @@ def __init__(self): super().__init__(width=1, signed=False) self.integer = 1 + # TODO: this should really be a specific type def __str__(self): typestring = 'uint<1>' return typestring @@ -218,6 +222,7 @@ class ExponentPrecisionType(PrecisionType): def __init__(self, width=16, signed=True): super().__init__(width=width, signed=signed) + # TODO: this should really be a specific type, not int def __str__(self): typestring = '{signed}int<{width}>'.format(signed='u' if not self.signed else '', width=self.width) return typestring diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h index 961c65037d..9dbbd92425 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h @@ -1,6 +1,8 @@ #ifndef NNET_CONV2D_RESOURCE_H_ #define NNET_CONV2D_RESOURCE_H_ +#include + #include "nnet_common.h" #include "nnet_dense.h" #include "nnet_helpers.h" diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index 5cd17d02e9..1bd9ff25ef 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -413,7 +413,7 @@ def make_layer_config(layer): def config_from_onnx_model( - model, granularity='model', backend=None, default_precision='ap_fixed<16,6>', default_reuse_factor=1 + model, granularity='name', backend=None, default_precision='ap_fixed<16,6>', default_reuse_factor=1 ): """Create an HLS conversion config given the ONNX model. @@ -423,8 +423,8 @@ def config_from_onnx_model( Args: model: ONNX model - granularity (str, optional): Granularity of the created config. Defaults to 'model'. - Can be set to 'model', 'type' and 'layer'. + granularity (str, optional): Granularity of the created config. Defaults to 'name'. + Can be set to 'model', 'type' and 'name'. Granularity can be used to generate a more verbose config that can be fine-tuned. The default granularity ('model') will generate config keys that apply to the whole @@ -443,6 +443,16 @@ def config_from_onnx_model( [dict]: The created config. """ + if granularity.lower() not in ['model', 'type', 'name']: + raise Exception( + f'Invalid configuration granularity specified, expected "model", "type" or "name" got "{granularity}"' + ) + + if backend is not None: + backend = hls4ml.backends.get_backend(backend) + elif granularity.lower() != 'model': + print('Warning: it is recommended to pass the backend to "config_from_onnx_model"') + config = {} model_config = {} @@ -452,4 +462,56 @@ def config_from_onnx_model( config['Model'] = model_config + layer_list, _, _ = hls4ml.converters.parse_onnx_model(model) + + def make_layer_config(layer): + cls_name = layer['class_name'] + + layer_cls = hls4ml.model.layers.layer_map[cls_name] + if backend is not None: + layer_cls = backend.create_layer_class(layer_cls) + + layer_config = {} + + # set the default precision of the layer to auto? + # (not really necessary if we set the backend appropriately) + # layer_config['Precision'] = {'default': 'auto'} + + config_attrs = [a for a in layer_cls.expected_attributes if a.configurable] + for attr in config_attrs: + if isinstance(attr, hls4ml.model.attributes.TypeAttribute): + precision_cfg = layer_config.setdefault('Precision', {}) + name = attr.name + if name.endswith('_t'): + name = name[:-2] + if attr.default is None: + precision_cfg[name] = 'auto' + else: + precision_cfg[name] = str(attr.default) + elif attr.name == 'reuse_factor': + layer_config[attr.config_name] = default_reuse_factor + else: + if attr.default is not None: + layer_config[attr.config_name] = attr.default + + return layer_config + + if granularity.lower() == 'type': + type_config = {} + for layer in layer_list: + if layer['class_name'] in type_config: + continue + layer_config = make_layer_config(layer) + type_config[layer['class_name']] = layer_config + + config['LayerType'] = type_config + + elif granularity.lower() == 'name': + name_config = {} + for layer in layer_list: + layer_config = make_layer_config(layer) + name_config[layer['name']] = layer_config + + config['LayerName'] = name_config + return config diff --git a/test/pytest/test_qkeras.py b/test/pytest/test_qkeras.py index 3d66107c85..a1ff93292e 100644 --- a/test/pytest/test_qkeras.py +++ b/test/pytest/test_qkeras.py @@ -356,8 +356,10 @@ def test_relu_negative_slope(randX_1000_1, quantizer, backend, io_type): ], ) def test_qactivation_kwarg(randX_100_10, activation_quantizer, weight_quantizer): - if activation_quantizer in ['binary', 'ternary']: + if activation_quantizer in ['binary']: name = 'bnbt_qdense_alpha' + elif activation_quantizer in ['ternary']: + name = 'bnbt_qdense_ternary_scale' else: name = f'qdense_{eval(activation_quantizer).__class__.__name__}' diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py new file mode 100644 index 0000000000..f822c591a7 --- /dev/null +++ b/test/pytest/test_qonnx.py @@ -0,0 +1,356 @@ +import os +import urllib +from pathlib import Path + +import numpy as np +import pytest +import qonnx.core.onnx_exec as oxe +import qonnx.util.cleanup +import qonnx.util.to_channels_last + +# To conveniently run QONNX inference +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean +from qonnx.transformation.gemm_to_matmul import GemmToMatMul + +import hls4ml + +test_root_path = Path(__file__).parent +example_model_path = (test_root_path / '../../example-models').resolve() + +# The models + + +@pytest.fixture(scope='module') +def tfc_2w2a_model(): + ''' + Load the tiny fully-connected model + ''' + dl_dir = test_root_path + dl_file = str(dl_dir / "qonnx-tfc-2w2a.onnx") + tfc_w2a2_qonnx_url = ( + "https://raw.githubusercontent.com/fastmachinelearning/" + "QONNX_model_zoo/main/models/MNIST/Brevitas_FINN_TFC/TFC/TFC_2W2A.onnx" + ) + urllib.request.urlretrieve(tfc_w2a2_qonnx_url, dl_file) + assert os.path.isfile(dl_file) + out_file = str(dl_dir / "qonnx-tfc-2w2a-clean.onnx") + + # cleanup + qonnx.util.cleanup.cleanup(dl_file, out_file=out_file) + model = ModelWrapper(out_file) + return model + + +@pytest.fixture(scope='module') +def cnv_2w2a_model(): + ''' + Load the small convolution model + ''' + dl_dir = test_root_path + dl_file = str(dl_dir / "qonnx-cnv-2w2a.onnx") + cnv_w2a2_qonnx_url = ( + "https://raw.githubusercontent.com/fastmachinelearning/" + "QONNX_model_zoo/main/models/CIFAR10/Brevitas_FINN_CNV/CNV_2W2A.onnx" + ) + urllib.request.urlretrieve(cnv_w2a2_qonnx_url, dl_file) + assert os.path.isfile(dl_file) + out_clean = str(dl_dir / "qonnx-cnv-2w2a-clean.onnx") + out_chanlast = str(dl_dir / "qonnx-cnv-2w2a-clean-channels-last.onnx") + out_file = str(dl_dir / "qonnx-cnv-2w2a-clean-channels-last-clean.onnx") + + # cleanup + qonnx.util.cleanup.cleanup(dl_file, out_file=out_clean) + qonnx.util.to_channels_last.to_channels_last(out_clean, make_input_channels_last=True, out_file=out_chanlast) + qonnx.util.cleanup.cleanup(out_chanlast, out_file=out_file) + model = ModelWrapper(out_file) + return model + + +@pytest.fixture(scope='module') +def jettagging_model(): + ''' + Load the 3 hidden layer QKeras example model trained on the jet tagging dataset + ''' + dl_dir = test_root_path + dl_file = str(dl_dir / "qkeras_jettagging.onnx") + jet_tagging_qonnx_url = ( + "https://raw.githubusercontent.com/fastmachinelearning/" + "QONNX_model_zoo/main/models/JetTagging/QKeras_hls4ml_3layer/qkeras_jettagging.onnx" + ) + urllib.request.urlretrieve(jet_tagging_qonnx_url, dl_file) + assert os.path.isfile(dl_file) + out_file = str(dl_dir / "qkeras_jettagging-clean.onnx") + + # cleanup + qonnx.util.cleanup.cleanup(dl_file, out_file=out_file) + model = ModelWrapper(out_file) + return model + + +@pytest.fixture(scope='module') +def sep_conv_model(): + """ + Load separabale conv model, already channels-last and cleaned + """ + dl_file = str(example_model_path / "onnx/separable_conv_model_ch_last.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + + return model + + +@pytest.fixture(scope='module') +def two_layer_keras_model(): + """ + Load a simple, two-layer, originally keras, unquantized model + """ + dl_file = str(example_model_path / "onnx/two_layer_keras.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def three_layer_keras_model(): + """ + Load a simple, three-layer, originally keras, unquantized model + """ + dl_file = str(example_model_path / "onnx/three_layer_keras.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def two_layer_pytorch_model(): + """ + Load a simple, two-layer, originally pytorch, unquantized model + """ + dl_file = str(example_model_path / "onnx/two_layer_keras.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = qonnx.util.cleanup.cleanup_model(model) + model = model.transform(GemmToMatMul()) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def three_layer_pytorch_model(): + """ + Load a simple, three-layer, originally pytorch, unquantized model + """ + dl_file = str(example_model_path / "onnx/three_layer_pytorch.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = qonnx.util.cleanup.cleanup_model(model) + model = model.transform(GemmToMatMul()) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def conv1d_small_keras_model(): + """ + Load a simple conv1d, originally keras, unquantized model + """ + dl_file = str(example_model_path / "onnx/conv1d_small_keras.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = qonnx.util.cleanup.cleanup_model(model) + model = model.transform(ConvertToChannelsLastAndClean()) + model = model.transform(GemmToMatMul()) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def conv2d_small_keras_model(): + """ + Load a simple conv2d, originally keras, unquantized model + """ + dl_file = str(example_model_path / "onnx/conv2d_small_keras.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = qonnx.util.cleanup.cleanup_model(model) + model = model.transform(ConvertToChannelsLastAndClean()) + model = model.transform(GemmToMatMul()) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def conv2d_small_mp_keras_model(): + """ + Load a conv2d model with max pooling, originally keras, unquantized model + """ + dl_file = str(example_model_path / "onnx/conv2d_small_mp_keras.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = qonnx.util.cleanup.cleanup_model(model) + model = model.transform(ConvertToChannelsLastAndClean()) + model = model.transform(GemmToMatMul()) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +# The actual tests + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_tfc_2w2a(tfc_2w2a_model, backend): + model = tfc_2w2a_model + + ishape = (1, 1, 28, 28) + X = np.random.uniform(low=-1, high=+1, size=np.prod(ishape)).reshape(ishape) + X = (np.round(X * 2**16) * 2**-16).astype(np.float32) + + idict = {model.graph.input[0].name: X} + y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name] + + # Convert QONNX model, compile, and run inference + config = hls4ml.utils.config_from_onnx_model(model, backend=backend, default_precision='fixed<32,16>') + hls_model = hls4ml.converters.convert_from_onnx_model( + model, output_dir=str(test_root_path / f'hls4mlprj_qonnx_tfc-2w2a_{backend}'), backend=backend, hls_config=config + ) + hls_model.compile() + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) + + +@pytest.mark.parametrize('backend', ['Vitis']) +def test_cnv_2w2a(cnv_2w2a_model, backend): + """ + This tests a convolution model. Note: the batch normalizations weights not quantized, so it is + difficult to make this match perfectly. It is also a slow test, which is why only Vitis is tested. + """ + model = cnv_2w2a_model + + ishape = (1, 32, 32, 3) + X = np.random.uniform(low=-1, high=+1, size=np.prod(ishape)).reshape(ishape) + X = (np.round(X * 2**6) * 2**-6).astype(np.float32) + idict = {model.graph.input[0].name: X} + y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name] + + # Convert QONNX model, compile, and run inference + config = hls4ml.utils.config_from_onnx_model(model, backend=backend, default_precision='fixed<32,6>') + hls_model = hls4ml.converters.convert_from_onnx_model( + model, + output_dir=str(test_root_path / f'hls4mlprj_qonnx_cnv-2w2a_{backend}'), + io_type='io_stream', + backend=backend, + hls_config=config, + ) + hls_model.compile() + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_jet_tagging(jettagging_model, backend): + model = jettagging_model + + # Execute QONNX model inference + # TODO make the test bigger + ishape = (1, 16) + X = np.random.uniform(low=-1, high=+1, size=np.prod(ishape)).reshape(ishape) + X = (np.round(X * 2**16) * 2**-16).astype(np.float32) + idict = {model.graph.input[0].name: X} + y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name] + + # Convert QONNX model, compile, and run inference + config = hls4ml.utils.config_from_onnx_model(model, backend=backend, default_precision='fixed<32,16>') + + hls_model = hls4ml.converters.convert_from_onnx_model( + model, output_dir=str(test_root_path / f'hls4mlprj_qonnx_jettag_{backend}'), backend=backend, hls_config=config + ) + hls_model.compile() + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) + + +@pytest.mark.parametrize('backend', ['Vitis']) +def test_sep_conv(sep_conv_model, backend): + model = sep_conv_model + ishape = tuple(model.get_tensor_shape(model.graph.input[0].name)) + X = np.random.uniform(low=0, high=1, size=np.prod(ishape)).reshape(ishape) + X = (np.round(X * 2**16) * 2**-16).astype(np.float32) + idict = {model.graph.input[0].name: X} + y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name] + + config = hls4ml.utils.config.config_from_onnx_model( + model, granularity='name', backend=backend, default_precision='fixed<32,16>' + ) + + hls_model = hls4ml.converters.convert_from_onnx_model( + model, + output_dir=str(test_root_path / f'hls4mlprj_qonnx_sep_conv_{backend}'), + io_type='io_stream', + backend=backend, + hls_config=config, + ) + hls_model.compile() + y_hls4ml = hls_model.predict(np.ascontiguousarray(X)) + + np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) + + +@pytest.mark.parametrize( + 'model_name', + [ + 'two_layer_keras_model', + 'three_layer_keras_model', + 'two_layer_pytorch_model', + 'three_layer_pytorch_model', + 'conv1d_small_keras_model', + 'conv2d_small_keras_model', + 'conv2d_small_mp_keras_model', + ], +) +@pytest.mark.parametrize('backend', ['Vitis']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_simple_model(model_name, io_type, backend, request): + if model_name == 'conv2d_small_mp_keras_model' and io_type == 'io_stream': + # Not yet supported due to an issue with channels last conversion + # There is a qonnx PR. + pytest.skip() + model = request.getfixturevalue(model_name) + ishape = tuple(model.get_tensor_shape(model.graph.input[0].name)) + X = np.random.uniform(low=0, high=1, size=np.prod(ishape)).reshape(ishape) + X = (np.round(X * 2**10) * 2**-10).astype(np.float32) + idict = {model.graph.input[0].name: X} + y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name] + + config = hls4ml.utils.config.config_from_onnx_model( + model, granularity='name', backend=backend, default_precision='fixed<16,6>' + ) + + for layer in config['LayerName']: + if layer.startswith('Softmax'): + config['LayerName'][layer]['Implementation'] = 'legacy' + + hls_model = hls4ml.converters.convert_from_onnx_model( + model, + output_dir=str(test_root_path / f'hls4mlprj_onnx_{model_name}_{io_type}_{backend}'), + io_type=io_type, + backend=backend, + hls_config=config, + ) + hls_model.compile() + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1)