Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LayerNorm support for Vivado #1110

Open
wants to merge 61 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
c4c818b
paser_mht
Ethan0Jiang Jul 13, 2022
3ee64d1
change parser and modify keras_to_hls
Ethan0Jiang Jul 13, 2022
5626a1a
IR_mutihead_attention
Ethan0Jiang Jul 14, 2022
d51f8a9
IR done
Ethan0Jiang Jul 15, 2022
89025a2
create mha file in template
Ethan0Jiang Jul 19, 2022
d76cf60
mha .h file dummy algo
Ethan0Jiang Jul 19, 2022
56811de
config of mha
Ethan0Jiang Jul 21, 2022
45cd493
update mha config
Ethan0Jiang Jul 21, 2022
1402f48
dummy mha
Ethan0Jiang Jul 21, 2022
430b9ea
add transpose into mha
Ethan0Jiang Jul 23, 2022
97f3e8d
projection_of_qkv_in_mha
Ethan0Jiang Jul 27, 2022
52cc7e8
mha_first_draft
Ethan0Jiang Aug 4, 2022
3961f97
able to predict model correct
Ethan0Jiang Aug 11, 2022
3533999
delete some unnassary comments
Ethan0Jiang Aug 11, 2022
d2f0df6
delete comments
Ethan0Jiang Aug 11, 2022
6aaa5ed
resource strategy of transformer
Ethan0Jiang Sep 16, 2022
3b7a288
change sm lagacy
Ethan0Jiang Oct 1, 2022
130092d
update MHA, optimized
Ethan0Jiang Oct 12, 2022
09b0ba0
support resource
Ethan0Jiang Oct 23, 2022
b49fffd
update
Ethan0Jiang Nov 27, 2022
5324a11
dense_muti_dim_support
Ethan0Jiang Dec 30, 2022
bf8c788
parallel execute dense
Ethan0Jiang Jan 1, 2023
b6be2c4
updates
Ethan0Jiang Jan 27, 2023
2472b7d
add_layerNorm_support
Ethan0Jiang Feb 15, 2023
97e71e9
MHA updated
Ethan0Jiang Feb 27, 2023
5ed4a76
LayerNorm_bug_fix
Ethan0Jiang Apr 4, 2023
5d28f58
update bit precision
Ethan0Jiang Apr 15, 2023
2fc68d0
config update
Ethan0Jiang Apr 17, 2023
b5c95cf
add some comment
Ethan0Jiang Apr 21, 2023
3b8aa8d
run pre-commit
JanFSchulte Sep 13, 2024
d28b24c
Added support on QMultiHeadAttention, QLayerNormalization, and quanti…
LostEcho365 Aug 7, 2023
de79bb9
updated on hls4ml transformer
LostEcho365 Nov 12, 2023
6c23326
trying to clean the diff
JanFSchulte Sep 13, 2024
20a0199
trying to clean the diff
JanFSchulte Sep 13, 2024
ddccde2
trying to clean the diff
Sep 17, 2024
afbe00b
trying to clean the diff
Sep 17, 2024
dedf96c
trying to clean the diff
Sep 17, 2024
a9de9cb
undo vhdl -> verilog change
Sep 18, 2024
49313d3
halfway working layernorm + test
Sep 18, 2024
1156ba5
layernorm is now pretty functional
Sep 18, 2024
17e0048
layernorm on pytorch also
Sep 19, 2024
63891fd
minor cleanup
Sep 19, 2024
8dccac6
more cleanup, pre-commit
Sep 19, 2024
595cc71
test for mha which kinda works maybe if you squint
Sep 19, 2024
5f3ec00
multihead attention working on keras and pytorch
Sep 20, 2024
5697334
fiddly precision / accuracy changes for layernorm
Sep 25, 2024
d2e27b8
Merge remote-tracking branch 'upstream/main' into transformer
rianbrooksflynn Oct 11, 2024
a149f2e
fix lookup table and label loops
rianbrooksflynn Oct 22, 2024
552fa83
remove dense_seq
rianbrooksflynn Oct 23, 2024
69f26bc
Merge remote-tracking branch 'upstream/main' into transformer
rianbrooksflynn Oct 23, 2024
be5f5a4
undo qkeras changes
rianbrooksflynn Oct 23, 2024
adf7356
fix merge conflict residue
rianbrooksflynn Oct 24, 2024
8437581
Merge remote-tracking branch 'upstream/main' into transformer
rianbrooksflynn Nov 4, 2024
39ab36c
remove non-layernorm changes
rianbrooksflynn Nov 4, 2024
b5b82e2
change to uniform LUT and fix precision
rianbrooksflynn Dec 9, 2024
f3ff077
Merge remote-tracking branch 'upstream/main' into layernorm
rianbrooksflynn Dec 9, 2024
0f08e7a
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Dec 9, 2024
21049e7
Merge remote-tracking branch 'upstream/main' into layernorm
rianbrooksflynn Dec 17, 2024
cbd88bd
fix encodings issue with dos2unix
rianbrooksflynn Jan 6, 2025
0fe0ec3
Merge branch 'main' into layernorm
JanFSchulte Jan 6, 2025
0d96cb0
add Vitis as another tested backend
rianbrooksflynn Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
from hls4ml.model.layers import (
Activation,
BatchNormalization,
Dense,
HardActivation,
LayerNormalization,
ParametrizedActivation,
PReLU,
Softmax,
)
from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT

