diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 836da6e68..c7f5b490a 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -1,6 +1,15 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax +from hls4ml.model.layers import ( + Activation, + BatchNormalization, + Dense, + HardActivation, + LayerNormalization, + ParametrizedActivation, + PReLU, + Softmax, +) from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT # Dense templates @@ -119,6 +128,59 @@ def format(self, node): return self.template.format(**params) +# LayerNormalization templates + +layernorm_config_template = """struct config{index} : nnet::layernorm_config {{ + static const unsigned n_in = {n_in}; + static const unsigned seq_len = {seq_len}; + static const unsigned table_size = {table_size}; + static constexpr double table_range = {table_range}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + static const bool store_weights_in_bram = false; + static constexpr double epsilon = {epsilon}; + typedef {bias_t.name} bias_t; + typedef {scale_t.name} scale_t; + typedef {mean_t.name} mean_t; + typedef {table_t.name} table_t; + template + using product = nnet::product::{product_type}; +}};\n""" + +layernorm_function_template = 'nnet::layernormalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' + +layernorm_include_list = ['nnet_utils/nnet_layernorm.h'] + + +class LayerNormalizationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(LayerNormalization) + self.template = layernorm_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_in'] = node.get_input_variable().size_cpp() + params['seq_len'] = node.get_attr('seq_len') + params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('scale').type.precision + ) + + return self.template.format(**params) + + +class LayerNormalizationFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(LayerNormalization, include_header=layernorm_include_list) + self.template = layernorm_function_template + + def format(self, node): + params = self._default_function_params(node) + params['scale'] = node.get_weights('scale').name + params['bias'] = node.get_weights('bias').name + + return self.template.format(**params) + + # Activation templates activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 117805dd8..8ebbbc999 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -21,6 +21,7 @@ GarNet, GarNetStack, Layer, + LayerNormalization, Pooling1D, Pooling2D, SeparableConv1D, @@ -558,6 +559,21 @@ def init_softmax(self, layer): len(layer.get_input_variable().shape) == 1 ), 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.' + @layer_optimizer(LayerNormalization) + def init_layernormalization(self, layer): + if 'table_t' not in layer.attributes: + layer.set_attr( + 'table_t', NamedType(name=layer.name + '_table_t', precision=FixedPrecisionType(width=16, integer=6)) + ) + if 'table_size' not in layer.attributes: + layer.set_attr('table_size', 4096) # table size + if 'table_range' not in layer.attributes: + layer.set_attr('table_range', 1.0) # table range + if 'mean_t' not in layer.attributes: + layer.set_attr( + 'mean_t', NamedType(name=layer.name + '_mean_t', precision=FixedPrecisionType(width=19, integer=6)) + ) + @layer_optimizer(Embedding) def init_embed(self, layer): if layer.attributes['n_in'] is None: diff --git a/hls4ml/converters/keras/core.py b/hls4ml/converters/keras/core.py index 637bb6d40..47148ee9f 100644 --- a/hls4ml/converters/keras/core.py +++ b/hls4ml/converters/keras/core.py @@ -129,6 +129,34 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader): return layer, [shape for shape in input_shapes[0]] +@keras_handler('LayerNormalization') +def parse_layernorm_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'LayerNormalization' in keras_layer['class_name'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + in_size = 1 + for dim in input_shapes[0][1:]: + in_size *= dim + layer['n_in'] = layer['n_out'] = in_size + + if not ((len(input_shapes[0])) == 3): + raise Exception('input size is not currently supported by hls4ml, only dim3 is supported') + layer['seq_len'] = input_shapes[0][-2] + + if not (keras_layer['config']['axis'][0] == 2): + raise Exception('assigning the axis is not currently supported by hls4ml, only axis 2 is supported') + + layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'gamma') + layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'beta') + + layer['epsilon'] = keras_layer['config']['epsilon'] + if layer['epsilon'] <= 0: + raise Exception('epsilon must be positive') + + return layer, [shape for shape in input_shapes[0]] + + @keras_handler('Embedding') def parse_embedding_layer(keras_layer, input_names, input_shapes, data_reader): assert 'Embedding' in keras_layer['class_name'] diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 2c05b7501..e4d99fe28 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -138,3 +138,32 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node layer['n_filt'] = input_shapes[0][1] # Always channel first for Pytorch return layer, [shape for shape in input_shapes[0]] + + +@pytorch_handler('LayerNorm') +def parse_layernorm_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert 'LayerNorm' in operation + + layer = {} + + layer['class_name'] = 'LayerNormalization' + layer['name'] = layer_name + layer['inputs'] = input_names + + in_size = 1 + for dim in input_shapes[0][1:]: + in_size *= dim + layer['n_in'] = layer['n_out'] = in_size + + if not ((len(input_shapes[0])) == 3): + raise Exception('input size is not currently supported by hls4ml, only dim3 is supported') + layer['seq_len'] = input_shapes[0][-2] + + layer['gamma_data'] = class_object.weight.data.numpy() + layer['beta_data'] = class_object.bias.data.numpy() + + layer['epsilon'] = class_object.eps + if layer['epsilon'] <= 0: + raise Exception('epsilon must be positive') + + return layer, [shape for shape in input_shapes[0]] diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 3847cda9c..f9324c1ee 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1058,6 +1058,30 @@ def add_bias(self, bias, quantizer=None, precision=None): self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer, precision=precision) +class LayerNormalization(Layer): + _expected_attributes = [ + Attribute('n_in'), + Attribute('seq_len'), + Attribute('epsilon', value_type=float, default=1e-3), + WeightAttribute('scale'), + WeightAttribute('bias'), + TypeAttribute('scale'), + TypeAttribute('bias'), + ] + + def initialize(self): + inp = self.get_input_variable() + shape = inp.shape + dims = inp.dim_names + self.add_output_variable(shape, dims) + + scale = self.get_attr('gamma_data') + bias = self.get_attr('beta_data') + + self.add_weights_variable(name='scale', var_name='s{index}', data=scale) + self.add_weights_variable(name='bias', var_name='b{index}', data=bias) + + class Merge(Layer): def initialize(self): assert len(self.inputs) == 2 @@ -1682,6 +1706,7 @@ def initialize(self): 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, 'SymbolicExpression': SymbolicExpression, + 'LayerNormalization': LayerNormalization, # TensorFlow-specific layers: 'BiasAdd': BiasAdd, } diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index 0b5f12c00..8150d0a1f 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -2,7 +2,7 @@ # Based on https://github.com/fastmachinelearning/qonnx/blob/ # 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py -from hls4ml.model.layers import Concatenate, Dense, Input, Reshape, Transpose +from hls4ml.model.layers import Concatenate, Dense, Input, LayerNormalization, Reshape, Transpose from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.types import WeightVariable @@ -45,6 +45,24 @@ def transform(self, model, node): node.get_output_variable().shape = input_shape dim_names = [f'N_INPUT_{i}_{node.index}' for i in range(1, len(input_shape) + 1)] node.get_output_variable().dim_names = dim_names + elif isinstance(node, LayerNormalization): + # LayerNorm only works on the last dimension in PyTorch + perm = [1, 0] + pre_transpose = model.make_node( + 'Transpose', f'pre_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.get_input_node().name] + ) + pre_transpose.channels_last_converted = True + model.insert_node(pre_transpose) + + # If not the output layer, transpose again + if not ( + node.get_attr('name') in model.outputs and model.config.config['HLSConfig']['Model']['TransposeOutputs'] + ): + post_transpose = model.make_node( + 'Transpose', f'post_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.name] + ) + post_transpose.channels_last_converted = True + model.insert_node(post_transpose) else: # Transpose weight tensors tensors = ['weight', 'depthwise', 'pointwise', 'zero_bias', 'scale', 'recurrent_weight'] diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index bd439e4a0..af97b4ccd 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -51,7 +51,7 @@ def _infer_precision(self, node, types_to_infer): if node_class in ['Dense']: return self._infer_dense_precision(node, types_to_infer) - if node_class in ['BatchNormalization', 'ApplyAlpha']: + if node_class in ['BatchNormalization', 'ApplyAlpha', 'LayerNormalization']: return self._infer_bn_precision(node, types_to_infer) if node_class in ['Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D', 'Conv2DBatchnorm']: diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 84a83de23..a47c1647f 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -273,6 +273,18 @@ def _keras_layer(layer): return layer.get_weights(), ['w', 'b'] +def _keras_layernorm(layer): + weights = layer.get_weights() + + gamma = weights[0] + beta = weights[1] + + scale = gamma + bias = beta + + return [scale, bias], ['s', 'b'] + + def _keras_lstm(layer): return layer.get_weights(), ['w', 'u', 'b'] @@ -282,6 +294,7 @@ def _keras_lstm(layer): { 'BatchNormalization': _keras_batchnorm, 'QBatchNormalization': _keras_batchnorm, + 'LayerNormalization': _keras_layernorm, 'LSTM': _keras_lstm, 'QLSTM': _keras_lstm, }, diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_layernorm.h b/hls4ml/templates/vivado/nnet_utils/nnet_layernorm.h new file mode 100644 index 000000000..17b071234 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_layernorm.h @@ -0,0 +1,138 @@ +#ifndef NNET_LAYERNORM_H_ +#define NNET_LAYERNORM_H_ + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_dense.h" +#include + +#include "hls_math.h" + +namespace nnet { + +struct layernorm_config { + // Internal data type definitions + typedef float bias_t; + typedef float scale_t; + + // Layer Sizes + static const unsigned n_in = 20; + static const unsigned seq_len = 4; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; + + template using product = nnet::product::mult; +}; + +template void init_invert_sqr_table(typename CONFIG_T::table_t table_out[N_TABLE]) { + // Inversion function: + // result = 1/sqrt(x) + float min_val = CONFIG_T::epsilon; + float max_val = CONFIG_T::table_range; + float step = max_val / (float)(N_TABLE); + for (int ii = 0; ii < N_TABLE; ii++) { + float in_val = min_val + step * ii; + table_out[ii] = (typename CONFIG_T::table_t)(1.0 / sqrt(in_val)); + } +} + +template +void layernorm_1d(data_T data[CONFIG_T::n_in / CONFIG_T::seq_len], res_T res[CONFIG_T::n_in / CONFIG_T::seq_len], + typename CONFIG_T::scale_t scale[CONFIG_T::n_in / CONFIG_T::seq_len], + typename CONFIG_T::bias_t bias[CONFIG_T::n_in / CONFIG_T::seq_len]) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + #pragma HLS ARRAY_PARTITION variable=data complete + #pragma HLS ARRAY_PARTITION variable=res complete + int inv_range_inv = (int)1 / CONFIG_T::table_range; + typename CONFIG_T::table_t deno_inver = 0; +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t invert_sqr_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t invert_sqr_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_invert_sqr_table(invert_sqr_table); + initialized = true; + } + + static const unsigned dim = CONFIG_T::n_in / CONFIG_T::seq_len; + typename CONFIG_T::mean_t sum_cache = 0; + typename CONFIG_T::mean_t sum_cache2 = 0; + typename CONFIG_T::mean_t var, mean, diff; + typename CONFIG_T::mean_t data_diff[dim]; + typename CONFIG_T::mean_t var_epsilon = (typename CONFIG_T::mean_t)CONFIG_T::epsilon; + + #pragma HLS ARRAY_PARTITION variable=data_diff complete + + const typename CONFIG_T::mean_t k_inv = 1.0 / dim; + +LAYERNORM_1D_SUM: + for (int i = 0; i < dim; ++i) { + sum_cache += static_cast(data[i]); + } + mean = CONFIG_T::template product::product(sum_cache, k_inv); + +LAYERNORM_1D_VAR: + for (int i = 0; i < dim; ++i) { + data_diff[i] = static_cast(data[i]) - mean; + diff = data_diff[i] * data_diff[i]; + sum_cache2 += diff; + } + var = CONFIG_T::template product::product(sum_cache2, k_inv); + + int index = (var) * (CONFIG_T::table_size)*inv_range_inv; + if (CONFIG_T::table_range > 1) + index = (var) * (CONFIG_T::table_size) / (int)CONFIG_T::table_range; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = CONFIG_T::table_size - 1; + deno_inver = invert_sqr_table[index]; + +LAYERNORM_1D_RESULT: + for (int i = 0; i < dim; ++i) { + res[i] = data_diff[i] * deno_inver * scale[i] + bias[i]; + } +} + +template +void layernormalize(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in], + typename CONFIG_T::scale_t scale[CONFIG_T::n_in / CONFIG_T::seq_len], + typename CONFIG_T::bias_t bias[CONFIG_T::n_in / CONFIG_T::seq_len]) { + static const unsigned dim = CONFIG_T::n_in / CONFIG_T::seq_len; + data_T in_val[dim]; + res_T outval[dim]; + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + #pragma HLS function_instantiate variable=scale,bias + + #pragma HLS ARRAY_PARTITION variable=scale complete + #pragma HLS ARRAY_PARTITION variable=bias complete + #pragma HLS ARRAY_PARTITION variable=in_val complete + #pragma HLS ARRAY_PARTITION variable=outval complete + +LAYERNORM_SEQ_LOOP: + for (int j = 0; j < CONFIG_T::seq_len; ++j) { + #pragma HLS PIPELINE + LAYERNORM_LOAD: + for (int i = 0; i < dim; ++i) { + #pragma HLS UNROLL + in_val[i] = data[j * dim + i]; + } + layernorm_1d(in_val, outval, scale, bias); + LAYERNORM_STORE: + for (int i = 0; i < dim; ++i) { + #pragma HLS UNROLL + res[j * dim + i] = outval[i]; + } + } +} + +} // namespace nnet + +#endif diff --git a/test/pytest/test_layernorm.py b/test/pytest/test_layernorm.py new file mode 100644 index 000000000..bc9290b16 --- /dev/null +++ b/test/pytest/test_layernorm.py @@ -0,0 +1,43 @@ +from pathlib import Path + +import numpy as np +import pytest +from tensorflow.keras.layers import LayerNormalization +from tensorflow.keras.models import Sequential + +import hls4ml + +test_root_path = Path(__file__).parent + +in_shape = (10, 8) +atol = 5e-2 + + +@pytest.fixture(scope='module') +def data(): + np.random.seed(0) + return np.random.rand(100, *in_shape) + + +@pytest.fixture(scope='module') +def model(): + model = Sequential() + model.add(LayerNormalization(input_shape=in_shape)) + model.compile() + return model + + +# Currently only Vivado/Vitis in io_parallel mode is supported +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +def test_layernorm(model, data, backend): + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend) + output_dir = str(test_root_path / f'hls4mlprj_layernorm_{backend}_io_parallel') + hls_model = hls4ml.converters.convert_from_keras_model( + model, backend=backend, hls_config=config, io_type='io_parallel', output_dir=output_dir + ) + hls_model.compile() + + # Predict + y_keras = model.predict(data).flatten() + y_hls = hls_model.predict(data).flatten() + np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True) diff --git a/test/pytest/test_layernorm_pytorch.py b/test/pytest/test_layernorm_pytorch.py new file mode 100644 index 000000000..ca2c9d68a --- /dev/null +++ b/test/pytest/test_layernorm_pytorch.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +from torch import nn + +import hls4ml + +test_root_path = Path(__file__).parent + +in_shape = (10, 8) +atol = 5e-2 + + +@pytest.fixture(scope='module') +def data(): + np.random.seed(0) + return np.random.rand(100, *in_shape) + + +@pytest.fixture(scope='module') +def model(): + model = nn.Sequential(nn.LayerNorm(in_shape[-1])) + model.eval() + return model + + +# Currently only Vivado/Vitis in io_parallel mode is supported +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +def test_layernorm(model, data, backend): + config = hls4ml.utils.config_from_pytorch_model(model, in_shape, granularity='name', backend=backend) + output_dir = str(test_root_path / f'hls4mlprj_layernorm_pytorch_{backend}_io_parallel') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, backend=backend, hls_config=config, io_type='io_parallel', output_dir=output_dir + ) + hls_model.compile() + + # Predict + y_pytorch = model(torch.Tensor(data)).detach().numpy().flatten() + y_hls = hls_model.predict(data).flatten() + np.testing.assert_allclose(y_pytorch, y_hls, rtol=0, atol=atol, verbose=True)