# Dense templates
Expand Down Expand Up @@ -119,6 +128,59 @@ def format(self, node):
return self.template.format(**params)


# LayerNormalization templates

layernorm_config_template = """struct config{index} : nnet::layernorm_config {{
static const unsigned n_in = {n_in};
static const unsigned seq_len = {seq_len};
static const unsigned table_size = {table_size};
static constexpr double table_range = {table_range};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this is a double? It is not used as such, and breaks Vivado synthesis.

static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
static const bool store_weights_in_bram = false;
static constexpr double epsilon = {epsilon};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, this will only ever be used to store values in table_t, so we should ensure epsilon is compatible

typedef {bias_t.name} bias_t;
typedef {scale_t.name} scale_t;
typedef {mean_t.name} mean_t;
typedef {table_t.name} table_t;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

layernorm_function_template = 'nnet::layernormalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'

layernorm_include_list = ['nnet_utils/nnet_layernorm.h']


class LayerNormalizationConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(LayerNormalization)
self.template = layernorm_config_template

def format(self, node):
params = self._default_config_params(node)
params['n_in'] = node.get_input_variable().size_cpp()
params['seq_len'] = node.get_attr('seq_len')
params['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('scale').type.precision
)

return self.template.format(**params)


class LayerNormalizationFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(LayerNormalization, include_header=layernorm_include_list)
self.template = layernorm_function_template

def format(self, node):
params = self._default_function_params(node)
params['scale'] = node.get_weights('scale').name
params['bias'] = node.get_weights('bias').name

return self.template.format(**params)


# Activation templates

activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
Expand Down
16 changes: 16 additions & 0 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GarNet,
GarNetStack,
Layer,
LayerNormalization,
Pooling1D,
Pooling2D,
SeparableConv1D,
Expand Down Expand Up @@ -558,6 +559,21 @@ def init_softmax(self, layer):
len(layer.get_input_variable().shape) == 1
), 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.'

@layer_optimizer(LayerNormalization)
def init_layernormalization(self, layer):
if 'table_t' not in layer.attributes:
layer.set_attr(
'table_t', NamedType(name=layer.name + '_table_t', precision=FixedPrecisionType(width=16, integer=6))
)
if 'table_size' not in layer.attributes:
layer.set_attr('table_size', 4096) # table size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These attributes should be set as default in _register_layer_attributes() not here.

Also, is 4096 necessary for this implementation to work? All other tables are 1024.

if 'table_range' not in layer.attributes:
layer.set_attr('table_range', 1.0) # table range
if 'mean_t' not in layer.attributes:
layer.set_attr(
'mean_t', NamedType(name=layer.name + '_mean_t', precision=FixedPrecisionType(width=19, integer=6))
)

@layer_optimizer(Embedding)
def init_embed(self, layer):
if layer.attributes['n_in'] is None:
Expand Down
28 changes: 28 additions & 0 deletions hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,34 @@ def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader):
return layer, [shape for shape in input_shapes[0]]


@keras_handler('LayerNormalization')
def parse_layernorm_layer(keras_layer, input_names, input_shapes, data_reader):
assert 'LayerNormalization' in keras_layer['class_name']

layer = parse_default_keras_layer(keras_layer, input_names)

in_size = 1
for dim in input_shapes[0][1:]:
in_size *= dim
layer['n_in'] = layer['n_out'] = in_size

if not ((len(input_shapes[0])) == 3):
raise Exception('input size is not currently supported by hls4ml, only dim3 is supported')
layer['seq_len'] = input_shapes[0][-2]

if not (keras_layer['config']['axis'][0] == 2):
raise Exception('assigning the axis is not currently supported by hls4ml, only axis 2 is supported')

layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'gamma')
layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'beta')

layer['epsilon'] = keras_layer['config']['epsilon']
if layer['epsilon'] <= 0:
raise Exception('epsilon must be positive')

return layer, [shape for shape in input_shapes[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should also parse axis parameter to avoid misparsing. Sure, we only support axis=-1 and we can raise exceptions about it, but should be handled.



@keras_handler('Embedding')
def parse_embedding_layer(keras_layer, input_names, input_shapes, data_reader):
assert 'Embedding' in keras_layer['class_name']
Expand Down
29 changes: 29 additions & 0 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,32 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node
layer['n_filt'] = input_shapes[0][1] # Always channel first for Pytorch

return layer, [shape for shape in input_shapes[0]]


@pytorch_handler('LayerNorm')
def parse_layernorm_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'LayerNorm' in operation

layer = {}

layer['class_name'] = 'LayerNormalization'
layer['name'] = layer_name
layer['inputs'] = input_names

in_size = 1
for dim in input_shapes[0][1:]:
in_size *= dim
layer['n_in'] = layer['n_out'] = in_size

if not ((len(input_shapes[0])) == 3):
raise Exception('input size is not currently supported by hls4ml, only dim3 is supported')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is dim3? This sentence could be made a bit nicer :-)

layer['seq_len'] = input_shapes[0][-2]

layer['gamma_data'] = class_object.weight.data.numpy()
layer['beta_data'] = class_object.bias.data.numpy()

layer['epsilon'] = class_object.eps
if layer['epsilon'] <= 0:
raise Exception('epsilon must be positive')

return layer, [shape for shape in input_shapes[0]]
25 changes: 25 additions & 0 deletions hls4ml/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,30 @@ def add_bias(self, bias, quantizer=None, precision=None):
self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer, precision=precision)


class LayerNormalization(Layer):
_expected_attributes = [
Attribute('n_in'),
Attribute('seq_len'),
Attribute('epsilon', value_type=float, default=1e-3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis too

WeightAttribute('scale'),
WeightAttribute('bias'),
TypeAttribute('scale'),
TypeAttribute('bias'),
]

def initialize(self):
inp = self.get_input_variable()
shape = inp.shape
dims = inp.dim_names
self.add_output_variable(shape, dims)

scale = self.get_attr('gamma_data')
bias = self.get_attr('beta_data')

self.add_weights_variable(name='scale', var_name='s{index}', data=scale)
self.add_weights_variable(name='bias', var_name='b{index}', data=bias)


class Merge(Layer):
def initialize(self):
assert len(self.inputs) == 2
Expand Down Expand Up @@ -1682,6 +1706,7 @@ def initialize(self):
'BatchNormOnnx': BatchNormOnnx,
'LayerGroup': LayerGroup,
'SymbolicExpression': SymbolicExpression,
'LayerNormalization': LayerNormalization,
# TensorFlow-specific layers:
'BiasAdd': BiasAdd,
}
Expand Down
20 changes: 19 additions & 1 deletion hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Based on https://github.com/fastmachinelearning/qonnx/blob/
# 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py

from hls4ml.model.layers import Concatenate, Dense, Input, Reshape, Transpose
from hls4ml.model.layers import Concatenate, Dense, Input, LayerNormalization, Reshape, Transpose
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.types import WeightVariable

Expand Down Expand Up @@ -45,6 +45,24 @@ def transform(self, model, node):
node.get_output_variable().shape = input_shape
dim_names = [f'N_INPUT_{i}_{node.index}' for i in range(1, len(input_shape) + 1)]
node.get_output_variable().dim_names = dim_names
elif isinstance(node, LayerNormalization):
# LayerNorm only works on the last dimension in PyTorch
perm = [1, 0]
pre_transpose = model.make_node(
'Transpose', f'pre_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.get_input_node().name]
)
pre_transpose.channels_last_converted = True
model.insert_node(pre_transpose)

# If not the output layer, transpose again
if not (
node.get_attr('name') in model.outputs and model.config.config['HLSConfig']['Model']['TransposeOutputs']
):
post_transpose = model.make_node(
'Transpose', f'post_transpose_for_{node.get_attr("name")}', {'perm': perm}, [node.name]
)
post_transpose.channels_last_converted = True
model.insert_node(post_transpose)
else:
# Transpose weight tensors
tensors = ['weight', 'depthwise', 'pointwise', 'zero_bias', 'scale', 'recurrent_weight']
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/model/optimizer/passes/infer_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _infer_precision(self, node, types_to_infer):
if node_class in ['Dense']:
return self._infer_dense_precision(node, types_to_infer)

if node_class in ['BatchNormalization', 'ApplyAlpha']:
if node_class in ['BatchNormalization', 'ApplyAlpha', 'LayerNormalization']:
return self._infer_bn_precision(node, types_to_infer)

if node_class in ['Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D', 'Conv2DBatchnorm']:
Expand Down
13 changes: 13 additions & 0 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,18 @@ def _keras_layer(layer):
return layer.get_weights(), ['w', 'b']


def _keras_layernorm(layer):
weights = layer.get_weights()

gamma = weights[0]
beta = weights[1]

scale = gamma
bias = beta

return [scale, bias], ['s', 'b']


def _keras_lstm(layer):
return layer.get_weights(), ['w', 'u', 'b']

Expand All @@ -282,6 +294,7 @@ def _keras_lstm(layer):
{
'BatchNormalization': _keras_batchnorm,
'QBatchNormalization': _keras_batchnorm,
'LayerNormalization': _keras_layernorm,
'LSTM': _keras_lstm,
'QLSTM': _keras_lstm,
},
Expand Down
Loading
Loading