From 2531afedf12159fb8cce42288e292b9fd778755c Mon Sep 17 00:00:00 2001 From: dgburnette <36940078+dgburnette@users.noreply.github.com> Date: Mon, 15 Apr 2024 07:12:17 -0700 Subject: [PATCH] Stage initial set of changes for the Catapult backend (#956) * Stage initial set of changes for the Catapult backend * applied some changes for issues reported by pre-commit. But pre-commit still reorders backends/__init__.py incorrectly * final changes for clean pre-commit * minor edits * Checkin * Add file * pre-commit format * add in nnet_utils files * format changes for pre-commit * run flows by netlist type * update design pragmas on some blocks. cleaned up TCL script * move AC submodules under hls4ml/templates/catapult * merged in latest changes from mainline * remove bad submodules * recreate AC submodules in hls4ml/templates/catapult * pre-commit fixes * pre-commit fixes * turn on Catapult backend testing * removed io_parallel testing for Catapult backend * add Catapult * added Catapult * added Catapult * added Catapult to some pytests * Added concept of ProjectDir to distinguish the project directory of the HLS tool from the ProjectName which is used for the cpp file and top function name * better handling of c++ testbench data files. enhanced directory naming. * fix syntax * workaround from Giuseppe * Add concept of ProjectDir for Catapult which is different from ProjectName that gets used for the top function name and the cpp files * add new file from Giuseppe * improvements to project management, reporting and testbench * include new file in generation of parameters.h * add hard_tanh for io_parallel. formatting * Full path to the header nnet_helpers.h is necessary in the include (check if this is not an issue with recent versions of Catapult) * Avoid ceiling function from the math library: ceil(n/d) ---> (n+d-1)/n * These are mostly workarounds for the BUP synyhesis of a testing model (should these changes make in something more general?) * revert format back to what clang-format currently enforces * simplification from Giuesspe * Fixes for bottom-up handling of libraries * pre-commit format fixes * fix loops * consolidate prj scripts * cleanup pragmas * switch from using ssh to https for submodules * fix include path for non-catapult install * update pytest environment * CL 1100381 * CL 1098112 * roll in latest changes. turn off Catapult variants of test_binary_cnn and test_cnn_mnist_qkeras for now * fix test failure * disable Catapult test for pytorch until it is supported * disable Catapult for pytorch tests * Simpler submodule initialization for CI --------- Co-authored-by: David Burnette Co-authored-by: Giuseppe Di Guglielmo Co-authored-by: Jovan Mitrevski Co-authored-by: Vladimir Loncar --- .gitmodules | 9 + hls4ml/backends/__init__.py | 3 + hls4ml/backends/catapult/__init__.py | 0 hls4ml/backends/catapult/catapult_backend.py | 515 ++++++++ hls4ml/backends/catapult/passes/__init__.py | 0 .../catapult/passes/broadcast_stream.py | 117 ++ .../backends/catapult/passes/conv_same_pad.py | 109 ++ .../backends/catapult/passes/conv_stream.py | 52 + .../catapult/passes/convolution_templates.py | 508 ++++++++ .../catapult/passes/convolution_winograd.py | 175 +++ .../catapult/passes/core_templates.py | 216 ++++ .../passes/fifo_depth_optimization.py | 104 ++ .../catapult/passes/garnet_templates.py | 249 ++++ .../catapult/passes/merge_templates.py | 106 ++ hls4ml/backends/catapult/passes/pointwise.py | 92 ++ .../catapult/passes/pooling_templates.py | 109 ++ .../catapult/passes/quantization_templates.py | 36 + .../catapult/passes/recurrent_templates.py | 175 +++ .../catapult/passes/reshaping_templates.py | 132 ++ .../catapult/passes/resource_strategy.py | 48 + .../catapult/passes/transform_types.py | 52 + hls4ml/backends/fpga/fpga_types.py | 65 + hls4ml/converters/__init__.py | 6 +- hls4ml/model/graph.py | 6 + hls4ml/model/layers.py | 1 + hls4ml/model/profiling.py | 1 + hls4ml/report/__init__.py | 3 + hls4ml/report/catapult_report.py | 256 ++++ hls4ml/templates/catapult/ac_math | 1 + hls4ml/templates/catapult/ac_simutils | 1 + hls4ml/templates/catapult/ac_types | 1 + hls4ml/templates/catapult/build_lib.sh | 21 + hls4ml/templates/catapult/build_prj.tcl | 356 ++++++ hls4ml/templates/catapult/catapult_synth.tcl | 3 + hls4ml/templates/catapult/firmware/defines.h | 15 + .../templates/catapult/firmware/myproject.cpp | 29 + .../templates/catapult/firmware/myproject.h | 15 + .../templates/catapult/firmware/parameters.h | 15 + .../templates/catapult/myproject_bridge.cpp | 72 ++ hls4ml/templates/catapult/myproject_test.cpp | 164 +++ .../catapult/nnet_utils/ap_shift_reg.h | 136 ++ .../templates/catapult/nnet_utils/hls_math.h | 24 + .../catapult/nnet_utils/nnet_activation.h | 1107 +++++++++++++++++ .../nnet_utils/nnet_activation_stream.h | 922 ++++++++++++++ .../catapult/nnet_utils/nnet_array.h | 52 + .../catapult/nnet_utils/nnet_batchnorm.h | 127 ++ .../nnet_utils/nnet_batchnorm_stream.h | 113 ++ .../catapult/nnet_utils/nnet_code_gen.h | 32 + .../catapult/nnet_utils/nnet_common.h | 66 + .../catapult/nnet_utils/nnet_conv1d.h | 62 + .../catapult/nnet_utils/nnet_conv1d_latency.h | 198 +++ .../nnet_utils/nnet_conv1d_resource.h | 241 ++++ .../catapult/nnet_utils/nnet_conv1d_stream.h | 94 ++ .../catapult/nnet_utils/nnet_conv2d.h | 84 ++ .../catapult/nnet_utils/nnet_conv2d_latency.h | 392 ++++++ .../nnet_utils/nnet_conv2d_resource.h | 275 ++++ .../catapult/nnet_utils/nnet_conv2d_stream.h | 117 ++ .../catapult/nnet_utils/nnet_conv_stream.h | 398 ++++++ .../catapult/nnet_utils/nnet_dense.h | 49 + .../nnet_utils/nnet_dense_compressed.h | 106 ++ .../catapult/nnet_utils/nnet_dense_latency.h | 92 ++ .../catapult/nnet_utils/nnet_dense_resource.h | 262 ++++ .../catapult/nnet_utils/nnet_dense_stream.h | 72 ++ .../catapult/nnet_utils/nnet_embed.h | 47 + .../catapult/nnet_utils/nnet_embed_stream.h | 34 + .../catapult/nnet_utils/nnet_garnet.h | 816 ++++++++++++ .../catapult/nnet_utils/nnet_helpers.h | 461 +++++++ .../catapult/nnet_utils/nnet_image.h | 41 + .../catapult/nnet_utils/nnet_image_stream.h | 66 + .../templates/catapult/nnet_utils/nnet_math.h | 178 +++ .../catapult/nnet_utils/nnet_merge.h | 232 ++++ .../catapult/nnet_utils/nnet_merge_stream.h | 380 ++++++ .../templates/catapult/nnet_utils/nnet_mult.h | 127 ++ .../catapult/nnet_utils/nnet_padding.h | 145 +++ .../catapult/nnet_utils/nnet_padding_stream.h | 95 ++ .../catapult/nnet_utils/nnet_pooling.h | 362 ++++++ .../catapult/nnet_utils/nnet_pooling_stream.h | 601 +++++++++ .../nnet_utils/nnet_recr_activations.h | 56 + .../catapult/nnet_utils/nnet_recurrent.h | 572 +++++++++ .../nnet_utils/nnet_sepconv1d_stream.h | 127 ++ .../catapult/nnet_utils/nnet_sepconv2d.h | 82 ++ .../nnet_utils/nnet_sepconv2d_stream.h | 152 +++ .../catapult/nnet_utils/nnet_sepconv_stream.h | 315 +++++ .../catapult/nnet_utils/nnet_stream.h | 156 +++ .../catapult/nnet_utils/nnet_types.h | 64 + .../templates/vivado_accelerator/build_lib.sh | 0 hls4ml/writer/__init__.py | 2 + hls4ml/writer/catapult_writer.py | 929 ++++++++++++++ test/pytest/ci-template.yml | 3 +- test/pytest/test_activations.py | 2 +- test/pytest/test_batchnorm.py | 2 +- test/pytest/test_batchnorm_pytorch.py | 2 +- test/pytest/test_clone_flatten.py | 2 +- test/pytest/test_cnn_mnist.py | 2 +- test/pytest/test_conv1d.py | 4 + test/pytest/test_embed.py | 4 +- test/pytest/test_globalpooling.py | 4 +- test/pytest/test_keras_h5_loader.py | 2 +- test/pytest/test_keras_nested_model.py | 4 +- test/pytest/test_pointwiseconv.py | 4 + test/pytest/test_pooling.py | 4 +- test/pytest/test_repack_stream.py | 4 +- test/pytest/test_reshape.py | 2 +- test/pytest/test_sepconv1d.py | 2 +- test/pytest/test_sepconv2d.py | 2 +- test/pytest/test_softmax.py | 4 +- test/pytest/test_softsign.py | 2 +- test/pytest/test_upsampling.py | 2 +- test/pytest/test_zeropadding.py | 2 +- 109 files changed, 14930 insertions(+), 28 deletions(-) create mode 100644 hls4ml/backends/catapult/__init__.py create mode 100644 hls4ml/backends/catapult/catapult_backend.py create mode 100644 hls4ml/backends/catapult/passes/__init__.py create mode 100644 hls4ml/backends/catapult/passes/broadcast_stream.py create mode 100755 hls4ml/backends/catapult/passes/conv_same_pad.py create mode 100755 hls4ml/backends/catapult/passes/conv_stream.py create mode 100755 hls4ml/backends/catapult/passes/convolution_templates.py create mode 100644 hls4ml/backends/catapult/passes/convolution_winograd.py create mode 100755 hls4ml/backends/catapult/passes/core_templates.py create mode 100755 hls4ml/backends/catapult/passes/fifo_depth_optimization.py create mode 100755 hls4ml/backends/catapult/passes/garnet_templates.py create mode 100755 hls4ml/backends/catapult/passes/merge_templates.py create mode 100755 hls4ml/backends/catapult/passes/pointwise.py create mode 100755 hls4ml/backends/catapult/passes/pooling_templates.py create mode 100755 hls4ml/backends/catapult/passes/quantization_templates.py create mode 100755 hls4ml/backends/catapult/passes/recurrent_templates.py create mode 100755 hls4ml/backends/catapult/passes/reshaping_templates.py create mode 100755 hls4ml/backends/catapult/passes/resource_strategy.py create mode 100755 hls4ml/backends/catapult/passes/transform_types.py create mode 100755 hls4ml/report/catapult_report.py create mode 160000 hls4ml/templates/catapult/ac_math create mode 160000 hls4ml/templates/catapult/ac_simutils create mode 160000 hls4ml/templates/catapult/ac_types create mode 100755 hls4ml/templates/catapult/build_lib.sh create mode 100755 hls4ml/templates/catapult/build_prj.tcl create mode 100644 hls4ml/templates/catapult/catapult_synth.tcl create mode 100755 hls4ml/templates/catapult/firmware/defines.h create mode 100755 hls4ml/templates/catapult/firmware/myproject.cpp create mode 100755 hls4ml/templates/catapult/firmware/myproject.h create mode 100755 hls4ml/templates/catapult/firmware/parameters.h create mode 100755 hls4ml/templates/catapult/myproject_bridge.cpp create mode 100755 hls4ml/templates/catapult/myproject_test.cpp create mode 100644 hls4ml/templates/catapult/nnet_utils/ap_shift_reg.h create mode 100755 hls4ml/templates/catapult/nnet_utils/hls_math.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_activation.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_array.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_batchnorm.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_code_gen.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_common.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_conv1d.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_conv1d_latency.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_conv1d_resource.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_conv1d_stream.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_conv2d.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_conv2d_latency.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_conv2d_resource.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_conv2d_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_conv_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_dense.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_dense_compressed.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_dense_latency.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_dense_resource.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_dense_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_embed.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_embed_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_garnet.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_helpers.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_image.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_image_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_math.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_merge.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_merge_stream.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_mult.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_padding.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_padding_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_pooling.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_pooling_stream.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_recr_activations.h create mode 100755 hls4ml/templates/catapult/nnet_utils/nnet_recurrent.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_sepconv1d_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_sepconv_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_stream.h create mode 100644 hls4ml/templates/catapult/nnet_utils/nnet_types.h mode change 100644 => 100755 hls4ml/templates/vivado_accelerator/build_lib.sh create mode 100755 hls4ml/writer/catapult_writer.py diff --git a/.gitmodules b/.gitmodules index 3513213a23..98c3df68fd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,12 @@ [submodule "example-models"] path = example-models url = https://github.com/hls-fpga-machine-learning/example-models.git +[submodule "hls4ml/templates/catapult/ac_types"] + path = hls4ml/templates/catapult/ac_types + url = https://github.com/hlslibs/ac_types.git +[submodule "hls4ml/templates/catapult/ac_simutils"] + path = hls4ml/templates/catapult/ac_simutils + url = https://github.com/hlslibs/ac_simutils.git +[submodule "hls4ml/templates/catapult/ac_math"] + path = hls4ml/templates/catapult/ac_math + url = https://github.com/hlslibs/ac_math.git diff --git a/hls4ml/backends/__init__.py b/hls4ml/backends/__init__.py index 6396d7815f..8b3117af7a 100644 --- a/hls4ml/backends/__init__.py +++ b/hls4ml/backends/__init__.py @@ -6,10 +6,13 @@ from hls4ml.backends.vivado_accelerator.vivado_accelerator_backend import VivadoAcceleratorBackend from hls4ml.backends.vivado_accelerator.vivado_accelerator_config import VivadoAcceleratorConfig # noqa: F401 +from hls4ml.backends.catapult.catapult_backend import CatapultBackend # isort: skip + from hls4ml.backends.vitis.vitis_backend import VitisBackend # isort: skip register_backend('Vivado', VivadoBackend) register_backend('VivadoAccelerator', VivadoAcceleratorBackend) register_backend('Vitis', VitisBackend) register_backend('Quartus', QuartusBackend) +register_backend('Catapult', CatapultBackend) register_backend('SymbolicExpression', SymbolicExpressionBackend) diff --git a/hls4ml/backends/catapult/__init__.py b/hls4ml/backends/catapult/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/backends/catapult/catapult_backend.py b/hls4ml/backends/catapult/catapult_backend.py new file mode 100644 index 0000000000..5556154dcb --- /dev/null +++ b/hls4ml/backends/catapult/catapult_backend.py @@ -0,0 +1,515 @@ +import os +import sys + +import numpy as np + +from hls4ml.backends import FPGABackend +from hls4ml.backends.fpga.fpga_types import ACTypeConverter, CatapultArrayVariableConverter, HLSTypeConverter +from hls4ml.model.attributes import ChoiceAttribute, ConfigurableAttribute, TypeAttribute +from hls4ml.model.flow import register_flow +from hls4ml.model.layers import ( + GRU, + LSTM, + Conv1D, + Conv2D, + Dense, + DepthwiseConv2D, + Embedding, + GarNet, + GarNetStack, + GlobalPooling1D, + GlobalPooling2D, + Layer, + Pooling1D, + Pooling2D, + SeparableConv1D, + SeparableConv2D, + SimpleRNN, + Softmax, +) +from hls4ml.model.optimizer import get_backend_passes, layer_optimizer +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType +from hls4ml.report import parse_catapult_report +from hls4ml.utils.fixed_point_utils import ceil_log2 + + +class CatapultBackend(FPGABackend): + def __init__(self): + super().__init__('Catapult') + self._register_layer_attributes() + self._register_flows() + + def _register_layer_attributes(self): + # Add RNN-specific attributes, recurrent_reuse_factor and static implementation + rnn_layers = [ + SimpleRNN, + LSTM, + GRU, + ] + + for layer in rnn_layers: + attrs = self.attribute_map.get(layer, []) + attrs.append(ConfigurableAttribute('recurrent_reuse_factor', default=1)) + attrs.append(ConfigurableAttribute('static', value_type=bool, default=True)) + attrs.append(ConfigurableAttribute('table_size', default=1024)) + attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8))) + self.attribute_map[layer] = attrs + + # Add ParallelizationFactor to Conv1D/2D + pf_layers = [ + Conv1D, + Conv2D, + ] + + for layer in pf_layers: + attrs = self.attribute_map.get(layer, []) + attrs.append(ConfigurableAttribute('parallelization_factor', default=1)) + self.attribute_map[layer] = attrs + + # Add ConvImplementation to Convolution+Pooling layers + cnn_layers = [Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Pooling1D, Pooling2D] + + for layer in cnn_layers: + attrs = self.attribute_map.get(layer, []) + # attrs.append(ConfigurableAttribute('conv_implementation', value_type=str, default='LineBuffer')) + attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer')) + self.attribute_map[layer] = attrs + + sep_conv_layers = [SeparableConv1D, SeparableConv2D] + for layer in sep_conv_layers: + attrs = self.attribute_map.get(layer, []) + attrs.append(TypeAttribute('dw_output', default=FixedPrecisionType(18, 8))) + self.attribute_map[layer] = attrs + + def _register_flows(self): + initializers = self._get_layer_initializers() + init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) + + streaming_passes = [ + 'catapult:reshape_stream', + 'catapult:clone_output', + 'catapult:insert_zero_padding_before_conv1d', + 'catapult:insert_zero_padding_before_conv2d', + 'catapult:broadcast_stream', + ] + streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name) + + quantization_passes = [ + 'catapult:merge_batch_norm_quantized_tanh', + 'catapult:quantize_dense_output', + 'fuse_consecutive_batch_normalization', + 'catapult:xnor_pooling', + ] + quantization_flow = register_flow('quantization', quantization_passes, requires=[init_flow], backend=self.name) + + optimization_passes = [ + 'catapult:remove_final_reshape', + 'catapult:optimize_pointwise_conv', + 'catapult:inplace_parallel_reshape', + 'catapult:inplace_stream_flatten', + 'catapult:skip_softmax', + 'catapult:fix_softmax_table_size', + ] + optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name) + + catapult_types = [ + 'catapult:transform_types', + 'catapult:register_bram_weights', + 'catapult:generate_conv_streaming_instructions', + 'catapult:apply_resource_strategy', + 'catapult:generate_conv_im2col', + ] + catapult_types_flow = register_flow('specific_types', catapult_types, requires=[init_flow], backend=self.name) + + templates = self._get_layer_templates() + template_flow = register_flow('apply_templates', self._get_layer_templates, requires=[init_flow], backend=self.name) + + writer_passes = ['make_stamp', 'catapult:write_hls'] + self._writer_flow = register_flow('write', writer_passes, requires=['catapult:ip'], backend=self.name) + + fifo_depth_opt_passes = [ + 'catapult:fifo_depth_optimization' + ] + writer_passes # After optimization, a new project will be written + + register_flow('fifo_depth_optimization', fifo_depth_opt_passes, requires=[self._writer_flow], backend=self.name) + + all_passes = get_backend_passes(self.name) + + extras = [ + # Ideally this should be empty + opt_pass + for opt_pass in all_passes + if opt_pass + not in initializers + + streaming_passes + + quantization_passes + + optimization_passes + + catapult_types + + templates + + writer_passes + + fifo_depth_opt_passes + ] + + if len(extras) > 0: + extras_flow = register_flow('extras', extras, requires=[init_flow], backend=self.name) + else: + extras_flow = None + + ip_flow_requirements = [ + 'optimize', + init_flow, + streaming_flow, + quantization_flow, + optimization_flow, + catapult_types_flow, + extras_flow, + template_flow, + ] + ip_flow_requirements = list(filter(None, ip_flow_requirements)) + + self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name) + + def get_default_flow(self): + return self._default_flow + + def get_writer_flow(self): + return self._writer_flow + + def create_initial_config( + self, + tech='fpga', + part='xcku115-flvb2104-2-i', + asiclibs='nangate-45nm', + fifo=None, + clock_period=5, + io_type='io_parallel', + ): + config = {} + + config['Technology'] = tech + if tech == 'fpga': + config['Part'] = part if part is not None else 'xcvu13p-flga2577-2-e' + else: + config['ASICLibs'] = asiclibs if asiclibs is not None else 'nangate-45nm' + config['ClockPeriod'] = clock_period + config['FIFO'] = fifo + config['IOType'] = io_type + config['HLSConfig'] = {} + + return config + + def build( + self, + model, + reset=False, + csim=True, + synth=True, + cosim=False, + validation=False, + vhdl=False, + verilog=True, + export=False, + vsynth=False, + fifo_opt=False, + bitfile=False, + ran_frame=5, + sw_opt=False, + power=False, + da=False, + bup=False, + ): + # print(f'ran_frame value: {ran_frame}') # Add this line for debugging + catapult_exe = 'catapult' + if 'linux' in sys.platform: + cmd = 'command -v ' + catapult_exe + ' > /dev/null' + found = os.system(cmd) + if found != 0: + catapult_exe = os.getenv('MGC_HOME') + '/bin/catapult' + cmd = 'command -v ' + catapult_exe + ' > /dev/null' + found = os.system(cmd) + if found != 0: + catapult_exe = os.getenv('CATAPULT_HOME') + '/bin/catapult' + cmd = 'command -v ' + catapult_exe + ' > /dev/null' + if found != 0: + raise Exception('Catapult HLS installation not found. Make sure "catapult" is on PATH.') + + curr_dir = os.getcwd() + # this execution moves into the hls4ml-generated "output_dir" and runs the build_prj.tcl script. + os.chdir(model.config.get_output_dir()) + ccs_args = f'"reset={reset} csim={csim} synth={synth} cosim={cosim} validation={validation}' + ccs_args += f' export={export} vsynth={vsynth} fifo_opt={fifo_opt} bitfile={bitfile} ran_frame={ran_frame}' + ccs_args += f' sw_opt={sw_opt} power={power} da={da} vhdl={vhdl} verilog={verilog} bup={bup}"' + ccs_invoke = catapult_exe + ' -product ultra -shell -f build_prj.tcl -eval \'set ::argv ' + ccs_args + '\'' + print(ccs_invoke) + os.system(ccs_invoke) + os.chdir(curr_dir) + + return parse_catapult_report(model.config.get_output_dir()) + + def _validate_conv_strategy(self, layer): + if layer.model.config.pipeline_style.lower() != 'dataflow': + print(f'WARNING: Layer {layer.name} requires "dataflow" pipeline style. Switching to "dataflow" pipeline style.') + layer.model.config.pipeline_style = 'dataflow' + + @layer_optimizer(Layer) + def init_base_layer(self, layer): + reuse_factor = layer.model.config.get_reuse_factor(layer) + layer.set_attr('reuse_factor', reuse_factor) + + target_cycles = layer.model.config.get_target_cycles(layer) + layer.set_attr('target_cycles', target_cycles) + + @layer_optimizer(Dense) + def init_dense(self, layer): + index_t = IntegerPrecisionType(width=1, signed=False) + compression = layer.model.config.get_compression(layer) + if layer.model.config.is_resource_strategy(layer): + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + if compression: + layer.set_attr('strategy', 'compressed') + index_t = layer.get_weights('weight').type.index_precision + else: + layer.set_attr('strategy', 'resource') + else: + layer.set_attr('strategy', 'latency') + layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', index_t)) + + # TODO consolidate these functions into a single `init_conv` + @layer_optimizer(Conv1D) + def init_conv1d(self, layer): + if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv1D + layer.weights['weight'].data = np.expand_dims(layer.weights['weight'].data, axis=(0, 1)) + + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + out_width = layer.get_output_variable().shape[0] + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(1, out_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + valid_pf_str = ','.join(map(str, valid_pf)) + print( + f'WARNING: Invalid ParallelizationFactor={chosen_pf} in layer "{layer.name}".' + f'Using ParallelizationFactor={closest_pf} instead. Valid ParallelizationFactor(s): {valid_pf_str}.' + ) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', out_width // closest_pf) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + + @layer_optimizer(SeparableConv1D) + def init_sepconv1d(self, layer): + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + layer.set_attr( + 'n_partitions', 1 + ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + # Set the output type of the depthwise phase + dw_out_precision, _ = layer.model.config.get_precision(layer, 'dw_output') + dw_out_name = layer.name + '_dw_out_t' + if layer.model.config.get_config_value('IOType') == 'io_stream': + dw_output_t = PackedType(dw_out_name, dw_out_precision, layer.get_attr('n_chan'), n_pack=1) + else: + dw_output_t = NamedType(dw_out_name, dw_out_precision) + layer.set_attr('dw_output_t', dw_output_t) + + @layer_optimizer(Conv2D) + def init_conv2d(self, layer): + if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D + layer.weights['weight'].data = np.expand_dims(layer.weights['weight'].data, axis=(0, 1)) + + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + self.set_target_reuse_factor(layer) + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + out_height = layer.get_output_variable().shape[0] + out_width = layer.get_output_variable().shape[1] + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(out_height, out_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + valid_pf_str = ','.join(map(str, valid_pf)) + print( + f'WARNING: Invalid ParallelizationFactor={chosen_pf} in layer "{layer.name}".' + f'Using ParallelizationFactor={closest_pf} instead. Valid ParallelizationFactor(s): {valid_pf_str}.' + ) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', out_height * out_width // closest_pf) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + + @layer_optimizer(SeparableConv2D) + def init_sepconv2d(self, layer): + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + layer.set_attr( + 'n_partitions', 1 + ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + # Set the output type of the depthwise phase + dw_out_precision, _ = layer.model.config.get_precision(layer, 'dw_output') + dw_out_name = layer.name + '_dw_out_t' + if layer.model.config.get_config_value('IOType') == 'io_stream': + dw_output_t = PackedType(dw_out_name, dw_out_precision, layer.get_attr('n_chan'), n_pack=1) + else: + dw_output_t = NamedType(dw_out_name, dw_out_precision) + layer.set_attr('dw_output_t', dw_output_t) + + @layer_optimizer(DepthwiseConv2D) + def init_depconv2d(self, layer): + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + layer.set_attr( + 'n_partitions', 1 + ) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + # Set the output type of the depthwise phase + dw_out_precision, _ = layer.model.config.get_precision(layer, 'dw_output') + dw_out_name = layer.name + '_dw_out_t' + if layer.model.config.get_config_value('IOType') == 'io_stream': + dw_output_t = PackedType(dw_out_name, dw_out_precision, layer.get_attr('n_chan'), n_pack=1) + else: + dw_output_t = NamedType(dw_out_name, dw_out_precision) + layer.set_attr('dw_output_t', dw_output_t) + + def _set_pooling_accum_t(self, layer, pool_size): + extra_bits = ceil_log2(pool_size) + accum_t = layer.get_attr('accum_t') + accum_t.precision.width += extra_bits * 2 + if isinstance(accum_t.precision, FixedPrecisionType): + accum_t.precision.integer += extra_bits + + @layer_optimizer(Pooling1D) + def init_pooling1d(self, layer): + pool_size = layer.get_attr('pool_width') + self._set_pooling_accum_t(layer, pool_size) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + @layer_optimizer(Pooling2D) + def init_pooling2d(self, layer): + pool_size = layer.get_attr('pool_height') * layer.get_attr('pool_width') + self._set_pooling_accum_t(layer, pool_size) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + @layer_optimizer(GlobalPooling1D) + def init_global_pooling1d(self, layer): + pool_size = layer.get_attr('n_in') + self._set_pooling_accum_t(layer, pool_size) + + @layer_optimizer(GlobalPooling2D) + def init_global_pooling2d(self, layer): + pool_size = layer.get_attr('in_height') * layer.get_attr('in_width') + self._set_pooling_accum_t(layer, pool_size) + + @layer_optimizer(Softmax) + def init_softmax(self, layer): + if layer.model.config.get_config_value('IOType') == 'io_parallel': + assert ( + len(layer.get_input_variable().shape) == 1 + ), 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.' + + @layer_optimizer(Embedding) + def init_embed(self, layer): + if layer.attributes['n_in'] is None: + raise Exception('Input length of Embedding layer must be specified.') + + @layer_optimizer(LSTM) + def init_lstm(self, layer): + # TODO Allow getting recurrent reuse factor from the config + reuse_factor = layer.model.config.get_reuse_factor(layer) + layer.set_attr('recurrent_reuse_factor', reuse_factor) + + if layer.model.config.is_resource_strategy(layer): + n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') + layer.set_attr('strategy', 'resource') + else: + layer.set_attr('strategy', 'latency') + + layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False))) + + @layer_optimizer(GRU) + def init_gru(self, layer): + reuse_factor = layer.model.config.get_reuse_factor(layer) + layer.set_attr('recurrent_reuse_factor', reuse_factor) + + if layer.model.config.is_resource_strategy(layer): + n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute='recurrent_reuse_factor') + layer.set_attr('strategy', 'resource') + else: + layer.set_attr('strategy', 'latency') + + layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False))) + + @layer_optimizer(GarNet) + def init_garnet(self, layer): + reuse_factor = layer.attributes['reuse_factor'] + + var_converter = CatapultArrayVariableConverter( + type_converter=HLSTypeConverter(precision_converter=ACTypeConverter()) + ) + + # A bit controversial but we are going to set the partitioning of the input here + in_layer = layer.model.graph[layer.inputs[0]] + in_var = layer.get_input_variable(layer.inputs[0]) + partition_factor = in_var.shape[1] * (in_var.shape[0] // reuse_factor) + in_pragma = ('partition', 'cyclic', partition_factor) + new_in_var = var_converter.convert(in_var, pragma=in_pragma) + in_layer.set_attr(layer.inputs[0], new_in_var) + + if layer.attributes['collapse']: + out_pragma = 'partition' + else: + partition_factor = layer._output_features * (layer.attributes['n_vertices'] // reuse_factor) + out_pragma = ('partition', 'cyclic', partition_factor) + + out_name, out_var = next(iter(layer.variables.items())) + new_out_var = var_converter.convert(out_var, pragma=out_pragma) + + layer.set_attr(out_name, new_out_var) + + @layer_optimizer(GarNetStack) + def init_garnet_stack(self, layer): + self.init_garnet(layer) diff --git a/hls4ml/backends/catapult/passes/__init__.py b/hls4ml/backends/catapult/passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/backends/catapult/passes/broadcast_stream.py b/hls4ml/backends/catapult/passes/broadcast_stream.py new file mode 100644 index 0000000000..97019e074b --- /dev/null +++ b/hls4ml/backends/catapult/passes/broadcast_stream.py @@ -0,0 +1,117 @@ +import numpy as np + +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import Concatenate, Layer, Merge, register_layer +from hls4ml.model.optimizer import OptimizerPass + + +class Broadcast(Layer): + '''Inserted between layers for broadcasting.''' + + def initialize(self): + shape = self.attributes['target_shape'] + if shape[0] is None: + shape = shape[1:] + dims = [f'N_SIZE_{i}_{self.index}' for i in range(1, len(shape) + 1)] + self.add_output_variable(shape, dims) + + +broadcast_function_template = 'nnet::broadcast_stream<{input_t}, {output_t}, {config}>({input}, {output});' +broadcast_config_template = """struct config{index} : nnet::broadcast_config {{ + static const unsigned in_width = {in_width}; + static const unsigned in_height = {in_height}; + static const unsigned in_chan = {in_chan}; + static const unsigned out_width = {out_width}; + static const unsigned out_height = {out_height}; + static const unsigned out_chan = {out_chan}; +}};\n""" +broadcast_include_list = ['nnet_utils/nnet_stream.h'] + + +class BroadcastConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Broadcast) + self.template = broadcast_config_template + + def format(self, node): + params = self._default_config_params(node) + params['in_height'] = node.get_input_variable().shape[0] + params['in_width'] = node.get_input_variable().shape[1] + params['in_chan'] = node.get_input_variable().shape[2] + params['out_height'] = node.get_output_variable().shape[0] + params['out_width'] = node.get_output_variable().shape[1] + params['out_chan'] = node.get_output_variable().shape[2] + + return self.template.format(**params) + + +class BroadcastFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Broadcast, include_header=broadcast_include_list) + self.template = broadcast_function_template + + def format(self, node): + params = self._default_function_params(node) + return self.template.format(**params) + + +def register_broadcast_stream(backend): + # Register the layer types to the layer map + register_layer('Broadcast', Broadcast) + + # Register the optimization passes + backend.register_pass('broadcast_stream', BroadcastStream) + + # Register template passes + backend.register_template(BroadcastConfigTemplate) + backend.register_template(BroadcastFunctionTemplate) + + +class BroadcastStream(OptimizerPass): + def match(self, node): + if isinstance(node, Merge) and not isinstance(node, Concatenate): + inp1 = node.get_input_variable(node.inputs[0]) + inp2 = node.get_input_variable(node.inputs[1]) + return inp1.shape != inp2.shape + else: + return False + + def transform(self, model, node): + if model.config.backend.name not in ['Catapult'] or model.config.get_config_value('IOType') != 'io_stream': + return False + + inp = [node.get_input_variable(inp_name) for inp_name in node.inputs] + + if np.prod(inp[0].shape) > np.prod(inp[1].shape): + idx = 1 + attrs = {'target_shape': inp[0].shape} + else: + idx = 0 + attrs = {'target_shape': inp[1].shape} + + def supported_broadcast(inp_shape, target_shape): + # Must be (H, W, C) + if not len(inp_shape) == 3: + return False + # Supported: (1, 1, C) -> (H, W, C) + if inp_shape[0] == inp_shape[1] == 1 and inp_shape[2] == target_shape[2]: + return True + # Supported: (H, W, 1) -> (H, W, C) + if inp_shape[2] == 1 and inp_shape[0] == target_shape[0] and inp_shape[1] == target_shape[1]: + return True + return False + + brdcst_inp = node.inputs[idx] + inp_shape = node.get_input_variable(brdcst_inp).shape + target_shape = attrs['target_shape'] + if not supported_broadcast(inp_shape, target_shape): + raise RuntimeError( + f'Unsupported broadcast type for stream: {inp_shape} -> {target_shape};' + + 'Only (1, 1, C) -> (H, W, C) and (H, W, 1) -> (H, W, C) currently supported' + ) + brdcst_out = 'broadcast_' + brdcst_inp + brdcst_layer = model.make_node('Broadcast', brdcst_out, attrs, [brdcst_inp].copy()) + model.insert_node(brdcst_layer, before=node, input_idx=idx) + node.inputs[idx] = brdcst_out + + return True diff --git a/hls4ml/backends/catapult/passes/conv_same_pad.py b/hls4ml/backends/catapult/passes/conv_same_pad.py new file mode 100755 index 0000000000..bb8354a3d0 --- /dev/null +++ b/hls4ml/backends/catapult/passes/conv_same_pad.py @@ -0,0 +1,109 @@ +from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.optimizer import OptimizerPass + + +class InsertZeroPaddingBeforeConv1D(OptimizerPass): + name = 'insert_zero_padding_before_conv1d' + + def match(self, node): + is_match = ( + isinstance(node, (Conv1D, SeparableConv1D)) + and ((node.get_attr('padding') == 'same') or (node.get_attr('padding') == 'causal')) + and node.get_attr('filt_width') != 1 + ) + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv1D layer + pad_left = node.get_attr('pad_left') + pad_right = node.get_attr('pad_right') + + # Check if no padding needs to be done + if pad_left == pad_right == 0: + return False + + out_width = pad_left + node.get_attr('in_width') + pad_right + + attrs = { + 'pad_left': pad_left, + 'pad_right': pad_right, + 'in_width': node.get_attr('in_width'), + 'out_width': out_width, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last'), + } + + # Switch Conv1D layer padding to 'valid' + node.set_attr('padding', 'valid') + node.set_attr('pad_left', 0) + node.set_attr('pad_right', 0) + node.set_attr('in_width', out_width) + + # Insert new ZeroPadding1D node above Conv1D + padding_layer = model.make_node('ZeroPadding1D', 'zp1d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer) + + return True + + +class InsertZeroPaddingBeforeConv2D(OptimizerPass): + name = 'insert_zero_padding_before_conv2d' + + def match(self, node): + is_match = ( + isinstance(node, (Conv2D, SeparableConv2D)) + and node.get_attr('padding') == 'same' + and node.get_attr('filt_height') != 1 + and node.get_attr('filt_width') != 1 + ) + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv2D layer + pad_top = node.get_attr('pad_top') + pad_bottom = node.get_attr('pad_bottom') + pad_left = node.get_attr('pad_left') + pad_right = node.get_attr('pad_right') + + # Check if no padding neeeds to be done + if pad_top == pad_bottom == pad_left == pad_right == 0: + return False + + out_height = pad_top + node.get_attr('in_height') + pad_bottom + out_width = pad_left + node.get_attr('in_width') + pad_right + + attrs = { + 'pad_top': pad_top, + 'pad_bottom': pad_bottom, + 'pad_left': pad_left, + 'pad_right': pad_right, + 'in_height': node.get_attr('in_height'), + 'in_width': node.get_attr('in_width'), + 'out_height': out_height, + 'out_width': out_width, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last'), + } + + # Switch Conv2D layer padding to 'valid' + node.set_attr('padding', 'valid') + node.set_attr('pad_top', 0) + node.set_attr('pad_bottom', 0) + node.set_attr('pad_left', 0) + node.set_attr('pad_right', 0) + node.set_attr('in_height', out_height) + node.set_attr('in_width', out_width) + + # Insert new ZeroPadding2D node above Conv2D + padding_layer = model.make_node('ZeroPadding2D', 'zp2d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer, before=node) + + return True diff --git a/hls4ml/backends/catapult/passes/conv_stream.py b/hls4ml/backends/catapult/passes/conv_stream.py new file mode 100755 index 0000000000..e0bb853d83 --- /dev/null +++ b/hls4ml/backends/catapult/passes/conv_stream.py @@ -0,0 +1,52 @@ +from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.optimizer import OptimizerPass + + +class GenerateConvStreamingInstructions(OptimizerPass): + '''Generates the instructions for streaming implementation of CNNs''' + + def match(self, node): + return isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D)) + + def transform(self, model, node): + node_class = node.__class__.__name__ + if '1D' in node_class: + self._generate_1d_instructions(node) + elif '2D' in node_class: + self._generate_2d_instructions(node) + else: + raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})') + + def _generate_1d_instructions(self, node): + if node.model.config.get_config_value('IOType') == 'io_stream': + min_w, instructions = node.model.config.backend.compute_conv1d_instructions( + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_attr('filt_width'), + node.get_attr('stride_width'), + ) + instructions_str = ','.join(str(i) for i in instructions) + node.set_attr('min_width', min_w) + node.set_attr('instructions', instructions_str) + else: + # these are unused; just put dummy values + node.set_attr('min_width', node.get_attr('in_width')) + node.set_attr('instructions', '0') + + def _generate_2d_instructions(self, node): + if node.model.config.get_config_value('IOType') == 'io_stream': + min_h, min_w, instructions = node.model.config.backend.compute_conv2d_instructions( + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_input_variable().shape[2], + node.get_attr('filt_height'), + node.get_attr('stride_height'), + ) + instructions_str = ','.join(str(i) for i in instructions) + node.set_attr('min_height', min_h) + node.set_attr('min_width', min_w) + node.set_attr('instructions', instructions_str) + else: + node.set_attr('min_height', node.get_attr('in_height')) + node.set_attr('min_width', node.get_attr('in_width')) + node.set_attr('instructions', '0') diff --git a/hls4ml/backends/catapult/passes/convolution_templates.py b/hls4ml/backends/catapult/passes/convolution_templates.py new file mode 100755 index 0000000000..8014a4ac8e --- /dev/null +++ b/hls4ml/backends/catapult/passes/convolution_templates.py @@ -0,0 +1,508 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import ( + Conv1D, + Conv2D, + Conv2DBatchnorm, + DepthwiseConv1D, + DepthwiseConv2D, + SeparableConv1D, + SeparableConv2D, +) + +# Shared multiplication template + +conv_mult_config_template = """struct config{index}_mult : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned reuse_factor = {reuse}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned n_zeros = {nzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + template + using product = nnet::product::{product_type}; +}};\n""" + +# Conv1D templates + +conv1d_config_template = """struct config{index} : nnet::conv1d_config {{ + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_width = {stride_width}; + static const unsigned dilation = {dilation}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned multiplier_limit = + DIV_ROUNDUP(kernel_size * n_chan * n_filt, reuse_factor) - n_zeros / reuse_factor; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_width = {min_width}; + static const ac_int pixels[min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned n_pixels = out_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; + template + using scale_index = nnet::{scale_index_type}; +}}; +// really this allocation of pixels array ought to be in a .cpp file +#ifndef INCLUDED_MC_TESTBENCH_H +const ac_int config{index}::pixels[] = {{{instructions}}}; +#endif\n""" + +conv1d_function_template = 'nnet::conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +depthconv1d_function_template = ( + 'nnet::depthwise_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +) + +conv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_conv1d_stream.h'] + + +class Conv1DConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((Conv1D, DepthwiseConv1D)) + self.template = conv1d_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + + params['config_t'] = f'config{node.index}_mult' + if node.get_attr('in_width') == node.get_attr('min_width'): + params['scale_index_type'] = 'scale_index_unscaled' + else: + params['scale_index_type'] = 'scale_index_regular' + + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = f'fill_buffer_{node.index}' + else: + params['fill_fn'] = 'FillConv1DBuffer' + + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['nzeros'] = node.get_weights('weight').nzeros + mult_params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision + ) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + + +class Conv1DFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Conv1D, include_header=conv1d_include_list) + self.template = conv1d_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) + + +class DepthwiseConv1DFunctionTemplate(Conv1DFunctionTemplate): + def __init__(self): + super(Conv1DFunctionTemplate, self).__init__(DepthwiseConv1D, include_header=sepconv1d_include_list) + self.template = depthconv1d_function_template + + +# Conv2D Templates + +conv2d_config_template = """struct config{index} : nnet::conv2d_config {{ + static const unsigned pad_top = {pad_top}; + static const unsigned pad_bottom = {pad_bottom}; + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_height = {in_height}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_height = {filt_height}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_height = {stride_height}; + static const unsigned stride_width = {stride_width}; + static const unsigned out_height = {out_height}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned multiplier_limit = + DIV_ROUNDUP(kernel_size * n_chan * n_filt, reuse_factor) - n_zeros / reuse_factor; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_height = {min_height}; + static const unsigned min_width = {min_width}; + static const ac_int pixels[min_height * min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned n_pixels = out_height * out_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; + template + using scale_index_height = nnet::{scale_index_height_type}; + template + using scale_index_width = nnet::{scale_index_width_type}; +}}; +// really this allocation of pixels array ought to be in a .cpp file +#ifndef INCLUDED_MC_TESTBENCH_H +const ac_int config{index}::pixels[] = {{{instructions}}}; +#endif\n""" + +conv2d_function_template = 'nnet::conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +depthconv2d_function_template = ( + 'nnet::depthwise_conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +) + +conv2d_include_list = ['nnet_utils/nnet_conv2d.h', 'nnet_utils/nnet_conv2d_stream.h'] + + +class Conv2DConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((Conv2D, Conv2DBatchnorm, DepthwiseConv2D)) + self.template = conv2d_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + + params['config_t'] = f'config{node.index}_mult' + + if node.get_attr('in_height') == node.get_attr('min_height'): + params['scale_index_height_type'] = 'scale_index_unscaled' + else: + params['scale_index_height_type'] = 'scale_index_regular' + + if node.get_attr('in_width') == node.get_attr('min_width'): + params['scale_index_width_type'] = 'scale_index_unscaled' + else: + params['scale_index_width_type'] = 'scale_index_regular' + + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = f'fill_buffer_{node.index}' + else: + params['fill_fn'] = 'FillConv2DBuffer' + + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width') + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['nzeros'] = node.get_weights('weight').nzeros + mult_params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision + ) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + + +class Conv2DFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((Conv2D, Conv2DBatchnorm), include_header=conv2d_include_list) + self.template = conv2d_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) + + +class DepthwiseConv2DFunctionTemplate(Conv2DFunctionTemplate): + def __init__(self): + super(Conv2DFunctionTemplate, self).__init__(DepthwiseConv2D, include_header=sepconv2d_include_list) + self.template = depthconv2d_function_template + + +# SeparableConv1D/2D Templates + +sepconv_config_template = """struct config{index} {{ + typedef {depthwise_config} depthwise_config; + typedef {pointwise_config} pointwise_config; +}};\n""" + +sepconv1d_function_template = ( + 'nnet::separable_conv_1d_{data_format}<{input_t}, {dw_output_t}, {output_t}, {config}>(' + '{input}, {output}, {d}, {p}, {z}, {b});' +) +sepconv2d_function_template = ( + 'nnet::separable_conv_2d_{data_format}<{input_t}, {dw_output_t}, {output_t}, {config}>(' + '{input}, {output}, {d}, {p}, {z}, {b});' +) + +sepconv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_sepconv1d_stream.h'] +sepconv2d_include_list = ['nnet_utils/nnet_conv2d.h', 'nnet_utils/nnet_sepconv2d.h', 'nnet_utils/nnet_sepconv2d_stream.h'] + + +class SeparableConv1DConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(SeparableConv1D) + self.template = sepconv_config_template + self.depthwise_template = conv1d_config_template + self.pointwise_template = conv1d_config_template + self.depthwise_mult_template = conv_mult_config_template + self.pointwise_mult_template = conv_mult_config_template + + def format(self, node): + # Separable master config + params = {} + params['index'] = node.index + params['depthwise_config'] = f'config{node.index}_depthwise' + params['pointwise_config'] = f'config{node.index}_pointwise' + sep_config = self.template.format(**params) + + # Depthwise config + params = self._default_config_params(node) + # Override bias and bias_t since these are zeros in depthwise step of SepConv1D + params['bias'] = params['zero_bias'] + params['bias_t'] = params['zero_bias_t'] + params['n_filt'] = params['n_chan'] # In depthwise step n_chan == n_filt + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('depthwise').nzeros + params['index'] = str(node.index) + '_depthwise' + params['weight_t'] = node.get_weights('depthwise').type + params['fill_fn'] = 'FillConv1DBuffer' + + if node.get_attr('unscaled'): + params['scale_index_type'] = 'scale_index_unscaled' + else: + params['scale_index_type'] = 'scale_index_regular' + + params['config_t'] = f'config{node.index}_depthwise_mult' + depthwise_config = self.depthwise_template.format(**params) + + # Depthwise mult config + mult_params = self._default_config_params(node) + mult_params['index'] = str(node.index) + '_depthwise' + mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') + mult_params['n_out'] = node.get_attr('n_chan') + mult_params['nzeros'] = node.get_weights('depthwise').nzeros + mult_params['weight_t'] = node.get_weights('depthwise').type + mult_params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('depthwise').type.precision + ) + depthwise_mult_config = self.depthwise_mult_template.format(**mult_params) + + # Pointwise config + params = self._default_config_params(node) + if node.get_attr('data_format') == 'channels_last': + params['in_width'] = node.get_output_variable().shape[0] + else: + params['in_width'] = node.get_output_variable().shape[1] + + params['filt_width'] = 1 + params['stride_width'] = 1 + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('pointwise').nzeros + params['index'] = str(node.index) + '_pointwise' + params['weight_t'] = node.get_weights('pointwise').type + params['min_width'] = params['in_width'] + params['instructions'] = '0' + params['fill_fn'] = 'FillConv1DBuffer' + + if node.get_attr('unscaled'): + params['scale_index_type'] = 'scale_index_unscaled' + else: + params['scale_index_type'] = 'scale_index_regular' + + params['config_t'] = f'config{node.index}_pointwise_mult' + pointwise_config = self.pointwise_template.format(**params) + + # Pointwise mult config + mult_params = self._default_config_params(node) + mult_params['index'] = str(node.index) + '_pointwise' + mult_params['n_in'] = node.get_attr('n_chan') + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['nzeros'] = node.get_weights('pointwise').nzeros + mult_params['weight_t'] = node.get_weights('pointwise').type + mult_params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('pointwise').type.precision + ) + pointwise_mult_config = self.pointwise_mult_template.format(**mult_params) + + return ( + depthwise_mult_config + + '\n' + + depthwise_config + + '\n' + + pointwise_mult_config + + '\n' + + pointwise_config + + '\n' + + sep_config + ) + + +class SeparableConv1DFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(SeparableConv1D, include_header=sepconv1d_include_list) + self.template = sepconv1d_function_template + + def format(self, node): + params = self._default_function_params(node) + params['dw_output_t'] = node.get_attr('dw_output_t').name + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['d'] = node.get_weights('depthwise').name + params['p'] = node.get_weights('pointwise').name + params['b'] = node.get_weights('bias').name + params['z'] = node.get_weights('zero_bias').name + + return self.template.format(**params) + + +class SeparableConv2DConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(SeparableConv2D) + self.template = sepconv_config_template + self.depthwise_template = conv2d_config_template + self.pointwise_template = conv2d_config_template + self.depthwise_mult_template = conv_mult_config_template + self.pointwise_mult_template = conv_mult_config_template + + def format(self, node): + # Separable master config + params = {} + params['index'] = node.index + params['depthwise_config'] = f'config{node.index}_depthwise' + params['pointwise_config'] = f'config{node.index}_pointwise' + sep_config = self.template.format(**params) + + # Depthwise config + params = self._default_config_params(node) + # Override bias and bias_t since these are zeros in depthwise step of SepConv2D + params['bias'] = params['zero_bias'] + params['bias_t'] = params['zero_bias_t'] + params['n_filt'] = params['n_chan'] # In depthwise step n_chan == n_filt + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('depthwise').nzeros + params['index'] = str(node.index) + '_depthwise' + params['weight_t'] = node.get_weights('depthwise').type + params['fill_fn'] = 'FillConv2DBuffer' + + if node.get_attr('unscaled_h'): + params['scale_index_height_type'] = 'scale_index_unscaled' + else: + params['scale_index_height_type'] = 'scale_index_regular' + + if node.get_attr('unscaled_w'): + params['scale_index_width_type'] = 'scale_index_unscaled' + else: + params['scale_index_width_type'] = 'scale_index_regular' + + params['config_t'] = f'config{node.index}_depthwise_mult' + depthwise_config = self.depthwise_template.format(**params) + + # Depthwise mult config + mult_params = self._default_config_params(node) + mult_params['index'] = str(node.index) + '_depthwise' + mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width') + mult_params['n_out'] = node.get_attr('n_chan') + mult_params['nzeros'] = node.get_weights('depthwise').nzeros + mult_params['weight_t'] = node.get_weights('depthwise').type + mult_params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('depthwise').type.precision + ) + depthwise_mult_config = self.depthwise_mult_template.format(**mult_params) + + # Pointwise config + params = self._default_config_params(node) + if node.get_attr('data_format') == 'channels_last': + params['in_height'] = node.get_output_variable().shape[0] + params['in_width'] = node.get_output_variable().shape[1] + else: + params['in_height'] = node.get_output_variable().shape[1] + params['in_width'] = node.get_output_variable().shape[2] + + params['filt_height'] = params['filt_width'] = 1 + params['stride_height'] = params['stride_width'] = 1 + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('pointwise').nzeros + params['index'] = str(node.index) + '_pointwise' + params['weight_t'] = node.get_weights('pointwise').type + params['min_height'] = params['in_height'] + params['min_width'] = params['in_width'] + params['instructions'] = '0' + params['fill_fn'] = 'FillConv2DBuffer' + + if node.get_attr('unscaled_h'): + params['scale_index_height_type'] = 'scale_index_unscaled' + else: + params['scale_index_height_type'] = 'scale_index_regular' + + if node.get_attr('unscaled_w'): + params['scale_index_width_type'] = 'scale_index_unscaled' + else: + params['scale_index_width_type'] = 'scale_index_regular' + params['config_t'] = f'config{node.index}_pointwise_mult' + pointwise_config = self.pointwise_template.format(**params) + + # Pointwise mult config + mult_params = self._default_config_params(node) + mult_params['index'] = str(node.index) + '_pointwise' + mult_params['n_in'] = node.get_attr('n_chan') + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['nzeros'] = node.get_weights('pointwise').nzeros + mult_params['weight_t'] = node.get_weights('pointwise').type + mult_params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('pointwise').type.precision + ) + pointwise_mult_config = self.pointwise_mult_template.format(**mult_params) + + return ( + depthwise_mult_config + + '\n' + + depthwise_config + + '\n' + + pointwise_mult_config + + '\n' + + pointwise_config + + '\n' + + sep_config + ) + + +class SeparableConv2DFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(SeparableConv2D, include_header=sepconv2d_include_list) + self.template = sepconv2d_function_template + + def format(self, node): + params = self._default_function_params(node) + params['dw_output_t'] = node.get_attr('dw_output_t').name + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['d'] = node.get_weights('depthwise').name + params['p'] = node.get_weights('pointwise').name + params['b'] = node.get_weights('bias').name + params['z'] = node.get_weights('zero_bias').name + + return self.template.format(**params) diff --git a/hls4ml/backends/catapult/passes/convolution_winograd.py b/hls4ml/backends/catapult/passes/convolution_winograd.py new file mode 100644 index 0000000000..8b25ab41b8 --- /dev/null +++ b/hls4ml/backends/catapult/passes/convolution_winograd.py @@ -0,0 +1,175 @@ +import math + +import numpy as np + +from hls4ml.model.layers import Conv1D, Conv2D +from hls4ml.model.optimizer import OptimizerPass + + +class ApplyWinogradKernelTransformation(OptimizerPass): + ''' + Transforms the weights of a Conv2D kernel to a format suitable for Wingorad convolution + For further information, refer to Lavin & Gray, 2015 - Fast Algorithms for Convolutional Neural Networks + ''' + + def match(self, node): + node_matches = isinstance(node, (Conv1D, Conv2D)) + + # This optimizer works only after the Resource Strategy Optimizer, since order of transposition matters + weights_transformed = node.get_attr('_weights_transposed', False) is True + + # User opted for Winograd + implementation_is_winograd = ( + node.get_attr('implementation', 'combination') == 'combination' + or node.get_attr('implementation', 'combination') == 'winograd' + ) + + parallel_io_type = node.model.config.get_config_value('IOType') == 'io_parallel' + + # Winograd algorithm-specific conditions + if isinstance(node, Conv1D): + # Winograd only applies to specific kernel sizes + # Current implementation only supports fs = 3; easily extendable to other filter sizes + filter_size_matches = node.get_attr('filt_width', 3) == 3 + + # Winograd's minimal filtering algorithm doesn't work with stride != 1 + stride_is_one = node.get_attr('stride_width', 1) == 1 + + # HLS Compiler fails to pipeline the entire component if Winograd loop only executes once + loop_itr_gt_one = node.get_attr('out_width') > 2 + + winograd_conditions = filter_size_matches and stride_is_one and loop_itr_gt_one and parallel_io_type + + elif isinstance(node, (Conv2D)): + # Winograd only applies to specific kernel sizes + # Current implementation only supports fs = 3; easily extendable to other filter sizes + filter_size_matches = node.get_attr('filt_height', 3) == 3 and node.get_attr('filt_width', 3) == 3 + + # Winograd's minimal filtering algorithm doesn't work with striede != 1 + stride_is_one = node.get_attr('stride_height', 1) == 1 and node.get_attr('stride_width', 1) == 1 + + # HLS Compiler fails to pipeline the entire component if Winograd loop only executes once + loop_itr_gt_one = node.get_attr('out_height') > 2 and node.get_attr('out_width') > 2 + + padding_is_equal = node.get_attr('pad_top', 0) == node.get_attr('pad_bottom', 0) and node.get_attr( + 'pad_left', 0 + ) == node.get_attr('pad_right', 0) + + winograd_conditions = ( + filter_size_matches and stride_is_one and padding_is_equal and loop_itr_gt_one and parallel_io_type + ) + + else: + winograd_conditions = False + + # Check any previous transformations + already_transformed = node.get_attr('_winograd_transformation_applied', False) is True + + if not winograd_conditions and node.get_attr('implementation', 'combination') == 'winograd': + raise RuntimeError( + 'Not possible to use Winograd algorithm with current architecture. ' + 'Please set implementation to im2col or combination' + ) + + return ( + node_matches + and weights_transformed + and winograd_conditions + and not already_transformed + and implementation_is_winograd + ) + + def transform(self, model, node): + if isinstance(node, Conv1D): + if node.get_attr('filt_width', 3) == 3: + # First, transpose to a format suitable for the Winograd algorithm (F, C, W) + # Note, this assumes a format post-resource strategy optimizer, that is (F, W, C) + # Therefore, (F, W, C) => (F, C, W) + node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[0, 2, 1]) + + # Temporary copy of data + weights = node.weights['weight'].data + + # Expand weight dimensionality (3) => (4) + node.weights['weight'].data = np.zeros((weights.shape[0], weights.shape[1], 4)) + + # Transformation matrices for 3x1 kernels + G = np.array([[1, 0, 0], [0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [0, 0, 1]]) + + # Transformation GfG' + for filter in range(0, weights.data.shape[0]): + for channel in range(0, weights.data.shape[1]): + node.weights['weight'].data[filter][channel] = np.matmul(G, weights[filter][channel]) + node.weights['weight'].data_length = node.weights['weight'].data.size + + # Winograd's minimal filtering algorithm transforms the weight matrix + # This transformation consists of addition and division (by 2&4) of the weight matrix + # Therefore, increase precision (if needed), to accomodate for new weights + # This error is only noticeable for low precisions, such as those used with QKeras + + # Integer precision is only updated if it exceeds the one defined in hls4ml config + maximum_value_rounded = int(math.ceil(np.abs(node.weights['weight'].data).max())) + if maximum_value_rounded.bit_length() + 1 > node.weights['weight'].type.precision.integer: + node.weights['weight'].type.precision.integer = maximum_value_rounded.bit_length() + 1 + node.weights['weight'].type.precision.width += ( + maximum_value_rounded.bit_length() + 1 - node.weights['weight'].type.precision.integer + ) + + # Fractional precision is increased by 2 bits (division by 4), + # for low-precision (less than 8) fractional weights + if node.weights['weight'].type.precision.fractional < 8: + node.weights['weight'].type.precision.width += 2 + + # Modified kernel size + node.set_attr('impl_filt_width', 4) + + elif isinstance(node, Conv2D): + if node.get_attr('filt_height', 3) == 3 and node.get_attr('filt_width', 3) == 3: + # First, transpose to a format suitable for the Winograd algorithm (F, C, H, W) + # Note, this assumes a format post-resource strategy optimizer, that is (F, H, W, C) + # Therefore, (F, H, W, C) => (F, C, H, W) + node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[0, 3, 1, 2]) + + # Temporary copy of data + weights = node.weights['weight'].data + + # Expand weight dimensionality (3x3) => (4x4) + node.weights['weight'].data = np.zeros((weights.shape[0], weights.shape[1], 4, 4)) + + # Transformation matrices for 3x3 kernels + G = np.array([[1, 0, 0], [0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [0, 0, 1]]) + GT = np.array([[1, 0.5, 0.5, 0], [0, 0.5, -0.5, 0], [0, 0.5, 0.5, 1]]) + + # Transformation GfG' + for filter in range(0, weights.data.shape[0]): + for channel in range(0, weights.data.shape[1]): + node.weights['weight'].data[filter][channel] = np.matmul(np.matmul(G, weights[filter][channel]), GT) + node.weights['weight'].data_length = node.weights['weight'].data.size + + # Winograd's minimal filtering algorithm transforms the weight matrix + # This transformation consists of addition and division (by 2&4) of the weight matrix + # Therefore, increase precision (if needed), to accomodate for new weights + # This error is only noticeable for low precisions, such as those used with QKeras + + # Integer precision is only updated if it exceeds the one defined in hls4ml config + maximum_value_rounded = int(math.ceil(np.abs(node.weights['weight'].data).max())) + if maximum_value_rounded.bit_length() + 1 > node.weights['weight'].type.precision.integer: + node.weights['weight'].type.precision.integer = maximum_value_rounded.bit_length() + 1 + node.weights['weight'].type.precision.width += ( + maximum_value_rounded.bit_length() + 1 - node.weights['weight'].type.precision.integer + ) + + # Fractional precision is increased by 2 bits (division by 4), + # for low-precision (less than 8) fractional weights + if node.weights['weight'].type.precision.fractional < 8: + node.weights['weight'].type.precision.width += 2 + + # Modified kernel size + node.set_attr('impl_filt_height', 4) + node.set_attr('impl_filt_width', 4) + else: + raise Exception(f'Unexpected layer {node.class_name} with Winograd kernel optimizer') + + node.set_attr('_winograd_transformation_applied', True) + + return False diff --git a/hls4ml/backends/catapult/passes/core_templates.py b/hls4ml/backends/catapult/passes/core_templates.py new file mode 100755 index 0000000000..2088923428 --- /dev/null +++ b/hls4ml/backends/catapult/passes/core_templates.py @@ -0,0 +1,216 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax + +# Dense templates + +dense_config_template = """struct config{index} : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned n_nonzeros = {nonzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + static const bool store_weights_in_bram = false; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {index_t.name} index_t; + template + using product = nnet::product::{product_type}; +}};\n""" + +dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' + +dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h', 'nnet_utils/nnet_dense_stream.h'] + + +class DenseConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Dense) + self.template = dense_config_template + + def format(self, node): + params = self._default_config_params(node) + params['nzeros'] = node.get_weights('weight').nzeros + params['nonzeros'] = node.get_weights('weight').nonzeros + params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision + ) + + return self.template.format(**params) + + +class DenseFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Dense, include_header=dense_include_list) + self.template = dense_function_template + + def format(self, node): + params = self._default_function_params(node) + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) + + +# BatchNormalization templates + +batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_filt = {n_filt}; + static const unsigned n_scale_bias = (n_filt == -1) ? n_in : n_filt; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in, reuse_factor); + static const bool store_weights_in_bram = false; + typedef {bias_t.name} bias_t; + typedef {scale_t.name} scale_t; + template + using product = nnet::product::{product_type}; +}};\n""" + +batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});' + +batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h'] + + +class BatchNormalizationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(BatchNormalization) + self.template = batchnorm_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_in'] = node.get_input_variable().size_cpp() + params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('scale').type.precision + ) + + return self.template.format(**params) + + +class BatchNormalizationFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(BatchNormalization, include_header=batchnorm_include_list) + self.template = batchnorm_function_template + + def format(self, node): + params = self._default_function_params(node) + params['scale'] = node.get_weights('scale').name + params['bias'] = node.get_weights('bias').name + + return self.template.format(**params) + + +# Activation templates + +activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {table_t.name} table_t; +}};\n""" + +hard_activ_config_template = """struct {type}_config{index} {{ + static const unsigned n_in = {n_in}; + static const {slope_t.name} slope; + static const {shift_t.name} shift; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; +}}; +// really this allocation of pixels array ought to be in a .cpp file +#ifndef INCLUDED_MC_TESTBENCH_H +const {slope_t.name} {type}_config{index}::slope = {slope}; +const {shift_t.name} {type}_config{index}::shift = {shift}; +#endif\n""" + +softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + static const unsigned axis = {axis}; + static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation}; + typedef {exp_table_t.name} exp_table_t; + typedef {inv_table_t.name} inv_table_t; +}};\n""" + +activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});' +param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});' + +activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h'] + + +class ActivationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((Activation, ParametrizedActivation, PReLU)) + self.template = activ_config_template + + def format(self, node): + params = self._default_config_params(node) + params['type'] = node.get_attr('activation') + + return self.template.format(**params) + + +class HardActivationConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(HardActivation) + self.template = hard_activ_config_template + + def format(self, node): + params = self._default_config_params(node) + params['type'] = node.get_attr('activation') + + return self.template.format(**params) + + +class SoftmaxConfigTemplate(ActivationConfigTemplate): + def __init__(self): + super(ActivationConfigTemplate, self).__init__(Softmax) # Skip ActivationConfigTemplate's __init__ + self.template = softmax_config_template + + +class ActivationFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((Activation, HardActivation, Softmax), include_header=activ_include_list) + self.template = activ_function_template + + def format(self, node): + params = self._default_function_params(node) + params['activation'] = node.get_attr('activation').lower() + params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index) + + return self.template.format(**params) + + +class ParametrizedActivationFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(ParametrizedActivation, include_header=activ_include_list) + self.template = param_activ_function_template + + def format(self, node): + params = self._default_function_params(node) + params['activation'] = node._get_act_function_name() + params['param'] = node.get_attr('activ_param', 1.0) + params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index) + + return self.template.format(**params) + + +class PReLUFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(PReLU, include_header=activ_include_list) + self.template = param_activ_function_template + + def format(self, node): + params = self._default_function_params(node) + params['activation'] = node.get_attr('activation').lower() + params['param'] = node.get_weights('alpha').name + params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index) + + return self.template.format(**params) diff --git a/hls4ml/backends/catapult/passes/fifo_depth_optimization.py b/hls4ml/backends/catapult/passes/fifo_depth_optimization.py new file mode 100755 index 0000000000..4d92e98de1 --- /dev/null +++ b/hls4ml/backends/catapult/passes/fifo_depth_optimization.py @@ -0,0 +1,104 @@ +import json + +from pyDigitalWaveTools.vcd.parser import VcdParser + +from hls4ml.model.optimizer.optimizer import ConfigurableOptimizerPass, ModelOptimizerPass + + +def populate_values(values, name, data, depth): + def get_values(x): + return int(x[1][1:], 2) + + values.append({'name': name, 'data': [], 'max': 0, 'depth': 0}) + values[-1]['data'] = [get_values(x) for x in data] + values[-1]['max'] = max(values[-1]['data']) + values[-1]['depth'] = int(depth[0][1][1:], 2) + return values + + +def set_big_fifos(vars_to_profile, profiling_fifo_depth): + for v in vars_to_profile.values(): + if v.pragma: + v.pragma = (v.pragma[0], profiling_fifo_depth) + + +def get_vcd_data(model): + model.write() + model.build(reset=False, csim=True, synth=True, cosim=True, validation=False, export=False, vsynth=False, fifo_opt=True) + + with open( + model.config.get_output_dir() + + '/' + + model.config.get_project_name() + + '_prj' + + '/solution1/sim/verilog/fifo_opt.vcd' + ) as vcd_file: + vcd = VcdParser() + vcd.parse(vcd_file) + data = vcd.scope.toJson() + return data + + +def generate_max_depth_file(model, maxs): + with open(model.config.get_output_dir() + '/max_depth.json', 'w') as f: + json.dump(maxs, f, indent=4) + + +def set_fifo_depth(model, maxs): + for v in model.output_vars.values(): + if v.pragma: + filtered_max = [x['max'] for x in maxs if v.name in x['name']] + if len(filtered_max) == 0: + continue + if len(filtered_max) > 1: + print('WARNING! Check names of FIFOs') + v.pragma = (v.pragma[0], filtered_max[0] + 1) + + +class FifoDepthOptimization(ConfigurableOptimizerPass, ModelOptimizerPass): + def __init__(self): + self.values = [] + + def transform(self, model): + # use `large_fifo_depth = 0` to keep the default fifo depth + profiling_fifo_depth = getattr(self, 'profiling_fifo_depth', 100_000) + + # check axi-stream or io-stream, if not one the 2 exit + if not (model.config.get_config_value('IOType') == 'io_stream'): + raise RuntimeError('To use this optimization you have to set `IOType` field to `io_stream` in the HLS config') + + # initialize all the fifos to `profiling_fifo_depth` so that they will be automatically implemented in BRAMs + # and so they will be profiled + if profiling_fifo_depth: + vars_to_profile = { + k: v + for k, v in model.output_vars.items() + if v != model.get_output_variables()[0] and v != model.get_input_variables()[0] + } + + set_big_fifos(vars_to_profile, profiling_fifo_depth) + + data = get_vcd_data(model) + + if len(data['children']) == 0: + print( + "FIFO depth optimization found no FIFOs implemented using BRAMs in the design, no optimization is possible." + ) + print("Consider increasing profiling_fifo_depth.") + return False + + n_elem = len(data['children'][0]['children'][0]['children']) + for i in range(n_elem): + name = data['children'][0]['children'][0]['children'][i]['name'] + data_p = data['children'][0]['children'][0]['children'][i]['children'][0]['data'] + depth = data['children'][0]['children'][0]['children'][i]['children'][1]['data'] + populate_values(self.values, name, data_p, depth) + + maxs = [{'name': i['name'], 'max': i['max'], 'depth': i['depth']} for i in self.values] + + generate_max_depth_file(model, maxs) + + set_fifo_depth(model, maxs) + + print('[hls4ml] - FIFO optimization completed') + return False diff --git a/hls4ml/backends/catapult/passes/garnet_templates.py b/hls4ml/backends/catapult/passes/garnet_templates.py new file mode 100755 index 0000000000..f73f627683 --- /dev/null +++ b/hls4ml/backends/catapult/passes/garnet_templates.py @@ -0,0 +1,249 @@ +import numpy as np + +from hls4ml.backends.fpga.fpga_types import ACTypeConverter +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import GarNet, GarNetStack +from hls4ml.model.types import FixedPrecisionType + +# GarNet templates + +garnet_common_config_template = """ + static const unsigned n_vertices = {n_vertices}; + static const unsigned n_vertices_width = {n_vertices_width}; + static const unsigned n_in_features = {n_in_features}; + static const unsigned distance_width = {distance_width}; + static const unsigned output_collapse = {collapse_type}; + static const bool mean_by_nvert = {mean_by_nvert}; + + typedef {norm_t} norm_t; + typedef ac_fixed<{distance_width}, {distance_nint}, true, AC_TRN, AC_SAT> distance_t; + typedef {edge_weight_t} edge_weight_t; + typedef {edge_weight_aggr_t} edge_weight_aggr_t; + typedef {aggr_t} aggr_t; + typedef {output_t} output_t; + + static const unsigned reuse_factor = {reuse}; + static const unsigned log2_reuse_factor = {log2_reuse}; +""" + +garnet_config_template = """struct config{index} : nnet::garnet_config {{""" +garnet_config_template += garnet_common_config_template +garnet_config_template += """ + static const unsigned n_propagate = {n_propagate}; + static const unsigned n_aggregators = {n_aggregators}; + static const unsigned n_out_features = {n_out_features}; + + typedef {input_transform_weights_t} input_transform_weights_t; + typedef {input_transform_biases_t} input_transform_biases_t; + typedef {aggregator_distance_weights_t} aggregator_distance_weights_t; + typedef {aggregator_distance_biases_t} aggregator_distance_biases_t; + typedef {output_transform_weights_t} output_transform_weights_t; + typedef {output_transform_biases_t} output_transform_biases_t; + + static const input_transform_weights_t (&input_transform_weights)[{input_transform_weights_size}]; + static const input_transform_biases_t (&input_transform_biases)[{input_transform_biases_size}]; + static const aggregator_distance_weights_t (&aggregator_distance_weights)[{aggregator_distance_weights_size}]; + static const aggregator_distance_biases_t (&aggregator_distance_biases)[{aggregator_distance_biases_size}]; + static const output_transform_weights_t (&output_transform_weights)[{output_transform_weights_size}]; + static const output_transform_biases_t (&output_transform_biases)[{output_transform_biases_size}]; + + typedef config{index} base_t; +}}; + +const config{index}::input_transform_weights_t (&config{index}::input_transform_weights)[{input_transform_weights_size}] = {input_transform_weights}; +const config{index}::input_transform_biases_t (&config{index}::input_transform_biases)[{input_transform_biases_size}] = {input_transform_biases}; +const config{index}::aggregator_distance_weights_t (&config{index}::aggregator_distance_weights)[{aggregator_distance_weights_size}] = {aggregator_distance_weights}; +const config{index}::aggregator_distance_biases_t (&config{index}::aggregator_distance_biases)[{aggregator_distance_biases_size}] = {aggregator_distance_biases}; +const config{index}::output_transform_weights_t (&config{index}::output_transform_weights)[{output_transform_weights_size}] = {output_transform_weights}; +const config{index}::output_transform_biases_t (&config{index}::output_transform_biases)[{output_transform_biases_size}] = {output_transform_biases}; +""" # noqa: E501 + +garnet_function_template = ( + 'nnet::garnet{impl}<{input_t}, {integer_input_t}, {output_t}, {config}>({input}, {nvtx}, {output});' +) + +garnet_include_list = ['nnet_utils/nnet_garnet.h'] + + +class GarNetConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(GarNet) + self.template = (garnet_config_template,) + + def get_transforms_config(self, node, params): + params['n_in_features'] = node.attributes['n_in_features'] + params['n_propagate'] = node.attributes['n_propagate'] + params['n_aggregators'] = node.get_weights('aggregator_distance_biases').shape[0] + params['n_out_features'] = node.get_weights('output_transform_biases').shape[0] + + for wname, weights in node.weights.items(): + params[wname] = weights.name + params[f'{wname}_t'] = weights.type.name + params[f'{wname}_size'] = weights.data_length + + def format(self, node): + params = self._default_config_params(node) + + params['n_vertices'] = node.attributes['n_vertices'] + params['n_vertices_width'] = int(np.log2(params['n_vertices'])) + params['distance_width'] = 12 + params['distance_nint'] = min(4, params['distance_width'] - 6) # this is tuned + params['log2_reuse'] = int(np.log2(params['reuse'])) + + # Define default precisions for various internal arrays (can be overridden from the config file) + # We always give 10 digits for the subintegral part + fwidth = 10 + # Integral precision for aggr_t depends on how large the temporary sum for weighed feature mean will be + aggr_intw = max(params['log2_reuse'], params['n_vertices_width'] - params['log2_reuse']) + 3 # safety factor 2**3 + aggr_w = aggr_intw + fwidth + # edge_weight_aggr_t does not need the safety factor + ew_aggr_intw = aggr_intw - 3 + ew_aggr_w = ew_aggr_intw + fwidth + # Integral precision for norm is fixed to 4 + norm_intw = 4 + norm_w = norm_intw + fwidth + + vspecs = [ + ('edge_weight', FixedPrecisionType(10, 0, signed=False)), + ('edge_weight_aggr', FixedPrecisionType(ew_aggr_w, ew_aggr_intw, signed=False)), + ('aggr', FixedPrecisionType(aggr_w, aggr_intw)), + ('norm', FixedPrecisionType(norm_w, norm_intw, signed=False)), + ] + precision_converter = ACTypeConverter() + for vname, default_precision in vspecs: + params[f'{vname}_t'], type_name = node.model.config.get_precision(node, var=vname) + if type_name.endswith('default_t'): + params[f'{vname}_t'] = precision_converter.convert(default_precision).definition_cpp() + else: + params[f'{vname}_t'] = precision_converter.convert(params[f'{vname}_t']).definition_cpp() + params['output_t'] = node.get_output_variable().type.name + + if node.attributes['collapse'] in ['mean', 'max']: + params['collapse_type'] = 'collapse_{}'.format(node.attributes['collapse']) + else: + params['collapse_type'] = 'no_collapse' + + params['mean_by_nvert'] = str(node.attributes['mean_by_nvert']).lower() + + self.get_transforms_config(node, params) + + return self.template[0].format(**params) + + +class GarNetFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(GarNet, include_header=garnet_include_list) + self.template = garnet_function_template + + def format(self, node): + params = self._default_function_params(node) + + data = node.get_input_variable(node.inputs[0]) + integer_input = node.get_input_variable(node.inputs[1]) + params['input_t'] = data.type.name + params['input'] = data.name + + params['integer_input_t'] = integer_input.type.name + params['nvtx'] = integer_input.name + + if node.ref_impl: + params['impl'] = '_ref' + else: + params['impl'] = '' + + return self.template.format(**params) + + +# GarNetStack Templates + +garnet_stack_base_config_template = """struct config{index}_base : nnet::garnet_config {{""" +garnet_stack_base_config_template += garnet_common_config_template +garnet_stack_base_config_template += """ + static const bool is_stack = true; + + typedef config{index}_base base_t; +}}; + +struct config{index} : config{index}_base {{ + static const unsigned n_sublayers = {n_sublayers}; + + template + struct sublayer_t : config{index}_base {{}}; +}}; + +{sublayer_configs} +""" + +garnet_stack_sublayer_config_template = """template<> +struct config{index}::sublayer_t<{il}> : config{index}_base {{ + static const unsigned n_in_features = {n_in_features}; + static const unsigned n_propagate = {n_propagate}; + static const unsigned n_aggregators = {n_aggregators}; + static const unsigned n_out_features = {n_out_features}; + + typedef {input_transform_weights_t} input_transform_weights_t; + typedef {input_transform_biases_t} input_transform_biases_t; + typedef {aggregator_distance_weights_t} aggregator_distance_weights_t; + typedef {aggregator_distance_biases_t} aggregator_distance_biases_t; + typedef {output_transform_biases_t} output_transform_biases_t; + + static const input_transform_weights_t (&input_transform_weights)[{input_transform_weights_size}]; + static const input_transform_biases_t (&input_transform_biases)[{input_transform_biases_size}]; + static const aggregator_distance_weights_t (&aggregator_distance_weights)[{aggregator_distance_weights_size}]; + static const aggregator_distance_biases_t (&aggregator_distance_biases)[{aggregator_distance_biases_size}]; + static const output_transform_biases_t (&output_transform_biases)[{output_transform_biases_size}]; + + typedef config{index}::sublayer_t<{next}> next_layer_t; +}}; + +const config{index}::sublayer_t<{il}>::input_transform_weights_t (&config{index}::sublayer_t<{il}>::input_transform_weights)[{input_transform_weights_size}] = {input_transform_weights}; +const config{index}::sublayer_t<{il}>::input_transform_biases_t (&config{index}::sublayer_t<{il}>::input_transform_biases)[{input_transform_biases_size}] = {input_transform_biases}; +const config{index}::sublayer_t<{il}>::aggregator_distance_weights_t (&config{index}::sublayer_t<{il}>::aggregator_distance_weights)[{aggregator_distance_weights_size}] = {aggregator_distance_weights}; +const config{index}::sublayer_t<{il}>::aggregator_distance_biases_t (&config{index}::sublayer_t<{il}>::aggregator_distance_biases)[{aggregator_distance_biases_size}] = {aggregator_distance_biases}; +const config{index}::sublayer_t<{il}>::output_transform_biases_t (&config{index}::sublayer_t<{il}>::output_transform_biases)[{output_transform_biases_size}] = {output_transform_biases}; +""" # noqa: E501 + +garnet_stack_config_template = (garnet_stack_base_config_template, garnet_stack_sublayer_config_template) +garnet_stack_function_template = ( + 'nnet::garnet_stack<{input_t}, {integer_input_t}, {output_t}, {config}>({input}, {nvtx}, {output});' +) + + +class GarNetStackConfigTemplate(GarNetConfigTemplate): + def __init__(self): + super(GarNetConfigTemplate, self).__init__(GarNetStack) + self.template = garnet_stack_config_template + + def get_transforms_config(self, node, params): + _, sublayer_template = self.template + + params['n_sublayers'] = node.attributes['n_sublayers'] + params['n_in_features'] = node.attributes['n_in_features'][0] + params['n_out_features'] = node.attributes['n_out_features'][-1] + + sublayer_configs = [] + for il in range(node.attributes['n_sublayers'] - 1, -1, -1): + sub_params = {'index': node.index, 'il': il} + + for p in ['n_in_features', 'n_propagate', 'n_aggregators', 'n_out_features']: + sub_params[p] = node.attributes[p][il] + + for wname, weights in node._sublayer_weights[il].items(): + sub_params[wname] = weights.name + sub_params[f'{wname}_t'] = weights.type.name + sub_params[f'{wname}_size'] = weights.data_length + + if il != node.attributes['n_sublayers'] - 1: + sub_params['next'] = il + 1 + else: + sub_params['next'] = 0 + + sublayer_configs.append(sublayer_template.format(**sub_params)) + + params['sublayer_configs'] = '\n'.join(sublayer_configs) + + +class GarNetStackFunctionTemplate(GarNetFunctionTemplate): + def __init__(self): + super(GarNetFunctionTemplate, self).__init__(GarNetStack, include_header=garnet_include_list) + self.template = garnet_stack_function_template diff --git a/hls4ml/backends/catapult/passes/merge_templates.py b/hls4ml/backends/catapult/passes/merge_templates.py new file mode 100755 index 0000000000..ff6928679c --- /dev/null +++ b/hls4ml/backends/catapult/passes/merge_templates.py @@ -0,0 +1,106 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import Concatenate, Dot, Merge + +# Merge templates + +merge_config_template = """struct config{index} : nnet::merge_config {{ + static const unsigned n_elem = {n_elem}; +}};\n""" + +merge_function_template = 'nnet::{merge}<{input1_t}, {input2_t}, {output_t}, {config}>({input1}, {input2}, {output});' + +merge_include_list = ['nnet_utils/nnet_merge.h', 'nnet_utils/nnet_merge_stream.h'] + + +class MergeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Merge) + self.template = merge_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_elem'] = node.get_input_variable(node.inputs[0]).size_cpp() + + return self.template.format(**params) + + +class MergeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((Merge, Concatenate, Dot), include_header=merge_include_list) + self.template = merge_function_template + + def format(self, node): + params = {} + params['merge'] = node.get_attr('op').lower() + params['config'] = f'config{node.index}' + params['input1_t'] = node.get_input_variable(node.inputs[0]).type.name + params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name + params['output_t'] = node.get_output_variable().type.name + params['input1'] = node.get_input_variable(node.inputs[0]).name + params['input2'] = node.get_input_variable(node.inputs[1]).name + params['output'] = node.get_output_variable().name + + return self.template.format(**params) + + +# Dot templates + +dot_config_template = """struct config{index} : nnet::dot_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned reuse_factor = {reuse}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in, reuse_factor); + typedef {accum_t.name} accum_t; + template + using product = nnet::product::{product_type}; +}};\n""" + + +class DotConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Dot) + self.template = dot_config_template + + def format(self, node): + inp1 = node.get_input_variable(node.inputs[0]) + inp2 = node.get_input_variable(node.inputs[1]) + params = self._default_config_params(node) + params['n_out'] = 1 + params['n_in'] = inp1.shape[0] + params['product_type'] = get_backend('catapult').product_type(inp1.type.precision, inp2.type.precision) + + return self.template.format(**params) + + +# Concatenate templates + +concat_config_template = """struct config{index} : nnet::concat_config {{ + static const unsigned n_elem1_0 = {n_elem1_0}; + static const unsigned n_elem1_1 = {n_elem1_1}; + static const unsigned n_elem1_2 = {n_elem1_2}; + static const unsigned n_elem2_0 = {n_elem2_0}; + static const unsigned n_elem2_1 = {n_elem2_1}; + static const unsigned n_elem2_2 = {n_elem2_2}; + + static const int axis = {axis}; +}};\n""" + + +class ConcatenateConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Concatenate) + self.template = concat_config_template + + def format(self, node): + params = self._default_config_params(node) + for i in range(3): + params.setdefault(f'n_elem1_{i}', 0) + params.setdefault(f'n_elem2_{i}', 0) + inp1 = node.get_input_variable(node.inputs[0]) + inp2 = node.get_input_variable(node.inputs[1]) + for i, (s1, s2) in enumerate(zip(inp1.shape, inp2.shape)): + params[f'n_elem1_{i}'] = s1 + params[f'n_elem2_{i}'] = s2 + + return self.template.format(**params) diff --git a/hls4ml/backends/catapult/passes/pointwise.py b/hls4ml/backends/catapult/passes/pointwise.py new file mode 100755 index 0000000000..2dd982b5d4 --- /dev/null +++ b/hls4ml/backends/catapult/passes/pointwise.py @@ -0,0 +1,92 @@ +from copy import copy + +import numpy as np + +from hls4ml.backends.catapult.passes.convolution_templates import ( + Conv1DConfigTemplate, + Conv1DFunctionTemplate, + Conv2DConfigTemplate, + Conv2DFunctionTemplate, + conv1d_config_template, + conv2d_config_template, + conv_mult_config_template, +) +from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D +from hls4ml.model.layers import register_layer +from hls4ml.model.optimizer import OptimizerPass + +pointwise_conv1d_function_template = ( + 'nnet::pointwise_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +) +pointwise_conv2d_function_template = ( + 'nnet::pointwise_conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +) + +sepconv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_sepconv1d_stream.h'] +sepconv2d_include_list = ['nnet_utils/nnet_conv2d.h', 'nnet_utils/nnet_sepconv2d_stream.h'] + + +class PointwiseConv1DConfigTemplate(Conv1DConfigTemplate): + def __init__(self): + super(Conv1DConfigTemplate, self).__init__(PointwiseConv1D) + self.template = conv1d_config_template + self.mult_template = conv_mult_config_template + + +class PointwiseConv1DFunctionTemplate(Conv1DFunctionTemplate): + def __init__(self): + super(Conv1DFunctionTemplate, self).__init__(PointwiseConv1D, include_header=sepconv1d_include_list) + self.template = pointwise_conv1d_function_template + + +class PointwiseConv2DConfigTemplate(Conv2DConfigTemplate): + def __init__(self): + super(Conv2DConfigTemplate, self).__init__(PointwiseConv2D) + self.template = conv2d_config_template + self.mult_template = conv_mult_config_template + + +class PointwiseConv2DFunctionTemplate(Conv2DFunctionTemplate): + def __init__(self): + super(Conv2DFunctionTemplate, self).__init__(PointwiseConv2D, include_header=sepconv2d_include_list) + self.template = pointwise_conv2d_function_template + + +def register_pointwise(backend): + # Register the layer types to the layer map + register_layer('PointwiseConv1D', PointwiseConv1D) + register_layer('PointwiseConv2D', PointwiseConv2D) + + # Register the optimization passes + backend.register_pass('optimize_pointwise_conv', OptimizePointwiseConv) + + # Register template passes + backend.register_template(PointwiseConv1DConfigTemplate) + backend.register_template(PointwiseConv1DFunctionTemplate) + backend.register_template(PointwiseConv2DConfigTemplate) + backend.register_template(PointwiseConv2DFunctionTemplate) + + +class OptimizePointwiseConv(OptimizerPass): + def match(self, node): + return ( + node.class_name in ('Conv1D', 'Conv2D') + and node.get_attr('filt_height', 1) == 1 + and node.get_attr('filt_width') == 1 + ) + + def transform(self, model, node): + dim = node.__class__.__name__[-2:] # '1D' or '2D' + pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy()) + if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D + expand_axis = tuple(range(int(dim[0]))) + pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis) + pw_node.weights['bias'].data = node.weights['bias'].data + # Set strategy to ensure lowercase string is passed to the template + if model.config.is_resource_strategy(pw_node): + pw_node.set_attr('strategy', 'resource') + else: + pw_node.set_attr('strategy', 'latency') + model.replace_node(node, pw_node) + + return True diff --git a/hls4ml/backends/catapult/passes/pooling_templates.py b/hls4ml/backends/catapult/passes/pooling_templates.py new file mode 100755 index 0000000000..77205a5df7 --- /dev/null +++ b/hls4ml/backends/catapult/passes/pooling_templates.py @@ -0,0 +1,109 @@ +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import GlobalPooling1D, GlobalPooling2D, Pooling1D, Pooling2D + +# Pooling templates + +pooling1d_config_template = """struct config{index} : nnet::pooling1d_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned n_filt = {n_filt}; + static const unsigned pool_width = {pool_width}; + + static const unsigned filt_width = pool_width; + static const unsigned n_chan = n_filt; + + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const bool count_pad = {count_pad}; + static const unsigned stride_width = {stride_width}; + static const nnet::Pool_Op pool_op = nnet::{pool_op}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned reuse_factor = {reuse}; + typedef {accum_t.name} accum_t; +}};\n""" + +pooling2d_config_template = """struct config{index} : nnet::pooling2d_config {{ + static const unsigned in_height = {in_height}; + static const unsigned in_width = {in_width}; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_height = {stride_height}; + static const unsigned stride_width = {stride_width}; + static const unsigned pool_height = {pool_height}; + static const unsigned pool_width = {pool_width}; + + static const unsigned filt_height = pool_height; + static const unsigned filt_width = pool_width; + static const unsigned n_chan = n_filt; + + static const unsigned out_height = {out_height}; + static const unsigned out_width = {out_width}; + static const unsigned pad_top = {pad_top}; + static const unsigned pad_bottom = {pad_bottom}; + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const bool count_pad = {count_pad}; + static const nnet::Pool_Op pool_op = nnet::{pool_op}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned reuse_factor = {reuse}; + typedef {accum_t.name} accum_t; +}};\n""" + +global_pooling1d_config_template = """struct config{index} : nnet::pooling1d_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_filt = {n_filt}; + static const nnet::Pool_Op pool_op = nnet::{pool_op}; + static const unsigned reuse_factor = {reuse}; + typedef {accum_t.name} accum_t; +}};\n""" + +global_pooling2d_config_template = """struct config{index} : nnet::pooling2d_config {{ + static const unsigned in_height = {in_height}; + static const unsigned in_width = {in_width}; + static const unsigned n_filt = {n_filt}; + static const nnet::Pool_Op pool_op = nnet::{pool_op}; + static const unsigned reuse_factor = {reuse}; + typedef {accum_t.name} accum_t; +}};\n""" + +pooling1d_function_template = 'nnet::pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' +pooling2d_function_template = 'nnet::pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' +global_pooling1d_function_template = ( + 'nnet::global_pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' +) +global_pooling2d_function_template = ( + 'nnet::global_pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' +) + +pooling_include_list = ['nnet_utils/nnet_pooling.h', 'nnet_utils/nnet_pooling_stream.h'] + + +class PoolingConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((Pooling1D, Pooling2D, GlobalPooling1D, GlobalPooling2D)) + self.templates = { + 'Pooling1D': pooling1d_config_template, + 'Pooling2D': pooling2d_config_template, + 'GlobalPooling1D': global_pooling1d_config_template, + 'GlobalPooling2D': global_pooling2d_config_template, + } + + def format(self, node): + params = self._default_config_params(node) + return self.templates[node.class_name].format(**params) + + +class PoolingFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((Pooling1D, Pooling2D, GlobalPooling1D, GlobalPooling2D), include_header=pooling_include_list) + self.templates = { + 'Pooling1D': pooling1d_function_template, + 'Pooling2D': pooling2d_function_template, + 'GlobalPooling1D': global_pooling1d_function_template, + 'GlobalPooling2D': global_pooling2d_function_template, + } + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + + return self.templates[node.class_name].format(**params) diff --git a/hls4ml/backends/catapult/passes/quantization_templates.py b/hls4ml/backends/catapult/passes/quantization_templates.py new file mode 100755 index 0000000000..7086b187f9 --- /dev/null +++ b/hls4ml/backends/catapult/passes/quantization_templates.py @@ -0,0 +1,36 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.catapult.passes.core_templates import ( + batchnorm_config_template, + batchnorm_function_template, + batchnorm_include_list, +) +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.optimizer.passes.qkeras import ApplyAlpha + + +class ApplyAlphaConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(ApplyAlpha) + self.template = batchnorm_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_in'] = node.get_input_variable().size_cpp() + params['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('scale').type.precision + ) + + return self.template.format(**params) + + +class ApplyAlphaFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(ApplyAlpha, include_header=batchnorm_include_list) + self.template = batchnorm_function_template + + def format(self, node): + params = self._default_function_params(node) + params['scale'] = node.get_weights('scale').name + params['bias'] = node.get_weights('bias').name + + return self.template.format(**params) diff --git a/hls4ml/backends/catapult/passes/recurrent_templates.py b/hls4ml/backends/catapult/passes/recurrent_templates.py new file mode 100755 index 0000000000..4079f25721 --- /dev/null +++ b/hls4ml/backends/catapult/passes/recurrent_templates.py @@ -0,0 +1,175 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import GRU, LSTM + +# recurrent multiplication template + +recr_mult_config_template = """struct config{index} : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned n_nonzeros = {nonzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + static const bool store_weights_in_bram = false; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {index_t.name} index_t; + template + using product = nnet::product::{product_type}; +}};\n""" + +# activation templates + +activ_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {table_t.name} table_t; +}};\n""" + +recr_activ_config_template = """struct {type}_config{index}_recr : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {table_t.name} table_t; +}};\n""" + +# LSTM + GRU templates + +recr_config_template = """struct config{index} : nnet::{recr_type}_config {{ + typedef {accum_t.name} accum_t; + typedef {weight_t.name} weight_t; // Matrix + typedef {bias_t.name} bias_t; // Vector + typedef {config_mult_t1} mult_config1; + typedef {config_mult_t2} mult_config2; + typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE}; + template + using activation_recr = nnet::activation::{recurrent_activation}; + typedef {act_t} ACT_CONFIG_T; + template + using activation = nnet::activation::{activation}; + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned n_state = {n_state}; + static const unsigned n_sequence = {n_sequence}; + static const unsigned n_sequence_out = {n_sequence_out}; + static const unsigned io_type = nnet::{strategy}; + static const unsigned reuse_factor = {reuse}; + static const bool store_weights_in_bram = false; + static const bool use_static = {static}; +}};\n""" + +recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});' + +recr_include_list = ['nnet_utils/nnet_recurrent.h'] + + +class RecurrentConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((LSTM, GRU)) + self.template = recr_config_template + self.act_template = activ_config_template + self.recr_act_template = recr_activ_config_template + self.mult1_template = recr_mult_config_template + self.mult2_template = recr_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + + params['n_in'] = node.get_input_variable().dim_names[1] + params['n_sequence'] = node.get_input_variable().dim_names[0] + if node.get_attr('return_sequences'): + params['n_sequence_out'] = node.get_output_variable().dim_names[0] + params['n_state'] = node.get_output_variable().dim_names[1] + params['n_out'] = node.get_output_variable().dim_names[1] + else: + params['n_sequence_out'] = 1 + params['n_state'] = node.get_output_variable().dim_names[0] + params['n_out'] = node.get_output_variable().dim_names[0] + params['config_mult_t1'] = f'config{node.index}_1' + params['config_mult_t2'] = f'config{node.index}_2' + params['recr_act_t'] = '{}_config{}_recr'.format(node.get_attr('recurrent_activation'), node.index) + params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index) + params['strategy'] = node.get_attr('strategy') + params['static'] = 'true' if node.attributes['static'] else 'false' + params['recr_type'] = node.class_name.lower() + params['RECR_TYPE'] = node.class_name + + if node.class_name == 'LSTM': + n_recr_mult = 4 + else: # GRU + n_recr_mult = 3 + + recr_config = self.template.format(**params) + + act_params = self._default_config_params(node) + recr_act_params = self._default_config_params(node) + + act_params['type'] = node.get_attr('activation') + recr_act_params['type'] = node.get_attr('recurrent_activation') + if node.get_attr('return_sequences'): + act_params['n_in'] = node.get_output_variable().dim_names[1] + recr_act_params['n_in'] = node.get_output_variable().dim_names[1] + ' * %i' % (n_recr_mult - 1) + else: + act_params['n_in'] = node.get_output_variable().dim_names[0] + recr_act_params['n_in'] = node.get_output_variable().dim_names[0] + ' * %i' % (n_recr_mult - 1) + + act_config = self.act_template.format(**act_params) + recr_act_config = self.recr_act_template.format(**recr_act_params) + + mult_params1 = self._default_config_params(node) + mult_params2 = self._default_config_params(node) + + mult_params1['n_in'] = node.get_input_variable().dim_names[1] + if node.get_attr('return_sequences'): + mult_params1['n_out'] = node.get_output_variable().dim_names[1] + ' * %i' % n_recr_mult + else: + mult_params1['n_out'] = node.get_output_variable().dim_names[0] + ' * %i' % n_recr_mult + mult_params1['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision + ) + mult_params1['reuse'] = params['reuse'] + mult_params1['index'] = str(node.index) + '_1' + mult_params1['nzeros'] = node.get_weights('weight').nzeros + mult_params1['nonzeros'] = node.get_weights('weight').nonzeros + if node.get_attr('return_sequences'): + mult_params2['n_in'] = node.get_output_variable().dim_names[1] + mult_params2['n_out'] = node.get_output_variable().dim_names[1] + ' * %i' % n_recr_mult + else: + mult_params2['n_in'] = node.get_output_variable().dim_names[0] + mult_params2['n_out'] = node.get_output_variable().dim_names[0] + ' * %i' % n_recr_mult + mult_params2['product_type'] = get_backend('catapult').product_type( + node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision + ) + mult_params2['reuse'] = node.attributes['recurrent_reuse_factor'] + mult_params2['index'] = str(node.index) + '_2' + mult_params2['nzeros'] = node.get_weights('recurrent_weight').nzeros + mult_params2['nonzeros'] = node.get_weights('recurrent_weight').nonzeros + + mult_config1 = self.mult1_template.format(**mult_params1) + mult_config2 = self.mult2_template.format(**mult_params2) + + return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config + + +class RecurrentFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((LSTM, GRU), include_header=recr_include_list) + self.template = recr_function_template + + def format(self, node): + params = self._default_function_params(node) + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + params['wr'] = node.get_weights('recurrent_weight').name + params['br'] = node.get_weights('recurrent_bias').name + params['activation'] = node.get_attr('activation') + params['recurrent_activation'] = node.get_attr('recurrent_activation') + params['recr_type'] = node.class_name.lower() + + return self.template.format(**params) diff --git a/hls4ml/backends/catapult/passes/reshaping_templates.py b/hls4ml/backends/catapult/passes/reshaping_templates.py new file mode 100755 index 0000000000..ec6705eb29 --- /dev/null +++ b/hls4ml/backends/catapult/passes/reshaping_templates.py @@ -0,0 +1,132 @@ +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D + +# ZeroPadding templates + +zeropad1d_config_template = """struct config{index} : nnet::padding1d_config {{ + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned out_width = {out_width}; + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; +}};\n""" + +zeropad2d_config_template = """struct config{index} : nnet::padding2d_config {{ + static const unsigned in_height = {in_height}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned out_height = {out_height}; + static const unsigned out_width = {out_width}; + static const unsigned pad_top = {pad_top}; + static const unsigned pad_bottom = {pad_bottom}; + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; +}};\n""" + +zeropad1d_function_template = 'nnet::zeropad1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' +zeropad2d_function_template = 'nnet::zeropad2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});' + +padding_include_list = ['nnet_utils/nnet_padding.h', 'nnet_utils/nnet_padding_stream.h'] + + +class ZeroPaddingConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__((ZeroPadding1D, ZeroPadding2D)) + self.templates = { + 'ZeroPadding1D': zeropad1d_config_template, + 'ZeroPadding2D': zeropad2d_config_template, + } + + def format(self, node): + params = self._default_config_params(node) + return self.templates[node.class_name].format(**params) + + +class ZeroPaddingFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((ZeroPadding1D, ZeroPadding2D), include_header=padding_include_list) + self.templates = { + 'ZeroPadding1D': zeropad1d_function_template, + 'ZeroPadding2D': zeropad2d_function_template, + } + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + + return self.templates[node.class_name].format(**params) + + +# Resize templates + +resize_config_template = """struct config{index} : nnet::resize_config {{ + static const unsigned height = {in_height}; + static const unsigned width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned new_height = {out_height}; + static const unsigned new_width = {out_width}; +}};\n""" + +resize_function_template = 'nnet::resize_{algorithm}<{input_t}, {config}>({input}, {output});' + +resize_include_list = ['nnet_utils/nnet_image.h', 'nnet_utils/nnet_image_stream.h'] + + +class ResizeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Resize) + self.template = resize_config_template + + def format(self, node): + params = self._default_config_params(node) + + return self.template.format(**params) + + +class ResizeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Resize, include_header=resize_include_list) + self.template = resize_function_template + + def format(self, node): + params = self._default_function_params(node) + params['algorithm'] = node.get_attr('algorithm') + + return self.template.format(**params) + + +# Transpose templates + +transpose_config_template = """struct config{index} : nnet::transpose_config {{ + static const unsigned depth = {depth}; + static const unsigned height = {height}; + static const unsigned width = {width}; + static constexpr unsigned perm[3] = {{{perm_str}}}; +}};\n""" + +transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {output_t}, {config}>({input}, {output});' + +transpose_include_list = ['nnet_utils/nnet_array.h', 'nnet_utils/nnet_stream.h'] + + +class TransposeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Transpose) + self.template = transpose_config_template + + def format(self, node): + params = self._default_config_params(node) + + return self.template.format(**params) + + +class TransposeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Transpose, include_header=transpose_include_list) + self.template = transpose_function_template + + def format(self, node): + params = self._default_function_params(node) + params['dim'] = node.get_attr('dim') + + return self.template.format(**params) diff --git a/hls4ml/backends/catapult/passes/resource_strategy.py b/hls4ml/backends/catapult/passes/resource_strategy.py new file mode 100755 index 0000000000..63e6e0b4db --- /dev/null +++ b/hls4ml/backends/catapult/passes/resource_strategy.py @@ -0,0 +1,48 @@ +import numpy as np + +from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D +from hls4ml.model.optimizer import OptimizerPass + + +class ApplyResourceStrategy(OptimizerPass): + '''Transposes the weights to use the dense_resource matrix multiply routine''' + + def match(self, node): + node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) + is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' + already_transformed = node.get_attr('_weights_transposed', False) is True + + return node_matches and is_resource_strategy and not already_transformed + + def transform(self, model, node): + if isinstance(node, Dense): + node.weights['weight'].data = np.transpose(node.weights['weight'].data) + elif isinstance(node, Conv1D): + node.weights['weight'].data = np.transpose(node.weights['weight'].data, axes=[2, 0, 1]) # (W,C,F) => (F,W,C) + elif isinstance(node, SeparableConv1D): + node.weights['depthwise'].data = np.transpose( + node.weights['depthwise'].data, axes=[2, 0, 1] + ) # (W,C,F) => (F,W,C) + node.weights['pointwise'].data = np.transpose( + node.weights['pointwise'].data, axes=[2, 0, 1] + ) # (W,C,F) => (F,W,C) + elif isinstance(node, Conv2D): + node.weights['weight'].data = np.transpose( + node.weights['weight'].data, axes=[3, 0, 1, 2] + ) # (H,W,C,F) => (F,H,W,C) + elif isinstance(node, SeparableConv2D): + node.weights['depthwise'].data = np.transpose( + node.weights['depthwise'].data, axes=[3, 0, 1, 2] + ) # (H,W,C,F) => (F,H,W,C) + node.weights['pointwise'].data = np.transpose( + node.weights['pointwise'].data, axes=[3, 0, 1, 2] + ) # (H,W,C,F) => (F,H,W,C) + elif isinstance(node, (LSTM, GRU)): + node.weights['weight'].data = np.transpose(node.weights['weight'].data) + node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data) + else: + raise Exception(f'Unexpected layer {node.class_name} with resource strategy') + + node.set_attr('_weights_transposed', True) + + return False diff --git a/hls4ml/backends/catapult/passes/transform_types.py b/hls4ml/backends/catapult/passes/transform_types.py new file mode 100755 index 0000000000..4ef3548cb6 --- /dev/null +++ b/hls4ml/backends/catapult/passes/transform_types.py @@ -0,0 +1,52 @@ +from hls4ml.backends.fpga.fpga_types import ( + ACTypeConverter, + CatapultArrayVariableConverter, + CatapultInplaceArrayVariableConverter, + CatapultInplaceStreamVariableConverter, + CatapultStreamVariableConverter, + HLSTypeConverter, + StaticWeightVariableConverter, +) +from hls4ml.model.optimizer import GlobalOptimizerPass +from hls4ml.model.types import InplaceTensorVariable + + +class TransformTypes(GlobalOptimizerPass): + def __init__(self): + self.type_converter = HLSTypeConverter(precision_converter=ACTypeConverter()) + self.array_var_converter = CatapultArrayVariableConverter(type_converter=self.type_converter) + self.inplace_array_var_converter = CatapultInplaceArrayVariableConverter(type_converter=self.type_converter) + self.stream_var_converter = CatapultStreamVariableConverter(type_converter=self.type_converter) + self.inplace_stream_var_converter = CatapultInplaceStreamVariableConverter(type_converter=self.type_converter) + self.weight_var_converter = StaticWeightVariableConverter(type_converter=self.type_converter) + + def transform(self, model, node): + io_type = node.model.config.get_config_value('IOType') + + for out_name, var in node.variables.items(): + if io_type == 'io_stream': + if isinstance(var, InplaceTensorVariable): + new_var = self.inplace_stream_var_converter.convert(var) + else: + new_var = self.stream_var_converter.convert(var) + elif io_type == 'io_serial': + new_var = self.array_var_converter.convert(var, pragma='stream') + elif io_type == 'io_parallel': + if out_name in node.model.inputs: + new_var = self.array_var_converter.convert(var, pragma='reshape') + elif isinstance(var, InplaceTensorVariable): + new_var = self.inplace_array_var_converter.convert(var, pragma='') + else: + new_var = self.array_var_converter.convert(var, pragma='partition') + else: + raise Exception(f'Unknown IOType {io_type} in {node.name} ({node.__class__.__name__})') + + node.set_attr(out_name, new_var) + + for w_name, weight in node.weights.items(): + new_weight = self.weight_var_converter.convert(weight) + node.set_attr(w_name, new_weight) + + for t_name, type in node.types.items(): + new_type = self.type_converter.convert(type) + node.set_attr(t_name, new_type) diff --git a/hls4ml/backends/fpga/fpga_types.py b/hls4ml/backends/fpga/fpga_types.py index c5327dab8c..408f1320e4 100644 --- a/hls4ml/backends/fpga/fpga_types.py +++ b/hls4ml/backends/fpga/fpga_types.py @@ -248,6 +248,13 @@ def definition_cpp(self, name_suffix='', as_reference=False): ) +class CatapultArrayVariableDefinition(VariableDefinition): + def definition_cpp(self, name_suffix='', as_reference=False): + return '{type} {name}{suffix}[{shape}] /* {pragma} */'.format( + type=self.type.name, name=self.name, suffix=name_suffix, shape=self.size_cpp(), pragma=self.pragma + ) + + class VivadoInplaceArrayVariableDefinition(VariableDefinition): def definition_cpp(self): return f'auto& {self.name} = {self.input_var.name}' @@ -258,6 +265,11 @@ def definition_cpp(self): return f'auto& {self.name} = {self.input_var.name}' +class CatapultInplaceArrayVariableDefinition(VariableDefinition): + def definition_cpp(self): + return f'auto& {self.name} = {self.input_var.name}' + + class ArrayVariableConverter: def __init__(self, type_converter, prefix, definition_cls): self.type_converter = type_converter @@ -285,6 +297,11 @@ def __init__(self, type_converter): super().__init__(type_converter=type_converter, prefix='Quartus', definition_cls=QuartusArrayVariableDefinition) +class CatapultArrayVariableConverter(ArrayVariableConverter): + def __init__(self, type_converter): + super().__init__(type_converter=type_converter, prefix='Catapult', definition_cls=CatapultArrayVariableDefinition) + + class VivadoInplaceArrayVariableConverter(ArrayVariableConverter): def __init__(self, type_converter): super().__init__(type_converter=type_converter, prefix='Vivado', definition_cls=VivadoInplaceArrayVariableDefinition) @@ -297,6 +314,13 @@ def __init__(self, type_converter): ) +class CatapultInplaceArrayVariableConverter(ArrayVariableConverter): + def __init__(self, type_converter): + super().__init__( + type_converter=type_converter, prefix='Catapult', definition_cls=CatapultInplaceArrayVariableDefinition + ) + + # endregion # region StructMemberVariable @@ -309,6 +333,13 @@ def definition_cpp(self, name_suffix='', as_reference=False): ) +class CatapultStructMemberVariableDefinition(VariableDefinition): + def definition_cpp(self, name_suffix='', as_reference=False): + return '{type} {name}{suffix}[{shape}]'.format( + type=self.type.name, name=self.member_name, suffix=name_suffix, shape=self.size_cpp() + ) + + class StructMemberVariableConverter: def __init__(self, type_converter, prefix, definition_cls): self.type_converter = type_converter @@ -338,6 +369,13 @@ def __init__(self, type_converter): ) +class CatapultStructMemberVariableConverter(StructMemberVariableConverter): + def __init__(self, type_converter): + super().__init__( + type_converter=type_converter, prefix='Catapult', definition_cls=CatapultStructMemberVariableDefinition + ) + + # endregion # region StreamVariable @@ -371,6 +409,21 @@ def definition_cpp(self): return f'auto& {self.name} = {self.input_var.name}' +class CatapultStreamVariableDefinition(VariableDefinition): + def definition_cpp(self, name_suffix='', as_reference=False): + if as_reference: # Function parameter + return f'ac_channel<{self.type.name}> &{self.name}{name_suffix}' + else: # Declaration (string name arg not implemented in ac_channel) + return 'ac_channel<{type}> {name}{suffix}/*("{name}")*/'.format( + type=self.type.name, name=self.name, suffix=name_suffix + ) + + +class CatapultInplaceStreamVariableDefinition(VariableDefinition): + def definition_cpp(self): + return f'auto& {self.name} = {self.input_var.name}' + + class StreamVariableConverter: def __init__(self, type_converter, prefix, definition_cls): self.type_converter = type_converter @@ -402,6 +455,11 @@ def __init__(self, type_converter): super().__init__(type_converter=type_converter, prefix='Quartus', definition_cls=QuartusStreamVariableDefinition) +class CatapultStreamVariableConverter(StreamVariableConverter): + def __init__(self, type_converter): + super().__init__(type_converter=type_converter, prefix='Catapult', definition_cls=CatapultStreamVariableDefinition) + + # endregion # region InplaceStreamVariable @@ -435,6 +493,13 @@ def __init__(self, type_converter): ) +class CatapultInplaceStreamVariableConverter(InplaceStreamVariableConverter): + def __init__(self, type_converter): + super().__init__( + type_converter=type_converter, prefix='Catapult', definition_cls=CatapultInplaceStreamVariableDefinition + ) + + # endregion # region WeightsVariable diff --git a/hls4ml/converters/__init__.py b/hls4ml/converters/__init__.py index b69dbec0f0..3bd6d06c3b 100644 --- a/hls4ml/converters/__init__.py +++ b/hls4ml/converters/__init__.py @@ -196,7 +196,7 @@ def convert_from_keras_model( output_data_tb (str, optional): String representing the path of output data in .npy or .dat format that will be used during csim and cosim. backend (str, optional): Name of the backend to use, e.g., 'Vivado' - or 'Quartus'. + or 'Quartus' or 'Catapult'. board (str, optional): One of target boards specified in `supported_board.json` file. If set to `None` a default device of a backend will be used. See documentation of the backend used. part (str, optional): The FPGA part. If set to `None` a default part of a backend will be used. @@ -258,7 +258,7 @@ def convert_from_pytorch_model( used during csim and cosim. Defaults to None. output_data_tb (str, optional): String representing the path of output data in .npy or .dat format that will be used during csim and cosim. Defaults to None. - backend (str, optional): Name of the backend to use, e.g., 'Vivado' or 'Quartus'. Defaults to 'Vivado'. + backend (str, optional): Name of the backend to use, e.g., 'Vivado' or 'Quartus' or 'Catapult'. Defaults to 'Vivado'. board (str, optional): One of target boards specified in `supported_board.json` file. If set to `None` a default device of a backend will be used. See documentation of the backend used. part (str, optional): The FPGA part. If set to `None` a default part of a backend will be used. @@ -332,7 +332,7 @@ def convert_from_onnx_model( output_data_tb (str, optional): String representing the path of output data in .npy or .dat format that will be used during csim and cosim. backend (str, optional): Name of the backend to use, e.g., 'Vivado' - or 'Quartus'. + or 'Quartus' or 'Catapult'. board (str, optional): One of target boards specified in `supported_board.json` file. If set to `None` a default device of a backend will be used. See documentation of the backend used. part (str, optional): The FPGA part. If set to `None` a default part of a backend will be used. diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index a6b5c29e89..04ec33294d 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -60,6 +60,12 @@ def get_config_value(self, key, default=None): def get_project_name(self): return self.get_config_value('ProjectName') + def get_project_dir(self): + if self.get_config_value('ProjectDir') is not None: + return self.get_config_value('ProjectDir') + else: + return self.get_config_value('ProjectName') + '_prj' + def get_output_dir(self): return self.get_config_value('OutputDir') diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index b74918f642..de191baa40 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -560,6 +560,7 @@ def initialize(self): if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in [ 'Vivado', 'VivadoAccelerator', + 'Catapult', ]: self.weights['weight'].data_unquantized = np.transpose(folded_weights, axes=[3, 0, 1, 2]) self.weights['weight'].data = self.get_attr('weight_quantizer')(self.weights['weight'].data_unquantized) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 904ecc3d35..c4a2f57051 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -589,6 +589,7 @@ def get_ymodel_keras(keras_model, X): name = layer.name if ( hasattr(layer, "activation") + and hasattr(layer.activation, "__name__") and layer.activation.__name__ != "linear" and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation)) ): diff --git a/hls4ml/report/__init__.py b/hls4ml/report/__init__.py index b73558f6ee..3c9b7707b7 100644 --- a/hls4ml/report/__init__.py +++ b/hls4ml/report/__init__.py @@ -1,3 +1,6 @@ +from hls4ml.report.catapult_report import parse_catapult_report # noqa: F401 +from hls4ml.report.catapult_report import qofr # noqa: F401 +from hls4ml.report.catapult_report import read_catapult_report # noqa: F401 from hls4ml.report.quartus_report import parse_quartus_report # noqa: F401 from hls4ml.report.quartus_report import read_quartus_report # noqa: F401 from hls4ml.report.vivado_report import parse_vivado_report # noqa: F401 diff --git a/hls4ml/report/catapult_report.py b/hls4ml/report/catapult_report.py new file mode 100755 index 0000000000..563a3a7594 --- /dev/null +++ b/hls4ml/report/catapult_report.py @@ -0,0 +1,256 @@ +import os +import re + +import yaml + + +def read_catapult_report(hls_dir, full_report=False): + if not os.path.exists(hls_dir): + print(f'Path {hls_dir} does not exist. Exiting.') + return + + prj_dir = None + top_func_name = None + + if os.path.isfile(hls_dir + '/build_prj.tcl'): + prj_dir, top_func_name = _parse_build_script(hls_dir + '/build_prj.tcl') + print('Prj Dir:', prj_dir) + print('Top func name:', top_func_name) + + if prj_dir is None or top_func_name is None: + print('Unable to read project data. Exiting.') + return + + sln_dir = hls_dir + '/' + prj_dir + if not os.path.exists(sln_dir): + print(f'Project {prj_dir} does not exist. Rerun "hls4ml build -p {hls_dir}".') + return + + solutions = _find_solutions(sln_dir, hls_dir) + + for sln in solutions: + print(f'Reports for solution "{sln}":\n') + _find_reports(sln_dir + '/' + sln, top_func_name, full_report) + + +def _parse_build_script(script_path): + prj_dir = None + top_func_name = None + + with open(script_path) as f: + for line in f.readlines(): + if 'project new' in line: + prj_dir = line.split()[-1] + if 'set design_top' in line: + top_func_name = line.split()[-1] + + return prj_dir, top_func_name + + +def _find_solutions(sln_dir, hls_dir): + solutions = [] + prj_dir, top_func_name = _parse_build_script(hls_dir + '/build_prj.tcl') + for path in os.listdir(sln_dir): + # check if current path is a dir + if os.path.isdir(os.path.join(sln_dir, path)): + pathstring = str(path) + if top_func_name in pathstring: + solutions.append(pathstring) + return solutions + + +def _find_reports(sln_dir, top_func_name, full_report=False): + csim_file = sln_dir + '/../../tb_data/csim_results.log' + if os.path.isfile(csim_file): + _show_csim_report(csim_file) + else: + print('C simulation report not found.') + + syn_file = sln_dir + '/rtl.rpt' + if os.path.isfile(syn_file): + _show_synth_report(syn_file, full_report) + else: + print('Synthesis report not found.') + + cosim_file = sln_dir + f'/sim/report/{top_func_name}_cosim.rpt' + if os.path.isfile(cosim_file): + _show_cosim_report(cosim_file) + else: + print('Co-simulation report not found.') + + timing_report = sln_dir + '/vivado_concat_v/timing_summary_synth.rpt' + if os.path.isfile(timing_report): + _show_timing_report(timing_report) + else: + print('Timing synthesis report not found.') + + utilization_report = sln_dir + '/vivado_concat_v/utilization_synth.rpt' + if os.path.isfile(utilization_report): + _show_utilization_report(utilization_report) + else: + print('Utilization synthesis report not found.') + + +def _show_csim_report(csim_file): + with open(csim_file) as f: + print('C SIMULATION RESULT:') + print(f.read()) + + +def _show_synth_report(synth_file, full_report=False): + with open(synth_file) as f: + print('SYNTHESIS REPORT:') + for line in f.readlines()[2:]: + if not full_report and '* DSP48' in line: + break + print(line, end='') + + +def _show_cosim_report(cosim_file): + with open(cosim_file) as f: + print('CO-SIMULATION RESULT:') + print(f.read()) + + +def _show_timing_report(timing_report): + with open(timing_report) as f: + print('TIMING REPORT:') + print(f.read()) + + +def _show_utilization_report(utilization_report): + with open(utilization_report) as f: + print('UTILIZATION REPORT:') + print(f.read()) + + +def _get_abs_and_percentage_values(unparsed_cell): + return int(unparsed_cell.split('(')[0]), float(unparsed_cell.split('(')[1].replace('%', '').replace(')', '')) + + +def parse_catapult_report(output_dir): + if not os.path.exists(output_dir): + print(f'Project OutputDir {output_dir} does not exist. Exiting.') + return + + # Read the YAML config file to determine the project settings + with open(output_dir + '/hls4ml_config.yml') as yfile: + ydata = yaml.safe_load(yfile) + + if not ydata['ProjectDir'] is None: + ProjectDir = ydata['ProjectDir'] + else: + ProjectDir = ydata['ProjectName'] + '_prj' + ProjectName = ydata['ProjectName'] + + sln_dir = output_dir + '/' + ProjectDir + if not os.path.exists(sln_dir): + print(f'Project {ProjectDir} does not exist. Rerun "hls4ml build -p {output_dir}".') + return + + solutions = _find_solutions(sln_dir, output_dir) + if len(solutions) > 1: + print(f'WARNING: Found {len(solutions)} solution(s) in {sln_dir}. Using the first solution.') + + report = {} + + sim_file = output_dir + '/tb_data/csim_results.log' + if os.path.isfile(sim_file): + csim_results = [] + with open(sim_file) as f: + for line in f.readlines(): + csim_results.append([r for r in line.split()]) + report['CSimResults'] = csim_results + + util_report_file = output_dir + '/' + ProjectDir + '/' + solutions[0] + '/vivado_concat_v/utilization_synth.rpt' + if os.path.isfile(util_report_file): + util_report = {} + a = 0 + with open(util_report_file) as f: + for line in f.readlines(): + # Sometimes, phrases such as 'CLB Registers' can show up in the non-tabular sections of the report + if '|' in line: + if ('CLB LUTs' in line) and (a == 0): + a += 1 + util_report['LUT'] = line.split('|')[2].strip() + elif ('CLB Registers' in line) and (a == 1): + a += 1 + util_report['FF'] = line.split('|')[2].strip() + elif ('RAMB18 ' in line) and (a == 2): + a += 1 + util_report['BRAM_18K'] = line.split('|')[2].strip() + elif ('DSPs' in line) and (a == 3): + a += 1 + util_report['DSP48E'] = line.split('|')[2].strip() + elif ('URAM' in line) and (a == 4): + a += 1 + util_report['URAM'] = line.split('|')[2].strip() + report['UtilizationReport'] = util_report + else: + print('Utilization report not found.') + + timing_report_file = output_dir + '/' + ProjectDir + '/' + solutions[0] + '/vivado_concat_v/timing_summary_synth.rpt' + if os.path.isfile(timing_report_file): + timing_report = {} + with open(timing_report_file) as f: + while not re.search('WNS', next(f)): + pass + # skip the successive line + next(f) + result = next(f).split() + + timing_report['WNS'] = float(result[0]) + timing_report['TNS'] = float(result[1]) + timing_report['WHS'] = float(result[4]) + timing_report['THS'] = float(result[5]) + timing_report['WPWS'] = float(result[8]) + timing_report['TPWS'] = float(result[9]) + + report['TimingReport'] = timing_report + else: + print('Timing report not found.') + + latest_prj_dir = get_latest_project_prj_directory(output_dir, ProjectDir) + latest_ver_dir = get_latest_project_version_directory(latest_prj_dir, ProjectName) + file_path = os.path.join(latest_ver_dir, 'nnet_layer_results.txt') + print('Results in nnet_layer_results.txt from:', file_path) + + # Initialize the array + report['PerLayerQOFR'] = [] + # Open the file and read its contents + with open(file_path) as file: + # Read each line and append it to the list + for line in file: + report['PerLayerQOFR'].append(line.strip()) # strip() removes leading/trailing + + return report + + +def get_latest_project_version_directory(base_path, ProjectName): + versions = [d for d in os.listdir(base_path) if d.startswith(ProjectName + '.v')] + if not versions: + raise FileNotFoundError('Error: No versions found.') + latest_version = max(versions) + return os.path.join(base_path, latest_version) + + +def get_latest_project_prj_directory(base_path, ProjectDir): + versions = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d)) and d.startswith(ProjectDir)] + if not versions: + raise FileNotFoundError('Error: No versions found.') + latest_version = max(versions) + return os.path.join(base_path, latest_version) + + +def qofr(report): + # Access the PerLayerQOFR list from the report dictionary + PerLayerQOFR = report.get('PerLayerQOFR', []) + + # Check if the list is not empty + if PerLayerQOFR: + # print('Results in nnet_layer_results.txt:') + # Iterate over each line in the list and print it + for line in PerLayerQOFR: + print(line) + else: + print('No results found in nnet_layer_results.txt') diff --git a/hls4ml/templates/catapult/ac_math b/hls4ml/templates/catapult/ac_math new file mode 160000 index 0000000000..3696be957d --- /dev/null +++ b/hls4ml/templates/catapult/ac_math @@ -0,0 +1 @@ +Subproject commit 3696be957d0b0fa0a285f90382d75c8a521d77ee diff --git a/hls4ml/templates/catapult/ac_simutils b/hls4ml/templates/catapult/ac_simutils new file mode 160000 index 0000000000..9dfe23415c --- /dev/null +++ b/hls4ml/templates/catapult/ac_simutils @@ -0,0 +1 @@ +Subproject commit 9dfe23415cf670ed7c990d9a6a237d06a5a62e57 diff --git a/hls4ml/templates/catapult/ac_types b/hls4ml/templates/catapult/ac_types new file mode 160000 index 0000000000..134dcb1a05 --- /dev/null +++ b/hls4ml/templates/catapult/ac_types @@ -0,0 +1 @@ +Subproject commit 134dcb1a05e16f242de593b9c9a33f6aa08c66e6 diff --git a/hls4ml/templates/catapult/build_lib.sh b/hls4ml/templates/catapult/build_lib.sh new file mode 100755 index 0000000000..2c7a11c626 --- /dev/null +++ b/hls4ml/templates/catapult/build_lib.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +CC=g++ +if [[ "$OSTYPE" == "linux-gnu" ]]; then + CFLAGS="-O3 -fPIC -std=c++11 -fno-gnu-unique" +elif [[ "$OSTYPE" == "linux"* ]]; then + CFLAGS="-O3 -fPIC -std=c++11 -fno-gnu-unique -Wno-pragmas" +elif [[ "$OSTYPE" == "darwin"* ]]; then + CFLAGS="-O3 -fPIC -std=c++11" +fi +LDFLAGS= + +# Pick up AC libraries from Catapult install first +INCFLAGS="-I$MGC_HOME/shared/include -I$MGC_HOME/shared/include/nnet_utils -Ifirmware/ac_types/include -Ifirmware/ac_math/include -Ifirmware/ac_simutils/include -Ifirmware/nnet_utils" +PROJECT=myproject +LIB_STAMP=mystamp + +${CC} ${CFLAGS} ${INCFLAGS} -c firmware/${PROJECT}.cpp -o ${PROJECT}.o +${CC} ${CFLAGS} ${INCFLAGS} -c ${PROJECT}_bridge.cpp -o ${PROJECT}_bridge.o +${CC} ${CFLAGS} ${INCFLAGS} -shared ${PROJECT}.o ${PROJECT}_bridge.o -o firmware/${PROJECT}-${LIB_STAMP}.so +rm -f *.o diff --git a/hls4ml/templates/catapult/build_prj.tcl b/hls4ml/templates/catapult/build_prj.tcl new file mode 100755 index 0000000000..7ee4d640dd --- /dev/null +++ b/hls4ml/templates/catapult/build_prj.tcl @@ -0,0 +1,356 @@ +################# +# HLS4ML +################# +array set opt { + reset 0 + csim 0 + synth 1 + cosim 0 + validation 0 + vhdl 1 + verilog 1 + export 0 + vsynth 0 + bitfile 0 + fifo_opt 0 + ran_frame 2 + sw_opt 0 + power 0 + da 0 + bup 0 +} + +# Get pathname to this script to use as dereference path for relative file pathnames +set sfd [file dirname [info script]] + +if { [info exists ::argv] } { + foreach arg $::argv { + foreach {optname optval} [split $arg '='] {} + if { [info exists opt($optname)] } { + if {[string is integer -strict $optval]} { + set opt($optname) $optval + } else { + set opt($optname) [string is true -strict $optval] + } + } + } +} + +puts "***** INVOKE OPTIONS *****" +foreach x [lsort [array names opt]] { + puts "[format { %-20s %s} $x $opt($x)]" +} +puts "" + +proc report_time { op_name time_start time_end } { + set time_taken [expr $time_end - $time_start] + set time_s [expr ($time_taken / 1000) % 60] + set time_m [expr ($time_taken / (1000*60)) % 60] + set time_h [expr ($time_taken / (1000*60*60)) % 24] + puts "***** ${op_name} COMPLETED IN ${time_h}h${time_m}m${time_s}s *****" +} + +proc setup_xilinx_part { part } { + # Map Xilinx PART into Catapult library names + set part_sav $part + set libname [lindex [library get /CONFIG/PARAMETERS/Vivado/PARAMETERS/Xilinx/PARAMETERS/*/PARAMETERS/*/PARAMETERS/$part/LIBRARIES/*/NAME -match glob -ret v] 0] + puts "Library Name: $libname" + if { [llength $libname] == 1 } { + set libpath [library get /CONFIG/PARAMETERS/Vivado/PARAMETERS/Xilinx/PARAMETERS/*/PARAMETERS/*/PARAMETERS/$part/LIBRARIES/*/NAME -match glob -ret p] + puts "Library Path: $libpath" + if { [regexp {/CONFIG/PARAMETERS/(\S+)/PARAMETERS/(\S+)/PARAMETERS/(\S+)/PARAMETERS/(\S+)/PARAMETERS/(\S+)/.*} $libpath dummy rtltool vendor family speed part] } { + solution library add $libname -- -rtlsyntool $rtltool -vendor $vendor -family $family -speed $speed -part $part_sav + } else { + solution library add $libname -- -rtlsyntool Vivado + } + } else { + logfile message "Could not find specific Xilinx base library for part '$part'. Using KINTEX-u\n" warning + solution library add mgc_Xilinx-KINTEX-u-2_beh -- -rtlsyntool Vivado -manufacturer Xilinx -family KINTEX-u -speed -2 -part xcku115-flvb2104-2-i + } + solution library add Xilinx_RAMS + solution library add Xilinx_ROMS + solution library add Xilinx_FIFO +} + + +proc setup_asic_libs { args } { + set do_saed 0 + foreach lib $args { + solution library add $lib -- -rtlsyntool DesignCompiler + if { [lsearch -exact {saed32hvt_tt0p78v125c_beh saed32lvt_tt0p78v125c_beh saed32rvt_tt0p78v125c_beh} $lib] != -1 } { + set do_saed 1 + } + } + solution library add ccs_sample_mem + solution library add ccs_sample_rom + solution library add hls4ml_lib + go libraries + + # special exception for SAED32 for use in power estimation + if { $do_saed } { + # SAED32 selected - enable DC settings to access Liberty data for power estimation + source [application get /SYSTEM/ENV_MGC_HOME]/pkgs/siflibs/saed/setup_saedlib.tcl + } +} + +options set Input/CppStandard {c++17} +options set Input/CompilerFlags -DRANDOM_FRAMES=$opt(ran_frame) +options set Input/SearchPath {$MGC_HOME/shared/include/nnet_utils} -append +options set ComponentLibs/SearchPath {$MGC_HOME/shared/pkgs/ccs_hls4ml} -append + +if {$opt(reset)} { + project load CATAPULT_DIR.ccs + go new +} else { + project new -name CATAPULT_DIR +} + +#-------------------------------------------------------- +# Configure Catapult Options +# downgrade HIER-10 +options set Message/ErrorOverride HIER-10 -remove +solution options set Message/ErrorOverride HIER-10 -remove + +if {$opt(vhdl)} { + options set Output/OutputVHDL true +} else { + options set Output/OutputVHDL false +} +if {$opt(verilog)} { + options set Output/OutputVerilog true +} else { + options set Output/OutputVerilog false +} + +#-------------------------------------------------------- +# Configure Catapult Flows +if { [info exists ::env(XILINX_PCL_CACHE)] } { +options set /Flows/Vivado/PCL_CACHE $::env(XILINX_PCL_CACHE) +solution options set /Flows/Vivado/PCL_CACHE $::env(XILINX_PCL_CACHE) +} + +# Turn on HLS4ML flow (wrapped in a cache so that older Catapult installs still work) +catch {flow package require /HLS4ML} + +# Turn on SCVerify flow +flow package require /SCVerify +# flow package option set /SCVerify/INVOKE_ARGS {$sfd/firmware/weights $sfd/tb_data/tb_input_features.dat $sfd/tb_data/tb_output_predictions.dat} +#hls-fpga-machine-learning insert invoke_args + +# Turn on VSCode flow +# flow package require /VSCode +# To launch VSCode on the C++ HLS design: +# cd my-Catapult-test +# code Catapult.code-workspace + +#-------------------------------------------------------- +# Start of HLS script +set design_top myproject +solution file add $sfd/firmware/myproject.cpp +solution file add $sfd/myproject_test.cpp -exclude true + +# Parse parameters.h to determine config info to control directives/pragmas +set IOType io_stream +if { ![file exists $sfd/firmware/parameters.h] } { + logfile message "Could not locate firmware/parameters.h. Unable to determine network configuration.\n" warning +} else { + set pf [open "$sfd/firmware/parameters.h" "r"] + while {![eof $pf]} { + gets $pf line + if { [string match {*io_type = nnet::io_stream*} $line] } { + set IOType io_stream + break + } + } + close $pf +} + +if { $IOType == "io_stream" } { +solution options set Architectural/DefaultRegisterThreshold 2050 +} +directive set -RESET_CLEARS_ALL_REGS no +# Constrain arrays to map to memory only over a certain size +directive set -MEM_MAP_THRESHOLD [expr 2048 * 16 + 1] +# The following line gets modified by the backend writer +set hls_clock_period 5 + +go analyze + +# NORMAL TOP DOWN FLOW +if { ! $opt(bup) } { + +go compile + +if {$opt(csim)} { + puts "***** C SIMULATION *****" + set time_start [clock clicks -milliseconds] + flow run /SCVerify/launch_make ./scverify/Verify_orig_cxx_osci.mk {} SIMTOOL=osci sim + set time_end [clock clicks -milliseconds] + report_time "C SIMULATION" $time_start $time_end +} + +puts "***** SETTING TECHNOLOGY LIBRARIES *****" +#hls-fpga-machine-learning insert techlibs + +directive set -CLOCKS [list clk [list -CLOCK_PERIOD $hls_clock_period -CLOCK_EDGE rising -CLOCK_OFFSET 0.000000 -CLOCK_UNCERTAINTY 0.0 -RESET_KIND sync -RESET_SYNC_NAME rst -RESET_SYNC_ACTIVE high -RESET_ASYNC_NAME arst_n -RESET_ASYNC_ACTIVE low -ENABLE_NAME {} -ENABLE_ACTIVE high]] + +if {$opt(synth)} { + puts "***** C/RTL SYNTHESIS *****" + set time_start [clock clicks -milliseconds] + + go assembly + + go architect + + go allocate + + go schedule + + go extract + set time_end [clock clicks -milliseconds] + report_time "C/RTL SYNTHESIS" $time_start $time_end +} + +# BOTTOM-UP FLOW +} else { + # Start at 'go analyze' + go analyze + + # Build the design bottom-up + directive set -CLOCKS [list clk [list -CLOCK_PERIOD $hls_clock_period -CLOCK_EDGE rising -CLOCK_OFFSET 0.000000 -CLOCK_UNCERTAINTY 0.0 -RESET_KIND sync -RESET_SYNC_NAME rst -RESET_SYNC_ACTIVE high -RESET_ASYNC_NAME arst_n -RESET_ASYNC_ACTIVE low -ENABLE_NAME {} -ENABLE_ACTIVE high]] + + set blocks [solution get /HIERCONFIG/USER_HBS/*/RESOLVED_NAME -match glob -rec 1 -ret v -state analyze] + set bu_mappings {} + set top [lindex $blocks 0] + foreach block [lreverse [lrange $blocks 1 end]] { + # skip blocks that are net nnet:: functions + if { [string match {nnet::*} $block] == 0 } { continue } + go analyze + solution design set $block -top + go compile + solution library remove * + puts "***** SETTING TECHNOLOGY LIBRARIES *****" +#hls-fpga-machine-learning insert techlibs + go extract + set block_soln "[solution get /TOP/name -checkpath 0].[solution get /VERSION -checkpath 0]" + lappend bu_mappings [solution get /CAT_DIR] /$top/$block "\[Block\] $block_soln" + } + + # Move to top design + go analyze + solution design set $top -top + go compile + + if {$opt(csim)} { + puts "***** C SIMULATION *****" + set time_start [clock clicks -milliseconds] + flow run /SCVerify/launch_make ./scverify/Verify_orig_cxx_osci.mk {} SIMTOOL=osci sim + set time_end [clock clicks -milliseconds] + report_time "C SIMULATION" $time_start $time_end + } + foreach {d i l} $bu_mappings { + logfile message "solution options set ComponentLibs/SearchPath $d -append\n" info + solution options set ComponentLibs/SearchPath $d -append + } + + # Add bottom-up blocks + puts "***** SETTING TECHNOLOGY LIBRARIES *****" + solution library remove * +#hls-fpga-machine-learning insert techlibs + # need to revert back to go compile + go compile + foreach {d i l} $bu_mappings { + logfile message "solution library add [list $l]\n" info + eval solution library add [list $l] + } + go libraries + + # Map to bottom-up blocks + foreach {d i l} $bu_mappings { + # Make sure block exists + set cnt [directive get $i/* -match glob -checkpath 0 -ret p] + if { $cnt != {} } { + logfile message "directive set $i -MAP_TO_MODULE [list $l]\n" info + eval directive set $i -MAP_TO_MODULE [list $l] + } + } + go assembly + set design [solution get -name] + logfile message "Adjusting FIFO_DEPTH for top-level interconnect channels\n" warning + # FIFO interconnect between layers + foreach ch_fifo_m2m [directive get -match glob -checkpath 0 -ret p $design/*_out:cns/MAP_TO_MODULE] { + set ch_fifo [join [lrange [split $ch_fifo_m2m '/'] 0 end-1] /]/FIFO_DEPTH + logfile message "directive set -match glob $ch_fifo 1\n" info + directive set -match glob "$ch_fifo" 1 + } + # For bypass paths - the depth will likely need to be larger than 1 + foreach ch_fifo_m2m [directive get -match glob -checkpath 0 -ret p $design/*_cpy*:cns/MAP_TO_MODULE] { + set ch_fifo [join [lrange [split $ch_fifo_m2m '/'] 0 end-1] /]/FIFO_DEPTH + logfile message "Bypass FIFO '$ch_fifo' depth set to 1 - larger value may be required to prevent deadlock\n" warning + logfile message "directive set -match glob $ch_fifo 1\n" info + directive set -match glob "$ch_fifo" 1 + } + go architect + go allocate + go schedule + go dpfsm + go extract +} + +project save + +if {$opt(cosim) || $opt(validation)} { + if {$opt(verilog)} { + flow run /SCVerify/launch_make ./scverify/Verify_rtl_v_msim.mk {} SIMTOOL=msim sim + } + if {$opt(vhdl)} { + flow run /SCVerify/launch_make ./scverify/Verify_rtl_vhdl_msim.mk {} SIMTOOL=msim sim + } +} + +if {$opt(export)} { + puts "***** EXPORT IP *****" + set time_start [clock clicks -milliseconds] +# Not yet implemented. Do we need to include value of $version ? +# flow package option set /Vivado/BoardPart xilinx.com:zcu102:part0:3.1 +# flow package option set /Vivado/IP_Taxonomy {/Catapult} +# flow run /Vivado/launch_package_ip -shell ./vivado_concat_v/concat_v_package_ip.tcl + set time_end [clock clicks -milliseconds] + report_time "EXPORT IP" $time_start $time_end +} +if {$opt(sw_opt)} { + puts "***** Pre Power Optimization *****" + go switching + if {$opt(verilog)} { + flow run /PowerAnalysis/report_pre_pwropt_Verilog + } + if {$opt(vhdl)} { + flow run /PowerAnalysis/report_pre_pwropt_VHDL + } +} + +if {$opt(power)} { + puts "***** Power Optimization *****" + go power +} + +if {$opt(vsynth)} { + puts "***** VIVADO SYNTHESIS *****" + set time_start [clock clicks -milliseconds] + flow run /Vivado/synthesize -shell vivado_concat_v/concat_rtl.v.xv + set time_end [clock clicks -milliseconds] + report_time "VIVADO SYNTHESIS" $time_start $time_end +} + +if {$opt(bitfile)} { + puts "***** Option bitfile not supported yet *****" +} + +if {$opt(da)} { + puts "***** Launching DA *****" + flow run /DesignAnalyzer/launch +} + +if { [catch {flow package present /HLS4ML}] == 0 } { + flow run /HLS4ML/collect_reports +} diff --git a/hls4ml/templates/catapult/catapult_synth.tcl b/hls4ml/templates/catapult/catapult_synth.tcl new file mode 100644 index 0000000000..6d80a33ef5 --- /dev/null +++ b/hls4ml/templates/catapult/catapult_synth.tcl @@ -0,0 +1,3 @@ +add_files myproject_prj/solution1/syn/vhdl +synth_design -top myproject -part xcku115-flvb2104-2-i +report_utilization -file vivado_synth.rpt diff --git a/hls4ml/templates/catapult/firmware/defines.h b/hls4ml/templates/catapult/firmware/defines.h new file mode 100755 index 0000000000..c5601779e4 --- /dev/null +++ b/hls4ml/templates/catapult/firmware/defines.h @@ -0,0 +1,15 @@ +#ifndef DEFINES_H_ +#define DEFINES_H_ + +#include "nnet_utils/nnet_types.h" +#include +#include +#include +#include +#include + +// hls-fpga-machine-learning insert numbers + +// hls-fpga-machine-learning insert layer-precision + +#endif diff --git a/hls4ml/templates/catapult/firmware/myproject.cpp b/hls4ml/templates/catapult/firmware/myproject.cpp new file mode 100755 index 0000000000..bdb0570f8b --- /dev/null +++ b/hls4ml/templates/catapult/firmware/myproject.cpp @@ -0,0 +1,29 @@ +#include + +#include "myproject.h" +#include "parameters.h" + +#include + +#pragma hls_design top +// hls-fpga-machine-learning insert IFSynPragmas +void CCS_BLOCK(myproject)( + // hls-fpga-machine-learning insert header +) { + + // hls-fpga-machine-learning insert IO + +#ifndef __SYNTHESIS__ + static bool loaded_weights = false; + if (!loaded_weights) { + // hls-fpga-machine-learning insert load weights + loaded_weights = true; + } +#endif + + // **************************************** + // NETWORK INSTANTIATION + // **************************************** + + // hls-fpga-machine-learning insert layers +} diff --git a/hls4ml/templates/catapult/firmware/myproject.h b/hls4ml/templates/catapult/firmware/myproject.h new file mode 100755 index 0000000000..dd73c3e807 --- /dev/null +++ b/hls4ml/templates/catapult/firmware/myproject.h @@ -0,0 +1,15 @@ +#ifndef MYPROJECT_H_ +#define MYPROJECT_H_ + +#include +#include +#include + +#include "defines.h" + +// Prototype of top level function for C-synthesis +void myproject( + // hls-fpga-machine-learning insert header +); + +#endif diff --git a/hls4ml/templates/catapult/firmware/parameters.h b/hls4ml/templates/catapult/firmware/parameters.h new file mode 100755 index 0000000000..2915c145c8 --- /dev/null +++ b/hls4ml/templates/catapult/firmware/parameters.h @@ -0,0 +1,15 @@ +#ifndef PARAMETERS_H_ +#define PARAMETERS_H_ + +#include +#include + +#include "nnet_utils/nnet_code_gen.h" +#include "nnet_utils/nnet_helpers.h" +// hls-fpga-machine-learning insert includes + +// hls-fpga-machine-learning insert weights + +// hls-fpga-machine-learning insert layer-config + +#endif diff --git a/hls4ml/templates/catapult/myproject_bridge.cpp b/hls4ml/templates/catapult/myproject_bridge.cpp new file mode 100755 index 0000000000..f1326a1faf --- /dev/null +++ b/hls4ml/templates/catapult/myproject_bridge.cpp @@ -0,0 +1,72 @@ +#ifndef MYPROJECT_BRIDGE_H_ +#define MYPROJECT_BRIDGE_H_ + +#include "firmware/myproject.h" +#include "nnet_helpers.h" +#include +#include + +static std::string s_weights_dir = "weights"; + +const char *get_weights_dir() { return s_weights_dir.c_str(); } + +// hls-fpga-machine-learning insert bram + +// hls-fpga-machine-learning insert declare weights + +namespace nnet { +bool trace_enabled = false; +std::map *trace_outputs = NULL; +size_t trace_type_size = sizeof(double); +} // namespace nnet + +extern "C" { + +struct trace_data { + const char *name; + void *data; +}; + +void allocate_trace_storage(size_t element_size) { + nnet::trace_enabled = true; + nnet::trace_outputs = new std::map; + nnet::trace_type_size = element_size; + // hls-fpga-machine-learning insert trace_outputs +} + +void free_trace_storage() { + for (std::map::iterator i = nnet::trace_outputs->begin(); i != nnet::trace_outputs->end(); i++) { + void *ptr = i->second; + free(ptr); + } + nnet::trace_outputs->clear(); + delete nnet::trace_outputs; + nnet::trace_outputs = NULL; + nnet::trace_enabled = false; +} + +void collect_trace_output(struct trace_data *c_trace_outputs) { + int ii = 0; + for (std::map::iterator i = nnet::trace_outputs->begin(); i != nnet::trace_outputs->end(); i++) { + c_trace_outputs[ii].name = i->first.c_str(); + c_trace_outputs[ii].data = i->second; + ii++; + } +} + +// Wrapper of top level function for Python bridge +void myproject_float( + // hls-fpga-machine-learning insert header #float +) { + + // hls-fpga-machine-learning insert wrapper #float +} + +void myproject_double( + // hls-fpga-machine-learning insert header #double +) { + // hls-fpga-machine-learning insert wrapper #double +} +} + +#endif diff --git a/hls4ml/templates/catapult/myproject_test.cpp b/hls4ml/templates/catapult/myproject_test.cpp new file mode 100755 index 0000000000..66b87f6741 --- /dev/null +++ b/hls4ml/templates/catapult/myproject_test.cpp @@ -0,0 +1,164 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static std::string s_weights_dir; + +const char *get_weights_dir() { return s_weights_dir.c_str(); } + +#include "firmware/myproject.h" +#include "nnet_utils/nnet_helpers.h" +// #include "firmware/parameters.h" + +#include + +// hls-fpga-machine-learning insert bram + +#define CHECKPOINT 5000 + +#ifndef RANDOM_FRAMES +#define RANDOM_FRAMES 1 +#endif + +// hls-fpga-machine-learning insert declare weights + +namespace nnet { +bool trace_enabled = true; +std::map *trace_outputs = NULL; +size_t trace_type_size = sizeof(double); +} // namespace nnet + +CCS_MAIN(int argc, char *argv[]) { + if (argc < 2) { + std::cerr << "Error - too few arguments" << std::endl; + std::cerr << "Usage: " << argv[0] << " " << std::endl; + std::cerr << "Where: - string pathname to directory containing wN.txt and bN.txt files" + << std::endl; + std::cerr << " - string pathname to tb_input_features.dat (optional)" << std::endl; + std::cerr << " - string pathname to tb_output_predictions.dat (optional)" << std::endl; + std::cerr << std::endl; + std::cerr << "If no testbench input/prediction data provided, random input data will be generated" << std::endl; + CCS_RETURN(-1); + } + s_weights_dir = argv[1]; + std::cout << " Weights directory: " << s_weights_dir << std::endl; + + std::string tb_in; + std::string tb_out; + std::ifstream fin; + std::ifstream fpr; + bool use_random = false; + if (argc == 2) { + std::cout << "No testbench files provided - Using random input data" << std::endl; + use_random = true; + } else { + tb_in = argv[2]; + tb_out = argv[3]; + std::cout << " Test Feature Data: " << tb_in << std::endl; + std::cout << " Test Predictions : " << tb_out << std::endl; + + // load input data from text file + fin.open(tb_in); + // load predictions from text file + fpr.open(tb_out); + if (!fin.is_open() || !fpr.is_open()) { + use_random = true; + } + } + +#ifdef RTL_SIM + std::string RESULTS_LOG = "tb_data/rtl_cosim_results.log"; +#else + std::string RESULTS_LOG = "tb_data/csim_results.log"; +#endif + std::ofstream fout(RESULTS_LOG); + +#ifndef __SYNTHESIS__ + static bool loaded_weights = false; + if (!loaded_weights) { + // hls-fpga-machine-learning insert load weights + loaded_weights = true; + } +#endif + std::string iline; + std::string pline; + int e = 0; + + if (!use_random) { + while (std::getline(fin, iline) && std::getline(fpr, pline)) { + if (e % CHECKPOINT == 0) + std::cout << "Processing input " << e << std::endl; + char *cstr = const_cast(iline.c_str()); + char *current; + std::vector in; + current = strtok(cstr, " "); + while (current != NULL) { + in.push_back(atof(current)); + current = strtok(NULL, " "); + } + cstr = const_cast(pline.c_str()); + std::vector pr; + current = strtok(cstr, " "); + while (current != NULL) { + pr.push_back(atof(current)); + current = strtok(NULL, " "); + } + // std::cout << " Input feature map size = " << in.size() << " Output predictions size = " << pr.size() << + // std::endl; + + // hls-fpga-machine-learning insert data + + // hls-fpga-machine-learning insert top-level-function + + if (e % CHECKPOINT == 0) { + std::cout << "Predictions" << std::endl; + // hls-fpga-machine-learning insert predictions + std::cout << "Quantized predictions" << std::endl; + // hls-fpga-machine-learning insert quantized + } + e++; + + // hls-fpga-machine-learning insert tb-output + } + if (fin.is_open()) { + fin.close(); + } + if (fpr.is_open()) { + fpr.close(); + } + } else { + std::cout << "INFO: Unable to open input/predictions file(s) so feeding random values" << std::endl; + std::cout << "Number of Frames Passed from the tcl= " << RANDOM_FRAMES << std::endl; + + if (RANDOM_FRAMES > 0) { + for (unsigned int k = 0; k < RANDOM_FRAMES; k++) { + // hls-fpga-machine-learning insert random + + // hls-fpga-machine-learning insert top-level-function + + // hls-fpga-machine-learning insert output + + // hls-fpga-machine-learning insert tb-output + } + } else { + // hls-fpga-machine-learning insert zero + + // hls-fpga-machine-learning insert top-level-function + + // hls-fpga-machine-learning insert output + + // hls-fpga-machine-learning insert tb-output + } + } + + fout.close(); + std::cout << "INFO: Saved inference results to file: " << RESULTS_LOG << std::endl; + + return 0; +} diff --git a/hls4ml/templates/catapult/nnet_utils/ap_shift_reg.h b/hls4ml/templates/catapult/nnet_utils/ap_shift_reg.h new file mode 100644 index 0000000000..0645efa73f --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/ap_shift_reg.h @@ -0,0 +1,136 @@ +/* +#- (c) Copyright 2011-2019 Xilinx, Inc. All rights reserved. +#- +#- This file contains confidential and proprietary information +#- of Xilinx, Inc. and is protected under U.S. and +#- international copyright and other intellectual property +#- laws. +#- +#- DISCLAIMER +#- This disclaimer is not a license and does not grant any +#- rights to the materials distributed herewith. Except as +#- otherwise provided in a valid license issued to you by +#- Xilinx, and to the maximum extent permitted by applicable +#- law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND +#- WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES +#- AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING +#- BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON- +#- INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and +#- (2) Xilinx shall not be liable (whether in contract or tort, +#- including negligence, or under any other theory of +#- liability) for any loss or damage of any kind or nature +#- related to, arising under or in connection with these +#- materials, including for any direct, or any indirect, +#- special, incidental, or consequential loss or damage +#- (including loss of data, profits, goodwill, or any type of +#- loss or damage suffered as a result of any action brought +#- by a third party) even if such damage or loss was +#- reasonably foreseeable or Xilinx had been advised of the +#- possibility of the same. +#- +#- CRITICAL APPLICATIONS +#- Xilinx products are not designed or intended to be fail- +#- safe, or for use in any application requiring fail-safe +#- performance, such as life-support or safety devices or +#- systems, Class III medical devices, nuclear facilities, +#- applications related to the deployment of airbags, or any +#- other applications that could lead to death, personal +#- injury, or severe property or environmental damage +#- (individually and collectively, "Critical +#- Applications"). Customer assumes the sole risk and +#- liability of any use of Xilinx products in Critical +#- Applications, subject only to applicable laws and +#- regulations governing limitations on product liability. +#- +#- THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS +#- PART OF THIS FILE AT ALL TIMES. +#- ************************************************************************ + + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#ifndef __SIM_AP_SHIFT_REG_H__ +#define __SIM_AP_SHIFT_REG_H__ + +/* + * This file contains a C++ model of shift register. + * It defines C level simulation model. + */ +#ifndef __cplusplus +#error C++ is required to include this header file +#else + +#ifndef __SYNTHESIS__ +#include +#endif + +////////////////////////////////////////////// +// C level simulation model for ap_shift_reg +////////////////////////////////////////////// +template class ap_shift_reg { + public: + /// Constructors + ap_shift_reg() { + for (unsigned int i = 0; i < __SHIFT_DEPTH__; i++) { + __SHIFT_T__ dummy; + Array[i] = dummy; // uninitialize so Catapult does not add a reset + } + } + ap_shift_reg(const char *name) {} + /// Destructor + virtual ~ap_shift_reg() {} + + private: + /// Make copy constructor and assignment operator private + ap_shift_reg(const ap_shift_reg<__SHIFT_T__, __SHIFT_DEPTH__> &shreg) { + for (unsigned i = 0; i < __SHIFT_DEPTH__; ++i) + Array[i] = shreg.Array[i]; + } + + ap_shift_reg &operator=(const ap_shift_reg<__SHIFT_T__, __SHIFT_DEPTH__> &shreg) { + for (unsigned i = 0; i < __SHIFT_DEPTH__; ++i) + Array[i] = shreg.Array[i]; + return *this; + } + + public: + // Shift the queue, push to back and read from a given address. + __SHIFT_T__ shift(__SHIFT_T__ DataIn, unsigned int Addr = __SHIFT_DEPTH__ - 1, bool Enable = true) { +#ifndef __SYNTHESIS__ + assert(Addr < __SHIFT_DEPTH__ && "Out-of-bound shift is found in ap_shift_reg."); +#endif + __SHIFT_T__ ret = Array[Addr]; + if (Enable) { + for (unsigned int i = __SHIFT_DEPTH__ - 1; i > 0; --i) + Array[i] = Array[i - 1]; + Array[0] = DataIn; + } + return ret; + } + + // Read from a given address. + __SHIFT_T__ read(unsigned int Addr = __SHIFT_DEPTH__ - 1) const { +#ifndef __SYNTHESIS__ + assert(Addr < __SHIFT_DEPTH__ && "Out-of-bound read is found in ap_shift_reg."); +#endif + return Array[Addr]; + } + + protected: + __SHIFT_T__ Array[__SHIFT_DEPTH__]; +}; + +#endif //__cplusplus + +#endif //__SIM_AP_SHIFT_REG_H__ diff --git a/hls4ml/templates/catapult/nnet_utils/hls_math.h b/hls4ml/templates/catapult/nnet_utils/hls_math.h new file mode 100755 index 0000000000..ea05fe122a --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/hls_math.h @@ -0,0 +1,24 @@ +#ifndef X_HLS_MATH_H +#define X_HLS_MATH_H + +#include "ac_fixed.h" +#include + +namespace hls { + +template static T exp(const T x) { return (T)std::exp(x.to_double()); } + +template T sin(T x) { return (T)std::sin(x.to_double()); }; + +template T cos(T x) { return (T)std::cos(x.to_double()); }; + +template T asin(T x) { return (T)std::asin(x.to_double()); }; + +template T acos(T x) { return (T)std::acos(x.to_double()); }; + +template T atan(T x) { return (T)std::atan(x.to_double()); }; + +template T atan2(T x, T y) { return (T)hls::atan2(x.to_double(), y.to_double()); }; + +} // namespace hls +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_activation.h b/hls4ml/templates/catapult/nnet_utils/nnet_activation.h new file mode 100644 index 0000000000..f08e75a0d6 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_activation.h @@ -0,0 +1,1107 @@ + +// Change History: +// 2022-06-30 dgburnette - Cleaned up code to separate AC Math from LUT code. +// Added LUT dump to text file. +// Activation functions not implemented in AC Math will assert. +// 2022-06-28 dgburnette - Replaced AP Types with AC Datatypes. +// Commented out all Vivado pragmas. +// Added Catapult hierarchy pragmas. +// Started replacement of activation functions with +// AC Math piecewise-linear versions. + +#ifndef NNET_ACTIVATION_H_ +#define NNET_ACTIVATION_H_ + +// Define this macro to switch the implementations of certain activiation functions +// from the original HLS4ML look-up table approach to using the piecewise-linear approximation +// functions in AC Math. +#define USE_AC_MATH 1 + +#if !defined(USE_AC_MATH) && !defined(__SYNTHESIS__) +// Define a macro that causes the look-up table generation code to dump out text files +// of the array contents. +// #define BUILD_TABLE_FILE 1 +#endif + +#include "ac_fixed.h" +#include "ac_std_float.h" +#include "nnet_common.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nnet { + +struct activ_config { + // IO size + static const unsigned n_in = 10; + + // Internal info + static const unsigned table_size = 1024; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + + // Internal data type definitions + typedef ac_fixed<18, 8, true> table_t; +}; + +// ************************************************* +// LINEAR Activation -- See Issue 53 +// ************************************************* +template void linear(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + res[ii] = data[ii]; + } +} + +// ************************************************* +// RELU Activation +// ************************************************* +template void relu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + data_T datareg; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; +#ifndef USE_AC_MATH + if (datareg > 0) + res[ii] = datareg; + else + res[ii] = 0; +#else + ac_math::ac_relu(datareg, res[ii]); +#endif + } +} + +template +void relu_max(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + data_T datareg; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + if (datareg < 0) + res[ii] = 0; + else if (datareg > MAX_INT) + res[ii] = MAX_INT; + else + res[ii] = datareg; + } +} + +template void relu6(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + relu_max(data, res); +} + +template void relu1(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + relu_max(data, res); +} + +// ************************************************* +// Sigmoid Activation +// ************************************************* + +template +void ac_sigmoid_pwl_wrapper(const ac_fixed(&input) /*[K]*/, + ac_fixed(&output) /*[K]*/) { + ac_fixed tmp; //[K]; + ac_math::ac_sigmoid_pwl(input, tmp); + output = tmp; +} + +inline float sigmoid_fcn_float(float input) { return 1.0 / (1 + std::exp(-input)); } + +template void init_sigmoid_table(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "sigmoid_table%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_sigmoid_table()\n"); +#endif + // Default logistic sigmoid function: + // result = 1/(1+e^(-x)) + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range -8 to +8) + float in_val = 2 * 8.0 * (ii - float(N_TABLE) / 2.0) / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::table_t real_val = sigmoid_fcn_float(in_val); + // std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << std::endl; + table_out[ii] = real_val; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", sigmoid_fcn_float(in_val)); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // sigmoid(%32.31f)", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template +void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t sigmoid_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t sigmoid_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_sigmoid_table(sigmoid_table); + initialized = true; + } + + //#pragma HLS PIPELINE + + // Index into the lookup table based on data + int data_round; + int index; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + data_round = data[ii].to_double() * (int)CONFIG_T::table_size / 16; + index = data_round + 8 * (int)CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + res[ii] = (res_T)sigmoid_table[index]; + } +} + +#else + +template +void sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + // res[ii] = ac_math::ac_sigmoid_pwl(data[ii]); + ac_sigmoid_pwl_wrapper(data[ii], res[ii]); + } +} + +#endif + +// ************************************************* +// Softmax Activation +// ************************************************* + +enum class softmax_implementation { latency = 0, legacy = 1, stable = 2 }; + +inline float exp_fcn_float(float input) { return std::exp(input); } + +template inline float softmax_real_val_from_idx(unsigned i) { + // Treat the index as the top N bits + static constexpr int N = ceillog2(CONFIG_T::table_size); // number of address bits for table + data_T x(0); + // CATAPULT_PORT + // x(x.width-1, x.width-N) = i; + ac_int tmp = i; + x.template set_slc(x.width - N, tmp); + return (float)x.to_double(); +} + +template inline unsigned softmax_idx_from_real_val(data_T x) { + // Slice the top N bits to get an index into the table + static constexpr int N = ceillog2(CONFIG_T::table_size); // number of address bits for table + // CATAPULT_PORT + // ac_int y = x(x.width-1, x.width-N); // slice the top N bits of input + // return (unsigned) y(N-1, 0); + ac_int y = x.template slc(x.width - N); // slice the top N bits of input + return (unsigned)y.template slc(0); +} + +template +void init_exp_table(typename CONFIG_T::exp_table_t table_out[CONFIG_T::table_size]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "exp_table%d.tab", CONFIG_T::table_size); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_exp_table()\n"); +#endif + // The template data_T is the data type used to address the table + for (unsigned i = 0; i < CONFIG_T::table_size; i++) { + // Slicing bits for address is going to round towards 0, so take the central value + float x = softmax_real_val_from_idx(i); + typename CONFIG_T::exp_table_t exp_x = exp_fcn_float(x); + table_out[i] = exp_x; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", exp_fcn_float(x)); + if (i < CONFIG_T::table_size - 1) + fprintf(f, ","); + fprintf(f, " // exp(%32.31f)", x); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +template +void init_invert_table(typename CONFIG_T::inv_table_t table_out[CONFIG_T::table_size]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "invert_table%d.tab", CONFIG_T::table_size); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_invert_table()\n"); +#endif + // The template data_T is the data type used to address the table + for (unsigned i = 0; i < CONFIG_T::table_size; i++) { + float x = softmax_real_val_from_idx(i); +#ifdef __SYNTHESIS__ + // hack for now to get through the flow + typename CONFIG_T::inv_table_t inv_x = 1 + x; +#else + typename CONFIG_T::inv_table_t inv_x = 1 / x; +#endif + table_out[i] = inv_x; +#ifdef BUILD_TABLE_FILE + if (x > 0.0) + fprintf(f, "%32.31f", (1.0 / x)); + else + fprintf(f, "%32.31f", 0.0); + if (i < CONFIG_T::table_size - 1) + fprintf(f, ","); + fprintf(f, " // 1/(%32.31f)", x); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template +void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS pipeline + // Initialize the lookup tables +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + +#endif + if (!initialized) { + // Note we are exponentiating the inputs, which have type data_T + init_exp_table(exp_table); + // Note we are inverting the exponentials, which have type exp_table_t + init_invert_table(invert_table); + initialized = true; + } + + // Calculate all the e^x's + typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; + //#pragma HLS array_partition variable=exp_res complete + typename CONFIG_T::exp_table_t exp_sum(0); + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + //#pragma HLS unroll + unsigned x = softmax_idx_from_real_val(data[i]); + exp_res[i] = exp_table[x]; + } + + // Explicitly sum the results with an adder tree. + // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing + Op_add op_add; + exp_sum = + reduce>(exp_res, op_add); + + typename CONFIG_T::inv_table_t inv_exp_sum = + invert_table[softmax_idx_from_real_val(exp_sum)]; + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + //#pragma HLS unroll + res[i] = exp_res[i] * inv_exp_sum; + } +} + +template +void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS pipeline + // Initialize the lookup tables +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + +#endif + if (!initialized) { + // Note we are exponentiating the inputs, which have type data_T + init_exp_table(exp_table); + // Note we are inverting the exponentials, which have type exp_table_t + init_invert_table(invert_table); + initialized = true; + } + + // Find the max and compute all delta(x_i, x_max) + Op_max op_max; + data_T x_max = reduce>(data, op_max); + + // For the diffs, use the same type as the input but force rounding and saturation + ac_fixed d_xi_xmax[CONFIG_T::n_in]; + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + //#pragma HLS unroll + d_xi_xmax[i] = data[i] - x_max; + } + + // Calculate all the e^x's + typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; + //#pragma HLS array_partition variable=exp_res complete + typename CONFIG_T::exp_table_t exp_sum(0); + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + //#pragma HLS unroll + unsigned x = softmax_idx_from_real_val(d_xi_xmax[i]); + exp_res[i] = exp_table[x]; + } + + // Explicitly sum the results with an adder tree. + // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing + Op_add op_add; + exp_sum = + reduce>(exp_res, op_add); + + typename CONFIG_T::inv_table_t inv_exp_sum = + invert_table[softmax_idx_from_real_val(exp_sum)]; + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + //#pragma HLS unroll + res[i] = exp_res[i] * inv_exp_sum; + } +} + +#endif + +template void init_exp_table_legacy(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "exp_table_legacy%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_exp_table_legacy()\n"); +#endif + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range -8 to +8) + float in_val = 2 * 8.0 * (ii - float(N_TABLE) / 2.0) / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::table_t real_val = exp_fcn_float(in_val); + // std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << std::endl; + table_out[ii] = real_val; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", exp_fcn_float(in_val)); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // exp(%32.31f)", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +template void init_invert_table_legacy(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "invert_table_legacy%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_invert_table_legacy()\n"); +#endif + // Inversion function: + // result = 1/x + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range 0 to +64) + float in_val = 64.0 * ii / float(N_TABLE); + // Next, compute lookup table function + if (in_val > 0.0) + table_out[ii] = 1.0 / in_val; + else + table_out[ii] = 0.0; +#ifdef BUILD_TABLE_FILE + if (in_val > 0.0) + fprintf(f, "%32.31f", (1.0 / in_val)); + else + fprintf(f, "%32.31f", 0.0); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // 1/%32.31f", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template +void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; + typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; + static typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_exp_table_legacy(exp_table); + init_invert_table_legacy(invert_table); + initialized = true; + } + + //#pragma HLS PIPELINE + + // Index into the lookup table based on data for exponentials + typename CONFIG_T::table_t exp_res[CONFIG_T::n_in]; // different, independent, fixed point precision + typename CONFIG_T::table_t exp_diff_res; // different, independent, fixed point precision + data_T data_cache[CONFIG_T::n_in]; + int data_round; + int index; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + data_cache[ii] = data[ii]; + exp_res[ii] = 0; + } + + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + for (int jj = 0; jj < CONFIG_T::n_in; jj++) { + if (ii == jj) + exp_diff_res = 1; + else { + // CATAPULT_PORT + // data_round = (data_cache[jj]-data_cache[ii])*CONFIG_T::table_size/16; + auto tmp_data_round = (data_cache[jj] - data_cache[ii]) * CONFIG_T::table_size / 16; + data_round = tmp_data_round.to_int(); + index = data_round + 8 * CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = CONFIG_T::table_size - 1; + exp_diff_res = exp_table[index]; + } + exp_res[ii] += exp_diff_res; + } + } + + // Second loop to invert + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + // CATAPULT_PORT + // int exp_res_index = exp_res[ii]*CONFIG_T::table_size/64; + auto tmp_exp_res_index = exp_res[ii] * CONFIG_T::table_size / 64; + int exp_res_index = tmp_exp_res_index.to_int(); + if (exp_res_index < 0) + exp_res_index = 0; + if (exp_res_index > CONFIG_T::table_size - 1) + exp_res_index = CONFIG_T::table_size - 1; + // typename CONFIG_T::table_t exp_res_invert = invert_table[exp_res_index]; + res[ii] = (res_T)invert_table[exp_res_index]; + } +} + +template +void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + switch (CONFIG_T::implementation) { + case softmax_implementation::latency: + softmax_latency(data, res); + break; + case softmax_implementation::stable: + softmax_stable(data, res); + break; + case softmax_implementation::legacy: + softmax_legacy(data, res); + break; + } +} + +#else +// This is a workaround to help the template deduction to work correctly and fix the inconsistency that HLS4ML expects +// softmax output to be signed but AC Math softmax knows it is always unsigned +template +void ac_softmax_pwl_wrapper(const ac_fixed (&input)[K], ac_fixed (&output)[K]) { + ac_fixed tmp[K]; + ac_math::ac_softmax_pwl(input, tmp); + for (unsigned int x = 0; x < K; x++) + output[x] = tmp[x]; +} + +template +void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + data_T data_copy[CONFIG_T::n_in]; + res_T res_copy[CONFIG_T::n_in]; +// workaround for the array passing - alternative is to change the signature of all of the functions to reference-of-array +COPY_IN_ARRAY: + for (unsigned i = 0; i < CONFIG_T::n_in; i++) + data_copy[i] = data[i]; + ac_softmax_pwl_wrapper(data_copy, res_copy); +COPY_OUT_ARRAY: + for (unsigned i = 0; i < CONFIG_T::n_in; i++) + res[i] = res_copy[i]; +} + +#endif + +// ************************************************* +// TanH Activation +// ************************************************* +template void init_tanh_table(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "tanh_table%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_tanh_table()\n"); +#endif + // Implement tanh lookup + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range -4 to +4) + float in_val = 2 * 4.0 * (ii - float(N_TABLE) / 2.0) / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::table_t real_val = tanh(in_val); + // std::cout << "Tanh: Lookup table Index: " << ii<< " In Value: " << in_val << " Result: " << real_val << + // std::endl; + table_out[ii] = real_val; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", tanh(in_val)); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // tanh(%32.31f)", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template void tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t tanh_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t tanh_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_tanh_table(tanh_table); + initialized = true; + } + + //#pragma HLS PIPELINE + + // Index into the lookup table based on data + int data_round; + int index; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + data_round = data[ii].to_double() * (int)CONFIG_T::table_size / 8; + index = data_round + 4 * (int)CONFIG_T::table_size / 8; + // std::cout << "Input: " << data[ii] << " Round: " << data_round << " Index: " << index << std::endl; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + res[ii] = (res_T)tanh_table[index]; + } +} + +#else + +template void tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + // res[ii] = ac_math::ac_tanh_pwl(data[ii]); + ac_math::ac_tanh_pwl(data[ii], res[ii]); + } +} + +#endif + +// ************************************************* +// Hard sigmoid Activation +// ************************************************* +template +void hard_sigmoid(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + data_T datareg; + data_T slope = (data_T)0.2; + data_T shift = (data_T)0.5; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = slope * data[ii] + shift; + if (datareg > 1) + datareg = 1; + else if (datareg < 0) + datareg = 0; + res[ii] = datareg; + } +} + +// ************************************************* +// Hard TanH Activation +// ************************************************* +template +void hard_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + data_T datareg; + data_T slope = (data_T)0.2; + data_T shift = (data_T)0.5; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + auto sigmoid = CONFIG_T::slope * data[ii] + CONFIG_T::shift; + if (sigmoid > 1) + datareg = 1; + else if (sigmoid < 0) + datareg = 0; + res[ii] = 2 * sigmoid - 1; + } +} + +// ************************************************* +// Leaky RELU Activation +// ************************************************* +template +void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + data_T datareg; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + if (datareg > 0) + res[ii] = datareg; + else + res[ii] = alpha * datareg; + } +} + +// ************************************************* +// Thresholded RELU Activation +// ************************************************* +template +void thresholded_relu(data_T data[CONFIG_T::n_in], data_T theta, res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + data_T datareg; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + if (datareg > theta) + res[ii] = datareg; + else + res[ii] = 0; + } +} + +// ************************************************* +// Softplus Activation +// ************************************************* +inline float softplus_fcn_float(float input) { return std::log(std::exp(input) + 1.); } + +template void init_softplus_table(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "softplus_table%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_softplus_table()\n"); +#endif + // Default softplus function: + // result = log(exp(x) + 1) + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range -8 to +8) + float in_val = 2 * 8.0 * (ii - float(N_TABLE) / 2.0) / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::table_t real_val = softplus_fcn_float(in_val); + // std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << std::endl; + table_out[ii] = real_val; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", softplus_fcn_float(in_val)); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // softplus(%32.31f)", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template +void softplus(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t softplus_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t softplus_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_softplus_table(softplus_table); + initialized = true; + } + + //#pragma HLS PIPELINE + + // Index into the lookup table based on data + int data_round; + int index; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + data_round = data[ii].to_double() * (int)CONFIG_T::table_size / 16; + index = data_round + 8 * (int)CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + res[ii] = (res_T)softplus_table[index]; + } +} + +#else +template +void ac_softplus_pwl_wrapper(const ac_fixed(&input), ac_fixed(&output)) { + ac_fixed tmp; + ac_math::ac_softplus_pwl(input, tmp); + output = tmp; +} + +template +void softplus(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + ac_softplus_pwl_wrapper(data[ii], res[ii]); + } +} + +#endif + +// ************************************************* +// Softsign Activation +// ************************************************* +inline float softsign_fcn_float(float input) { return input / (std::abs(input) + 1.); } + +template void init_softsign_table(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "softsign_table%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_softsign_table()\n"); +#endif + // Default softsign function: + // result = x / (abs(x) + 1) + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range -8 to +8) + float in_val = 2 * 8.0 * (ii - float(N_TABLE) / 2.0) / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::table_t real_val = softsign_fcn_float(in_val); + // std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << std::endl; + table_out[ii] = real_val; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", softsign_fcn_float(in_val)); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // softsign(%32.31f)", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template +void softsign(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t softsign_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t softsign_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_softsign_table(softsign_table); + initialized = true; + } + + //#pragma HLS PIPELINE + + // Index into the lookup table based on data + int data_round; + int index; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + data_round = data[ii].to_double() * (int)CONFIG_T::table_size / 16; + index = data_round + 8 * (int)CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + res[ii] = (res_T)softsign_table[index]; + } +} + +#else + +template +void softsign(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + // res[ii] = ac_math::ac_softsign_pwl(data[ii]); + ac_math::ac_softsign_pwl(data[ii], res[ii]); + } +} + +#endif + +// ************************************************* +// ELU Activation +// ************************************************* +inline float elu_fcn_float(float input) { return std::exp(input) - 1.; } + +template void init_elu_table(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "elu_table%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_elu_table()\n"); +#endif + // Default ELU function: + // result = alpha * (e^(x) - 1) + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range -8 to 0) + float in_val = -8.0 * ii / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::table_t real_val = elu_fcn_float(in_val); + // std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << std::endl; + table_out[ii] = real_val; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", elu_fcn_float(in_val)); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // elu(%32.31f)", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template +void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t elu_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t elu_table[CONFIG_T::table_size]; +#endif + + if (!initialized) { + init_elu_table(elu_table); + initialized = true; + } + + //#pragma HLS PIPELINE + + data_T datareg; + // Index into the lookup table based on data + int index; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + if (datareg >= 0) { + res[ii] = datareg; + } else { + index = datareg.to_double() * (int)CONFIG_T::table_size / -8; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + res[ii] = alpha * elu_table[index]; + } + } +} + +#else + +template +void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) { + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + ac_math::ac_elu_pwl(data[ii], res[ii], alpha); + } +} + +#endif + +// ************************************************* +// SELU Activation +// ************************************************* +inline float selu_fcn_float(float input) { + return 1.0507009873554804934193349852946 * (1.6732632423543772848170429916717 * (std::exp(input) - 1.)); +} + +template void init_selu_table(typename CONFIG_T::table_t table_out[N_TABLE]) { +#ifdef BUILD_TABLE_FILE + char filename[1024]; + sprintf(filename, "selu_table%d.tab", N_TABLE); + FILE *f = fopen(filename, "w"); + fprintf(f, "// init_selu_table()\n"); +#endif + // Default SELU function: + // result = 1.05 * (1.673 * (e^(x) - 1)) + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (signed 8-bit, range -8 to 0) + float in_val = -8.0 * ii / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::table_t real_val = selu_fcn_float(in_val); + // std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << std::endl; + table_out[ii] = real_val; +#ifdef BUILD_TABLE_FILE + fprintf(f, "%32.31f", selu_fcn_float(in_val)); + if (ii < N_TABLE - 1) + fprintf(f, ","); + fprintf(f, " // selu(%32.31f)", in_val); + fprintf(f, "\n"); +#endif + } +#ifdef BUILD_TABLE_FILE + fclose(f); +#endif +} + +#ifndef USE_AC_MATH + +template void selu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t selu_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t selu_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_selu_table(selu_table); + initialized = true; + } + + //#pragma HLS PIPELINE + + data_T datareg; + // Index into the lookup table based on data + int index; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + if (datareg >= 0) { + res[ii] = res_T(1.0507009873554804934193349852946) * datareg; + } else { + index = datareg.to_double() * (int)CONFIG_T::table_size / -8; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + res[ii] = selu_table[index]; + } + } +} + +#else + +template void selu(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + res[ii] = ac_math::ac_selu_pwl(data[ii]); + } +} + +#endif + +// ************************************************* +// PReLU Activation +// ************************************************* +template +void prelu(data_T data[CONFIG_T::n_in], data_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + data_T datareg; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + if (datareg > 0) + res[ii] = datareg; + else + res[ii] = alpha[ii] * datareg; + } +} + +// ************************************************* +// Binary TanH Activation +// ************************************************* +template +void binary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + + data_T datareg; + res_T cache; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + if (datareg > 0) + cache = 1; + else + cache = -1; + + res[ii] = (res_T)cache; + } +} + +// ************************************************* +// Ternary TanH Activation +// ************************************************* +template +void ternary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + + //#pragma HLS PIPELINE + + data_T datareg; + res_T cache; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = 2 * data[ii]; + if (datareg > 1) + cache = 1; + else if (datareg > -1 && datareg <= 1) + cache = 0; + else + cache = -1; + + res[ii] = (res_T)cache; + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h new file mode 100644 index 0000000000..509560bc2b --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h @@ -0,0 +1,922 @@ + +// Change History: +// 2022-06-30 dgburnette - Cleaned up code to separate AC Math from LUT code. +// Activation functions not implemented in AC Math will assert. +// 2022-06-28 dgburnette - Replaced AP Types with AC Datatypes. + +#ifndef NNET_ACTIVATION_STREAM_H_ +#define NNET_ACTIVATION_STREAM_H_ + +#include "ac_channel.h" +#include "ac_fixed.h" +#include "nnet_activation.h" +#include "nnet_common.h" +#include "nnet_stream.h" +#include "nnet_types.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nnet { + +// ************************************************* +// LINEAR Activation +// ************************************************* +// Adding this to work around problem with Catapult and SR model where the output channel appears to be inout +template void linear(ac_channel &data, ac_channel &res) { +LinearActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + LinearPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + out_data[j] = in_data[j]; + } + + res.write(out_data); + } +} + +// ************************************************* +// RELU Activation +// ************************************************* +template void relu(ac_channel &data, ac_channel &res) { +ReLUActLoop: + for (unsigned int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ReLUPackLoop: + for (unsigned int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL +#ifndef USE_AC_MATH + if (in_data[j] > 0) + out_data[j] = in_data[j]; + else + out_data[j] = 0; +#else + ac_math::ac_relu(in_data[j], out_data[j]); +#endif + } + + res.write(out_data); + } +} + +// ************************************************* +// Sigmoid Activation +// ************************************************* +#ifndef USE_AC_MATH + +template void sigmoid(ac_channel &data, ac_channel &res) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t sigmoid_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t sigmoid_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_sigmoid_table(sigmoid_table); + initialized = true; + } + +SigmoidActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + SigmoidPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + int data_round = in_data[j].to_double() * (int)CONFIG_T::table_size / 16; + int index = data_round + 8 * (int)CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + else if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + out_data[j] = sigmoid_table[index]; + } + + res.write(out_data); + } +} + +#else + +template void sigmoid(ac_channel &data, ac_channel &res) { +SigmoidActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + data_T in_data = data.read(); + res_T out_data; + SigmoidPackLoop: + for (int j = 0; j < res_T::size; j++) { + // ac_math::ac_sigmoid_pwl(in_data[j], out_data[j]); + ac_sigmoid_pwl_wrapper(in_data[j], out_data[j]); + } + res.write(out_data); + } +} + +#endif + +// ************************************************* +// Softmax Activation +// ************************************************* + +#ifndef USE_AC_MATH + +template +void softmax_latency(ac_channel &data, ac_channel &res) { + // Initialize the lookup tables +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + +#endif + if (!initialized) { + // Note we are exponentiating the inputs, which have type data_T + init_exp_table(exp_table); + // Note we are inverting the exponentials, which have type exp_table_t + init_invert_table(invert_table); + initialized = true; + } + + constexpr unsigned multiplier_limit = DIV_ROUNDUP(data_T::size, CONFIG_T::reuse_factor); + constexpr unsigned ii = data_T::size / multiplier_limit; + (void)ii; + + // Calculate all the e^x's + typename CONFIG_T::exp_table_t exp_res[data_T::size]; + //#pragma HLS array_partition variable=exp_res complete + typename CONFIG_T::exp_table_t exp_sum(0); + +SoftmaxExpLoop: + for (unsigned i = 0; i < CONFIG_T::n_in / data_T::size; i++) { + //#pragma HLS PIPELINE II=ii + + data_T in_pack = data.read(); + SoftmaxExpPackLoop: + for (unsigned j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + unsigned x = softmax_idx_from_real_val(in_pack[j]); + exp_res[j] = exp_table[x]; + } + + // Explicitly sum the results with an adder tree. + // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing + Op_add op_add; + exp_sum = + reduce>(exp_res, op_add); + + typename CONFIG_T::inv_table_t inv_exp_sum = + invert_table[softmax_idx_from_real_val(exp_sum)]; + + res_T out_pack; + //#pragma HLS DATA_PACK variable=out_pack + SoftmaxInvPackLoop: + for (unsigned j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + //#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + out_pack[j] = exp_res[j] * inv_exp_sum; + } + res.write(out_pack); + } +} + +template +void softmax_stable(ac_channel &data, ac_channel &res) { + // Initialize the lookup tables +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + +#endif + if (!initialized) { + // Note we are exponentiating the inputs, which have type data_T + init_exp_table(exp_table); + // Note we are inverting the exponentials, which have type exp_table_t + init_invert_table(invert_table); + initialized = true; + } + + constexpr unsigned multiplier_limit = DIV_ROUNDUP(data_T::size, CONFIG_T::reuse_factor); + constexpr unsigned ii = data_T::size / multiplier_limit; + (void)ii; + + typename data_T::value_type data_array[data_T::size]; + //#pragma HLS ARRAY_PARTITION variable=data_array complete + + if constexpr (ii == 1) { + } + if constexpr (ii != 1) { + // future enhancement for Catapult + } +SoftmaxArrayLoop: + for (unsigned i = 0; i < CONFIG_T::n_in / data_T::size; i++) { + //#pragma HLS PIPELINE II=ii + + data_T in_pack = data.read(); + SoftmaxArrayPackLoop: + for (unsigned j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + data_array[j] = in_pack[j]; + } + + // Find the max and compute all delta(x_i, x_max) + Op_max op_max; + typename data_T::value_type x_max = + reduce>(data_array, op_max); + + // For the diffs, use the same type as the input but force rounding and saturation + ac_fixed d_xi_xmax[data_T::size]; + for (unsigned j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + d_xi_xmax[j] = data_array[j] - x_max; + } + + // Calculate all the e^x's + typename CONFIG_T::exp_table_t exp_res[data_T::size]; + //#pragma HLS ARRAY_PARTITION variable=exp_res complete + typename CONFIG_T::exp_table_t exp_sum(0); + for (unsigned j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + unsigned x = softmax_idx_from_real_val(d_xi_xmax[j]); + exp_res[j] = exp_table[x]; + } + + // Explicitly sum the results with an adder tree. + // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing + Op_add op_add; + exp_sum = + reduce>(exp_res, op_add); + + typename CONFIG_T::inv_table_t inv_exp_sum = + invert_table[softmax_idx_from_real_val(exp_sum)]; + + res_T out_pack; + //#pragma HLS DATA_PACK variable=out_pack + SoftmaxInvPackLoop: + for (unsigned j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + //#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + out_pack[j] = exp_res[j] * inv_exp_sum; + } + res.write(out_pack); + } +} + +template +void softmax_legacy(ac_channel &data, ac_channel &res) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; + typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; + static typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_exp_table_legacy(exp_table); + init_invert_table_legacy(invert_table); + initialized = true; + } + + // Index into the lookup table based on data for exponentials + typename CONFIG_T::table_t exp_res[data_T::size]; + typename CONFIG_T::table_t exp_diff_res; + typename data_T::value_type data_cache[data_T::size]; + +SoftmaxInitLoop: + for (unsigned s = 0; s < CONFIG_T::n_in / data_T::size; s++) { + //#pragma HLS PIPELINE + data_T in_pack = data.read(); + SoftmaxInitPackLoop: + for (unsigned j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + data_cache[j] = in_pack[j]; + exp_res[j] = 0; + } + + SoftmaxExpLoop: + for (int i = 0; i < data_T::size; i++) { + //#pragma HLS UNROLL + SoftmaxExpInner: + for (int j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + + if (i == j) { + exp_diff_res = 1; + } else { + int data_round = + (data_cache[j].to_double() - data_cache[i].to_double()) * (int)CONFIG_T::table_size / 16; + int index = data_round + 8 * (int)CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + exp_diff_res = exp_table[index]; + } + + exp_res[i] += exp_diff_res; + } + } + + res_T out_pack; + //#pragma HLS DATA_PACK variable=out_pack + SoftmaxInvPackLoop: + for (unsigned j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + + int exp_res_index = exp_res[j].to_double() * (int)CONFIG_T::table_size / 64; + if (exp_res_index < 0) + exp_res_index = 0; + if (exp_res_index > CONFIG_T::table_size - 1) + exp_res_index = (int)CONFIG_T::table_size - 1; + + out_pack[j] = (typename res_T::value_type)invert_table[exp_res_index]; + } + res.write(out_pack); + } +} + +template void softmax(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::axis == -1); + + switch (CONFIG_T::implementation) { + case softmax_implementation::latency: + softmax_latency(data, res); + break; + case softmax_implementation::stable: + softmax_stable(data, res); + break; + case softmax_implementation::legacy: + softmax_legacy(data, res); + break; + } +} + +#else + +template void softmax(ac_channel &data, ac_channel &res) { + typename data_T::value_type data_cache[data_T::size]; + typename res_T::value_type res_cache[res_T::size]; +SoftmaxInitLoop: + for (unsigned s = 0; s < CONFIG_T::n_in / data_T::size; s++) { + data_T in_pack = data.read(); + + SoftmaxInitPackLoop: + for (unsigned j = 0; j < data_T::size; j++) { + data_cache[j] = in_pack[j]; + } + + res_T out_pack; + // ac_math::ac_softmax_pwl(data_cache,res_cache); + ac_softmax_pwl_wrapper(data_cache, res_cache); + + SoftmaxResPackLoop: + for (unsigned j = 0; j < res_T::size; j++) { + out_pack[j] = res_cache[j]; + } + + res.write(out_pack); + } +} + +#endif + +// ************************************************* +// TanH Activation +// ************************************************* + +#ifndef USE_AC_MATH + +template void tanh(ac_channel &data, ac_channel &res) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t tanh_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t tanh_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_tanh_table(tanh_table); + initialized = true; + } + +TanHActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + TanHPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + int data_round = in_data[j].to_double() * (int)CONFIG_T::table_size / 8; + int index = data_round + 4 * (int)CONFIG_T::table_size / 8; + if (index < 0) + index = 0; + else if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + out_data[j] = tanh_table[index]; + } + + res.write(out_data); + } +} + +#else + +template void tanh(ac_channel &data, ac_channel &res) { +TanHActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + + data_T in_data = data.read(); + res_T out_data; + TanHPackLoop: + for (int j = 0; j < res_T::size; j++) { + // int data_round = in_data[j]*CONFIG_T::table_size/8; + ac_math::ac_tanh_pwl(in_data[j], out_data[j]); + } + res.write(out_data); + } +} + +#endif + +// ************************************************* +// Hard sigmoid Activation +// ************************************************* + +template void hard_sigmoid(ac_channel &data, ac_channel &res) { + typename data_T::value_type slope = (typename data_T::value_type)0.2; + typename data_T::value_type shift = (typename data_T::value_type)0.5; + +HardSigmoidActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + HardSigmoidPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + typename data_T::value_type datareg = slope * in_data[j] + shift; + if (datareg > 1) + datareg = 1; + else if (datareg < 0) + datareg = 0; + out_data[j] = datareg; + } + + res.write(out_data); + } +} + +// ************************************************* +// Hard TanH Activation +// ************************************************* + +template void hard_tanh(ac_channel &data, ac_channel &res) { + // typename data_T::value_type slope = (typename data_T::value_type) 0.2; + // typename data_T::value_type shift = (typename data_T::value_type) 0.5; + +HardTanhActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + // PRAGMA_DATA_PACK(out_data) + + HardTanhPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + auto sigmoid = CONFIG_T::slope * in_data[j] + CONFIG_T::shift; + if (sigmoid > 1) + sigmoid = 1; + else if (sigmoid < 0) + sigmoid = 0; + out_data[j] = 2 * sigmoid - 1; + } + + res.write(out_data); + } +} + +// ************************************************* +// Leaky RELU Activation +// ************************************************* +template +void leaky_relu(ac_channel &data, typename data_T::value_type alpha, ac_channel &res) { +LeakyReLUActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + LeakyReLUPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + if (in_data[j] > 0) + out_data[j] = in_data[j]; + else + out_data[j] = alpha * in_data[j]; + } + res.write(out_data); + } +} + +// ************************************************* +// Thresholded RELU Activation +// ************************************************* + +template +void thresholded_relu(ac_channel &data, typename data_T::value_type theta, ac_channel &res) { +ThresholdedReLUActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ThresholdedReLUPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + if (in_data[j] > theta) + out_data[j] = in_data[j]; + else + out_data[j] = 0; + } + + res.write(out_data); + } +} + +// ************************************************* +// Softplus Activation +// ************************************************* + +#ifndef USE_AC_MATH + +template void softplus(ac_channel &data, ac_channel &res) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t softplus_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t softplus_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_softplus_table(softplus_table); + initialized = true; + } + +SoftplusActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + SoftplusPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + int data_round = in_data[j].to_double() * (int)CONFIG_T::table_size / 16; + int index = data_round + 8 * (int)CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + else if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + out_data[j] = softplus_table[index]; + } + res.write(out_data); + } +} + +#else + +template void softplus(ac_channel &data, ac_channel &res) { +SoftplusActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + data_T in_data = data.read(); + res_T out_data; + SoftplusPackLoop: + for (int j = 0; j < res_T::size; j++) { + ac_softplus_pwl_wrapper(in_data[j], out_data[j]); + } + res.write(out_data); + } +} + +#endif + +// ************************************************* +// Softsign Activation +// ************************************************* + +#ifndef USE_AC_MATH + +template void softsign(ac_channel &data, ac_channel &res) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t softsign_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t softsign_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_softsign_table(softsign_table); + initialized = true; + } + +SoftsignActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + SoftsignPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + int data_round = in_data[j].to_double() * (int)CONFIG_T::table_size / 16; + int index = data_round + 8 * (int)CONFIG_T::table_size / 16; + if (index < 0) + index = 0; + else if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + out_data[j] = softsign_table[index]; + } + res.write(out_data); + } +} + +#else + +template void softsign(ac_channel &data, ac_channel &res) { +SoftsignActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + data_T in_data = data.read(); + res_T out_data; + SoftsignPackLoop: + for (int j = 0; j < res_T::size; j++) { + ac_math::ac_softsign_pwl(in_data[j], out_data[j]); + } + res.write(out_data); + } +} + +#endif + +// ************************************************* +// ELU Activation +// ************************************************* + +#ifndef USE_AC_MATH + +template +void elu(ac_channel &data, typename data_T::value_type alpha, ac_channel &res) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t elu_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t elu_table[CONFIG_T::table_size]; +#endif + + if (!initialized) { + init_elu_table(elu_table); + initialized = true; + } + +EluActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + EluPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + + typename data_T::value_type datareg = in_data[j]; + if (datareg >= 0) { + out_data[j] = datareg; + } else { + int index = (int)datareg.to_double() * (int)CONFIG_T::table_size / -8; + if (index > CONFIG_T::table_size - 1) + index = CONFIG_T::table_size - 1; + out_data[j] = alpha * elu_table[index]; + } + } + res.write(out_data); + } +} + +#else +template +void elu(ac_channel &data, typename data_T::value_type alpha, ac_channel &res) { +EluActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + data_T in_data = data.read(); + res_T out_data; + EluPackLoop: + for (int j = 0; j < res_T::size; j++) { + ac_math::ac_elu_pwl(in_data[j], out_data[j], alpha); + } + res.write(out_data); + } +} + +#endif + +// ************************************************* +// SELU Activation +// ************************************************* + +#ifndef USE_AC_MATH + +template void selu(ac_channel &data, ac_channel &res) { + // Initialize the lookup table +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::table_t selu_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::table_t selu_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_selu_table(selu_table); + initialized = true; + } + +SeluActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + SeluPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + + typename data_T::value_type datareg = in_data[j]; + if (datareg >= 0) { + out_data[j] = (typename data_T::value_type)1.0507009873554804934193349852946 * datareg; + } else { + int index = (int)datareg.to_double() * (int)CONFIG_T::table_size / -8; + if (index > CONFIG_T::table_size - 1) + index = (int)CONFIG_T::table_size - 1; + out_data[j] = selu_table[index]; + } + } + res.write(out_data); + } +} + +#else + +template void selu(ac_channel &data, ac_channel &res) { +SeluActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + data_T in_data = data.read(); + res_T out_data; + SeluPackLoop: + for (int j = 0; j < res_T::size; j++) { + ac_math::ac_selu_pwl(in_data[j], out_data[j]); + } + res.write(out_data); + } +} + +#endif + +// ************************************************* +// PReLU Activation +// ************************************************* +template +void prelu(ac_channel &data, typename data_T::value_type alpha[CONFIG_T::n_in], ac_channel &res) { +PReLUActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + PReLUPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + if (in_data[j] > 0) + out_data[j] = in_data[j]; + else + out_data[j] = alpha[i * res_T::size + j] * in_data[j]; + } + res.write(out_data); + } +} + +// ************************************************* +// Binary TanH Activation +// ************************************************* +template void binary_tanh(ac_channel &data, ac_channel &res) { +PReLUActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + PReLUPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + if (in_data[j] > 0) + out_data[j] = (typename res_T::value_type)1; + else + out_data[j] = (typename res_T::value_type) - 1; + } + res.write(out_data); + } +} + +// ************************************************* +// Ternary TanH Activation +// ************************************************* +template void ternary_tanh(ac_channel &data, ac_channel &res) { +PReLUActLoop: + for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + PReLUPackLoop: + for (int j = 0; j < res_T::size; j++) { + //#pragma HLS UNROLL + if (in_data[j] > 1) + out_data[j] = (typename res_T::value_type)1; + else if (in_data[j] <= -1) + out_data[j] = (typename res_T::value_type) - 1; + else + out_data[j] = (typename res_T::value_type)0; + } + res.write(out_data); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_array.h b/hls4ml/templates/catapult/nnet_utils/nnet_array.h new file mode 100755 index 0000000000..cd3b73cf73 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_array.h @@ -0,0 +1,52 @@ +#ifndef NNET_ARRAY_H_ +#define NNET_ARRAY_H_ + +#include + +namespace nnet { + +struct transpose_config { + static const unsigned height = 10; + static const unsigned width = 10; + static const unsigned depth = 10; + static constexpr unsigned perm[3] = {2, 0, 1}; +}; + +template +void transpose_2d(data_T data[CONFIG_T::height * CONFIG_T::width], res_T data_t[CONFIG_T::height * CONFIG_T::width]) { + //#pragma HLS PIPELINE + + for (int i = 0; i < CONFIG_T::height; i++) { + for (int j = 0; j < CONFIG_T::width; j++) { + data_t[j * CONFIG_T::height + i] = data[i * CONFIG_T::width + j]; + } + } +} + +template +void transpose_3d(data_T data[CONFIG_T::depth * CONFIG_T::height * CONFIG_T::width], + res_T data_t[CONFIG_T::depth * CONFIG_T::height * CONFIG_T::width]) { + unsigned dims[3] = {CONFIG_T::depth, CONFIG_T::height, CONFIG_T::width}; + unsigned dims_t[3]; + dims_t[0] = dims[CONFIG_T::perm[0]]; + dims_t[1] = dims[CONFIG_T::perm[1]]; + dims_t[2] = dims[CONFIG_T::perm[2]]; + + int idx[3] = {0}, idx_t[3] = {0}; + for (idx[0] = 0; idx[0] < dims[0]; idx[0]++) { + for (idx[1] = 0; idx[1] < dims[1]; idx[1]++) { + for (idx[2] = 0; idx[2] < dims[2]; idx[2]++) { + idx_t[0] = idx[CONFIG_T::perm[0]]; + idx_t[1] = idx[CONFIG_T::perm[1]]; + idx_t[2] = idx[CONFIG_T::perm[2]]; + + data_t[idx_t[0] * dims_t[1] * dims_t[2] + idx_t[1] * dims_t[2] + idx_t[2]] = + data[idx[0] * dims[1] * dims[2] + idx[1] * dims[2] + idx[2]]; + } + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm.h b/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm.h new file mode 100644 index 0000000000..1db18043ec --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm.h @@ -0,0 +1,127 @@ +#ifndef NNET_BATCHNORM_H_ +#define NNET_BATCHNORM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_dense.h" +#include + +namespace nnet { + +struct batchnorm_config { + // Internal data type definitions + typedef float bias_t; + typedef float scale_t; + + // Layer Sizes + static const unsigned n_in = 10; + static const int n_filt = -1; + static const unsigned n_scale_bias = 10; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; + // partitioning arrays cyclically to go with roll factors? + template using product = nnet::product::mult; +}; + +template +void normalize(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in], + typename CONFIG_T::scale_t scale[CONFIG_T::n_scale_bias], + typename CONFIG_T::bias_t bias[CONFIG_T::n_scale_bias]) { + data_T cache; + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + //#pragma HLS function_instantiate variable=scale,bias + + // For parallel inputs: + // - completely partition arrays -- target fabric + // - if we have an unroll factor, limit number of multipliers + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + + // #pragma HLS ARRAY_PARTITION variable=weights complete // remove this line for now, it breaks compression sometimes + //#pragma HLS ARRAY_PARTITION variable=scale complete + //#pragma HLS ARRAY_PARTITION variable=bias complete + + int multiplier_limit = ceil(float(CONFIG_T::n_in) / float(CONFIG_T::reuse_factor)); + CONFIG_T::template product::limit(multiplier_limit); + + // Calcuate result +Result: + for (int ires = 0; ires < CONFIG_T::n_in; ires++) { + if (CONFIG_T::n_filt == -1) { + res[ires] = CONFIG_T::template product::product(data[ires], scale[ires]) + + bias[ires]; + } else { + int norm_index = ires % CONFIG_T::n_filt; + res[ires] = + CONFIG_T::template product::product(data[ires], scale[norm_index]) + + bias[norm_index]; + } + } +} + +// **************************************************** +// Merged Batch Normalization and Quantized Tanh +// **************************************************** +struct batchnorm_quantized_tanh_config { + // Layer Sizes + static const unsigned n_in = 10; + static const int n_filt = -1; + static const unsigned n_scale_bias = 10; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const unsigned n_zeros = 0; +}; + +template +void normalize_binary_tanh(data_T data[CONFIG_T::n_in], ac_int<1, false> res[CONFIG_T::n_in], + data_T threshold[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + //#pragma HLS ARRAY_PARTITION variable=res complete + + data_T datareg; + ac_int<1, false> cache; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt; + if (datareg >= threshold[norm_index]) + cache = 1; + else + cache = 0; + + res[ii] = cache; + } +} + +template +void normalize_ternary_tanh(data_T data[CONFIG_T::n_in], ac_int<2, true> res[CONFIG_T::n_in], + data_T threshold_hi[CONFIG_T::n_in], data_T threshold_lo[CONFIG_T::n_in]) { + //#pragma HLS PIPELINE + //#pragma HLS ARRAY_PARTITION variable=res complete + + data_T datareg; + ac_int<2, true> cache; + for (int ii = 0; ii < CONFIG_T::n_in; ii++) { + datareg = data[ii]; + int norm_index = CONFIG_T::n_filt == -1 ? ii : ii % CONFIG_T::n_filt; + if (datareg > threshold_hi[norm_index]) + cache = 1; + else if (datareg <= threshold_lo[norm_index]) + cache = -1; + else + cache = 0; + + res[ii] = cache; + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h new file mode 100644 index 0000000000..48085f82dc --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h @@ -0,0 +1,113 @@ + +#ifndef NNET_BATCHNORM_STREAM_H_ +#define NNET_BATCHNORM_STREAM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_mult.h" +#include "nnet_types.h" + +namespace nnet { + +// **************************************************** +// Streaming Batch Normalization +// **************************************************** + +template +void normalize(ac_channel &data, ac_channel &res, typename CONFIG_T::scale_t scale[CONFIG_T::n_scale_bias], + typename CONFIG_T::bias_t bias[CONFIG_T::n_scale_bias]) { + //#pragma HLS ARRAY_PARTITION variable=scale complete + //#pragma HLS ARRAY_PARTITION variable=bias complete + + constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); + constexpr unsigned ii = CONFIG_T::n_in / multiplier_limit; + (void)ii; + CONFIG_T::template product::limit(multiplier_limit); + +BatchNormLoop: + for (unsigned int i = 0; i < CONFIG_T::n_in / data_T::size; i++) { + //#pragma HLS PIPELINE II=ii + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + BatchNormpack: + for (unsigned int j = 0; j < data_T::size; j++) { + // #pragma HLS UNROLL + int norm_index; + if (CONFIG_T::n_filt == -1) { + norm_index = i * data_T::size + j; + } else { + norm_index = j % CONFIG_T::n_filt; + } + out_data[j] = CONFIG_T::template product::product( + in_data[j], scale[norm_index]) + + bias[norm_index]; + } + + res.write(out_data); + } +} + +// **************************************************** +// Merged Batch Normalization and Quantized Tanh +// **************************************************** +template +void normalize_binary_tanh(ac_channel &data, ac_channel, CONFIG_T::n_in>> &res, + typename data_T::value_type threshold[CONFIG_T::n_in]) { + //#pragma HLS ARRAY_PARTITION variable=threshold complete + +BinaryNormLoop: + for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + nnet::array, CONFIG_T::n_scale_bias> out_data; + //#pragma HLS DATA_PACK variable=out_data + + BatchNormPack: + for (int j = 0; j < data_T::size; j++) { + out_data[j] = (in_data[j] > threshold[i * data_T::size + j]) ? 1 : 0; + } + + res.write(out_data); + } +} + +template +void normalize_ternary_tanh(ac_channel &data, ac_channel, CONFIG_T::n_in>> &res, + typename data_T::value_type threshold_hi[CONFIG_T::n_in], + typename data_T::value_type threshold_lo[CONFIG_T::n_in]) { + //#pragma HLS ARRAY_PARTITION variable=threshold_hi complete + //#pragma HLS ARRAY_PARTITION variable=threshold_lo complete + +TernaryNormLoop: + for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + nnet::array, CONFIG_T::n_scale_bias> out_data; + //#pragma HLS DATA_PACK variable=out_data + + BatchNormPack: + for (int j = 0; j < data_T::size; j++) { + + int norm_index = i * data_T::size + j; + + if (in_data[j] > threshold_hi[norm_index]) { + out_data[j] = 1; + } else if (in_data[j] <= threshold_lo[norm_index]) { + out_data[j] = -1; + } else { + out_data[j] = 0; + } + } + + res.write(out_data); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_code_gen.h b/hls4ml/templates/catapult/nnet_utils/nnet_code_gen.h new file mode 100755 index 0000000000..e4db43682e --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_code_gen.h @@ -0,0 +1,32 @@ +#ifndef NNET_INSTR_GEN_H_ +#define NNET_INSTR_GEN_H_ + +#include "nnet_helpers.h" +#include + +namespace nnet { + +template class FillConv1DBuffer { + public: + static void fill_buffer(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_width * CONFIG_T::n_chan], + const unsigned partition) { + // To be implemented in subclasses + } +}; + +template class FillConv2DBuffer { + public: + static void + fill_buffer(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + const unsigned partition) { + // To be implemented in subclasses + } +}; + +// hls4ml insert code + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_common.h b/hls4ml/templates/catapult/nnet_utils/nnet_common.h new file mode 100755 index 0000000000..b9b27209fa --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_common.h @@ -0,0 +1,66 @@ + +#ifndef NNET_COMMON_H_ +#define NNET_COMMON_H_ + +#include "ac_fixed.h" + +// This is a substitute for "ceil(n/(float)d)". +#define DIV_ROUNDUP(n, d) ((n + d - 1) / d) +#define MIN(n, d) (n > d ? d : n) +#define MAX(n, d) (n > d ? n : d) + +namespace nnet { + +// Common type definitions +enum io_type { io_parallel = 0, io_stream }; +enum strategy { latency, resource }; + +/* --- + * Balanced tree reduce implementation. + * For use in scenarios where Vivado cannot expression balance + * Reduces an array of inputs to a single value using the template binary operator 'Op', + * for example summing all elements with Op_add, or finding the maximum with Op_max + * Use only when the input array is fully unrolled. Or, slice out a fully unrolled section + * before applying and accumulate the result over the rolled dimension. + * --- */ +template T reduce(const T *x, Op op) { + static constexpr int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0; + static constexpr int rightN = N - leftN > 0 ? N - leftN : 0; + + if (N == 1) { + return x[0]; + } else if (N == 2) { + return op(x[0], x[1]); + } else { + return op(reduce(x, op), reduce(x + leftN, op)); + } +} + +template class Op_add { + public: + T operator()(T a, T b) { return a + b; } +}; + +template class Op_and { + public: + T operator()(T a, T b) { return a && b; } +}; + +template class Op_or { + public: + T operator()(T a, T b) { return a || b; } +}; + +template class Op_max { + public: + T operator()(T a, T b) { return a >= b ? a : b; } +}; + +template class Op_min { + public: + T operator()(T a, T b) { return a <= b ? a : b; } +}; + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv1d.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d.h new file mode 100755 index 0000000000..98e075d4ab --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d.h @@ -0,0 +1,62 @@ + +#ifndef NNET_CONV1D_H_ +#define NNET_CONV1D_H_ + +#include "nnet_common.h" +#include "nnet_conv1d_latency.h" +#include "nnet_conv1d_resource.h" +#include + +namespace nnet { + +struct conv1d_config { + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + // Convolutional parameters + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_width = 10; + static const unsigned n_chan = 0; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_width; + static const unsigned n_filt = 1; + static const unsigned stride_width = 1; + static const unsigned dilation = 1; + static const unsigned out_width = 10; //(N_IN + PAD_LEFT * PAD_RIGHT - (DILATION * (FILT_WIDTH - 1) + 1)) / STRIDE + 1 + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; // not used yet +}; + +template +void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + if (CONFIG_T::strategy == nnet::latency) { + conv_1d_latency_cl(data, res, weights, biases); + } else { + conv_1d_resource_cl(data, res, weights, biases); + } +} + +template +void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::filt_width == 1); + + if (CONFIG_T::strategy == nnet::latency) { + pointwise_conv_1d_latency_cl(data, res, weights, biases); + } else { + pointwise_conv_1d_resource_cl(data, res, weights, biases); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_latency.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_latency.h new file mode 100755 index 0000000000..0323b1ac4b --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_latency.h @@ -0,0 +1,198 @@ +#ifndef NNET_CONV1D_LATENCY_H_ +#define NNET_CONV1D_LATENCY_H_ + +#include "nnet_common.h" +#include + +namespace nnet { + +// Computes multiplier limit +// This function should not be synthesized into firmware +template +int compute_multiplier_limit( + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt]) { + int n_mult = 0; + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + for (int jj = 0; jj < CONFIG_T::filt_width; jj++) { + + int index_weight = jj * CONFIG_T::n_chan * CONFIG_T::n_filt + cc * CONFIG_T::n_filt + ff; + + if ((ii * CONFIG_T::stride_width + jj) < CONFIG_T::pad_left || + (ii * CONFIG_T::stride_width + jj) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + // padded -- do nothing + continue; + } else { + // need to tune this cut? + if (weights[index_weight] > 1e-20 || weights[index_weight] < -1e-20) { + n_mult++; + } // end if nonzero weight + } // end not padding + } // end loop accross filter + } // end channel loop + } // end filter loop + } // end output loop + + return ceil(float(n_mult) / float(CONFIG_T::reuse_factor)); + +} // end compute_n_mult + +template +void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + + typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_width]; + typename CONFIG_T::accum_t acc[CONFIG_T::out_width][CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=mult complete dim=0 + //#pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + //#pragma HLS function_instantiate variable=weights,biases + + // Parallel mode + //#pragma HLS PIPELINE + //#pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + + // Limit multipliers to control parallelization + const int multiplier_limit = compute_multiplier_limit(weights); +//#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + +// Convolve, saving all multiplication results to accumulate later +ConvOut: + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + ConvFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + ConvChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + ConvMult: + for (int jj = 0; jj < CONFIG_T::filt_width; jj++) { + + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_width + + ff * CONFIG_T::n_chan * CONFIG_T::filt_width + cc * CONFIG_T::filt_width + jj; + int index_weight = jj * CONFIG_T::n_chan * CONFIG_T::n_filt + cc * CONFIG_T::n_filt + ff; + int index_data = (ii * CONFIG_T::stride_width + jj - CONFIG_T::pad_left) * CONFIG_T::n_chan + cc; + + if ((ii * CONFIG_T::stride_width + jj) < CONFIG_T::pad_left || + (ii * CONFIG_T::stride_width + jj) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + mult[index_mult] = 0; + } else { + mult[index_mult] = data[index_data] * weights[index_weight]; + } + } + } // end channel loop + } // end filter loop + } // end output loop + + // Initialize accumulator with input biases + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + acc[ii][ff] = biases[ff]; + } + } + +// Accumulate multiplication result +AccumOut: + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + AccumFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Do "dot product" sum within filter and sum over channels + AccumChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + AccumDot: + for (int jj = 0; jj < CONFIG_T::filt_width; jj++) { + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_width + + ff * CONFIG_T::n_chan * CONFIG_T::filt_width + cc * CONFIG_T::filt_width + jj; + acc[ii][ff] += mult[index_mult]; + } // end dot product loop + } // end channel loop + } // end filter loop + } // end output loop + + // Cast to "res_t" type + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]); + } + } +} + +template +void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::filt_width == 1); + + typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan]; + typename CONFIG_T::accum_t acc[CONFIG_T::out_width][CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=mult complete dim=0 + //#pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + //#pragma HLS function_instantiate variable=weights,biases + + // Parallel mode + //#pragma HLS PIPELINE + //#pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + + // Limit multipliers to control parallelization + const int multiplier_limit = compute_multiplier_limit(weights); +//#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + +// Convolve, saving all multiplication results to accumulate later +ConvOut: + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + ConvFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + ConvChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + int index_weight = cc * CONFIG_T::n_filt + ff; + int index_data = (ii * CONFIG_T::stride_width - CONFIG_T::pad_left) * CONFIG_T::n_chan + cc; + + if ((ii * CONFIG_T::stride_width) < CONFIG_T::pad_left || + (ii * CONFIG_T::stride_width) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + mult[index_mult] = 0; + } else { + mult[index_mult] = data[index_data] * weights[index_weight]; + } + } // end channel loop + } // end filter loop + } // end output loop + + // Initialize accumulator with input biases + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + acc[ii][ff] = biases[ff]; + } + } + +// Accumulate multiplication result +AccumOut: + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + AccumFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Do "dot product" sum within filter and sum over channels + AccumChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + acc[ii][ff] += mult[index_mult]; + } // end channel loop + } // end filter loop + } // end output loop + + // Cast to "res_t" type + for (int ii = 0; ii < CONFIG_T::out_width; ii++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]); + } + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_resource.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_resource.h new file mode 100644 index 0000000000..143a1271ba --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_resource.h @@ -0,0 +1,241 @@ +#ifndef NNET_CONV1D_RESOURCE_H_ +#define NNET_CONV1D_RESOURCE_H_ + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet { + +template +void im2col_1d(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_width]) { + // int index = 0; + for (int channel = CONFIG_T::n_chan; channel--; data += CONFIG_T::in_width) { + //#pragma HLS PIPELINE II=1 rewind + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation; + for (int output_col = CONFIG_T::out_width; output_col; output_col--) { + if (input_col >= 0 && input_col < CONFIG_T::in_width) { + *(data_col++) = data[input_col]; + // data_col[index] = data[input_col]; + } else { + *(data_col++) = 0; + // data_col[index] = 0; + } + // index++; + input_col += CONFIG_T::stride_width; + } + } + } +} + +template +void conv_1d_full(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + data_T data_conv[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_width]; + data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + ////#pragma HLS ARRAY_PARTITION variable=data_conv complete + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + + im2col_1d(data, data_conv); + + for (int i = 0; i < CONFIG_T::out_width; i++) { + for (int j = 0; j < CONFIG_T::filt_width * CONFIG_T::n_chan; j++) { + data_col[j] = data_conv[j * CONFIG_T::out_width + i]; + } + dense_resource(data_col, res_col, weights, biases); + for (int j = 0; j < CONFIG_T::n_filt; j++) { + // res[i * CONFIG_T::n_filt + j] = res_col[j]; + res[j * CONFIG_T::out_width + i] = res_col[j]; // Transposed order + } + } +} + +template +void im2col_1d_cf_idx(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan], const int col) { +ChannelLoop: + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + //#pragma HLS PIPELINE II=1 rewind + KernelLoop: + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation + col * CONFIG_T::stride_width; + if (input_col >= 0 && input_col < CONFIG_T::in_width) { + //*(data_col++) = data[input_col]; + data_col[channel * CONFIG_T::filt_width + kernel_col] = data[channel * CONFIG_T::in_width + input_col]; + } else { + //*(data_col++) = 0; + data_col[channel * CONFIG_T::filt_width + kernel_col] = 0; + } + } + } +} + +template +void im2col_1d_cf(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + data_T data_col[CONFIG_T::n_chan * CONFIG_T::filt_width], const int col) { + int index = 0; +ChannelLoop: + for (int channel = CONFIG_T::n_chan; channel--; data += CONFIG_T::in_width) { + KernelLoop: + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation + col * CONFIG_T::stride_width; + if (input_col >= 0 && input_col < CONFIG_T::in_width) { + //*(data_col++) = data[input_col]; + data_col[index] = data[input_col]; + } else { + //*(data_col++) = 0; + data_col[index] = 0; + } + index++; + } + } +} + +template +void conv_1d_resource_cf(data_T data[CONFIG_T::n_chan * CONFIG_T::in_width], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; + const int nout = CONFIG_T::n_filt; + const int rufactor = CONFIG_T::reuse_factor; + const int block_factor = DIV_ROUNDUP(nin * nout, rufactor); + + ////#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose + /// correctly + ////#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + ////#pragma HLS ARRAY_PARTITION variable=biases complete + + data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + +ColLoop: + for (int i = 0; i < CONFIG_T::out_width; i++) { + //#pragma HLS PIPELINE + im2col_1d_cf(data, data_col, i); + dense_resource(data_col, res_col, weights, biases); + for (int j = 0; j < CONFIG_T::n_filt; j++) { + // res[i * CONFIG_T::n_filt + j] = res_col[j]; + res[j * CONFIG_T::out_width + i] = res_col[j]; // Transposed order + } + } +} + +template +void im2col_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan], const int col) { + int index = 0; +KernelLoop: + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + + ChannelLoop: + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + int index_data = (col * CONFIG_T::stride_width + kernel_col - CONFIG_T::pad_left) * CONFIG_T::n_chan + channel; + + if (index_data >= 0 && index_data < CONFIG_T::in_width * CONFIG_T::n_chan) { + data_col[index] = data[index_data]; + } else { + data_col[index] = 0; + } + index++; + } + } +} + +template +void im2col_1d_pointwise_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], data_T data_col[CONFIG_T::n_chan], + const int col) { + int index = 0; +ChannelLoop: + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + + int index_data = (col * CONFIG_T::stride_width - CONFIG_T::pad_left) * CONFIG_T::n_chan + channel; + + if (index_data >= 0 && index_data < CONFIG_T::in_width * CONFIG_T::n_chan) { + data_col[index] = data[index_data]; + } else { + data_col[index] = 0; + } + index++; + } +} + +template +void conv_1d_resource_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; + const int nout = CONFIG_T::n_filt; + const int rufactor = CONFIG_T::reuse_factor; + const int block_factor = DIV_ROUNDUP(nin * nout, rufactor); + + ////#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose + /// correctly + ////#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + ////#pragma HLS ARRAY_PARTITION variable=biases complete + + data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + +ColLoop: + for (int i = 0; i < CONFIG_T::out_width; i++) { + //#pragma HLS PIPELINE + im2col_1d_cl(data, data_col, i); + dense_resource(data_col, res_col, weights, biases); + for (int j = 0; j < CONFIG_T::n_filt; j++) { + res[i * CONFIG_T::n_filt + j] = res_col[j]; + } + } +} + +template +void pointwise_conv_1d_resource_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::filt_width == 1); + + const int nin = CONFIG_T::n_chan; + const int nout = CONFIG_T::n_filt; + const int rufactor = CONFIG_T::reuse_factor; + const int block_factor = DIV_ROUNDUP(nin * nout, rufactor); + + ////#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose + /// correctly + ////#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + ////#pragma HLS ARRAY_PARTITION variable=biases complete + + data_T data_col[CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + +ColLoop: + for (int i = 0; i < CONFIG_T::out_width; i++) { + //#pragma HLS PIPELINE + im2col_1d_pointwise_cl(data, data_col, i); + dense_resource(data_col, res_col, weights, biases); + for (int j = 0; j < CONFIG_T::n_filt; j++) { + res[i * CONFIG_T::n_filt + j] = res_col[j]; + } + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_stream.h new file mode 100644 index 0000000000..48f6244ce1 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv1d_stream.h @@ -0,0 +1,94 @@ +#ifndef NNET_CONV1D_STREAM_H_ +#define NNET_CONV1D_STREAM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_conv_stream.h" + +namespace nnet { + +template +void compute_scaled_indices_1d(const unsigned w_idx, ac_int *pixel_idx) { + unsigned wp_idx = w_idx * (data_T::size / CONFIG_T::n_chan); + +ComputeIndex: + for (unsigned p = 0; p < data_T::size / CONFIG_T::n_chan; p++) { + // #pragma HLS UNROLL + unsigned sw_idx = + CONFIG_T::template scale_index::scale_index( + wp_idx + p); + pixel_idx[p] = CONFIG_T::pixels[sw_idx]; + } +} + +template +void conv_1d_encoded_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + ac_channel data_window[CONFIG_T::filt_width * CONFIG_T::n_chan]; + // const int win_depth = CONFIG_T::out_width; + // for (unsigned i_out = 0; i_out < CONFIG_T::filt_width * CONFIG_T::n_chan; i_out++) { + // #pragma HLS STREAM variable=data_window[i_out] depth=win_depth + // } + + //#pragma HLS ARRAY_PARTITION variable=CONFIG_T::pixels complete + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + unsigned outputs_ready = 0; + + ac_int pixel_idx[data_T::size / CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=pixel_idx complete + + constexpr int ce_reuse_factor = + CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1); + (void)ce_reuse_factor; +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_chan); i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_scaled_indices_1d(i_iw, pixel_idx); + compute_output_encoded(data.read(), data_window, res, res_pack, outputs_ready, weights, + biases, pixel_idx); + } +} + +template +void conv_1d_buffer_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency); + (void)ce_reuse_factor; +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_output_buffer_1d(data.read(), res, weights, biases); + } +} + +template +void conv_1d_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS inline region + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + conv_1d_buffer_cl(data, res, weights, biases); + break; + case conv_implementation::encoded: + conv_1d_encoded_cl(data, res, weights, biases); + break; + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv2d.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d.h new file mode 100755 index 0000000000..01476a0449 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d.h @@ -0,0 +1,84 @@ + +#ifndef NNET_CONV2D_H_ +#define NNET_CONV2D_H_ + +#include "nnet_common.h" +#include "nnet_conv2d_latency.h" +#include "nnet_conv2d_resource.h" +#include + +namespace nnet { + +struct conv2d_config { + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + // Convolutional parameters + static const unsigned pad_top = 0; + static const unsigned pad_bottom = 0; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_height = 10; + static const unsigned in_width = 10; + static const unsigned n_chan = 1; + static const unsigned filt_height = 1; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = 1; + static const unsigned stride_height = 1; + static const unsigned stride_width = 1; + static const unsigned out_height = 10; + static const unsigned out_width = 10; + static const unsigned dilation_height = 1; + static const unsigned dilation_width = 1; + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; // not used yet +}; + +template +void conv_2d_cf( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + if (CONFIG_T::strategy == nnet::latency) { + conv_2d_latency_cf(data, res, weights, biases); + } else { + conv_2d_resource_cf(data, res, weights, biases); + } +} + +template +void conv_2d_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + if (CONFIG_T::strategy == nnet::latency) { + conv_2d_latency_cl(data, res, weights, biases); + } else { + conv_2d_resource_cl(data, res, weights, biases); + } +} + +template +void pointwise_conv_2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::filt_width == 1); + + if (CONFIG_T::strategy == nnet::latency) { + pointwise_conv_2d_latency_cl(data, res, weights, biases); + } else { + pointwise_conv_2d_resource_cl(data, res, weights, biases); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_latency.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_latency.h new file mode 100644 index 0000000000..29dd8ca633 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_latency.h @@ -0,0 +1,392 @@ +#ifndef NNET_CONV2D_LATENCY_H_ +#define NNET_CONV2D_LATENCY_H_ + +#include "nnet_common.h" +#include + +namespace nnet { + +// Computes multiplier limit +// This function should not be synthesized into firmware +template +int compute_multiplier_limit_conv2d(typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * + CONFIG_T::n_chan * CONFIG_T::n_filt]) { + int n_mult = 0; + + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + for (int fh = 0; fh < CONFIG_T::filt_height; fh++) { + for (int fw = 0; fw < CONFIG_T::filt_width; fw++) { + + int index_weight = fh * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt + + fw * CONFIG_T::n_chan * CONFIG_T::n_filt + cc * CONFIG_T::n_filt + ff; + + if ((oh * CONFIG_T::stride_height + fh) < CONFIG_T::pad_top || + (oh * CONFIG_T::stride_height + fh) >= (CONFIG_T::pad_top + CONFIG_T::in_height) || + (ow * CONFIG_T::stride_width + fw) < CONFIG_T::pad_left || + (ow * CONFIG_T::stride_width + fw) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + // padded - do nothing + continue; + } else { + if (weights[index_weight] > 1e-20 || weights[index_weight] < -1e-20) { + n_mult++; + } + } + + } // end mult loop + } // end channel loop + } // end filter width loop + } // end filter height loop + } // end output width loop + } // end output height loop + + // return ceil(float(n_mult) / float(CONFIG_T::reuse_factor)); + return (n_mult + CONFIG_T::reuse_factor - 1) / CONFIG_T::reuse_factor; + +} // end compute_n_mult + +template +void conv_2d_latency_cf( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + + typename CONFIG_T::accum_t mult[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * + CONFIG_T::filt_height * CONFIG_T::filt_width]; + typename CONFIG_T::accum_t acc[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=mult complete dim=0 + //#pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + //#pragma HLS function_instantiate variable=weights,biases + + // Parallel mode + //#pragma HLS PIPELINE + //#pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + + // Limit multipliers to control parallelization + const int multiplier_limit = compute_multiplier_limit_conv2d(weights); +//#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + +// Convolve, saving all multiplication results to accumulate later +ConvOutHeight: + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + ConvOutWidth: + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + ConvFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + ConvChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + ConvFiltHeight: + for (int fh = 0; fh < CONFIG_T::filt_height; fh++) { + ConvFiltWidth: + for (int fw = 0; fw < CONFIG_T::filt_width; fw++) { + + int index_mult = + oh * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * + CONFIG_T::filt_width + + ow * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + ff * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + cc * CONFIG_T::filt_height * CONFIG_T::filt_width + fh * CONFIG_T::filt_width + fw; + + int index_weight = fh * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt + + fw * CONFIG_T::n_chan * CONFIG_T::n_filt + cc * CONFIG_T::n_filt + ff; + + if ((oh * CONFIG_T::stride_height + fh) < CONFIG_T::pad_top || + (oh * CONFIG_T::stride_height + fh) >= (CONFIG_T::pad_top + CONFIG_T::in_height) || + (ow * CONFIG_T::stride_width + fw) < CONFIG_T::pad_left || + (ow * CONFIG_T::stride_width + fw) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + mult[index_mult] = 0; + } else { + int index_data = + cc * CONFIG_T::in_height * CONFIG_T::in_width + + (oh * CONFIG_T::stride_height + fh - CONFIG_T::pad_top) * CONFIG_T::in_width + + (ow * CONFIG_T::stride_width + fw - CONFIG_T::pad_left); + mult[index_mult] = data[index_data] * weights[index_weight]; + } + + } // end mult loop + } // end channel loop + } // end filter width loop + } // end filter height loop + } // end output width loop + } // end output height loop + + // Initialize accumulator with input biases + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + acc[oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff] = biases[ff]; + } + } + } + +// Accumulate multiplication result +AccumOutHeight: + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + AccumOutWidth: + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + AccumFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Do "dot product" sum within filter and sum over channels + AccumChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + AccumDotHeight: + for (int fh = 0; fh < CONFIG_T::filt_height; fh++) { + AccumDotWidth: + for (int fw = 0; fw < CONFIG_T::filt_width; fw++) { + + int index_mult = + oh * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * + CONFIG_T::filt_width + + ow * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + ff * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + cc * CONFIG_T::filt_height * CONFIG_T::filt_width + fh * CONFIG_T::filt_width + fw; + int index_acc = oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff; + + acc[index_acc] += mult[index_mult]; + + } // end dot product filter width loop + } // end dot product filter height loop + } // end n channel loop + } // end n filter loop + } // end output width loop + } // end output height loop + + // Cast to "res_t" type + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + int res_index = ff * CONFIG_T::out_height * CONFIG_T::out_width + oh * CONFIG_T::out_width + ow; + int acc_index = oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff; + res[res_index] = acc[acc_index]; + } + } + } + +} // end conv2d + +template +void conv_2d_latency_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + + typename CONFIG_T::accum_t mult[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * + CONFIG_T::filt_height * CONFIG_T::filt_width]; + typename CONFIG_T::accum_t acc[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=mult complete dim=0 + //#pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + //#pragma HLS function_instantiate variable=weights,biases + + // Parallel mode + //#pragma HLS PIPELINE + //#pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + + // Limit multipliers to control parallelization + const int multiplier_limit = compute_multiplier_limit_conv2d(weights); +//#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + +// Convolve, saving all multiplication results to accumulate later +ConvOutHeight: + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + ConvOutWidth: + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + ConvFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + ConvChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + ConvFiltHeight: + for (int fh = 0; fh < CONFIG_T::filt_height; fh++) { + ConvFiltWidth: + for (int fw = 0; fw < CONFIG_T::filt_width; fw++) { + + int index_mult = + oh * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * + CONFIG_T::filt_width + + ow * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + ff * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + cc * CONFIG_T::filt_height * CONFIG_T::filt_width + fh * CONFIG_T::filt_width + fw; + + int index_weight = fh * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt + + fw * CONFIG_T::n_chan * CONFIG_T::n_filt + cc * CONFIG_T::n_filt + ff; + + if ((oh * CONFIG_T::stride_height + fh) < CONFIG_T::pad_top || + (oh * CONFIG_T::stride_height + fh) >= (CONFIG_T::pad_top + CONFIG_T::in_height) || + (ow * CONFIG_T::stride_width + fw) < CONFIG_T::pad_left || + (ow * CONFIG_T::stride_width + fw) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + mult[index_mult] = 0; + } else { + int index_data = (oh * CONFIG_T::stride_height + fh - CONFIG_T::pad_top) * + CONFIG_T::in_width * CONFIG_T::n_chan + + (ow * CONFIG_T::stride_width + fw - CONFIG_T::pad_left) * CONFIG_T::n_chan + + cc; + mult[index_mult] = data[index_data] * weights[index_weight]; + } + + } // end mult loop + } // end channel loop + } // end filter width loop + } // end filter height loop + } // end output width loop + } // end output height loop + + // Initialize accumulator with input biases + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + acc[oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff] = biases[ff]; + } + } + } + +// Accumulate multiplication result +AccumOutHeight: + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + AccumOutWidth: + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + AccumFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Do "dot product" sum within filter and sum over channels + AccumChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + AccumDotHeight: + for (int fh = 0; fh < CONFIG_T::filt_height; fh++) { + AccumDotWidth: + for (int fw = 0; fw < CONFIG_T::filt_width; fw++) { + + int index_mult = + oh * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * + CONFIG_T::filt_width + + ow * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + ff * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width + + cc * CONFIG_T::filt_height * CONFIG_T::filt_width + fh * CONFIG_T::filt_width + fw; + int index_acc = oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff; + + acc[index_acc] += mult[index_mult]; + + } // end dot product filter width loop + } // end dot product filter height loop + } // end n channel loop + } // end n filter loop + } // end output width loop + } // end output height loop + + // Cast to "res_t" type + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + int index = oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff; + res[index] = (res_T)(acc[index]); + } + } + } + +} // end conv2d + +template +void pointwise_conv_2d_latency_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + + typename CONFIG_T::accum_t mult[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan]; + typename CONFIG_T::accum_t acc[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=mult complete dim=0 + //#pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + //#pragma HLS function_instantiate variable=weights,biases + + // Parallel mode + //#pragma HLS PIPELINE + //#pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + + // Limit multipliers to control parallelization + const int multiplier_limit = compute_multiplier_limit_conv2d(weights); +//#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + +// Convolve, saving all multiplication results to accumulate later +ConvOutHeight: + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + ConvOutWidth: + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + ConvFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + ConvChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + + int index_mult = oh * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan + + ow * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + + int index_weight = cc * CONFIG_T::n_filt + ff; + + if ((oh * CONFIG_T::stride_height) < CONFIG_T::pad_top || + (oh * CONFIG_T::stride_height) >= (CONFIG_T::pad_top + CONFIG_T::in_height) || + (ow * CONFIG_T::stride_width) < CONFIG_T::pad_left || + (ow * CONFIG_T::stride_width) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) { + mult[index_mult] = 0; + } else { + int index_data = + (oh * CONFIG_T::stride_height - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_chan + + (ow * CONFIG_T::stride_width - CONFIG_T::pad_left) * CONFIG_T::n_chan + cc; + mult[index_mult] = data[index_data] * weights[index_weight]; + } + } + } + } + } + + // Initialize accumulator with input biases + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + acc[oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff] = biases[ff]; + } + } + } + +// Accumulate multiplication result +AccumOutHeight: + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + AccumOutWidth: + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + AccumFilt: + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Do "dot product" sum within filter and sum over channels + AccumChan: + for (int cc = 0; cc < CONFIG_T::n_chan; cc++) { + + int index_mult = oh * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan + + ow * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc; + int index_acc = oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff; + + acc[index_acc] += mult[index_mult]; + } + } + } + } + + // Cast to "res_t" type + for (int oh = 0; oh < CONFIG_T::out_height; oh++) { + for (int ow = 0; ow < CONFIG_T::out_width; ow++) { + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + int index = oh * CONFIG_T::out_width * CONFIG_T::n_filt + ow * CONFIG_T::n_filt + ff; + res[index] = (res_T)(acc[index]); + } + } + } + +} // end conv2d + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_resource.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_resource.h new file mode 100644 index 0000000000..c5e386b5e9 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_resource.h @@ -0,0 +1,275 @@ +#ifndef NNET_CONV2D_RESOURCE_H_ +#define NNET_CONV2D_RESOURCE_H_ + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet { + +template +void im2col_2d(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_height * + CONFIG_T::out_width]) { + const int output_h = (CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom - + (CONFIG_T::dilation_height * (CONFIG_T::filt_height - 1) + 1)) / + CONFIG_T::stride_height + + 1; + const int output_w = (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right - + (CONFIG_T::dilation_width * (CONFIG_T::filt_width - 1) + 1)) / + CONFIG_T::stride_width + + 1; + const int channel_size = CONFIG_T::in_height * CONFIG_T::in_width; + + for (int channel = CONFIG_T::n_chan; channel--; data += channel_size) { + for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height; + for (int output_rows = output_h; output_rows; output_rows--) { + if (input_row < 0 || input_row > CONFIG_T::in_height) { + for (int output_cols = output_w; output_cols; output_cols--) { + *(data_col++) = 0; + } + } else { + int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width; + for (int output_col = output_w; output_col; output_col--) { + if (input_col >= 0 && input_col < CONFIG_T::in_width) { + *(data_col++) = data[input_row * CONFIG_T::in_width + input_col]; + } else { + *(data_col++) = 0; + } + input_col += CONFIG_T::stride_width; + } + } + input_row += CONFIG_T::stride_height; + } + } + } + } +} + +template +void conv_2d_full( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + data_T data_conv[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_height * + CONFIG_T::out_width]; + data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + ////#pragma HLS ARRAY_PARTITION variable=data_conv complete + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + + im2col_2d(data, data_conv); + + for (int i = 0; i < CONFIG_T::out_height * CONFIG_T::out_width; i++) { + for (int j = 0; j < CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan; j++) { + data_col[j] = data[j * CONFIG_T::out_height * CONFIG_T::out_width + i]; + } + dense(data_col, res_col, weights, biases); + for (int j = 0; j < CONFIG_T::n_filt; j++) { + // res[i * CONFIG_T::n_filt + j] = res_col[j]; + res[j * CONFIG_T::out_height * CONFIG_T::out_width + i] = res_col[j]; // Transposed order + } + } +} + +template +void im2col_2d_cf(data_T data[CONFIG_T::n_chan * CONFIG_T::in_height * CONFIG_T::in_width], + data_T data_col[CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width], const int row, + const int col) { + const int channel_size = CONFIG_T::in_height * CONFIG_T::in_width; + int index = 0; + for (int channel = CONFIG_T::n_chan; channel--; data += channel_size) { + for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { + int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height; + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + if (input_row < 0 || input_row > CONFIG_T::in_height) { + data_col[index++] = 0; + } else { + int input_col = + -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width + col * CONFIG_T::stride_width; + if (input_col >= 0 && input_col < CONFIG_T::in_width) { + //*(data_col++) = data[input_row * CONFIG_T::in_width + input_col]; + data_col[index++] = data[input_row * CONFIG_T::in_width + input_col]; + } else { + //*(data_col++) = 0; + data_col[index++] = 0; + } + input_col += CONFIG_T::stride_width; + } + } + input_row += CONFIG_T::stride_height; + } + } +} + +template +void conv_2d_resource_cf( + data_T data[CONFIG_T::n_chan * CONFIG_T::in_height * CONFIG_T::in_width], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; + const int nout = CONFIG_T::n_filt; + const int rufactor = CONFIG_T::reuse_factor; + const int block_factor = DIV_ROUNDUP(nin * nout, rufactor); + + ////#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose + /// correctly + ////#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + ////#pragma HLS ARRAY_PARTITION variable=biases complete + + data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + +HeightLoop: + for (int i = 0; i < CONFIG_T::out_height; i++) { + WidthLoop: + for (int j = 0; j < CONFIG_T::out_width; j++) { + //#pragma HLS PIPELINE + im2col_2d_cf(data, data_col, i, j); + dense(data_col, res_col, weights, biases); + FiltLoop: + for (int k = 0; k < CONFIG_T::n_filt; k++) { + // res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k]; + res[k * CONFIG_T::out_height * CONFIG_T::out_width + i * CONFIG_T::out_width + j] = + res_col[k]; // Transposed order + } + } + } +} + +template +void im2col_2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], const int row, + const int col) { + int index = 0; + for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { + int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height; + for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + if (input_row < 0 || input_row >= CONFIG_T::in_height) { + data_col[index++] = 0; + } else { + int input_col = + -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width + col * CONFIG_T::stride_width; + if (input_col >= 0 && input_col < CONFIG_T::in_width) { + data_col[index++] = + data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan + channel]; + } else { + data_col[index++] = 0; + } + } + } + } + } +} + +template +void im2col_2d_pointwise_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + data_T data_col[CONFIG_T::n_chan], const int row, const int col) { + int index = 0; + int input_row = -CONFIG_T::pad_top + row * CONFIG_T::stride_height; + +ChannelLoop: + for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { + if (input_row < 0 || input_row >= CONFIG_T::in_height) { + data_col[index++] = 0; + } else { + int input_col = -CONFIG_T::pad_left + col * CONFIG_T::stride_width; + if (input_col >= 0 && input_col < CONFIG_T::in_width) { + data_col[index++] = + data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan + channel]; + } else { + data_col[index++] = 0; + } + } + } +} + +template +void conv_2d_resource_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; + const int nout = CONFIG_T::n_filt; + const int rufactor = CONFIG_T::reuse_factor; + const int block_factor = DIV_ROUNDUP(nin * nout, rufactor); + + ////#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose + /// correctly + ////#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + ////#pragma HLS ARRAY_PARTITION variable=biases complete + + data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + +HeightLoop: + for (int i = 0; i < CONFIG_T::out_height; i++) { + WidthLoop: + for (int j = 0; j < CONFIG_T::out_width; j++) { + //#pragma HLS PIPELINE + im2col_2d_cl(data, data_col, i, j); + dense(data_col, res_col, weights, biases); + FiltLoop: + for (int k = 0; k < CONFIG_T::n_filt; k++) { + res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k]; + } + } + } +} + +template +void pointwise_conv_2d_resource_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::filt_height == 1 && CONFIG_T::filt_width == 1); + + const int nin = CONFIG_T::n_chan; + const int nout = CONFIG_T::n_filt; + const int rufactor = CONFIG_T::reuse_factor; + const int block_factor = DIV_ROUNDUP(nin * nout, rufactor); + + ////#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose + /// correctly + ////#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + ////#pragma HLS ARRAY_PARTITION variable=biases complete + + data_T data_col[CONFIG_T::n_chan]; + res_T res_col[CONFIG_T::n_filt]; + + //#pragma HLS ARRAY_PARTITION variable=data_col complete + //#pragma HLS ARRAY_PARTITION variable=res_col complete + +HeightLoop: + for (int i = 0; i < CONFIG_T::out_height; i++) { + WidthLoop: + for (int j = 0; j < CONFIG_T::out_width; j++) { + //#pragma HLS PIPELINE + im2col_2d_pointwise_cl(data, data_col, i, j); + dense(data_col, res_col, weights, biases); + FiltLoop: + for (int k = 0; k < CONFIG_T::n_filt; k++) { + res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k]; + } + } + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_stream.h new file mode 100644 index 0000000000..7e76be12a9 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv2d_stream.h @@ -0,0 +1,117 @@ +#ifndef NNET_CONV2D_STREAM_H_ +#define NNET_CONV2D_STREAM_H_ + +#include "ac_channel.h" +#include "ap_shift_reg.h" +#include "nnet_common.h" +#include "nnet_conv_stream.h" + +namespace nnet { + +template +void compute_scaled_indices_2d(const unsigned h_idx, const unsigned w_idx, + ac_int *pixel_idx) { + const unsigned sh_idx = CONFIG_T::template scale_index_height::scale_index(h_idx); + unsigned wp_idx = w_idx * (data_T::size / CONFIG_T::n_chan); + +ComputeIndex: + for (unsigned p = 0; p < data_T::size / CONFIG_T::n_chan; p++) { + // #pragma HLS UNROLL + + unsigned sw_idx = CONFIG_T::template scale_index_width::scale_index(wp_idx + p); + pixel_idx[p] = CONFIG_T::pixels[sh_idx * CONFIG_T::min_width + sw_idx]; + } +} + +template +void conv_2d_encoded_cl( + ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::filt_height == CONFIG_T::filt_width); + + ac_channel data_window[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + const int win_depth = CONFIG_T::filt_height * CONFIG_T::out_width; + for (unsigned i_out = 0; i_out < CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan; i_out++) { + //#pragma HLS STREAM variable=data_window[i_out] depth=win_depth + } + + //#pragma HLS ARRAY_PARTITION variable=CONFIG_T::pixels complete + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + unsigned outputs_ready = 0; + + ac_int pixel_idx[data_T::size / CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=pixel_idx complete + + constexpr int ce_reuse_factor = + CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1); + (void)ce_reuse_factor; +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_chan); i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_scaled_indices_2d(i_ih, i_iw, pixel_idx); + compute_output_encoded(data.read(), data_window, res, res_pack, outputs_ready, weights, + biases, pixel_idx); + } + } +} + +// Line Buffer +template +void conv_2d_buffer_cl( + ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + static ap_shift_reg line_buffer[MAX(CONFIG_T::filt_height - 1, 1)] + [CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency); + (void)ce_reuse_factor; +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + if (CONFIG_T::filt_height > 1) { + compute_output_buffer_2d(data.read(), line_buffer, res, weights, biases); + } else { + compute_output_buffer_1d(data.read(), res, weights, biases); + } + } + } +} + +template +void conv_2d_cl( + ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS inline region + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + conv_2d_buffer_cl(data, res, weights, biases); + break; + case conv_implementation::encoded: + conv_2d_encoded_cl(data, res, weights, biases); + break; + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_conv_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_conv_stream.h new file mode 100644 index 0000000000..4d92cbf69f --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_conv_stream.h @@ -0,0 +1,398 @@ +#ifndef NNET_CONV_STREAM_H_ +#define NNET_CONV_STREAM_H_ + +#include "ac_channel.h" +#include "ap_shift_reg.h" +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet { + +enum class conv_implementation { linebuffer = 0, encoded = 1 }; + +// ************************************************* +// Encoded Implementation (Vlad's) +// ************************************************* +template unsigned scale_index_K_gte_S(const unsigned idx) { + //#pragma HLS INLINE + + if (idx < K - S) { + return idx; + } + + constexpr unsigned nW = ((W - K) / S) * S + K; // Nearest W without unused pixels on the right + constexpr unsigned sW = (DIV_ROUNDUP(K, S) - 1) * S + K; // Scaled W that behaves like original W + if (idx >= nW) { + return sW; + } + + const unsigned r = nW - idx; + if (r <= K - S) { + return sW - r; + } + + return K - S + (idx - (K - S)) % S; +} + +template unsigned scale_index_K_lt_S(const unsigned idx) { + //#pragma HLS INLINE + + if (idx < S - K) { + return idx; + } + + constexpr unsigned nW = ((W - K) / S) * S + K; // Nearest W without unused pixels on the right + constexpr unsigned sW = (DIV_ROUNDUP(S, K) - 1) * S + K; // Scaled W that behaves like original W + if (idx >= nW) { + return sW; + } + + const unsigned r = nW - idx; + if (r <= S - K) { + return sW - r; + } + + return S - K + (idx - (S - K)) % S; +} + +template class scale_index_regular { + public: + static unsigned scale_index(const unsigned idx) { + // #pragma HLS INLINE + + if (K >= S) { + return scale_index_K_gte_S(idx); + } else { + return scale_index_K_lt_S(idx); + } + } +}; + +template class scale_index_unscaled { + public: + static unsigned scale_index(const unsigned idx) { + // #pragma HLS INLINE + return idx; + } +}; + +template +void mult_buffer(ac_channel data_window[CONFIG_T::kernel_size * CONFIG_T::n_chan], + res_T &res_pack, ac_channel &res_stream, unsigned &outputs_ready, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS INLINE + + typename data_T::value_type data[CONFIG_T::kernel_size * CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=data complete + typename res_T::value_type res[CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable=res complete + +InitData: + for (unsigned int id = 0; id < CONFIG_T::kernel_size * CONFIG_T::n_chan; id++) { + // #pragma HLS UNROLL + data[id] = data_window[id].read(); + } + + //#pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + data, res, weights, biases); + } else { + dense_resource( + data, res, weights, biases); + } + +CastLoop: + for (unsigned jj = 0; jj < CONFIG_T::n_filt; jj++) { + // #pragma HLS UNROLL + if (res_T::size / CONFIG_T::n_filt == 1) { + res_pack[jj] = res[jj]; + } else { + res_pack[outputs_ready * CONFIG_T::n_filt + jj] = res[jj]; + } + } + + if (res_T::size / CONFIG_T::n_filt == 1) { + res_stream.write(res_pack); + } else { + if (outputs_ready == (res_T::size / CONFIG_T::n_filt) - 1) { + res_stream.write(res_pack); + outputs_ready = 0; + } else { + outputs_ready++; + } + } +} + +template +void compute_output_encoded(const data_T &in_elem, + ac_channel data_window[CONFIG_T::kernel_size * CONFIG_T::n_chan], + ac_channel &res, res_T &res_pack, unsigned &outputs_ready, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt], + ac_int *pixel_idx) { + //#pragma HLS INLINE + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +MultLoop: + for (unsigned p = 0; p < data_T::size / CONFIG_T::n_chan; p++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + CopyDataFilt: + for (unsigned f = 0; f < CONFIG_T::kernel_size; f++) { + // #pragma HLS UNROLL + CopyDataChan: + for (unsigned c = 0; c < CONFIG_T::n_chan; c++) { + // #pragma HLS UNROLL + if (pixel_idx[p][f]) + data_window[f * CONFIG_T::n_chan + c].write(in_elem[p * CONFIG_T::n_chan + c]); + } + } + if (pixel_idx[p][CONFIG_T::kernel_size - 1]) { + mult_buffer(data_window, res_pack, res, outputs_ready, weights, biases); + } + } +} + +// ************************************************* +// Line Buffer Implementation (Phil's) +// ************************************************* +template +void kernel_shift_1d(const data_T &in_elem, + typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::n_chan]) { + //#pragma HLS inline + //#pragma HLS PIPELINE II = 1 + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::filt_width - 1; +KernelShiftWidth: + for (int i_iw = 0; i_iw < filt_width; i_iw++) { + // #pragma HLS PIPELINE II = 1 + KernelShiftChannel: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // #pragma HLS UNROLL + // Shift every element in kernel_window to the left + kernel_window[i_iw * CONFIG_T::n_chan + i_ic] = kernel_window[(i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::filt_width - 1) * CONFIG_T::n_chan; +KernelPushChannel: + for (unsigned int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // #pragma HLS UNROLL + kernel_window[lastheight + i_ic] = in_elem[i_ic]; + } +} + +template +void kernel_shift_2d( + typename data_T::value_type shift_buffer[CONFIG_T::filt_height][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::filt_height * CONFIG_T::n_chan]) { + //#pragma HLS inline + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::filt_width - 1; +KernelShiftWidth: + for (int i_iw = 0; i_iw < filt_width; i_iw++) { + //#pragma HLS PIPELINE II = 1 + KernelShiftHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::filt_height; i_ih++) { + KernelShiftChannel: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift every element in kernel_window to the left + kernel_window[i_ih * CONFIG_T::filt_width * CONFIG_T::n_chan + i_iw * CONFIG_T::n_chan + i_ic] = + kernel_window[i_ih * CONFIG_T::filt_width * CONFIG_T::n_chan + (i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::filt_width - 1) * CONFIG_T::n_chan; +KernelPushHeight: + for (unsigned int i_ih = 0; i_ih < CONFIG_T::filt_height; i_ih++) { + KernelPushChannel: + for (unsigned int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + kernel_window[lastheight + i_ih * CONFIG_T::filt_width * CONFIG_T::n_chan + i_ic] = shift_buffer[i_ih][i_ic]; + } + } +} + +template +void shift_line_buffer( + const data_T &in_elem, + ap_shift_reg line_buffer[MAX(CONFIG_T::filt_height - 1, 1)] + [CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]) { + + //#pragma HLS PIPELINE + + // Temporary buffer for popped (shifted) elements + typename data_T::value_type shift_buffer[CONFIG_T::filt_height][CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable = shift_buffer complete dim = 0 + +UpdateBuffer: + for (unsigned int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // #pragma HLS UNROLL + + // Insert pixel(s) at end of shift buffer + shift_buffer[CONFIG_T::filt_height - 1][i_ic] = in_elem[i_ic]; + } + +LineBufferDataIn: + for (unsigned int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift the shift buffer into the line buffer + LineBufferShift: + for (unsigned i_ih = 1; i_ih < CONFIG_T::filt_height; i_ih++) { + // #pragma HLS UNROLL + typename data_T::value_type pop_elem = line_buffer[i_ih - 1][i_ic].shift( + shift_buffer[CONFIG_T::filt_height - i_ih][i_ic]); // Shift the line buffer, return the popped pixel + shift_buffer[CONFIG_T::filt_height - i_ih - 1][i_ic] = + pop_elem; // Popped element placed back into shift_buffer, one row up. + } + } + kernel_shift_2d(shift_buffer, kernel_window); +} + +template +void compute_output_buffer_2d( + const data_T &in_elem, + ap_shift_reg line_buffer[MAX(CONFIG_T::filt_height - 1, 1)] + [CONFIG_T::n_chan], + ac_channel &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS INLINE + + // Thresholds + const static int lShiftX = CONFIG_T::filt_width - 1; + const static int lShiftY = CONFIG_T::filt_height - 1; + + // Counters + static int pX = 0; // Pixel X + static int pY = 0; // Pixel Y + + static int sX = 0; // Stride X + static int sY = 0; // Stride Y + + static typename data_T::value_type kernel_data[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + + // Add pixel to buffer + nnet::shift_line_buffer(in_elem, line_buffer, kernel_data); + + // Check to see if we have a full kernel + if ((sX - lShiftX) == 0 && (sY - lShiftY) == 0 && pY > lShiftY - 1 && pX > lShiftX - 1) { + + // Dense multiply + //#pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights, biases); + } else { + dense_resource( + kernel_data, res_out, weights, biases); + } + + // Pack output + CastLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + // #pragma HLS UNROLL + res_pack[i_ic] = res_out[i_ic]; + } + + // Write output to stream when output ready + res_stream.write(res_pack); + } + + // Counter Housekeeping + if (pX + 1 == CONFIG_T::in_width) // Includes padding, end of line (padded) + { + pX = 0; + sX = 0; + if (pY + 1 == CONFIG_T::in_height) { // Reached bottom of image + pY = 0; + sY = 0; + } else { + pY = pY + 1; + // Update stride (threshold) ? subtract stride : increment stride + sY = ((sY - lShiftY) == 0) ? sY - CONFIG_T::stride_height + 1 : sY + 1; + } + } else { + pX = pX + 1; + // Update stride (threshold) ? subtract stride : increment stride + sX = ((sX - lShiftX) == 0) ? sX - CONFIG_T::stride_width + 1 : sX + 1; + } +} + +// Conv 1D compute output +template +void compute_output_buffer_1d( + const data_T &in_elem, ac_channel &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS INLINE + + // Thresholds + const static int lShiftX = CONFIG_T::filt_width - 1; + + // Counters + static int pX = 0; // pixel counter + static int sX = 0; // stride counter + + static typename data_T::value_type kernel_data[CONFIG_T::filt_width * CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + + // Add pixel to buffer + nnet::kernel_shift_1d(in_elem, kernel_data); + + // Check to see if we have a full kernel + if ((sX - lShiftX) == 0 && pX > lShiftX - 1) { + + // Dense multiply + //#pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights, biases); + } else { + dense_resource( + kernel_data, res_out, weights, biases); + } + + // Pack output + CastLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + // #pragma HLS UNROLL + res_pack[i_ic] = res_out[i_ic]; + } + + // Write output to stream when output ready + res_stream.write(res_pack); + } + + // Counter Housekeeping + if (pX + 1 == CONFIG_T::in_width) // Includes padding, end of line (padded) + { + pX = 0; + sX = 0; + } else { + pX = pX + 1; + // Update stride (threshold) ? subtract stride : increment stride + sX = ((sX - lShiftX) == 0) ? sX - CONFIG_T::stride_width + 1 : sX + 1; + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_dense.h b/hls4ml/templates/catapult/nnet_utils/nnet_dense.h new file mode 100644 index 0000000000..64b927cc64 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_dense.h @@ -0,0 +1,49 @@ +#ifndef NNET_DENSE_H_ +#define NNET_DENSE_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_dense_latency.h" +#include "nnet_dense_resource.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" +#include + +namespace nnet { + +struct dense_config { + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + // Layer Sizes + static const unsigned n_in = 10; + static const unsigned n_out = 10; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned strategy = latency; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; + // partitioning arrays cyclically to go with roll factors? + // Product function to use + template using product = nnet::product::mult; +}; + +template +void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + //#pragma HLS inline + if (CONFIG_T::strategy == nnet::latency) { + dense_latency(data, res, weights, biases); + } else { + dense_resource(data, res, weights, biases); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_dense_compressed.h b/hls4ml/templates/catapult/nnet_utils/nnet_dense_compressed.h new file mode 100644 index 0000000000..f3f27b6db8 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_dense_compressed.h @@ -0,0 +1,106 @@ +// +// hls4ml: Vivado HLS code for neural-net building blocks +// +// Copyright (C) 2018 Giuseppe Di Guglielmo +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// + +#ifndef NNET_COMPRESSED_LAYER_H_ +#define NNET_COMPRESSED_LAYER_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_dense.h" +#include + +namespace nnet { + +template +void fill_mult(typename CONFIG_T::index_t index, typename CONFIG_T::accum_t mult[CONFIG_T::n_out], + typename CONFIG_T::accum_t weight) { + for (unsigned k = 0; k < CONFIG_T::n_out; k++) { + // #pragma HLS UNROLL + if (k == index) + mult[k] += weight; + } +} + +template +void dense_compressed(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_nonzeros], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + const int multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_nonzeros, CONFIG_T::reuse_factor); + + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + //#pragma HLS ARRAY_PARTITION variable=acc complete + //#pragma HLS ARRAY_PARTITION variable=biases complete + //#pragma HLS ARRAY_RESHAPE variable=weights block factor=multiplier_limit + // if (CONFIG_T::store_weights_in_bram){ + ////#pragma HLS RESOURCE variable=weights core=ROM_1P_BRAM + //#pragma HLS data_pack variable=weights struct_level + //} + +InitAccum: + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + acc[i] = (typename CONFIG_T::accum_t)(biases[i]); + } + + // Do the compressed matrix-multiply + const int rufactor = CONFIG_T::reuse_factor; +ReuseLoop: + for (unsigned ir = 0; ir < rufactor; ir++) { + //#pragma HLS PIPELINE II=1 rewind + + typename CONFIG_T::accum_t mult[CONFIG_T::n_out]; + //#pragma HLS ARRAY_PARTITION variable=mult complete + + ResetMult: + for (int imult = 0; imult < CONFIG_T::n_out; imult++) { + // #pragma HLS UNROLL + mult[imult] = 0; + } + + CompressedMultLoop: + for (unsigned im = 0; im < multiplier_limit; im++) { + // #pragma HLS UNROLL + unsigned w = im * rufactor + ir; + auto row = weights[w].row_index; + auto col = weights[w].col_index; + auto weight_cache = weights[w].weight; + data_T data_cache = data[row]; + // mult[col] += weight_cache * data_cache; + typename CONFIG_T::accum_t prod = + CONFIG_T::template product::product(data_cache, weight_cache); + fill_mult(col, mult, prod); + } + + for (int im = 0; im < CONFIG_T::n_out; im++) { + acc[im] += mult[im]; + } + } + +// Cast to "res_t" type +ResultLoop: + for (unsigned i = 0; i < CONFIG_T::n_out; i++) { + // #pragma HLS UNROLL + // res[i] = (res_T) (acc[i]); + res[i] = cast(acc[i]); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_dense_latency.h b/hls4ml/templates/catapult/nnet_utils/nnet_dense_latency.h new file mode 100644 index 0000000000..40e5cd2b9d --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_dense_latency.h @@ -0,0 +1,92 @@ + +#ifndef NNET_DENSE_LATENCY_H_ +#define NNET_DENSE_LATENCY_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_helpers.h" +#include "nnet_mult.h" +#include + +namespace nnet { + +template +void dense_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + // Partial unroll config + constexpr int prod1_unroll = + (ce_reuse_factor < CONFIG_T::n_in) ? CONFIG_T::n_in : (int)(CONFIG_T::n_in * CONFIG_T::n_out) / ce_reuse_factor; + constexpr int prod2_unroll = (int)CONFIG_T::n_out / ce_reuse_factor; + + (void)ce_reuse_factor; // to silence compiler warnings + (void)prod1_unroll; + (void)prod2_unroll; + + // For Catapult, add an extra scope so that we can apply the pipeline pragma as if it applied to the function + do { + data_T cache; + typename CONFIG_T::accum_t mult[CONFIG_T::n_in * CONFIG_T::n_out]; + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + //#pragma HLS function_instantiate variable=weights,biases + + // For parallel inputs: + // - completely partition arrays -- target fabric + // - if we have an unroll factor, limit number of multipliers + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + // //#pragma HLS ARRAY_PARTITION variable=weights complete // remove this line for now, it breaks compression + // sometimes + //#pragma HLS ARRAY_PARTITION variable=biases complete + //#pragma HLS ARRAY_PARTITION variable=mult complete + //#pragma HLS ARRAY_PARTITION variable=acc complete + + // int multiplier_limit = ceil(float(CONFIG_T::n_in*CONFIG_T::n_out) / float(CONFIG_T::reuse_factor)) - + // floor(float(CONFIG_T::n_zeros) / float(CONFIG_T::reuse_factor)); + constexpr int multiplier_limit = + ((CONFIG_T::n_in * CONFIG_T::n_out) / CONFIG_T::reuse_factor) - CONFIG_T::n_zeros / CONFIG_T::reuse_factor; + CONFIG_T::template product::limit(multiplier_limit); + + // Do the matrix-multiply + Product1: + for (unsigned int ii = 0; ii < CONFIG_T::n_in; ii++) { + cache = data[ii]; + Product2: + for (unsigned int jj = 0; jj < CONFIG_T::n_out; jj++) { + int index = ii * CONFIG_T::n_out + jj; + mult[index] = + CONFIG_T::template product::product(cache, weights[index]); + } + } + + // Initialize accumulator with input biases + ResetAccum: + for (unsigned int iacc = 0; iacc < CONFIG_T::n_out; iacc++) { + acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc]; + } + + // Accumulate multiplication result + Accum1: + for (unsigned int ii = 0; ii < CONFIG_T::n_in; ii++) { + Accum2: + for (unsigned int jj = 0; jj < CONFIG_T::n_out; jj++) { + int index = ii * CONFIG_T::n_out + jj; + acc[jj] += mult[index]; + } + } + + // Cast to "res_t" type + Result: + for (unsigned int ires = 0; ires < CONFIG_T::n_out; ires++) { + // res[ires] = (res_T) (acc[ires]); + res[ires] = cast(acc[ires]); + } + } while (false); // one iteration loop +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/catapult/nnet_utils/nnet_dense_resource.h new file mode 100644 index 0000000000..5bcd1a54b7 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_dense_resource.h @@ -0,0 +1,262 @@ + +#ifndef NNET_DENSE_RESOURCE_H_ +#define NNET_DENSE_RESOURCE_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_mult.h" +#include +#include + +namespace nnet { + +template +void dense_resource_rf_leq_nin(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + const int rufactor = CONFIG_T::reuse_factor; + const int multfactor = MIN(CONFIG_T::n_in, CONFIG_T::reuse_factor); + const int multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, multfactor); + const int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor); + const int multscale = multiplier_limit / CONFIG_T::n_out; + const int nin = CONFIG_T::n_in; + const int nout = CONFIG_T::n_out; + + assert((multiplier_limit % nout == 0 || rufactor >= nin) && "The current Reuse Factor is not allowed"); + assert((multiplier_limit == block_factor) && "This function is correct only for RF <= N_IN"); + + //#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly + //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + //#pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + //#pragma HLS ARRAY_PARTITION variable=acc complete + +InitAccum: + for (int iacc = 0; iacc < nout; iacc++) { + //#pragma HLS UNROLL + acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc]; + } + +ReuseLoop: + for (int ir = 0; ir < rufactor; ir++) { + //#pragma HLS PIPELINE II=1 rewind + + int w_index = ir; + int in_index = ir; + int out_index = 0; + int acc_step = 0; + + MultLoop: + for (int im = 0; im < block_factor; im++) { + //#pragma HLS UNROLL + + acc[out_index] += static_cast( + CONFIG_T::template product::product(data[in_index], weights[w_index])); + + // Increment w_index + w_index += rufactor; + // Increment in_index + in_index += rufactor; + if (in_index >= nin) { + in_index = ir; + } + // Increment out_index + if (acc_step + 1 >= multscale) { + acc_step = 0; + out_index++; + } else { + acc_step++; + } + } + } + +// Cast to "res_t" type +Result: + for (unsigned int ires = 0; ires < CONFIG_T::n_out; ires++) { + //#pragma HLS UNROLL + res[ires] = cast(acc[ires]); + } +} + +template +void dense_resource_rf_gt_nin_rem0(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + const int rufactor = MIN(CONFIG_T::reuse_factor, CONFIG_T::n_in * CONFIG_T::n_out); + const int multfactor = MIN(CONFIG_T::n_in, CONFIG_T::reuse_factor); + const int multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, multfactor); + const int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor); + const int nin = CONFIG_T::n_in; + const int nout = CONFIG_T::n_out; + + assert((multiplier_limit % nout == 0 || rufactor >= nin) && "The current Reuse Factor is not allowed"); + assert((rufactor > nin && rufactor % nin == 0) && "This function is correct only for RF > N_IN && RF % N_IN == 0"); + + //#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly + //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + //#pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + //#pragma HLS ARRAY_PARTITION variable=acc complete + +InitAccum: + for (int iacc = 0; iacc < nout; iacc++) { + //#pragma HLS UNROLL + acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc]; + } + + unsigned int w_index; + int in_index = 0; + int out_index; + int outstep = 0; + const int outscale = rufactor / nin; + + int outidx[rufactor]; +IndexLoop: + for (int ir = 0; ir < rufactor; ir++) { + outidx[ir] = outstep; + if ((ir + 1) % nin == 0) { + outstep++; + } + } + +ReuseLoop: + for (unsigned int ir = 0; ir < rufactor; ir++) { + //#pragma HLS PIPELINE II=1 rewind + + w_index = ir; + out_index = outidx[ir] /*outstep*/; + + MultLoop: + for (unsigned int im = 0; im < block_factor; im++) { + //#pragma HLS UNROLL + acc[out_index] += static_cast( + CONFIG_T::template product::product(data[in_index], weights[w_index])); + + w_index += rufactor; + if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) + break; // check out of bounds + out_index += outscale; + } + + in_index++; + if (in_index >= nin) { + in_index = 0; + // outstep++; // This causes a huge increase in scheduling and RTL generation times, hence the above workaround. + } + } + +// Cast to "res_t" type +Result: + for (unsigned int ires = 0; ires < CONFIG_T::n_out; ires++) { + //#pragma HLS UNROLL + res[ires] = cast(acc[ires]); + } +} + +template +void dense_resource_rf_gt_nin(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + const int rufactor = CONFIG_T::reuse_factor; + const int multfactor = MIN(CONFIG_T::n_in, CONFIG_T::reuse_factor); + const int multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, multfactor); + const int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor); + const int nin = CONFIG_T::n_in; + const int nout = CONFIG_T::n_out; + + assert((multiplier_limit % nout == 0 || rufactor >= nin) && "The current Reuse Factor is not allowed"); + assert((rufactor > nin) && "This function is correct only for RF > N_IN"); + + //#pragma HLS function_instantiate variable=weights,biases + ////#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly + //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + //#pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + //#pragma HLS ARRAY_PARTITION variable=acc complete + +InitAccum: + for (int iacc = 0; iacc < nout; iacc++) { + //#pragma HLS UNROLL + acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc]; + } + +ReuseLoop: + for (int ir = 0; ir < rufactor; ir++) { + //#pragma HLS PIPELINE II=1 rewind + typename CONFIG_T::accum_t tmpmult[block_factor]; + //#pragma HLS ARRAY_PARTITION variable=tmpmult complete + + MultLoop: + for (int im = 0; im < block_factor; im++) { + //#pragma HLS UNROLL + unsigned int w_index = ir + rufactor * im; + int in_index = w_index % nin; + if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) + continue; // check out of bounds + tmpmult[im] = + CONFIG_T::template product::product(data[in_index], weights[w_index]); + } + + typename CONFIG_T::accum_t mult[multiplier_limit]; + //#pragma HLS ARRAY_PARTITION variable=mult complete + + ResetMult: + for (int imult = 0; imult < multiplier_limit; imult++) { + //#pragma HLS UNROLL + mult[imult] = 0; + } + + AccumLoop1: + for (int im = 0; im < block_factor; im++) { + //#pragma HLS UNROLL + int w_index = ir + rufactor * im; + int out_index = w_index / multfactor; + if (out_index >= multiplier_limit) + continue; // check out of bounds + mult[out_index] += tmpmult[im]; + } + + AccumLoop2: + for (int im = 0; im < multiplier_limit; im++) { + //#pragma HLS UNROLL + // int out_index = im/multscale; // This is the general case + // acc[out_index] += mult[im]; + acc[im] += mult[im]; // If RF > N_IN then multiplier_limit == n_out + } + } + +// Cast to "res_t" type +Result: + for (unsigned int ires = 0; ires < CONFIG_T::n_out; ires++) { + //#pragma HLS UNROLL + res[ires] = cast(acc[ires]); + } +} + +template +void dense_resource(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + //#pragma HLS INLINE region + + if (CONFIG_T::reuse_factor <= CONFIG_T::n_in) { + dense_resource_rf_leq_nin(data, res, weights, biases); + } else if (CONFIG_T::reuse_factor % CONFIG_T::n_in == 0) { + dense_resource_rf_gt_nin_rem0(data, res, weights, biases); + } else { + dense_resource_rf_gt_nin(data, res, weights, biases); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_dense_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_dense_stream.h new file mode 100644 index 0000000000..665d2f43f3 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_dense_stream.h @@ -0,0 +1,72 @@ +#ifndef NNET_DENSE_STREAM_H_ +#define NNET_DENSE_STREAM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_types.h" +#include +#include + +namespace nnet { + +template +void dense_wrapper(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + //#pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + dense_latency(data, res, weights, biases); + } else { + dense_resource(data, res, weights, biases); + } +} + +template +void dense(ac_channel &data_stream, ac_channel &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + typename data_T::value_type data[CONFIG_T::n_in]; + //#pragma HLS ARRAY_PARTITION variable=data complete + + typename res_T::value_type res[CONFIG_T::n_out]; + //#pragma HLS ARRAY_PARTITION variable=res complete + + if ((CONFIG_T::n_in / data_T::size) > 1) { + } +DataPrepare: + for (unsigned int i_in = 0; i_in < CONFIG_T::n_in / data_T::size; i_in++) { + if (CONFIG_T::n_in / data_T::size > 1) { + //#pragma HLS PIPELINE + } + data_T data_pack = data_stream.read(); + DataPack: + for (unsigned int i_pack = 0; i_pack < data_T::size; i_pack++) { + //#pragma HLS UNROLL + data[i_in * data_T::size + i_pack] = data_pack[i_pack]; + } + } + + dense_wrapper(data, res, weights, biases); + + if ((CONFIG_T::n_out / res_T::size) > 1) { + } +ResWrite: + for (unsigned i_out = 0; i_out < CONFIG_T::n_out / res_T::size; i_out++) { + if (CONFIG_T::n_out / res_T::size > 1) { + //#pragma HLS PIPELINE + } + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + ResPack: + for (unsigned int i_pack = 0; i_pack < res_T::size; i_pack++) { + //#pragma HLS UNROLL + res_pack[i_pack] = res[i_out * res_T::size + i_pack]; + } + res_stream.write(res_pack); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_embed.h b/hls4ml/templates/catapult/nnet_utils/nnet_embed.h new file mode 100644 index 0000000000..4cdf507f9d --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_embed.h @@ -0,0 +1,47 @@ +#ifndef NNET_EMBED_H_ +#define NNET_EMBED_H_ + +#include "nnet_common.h" +#include "nnet_helpers.h" + +namespace nnet { + +struct embed_config { + // Internal data type definitions + typedef float embeddings_t; + + // Layer Sizes + static const unsigned n_in = 10; + static const unsigned n_out = 16; + static const unsigned vocab_size = 50; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; +}; + +template +void embedding(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in * CONFIG_T::n_out], + typename CONFIG_T::embeddings_t embeddings[CONFIG_T::vocab_size * CONFIG_T::n_out]) { + + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + // This can save a few cycles, but it will create a large multiplexer due to + // non-constant access pattern, so let's leave it out + ////#pragma HLS ARRAY_PARTITION variable=embeddings complete + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +InputSequence: + for (int j = 0; j < CONFIG_T::n_in; j++) { + // #pragma HLS UNROLL + DenseEmbedding: + for (int i = 0; i < CONFIG_T::n_out; i++) { + // #pragma HLS UNROLL + res[j * CONFIG_T::n_out + i] = embeddings[data[j] * CONFIG_T::n_out + i]; + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_embed_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_embed_stream.h new file mode 100644 index 0000000000..1378100879 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_embed_stream.h @@ -0,0 +1,34 @@ +#ifndef NNET_EMBED_STREAM_H_ +#define NNET_EMBED_STREAM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_helpers.h" + +namespace nnet { + +template +void embedding(ac_channel &data, ac_channel &res, + typename CONFIG_T::embeddings_t embeddings[CONFIG_T::vocab_size * CONFIG_T::n_out]) { + data_T in_data = data.read(); + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +InputSequence: + for (int j = 0; j < data_T::size; j++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + + DenseEmbedding: + for (int i = 0; i < CONFIG_T::n_out; i++) { + // #pragma HLS UNROLL + res_pack[i] = embeddings[in_data[j] * CONFIG_T::n_out + i]; + } + res.write(res_pack); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_garnet.h b/hls4ml/templates/catapult/nnet_utils/nnet_garnet.h new file mode 100644 index 0000000000..7451110fba --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_garnet.h @@ -0,0 +1,816 @@ + +#ifndef NNET_GARNET_H_ +#define NNET_GARNET_H_ + +#include "ac_channel.h" +#include "hls_math.h" +#include "nnet_common.h" + +namespace nnet { +namespace garnet_utils { + +template +inline typename std::enable_if::value>::type +initialize_edge_weights_table(typename CONFIG_T::edge_weight_t edge_weights_table[]) { + typedef ac_int index_t; + + unsigned const table_size = (1 << CONFIG_T::distance_width); + + index_t index; + typename CONFIG_T::distance_t distance; + + // edge_weight_t is ap_ufixed with 0 iwidth -> let index 0 be a saturated version of 1 + edge_weights_table[0] = ac_fixed(1.); + + for (unsigned iw = 1; iw < table_size; ++iw) { + index = iw; + distance.range(CONFIG_T::distance_width - 1, 0) = index.range(CONFIG_T::distance_width - 1, 0); + edge_weights_table[iw] = hls::exp(-distance * distance); + } +} + +template +inline typename std::enable_if::value>::type +initialize_edge_weights_table(typename CONFIG_T::edge_weight_t edge_weights_table[]) { + unsigned const table_size = (1 << CONFIG_T::distance_width); + double const step = 64. / table_size; + + typename CONFIG_T::distance_t v = -32.; + for (unsigned iw = 0; iw < table_size; ++iw) { +#ifdef __SYNTHESIS__ + // hack for now to get through the flow + edge_weights_table[iw] = (-v * v); +#else + edge_weights_table[iw] = std::exp(-v * v); +#endif + v += step; + } +} + +template +inline typename std::enable_if::value, typename CONFIG_T::edge_weight_t>::type +get_edge_weight(typename CONFIG_T::distance_t distance, typename CONFIG_T::edge_weight_t edge_weights_table[]) { + typedef ac_int index_t; + + index_t index(distance.range(CONFIG_T::distance_width - 1, 0)); + + return edge_weights_table[index]; +} + +template +inline + typename std::enable_if::value, typename CONFIG_T::edge_weight_t>::type + get_edge_weight(typename CONFIG_T::distance_t distance, typename CONFIG_T::edge_weight_t edge_weights_table[]) { + unsigned const table_size = (1 << CONFIG_T::distance_width); + double const step = 64. / table_size; + + int index = (distance + 32.) / step; + if (index < 0) + index = 0; + else if (index >= table_size) + index = table_size - 1; + + return edge_weights_table[index]; +} + +template typename CONFIG_T::edge_weight_t compute_edge_weight(typename CONFIG_T::distance_t distance) { + if (CONFIG_T::is_stack) { + //#pragma HLS INLINE OFF + } +#ifdef __SYNTHESIS__ + typename CONFIG_T::edge_weight_t edge_weights_table[1 << CONFIG_T::distance_width]; + // unsigned const reshape_factor = CONFIG_T::n_aggregators * CONFIG_T::n_in_features * (CONFIG_T::n_vertices / + // CONFIG_T::reuse_factor); + // //#pragma HLS ARRAY_RESHAPE variable=edge_weights_table cyclic factor=reshape_factor dim=1 + bool initialized = false; +#else + static typename CONFIG_T::edge_weight_t edge_weights_table[1 << CONFIG_T::distance_width]; + static bool initialized = false; +#endif + if (not initialized) { + initialize_edge_weights_table(edge_weights_table); + initialized = true; + } + + return get_edge_weight(distance, edge_weights_table); +} + +template +inline typename std::enable_if::value, dividend_T>::type normalize_log2(dividend_T dividend, + exponent_T exponent) { + //#pragma HLS INLINE + return dividend >> exponent; +} + +template +inline typename std::enable_if::value, dividend_T>::type normalize_log2(dividend_T dividend, + exponent_T exponent) { + //#pragma HLS INLINE + return dividend / std::pow(2., exponent); +} + +template struct Means { + typedef E edge_weight_t; + + edge_weight_t edge_weight_mean[CONFIG_T::n_aggregators]; + typename CONFIG_T::aggr_t weighted_feature_mean[CONFIG_T::n_aggregators * CONFIG_T::n_in_features]; + + Means() { + //#pragma HLS INLINE + //#pragma HLS ARRAY_PARTITION variable=edge_weight_mean complete + //#pragma HLS ARRAY_PARTITION variable=weighted_feature_mean complete + + Aggregators: + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + edge_weight_mean[ia] = 0.; + + InFeatures: + for (unsigned ix = 0; ix < CONFIG_T::n_in_features; ++ix) { + unsigned const iax = ia * CONFIG_T::n_in_features + ix; + weighted_feature_mean[iax] = 0.; + } + } + } + + void set_weight(unsigned, edge_weight_t const &) { + //#pragma HLS INLINE + } + + void add_means_normalized(Means const &local) { + //#pragma HLS INLINE + // Always called within a pipelined region - no UNROLL needed + + unsigned const log2_unroll_factor = CONFIG_T::n_vertices_width - CONFIG_T::log2_reuse_factor; + + Aggregators: + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + edge_weight_mean[ia] += normalize_log2(local.edge_weight_mean[ia], log2_unroll_factor); + + InFeatures: + for (unsigned ix = 0; ix < CONFIG_T::n_in_features; ++ix) { + unsigned const iax = ia * CONFIG_T::n_in_features + ix; + weighted_feature_mean[iax] += normalize_log2(local.weighted_feature_mean[iax], log2_unroll_factor); + } + } + } + + template + typename std::enable_if::type set_means_normalized(nvtx_T const nvtx, arrays_T const &accum) { + //#pragma HLS INLINE + + // accum comes divided by unroll factor + typename T::norm_t nvtx_norm = (T::n_vertices / T::reuse_factor) / nvtx; + + Aggregators: + for (unsigned ia = 0; ia < T::n_aggregators; ++ia) { + edge_weight_mean[ia] = accum.edge_weight_mean[ia] * nvtx_norm; + + InFeatures: + for (unsigned ix = 0; ix < T::n_in_features; ++ix) { + unsigned const iax = ia * T::n_in_features + ix; + + weighted_feature_mean[iax] = accum.weighted_feature_mean[iax] * nvtx_norm; + } + } + } + + template + typename std::enable_if::type set_means_normalized(nvtx_T const nvtx, arrays_T const &accum) { + //#pragma HLS INLINE + + Aggregators: + for (unsigned ia = 0; ia < T::n_aggregators; ++ia) { + + edge_weight_mean[ia] = normalize_log2(accum.edge_weight_mean[ia], T::log2_reuse_factor); + + InFeatures: + for (unsigned ix = 0; ix < T::n_in_features; ++ix) { + unsigned const iax = ia * T::n_in_features + ix; + + weighted_feature_mean[iax] = normalize_log2(accum.weighted_feature_mean[iax], T::log2_reuse_factor); + } + } + } +}; + +template struct WeightsAndMeans : public Means { + typedef E edge_weight_t; + + edge_weight_t edge_weights[CONFIG_T::n_vertices * CONFIG_T::n_aggregators]; + + WeightsAndMeans() : Means() { + //#pragma HLS INLINE + unsigned const reshape_factor = CONFIG_T::n_aggregators * (CONFIG_T::n_vertices / CONFIG_T::reuse_factor); + //#pragma HLS ARRAY_PARTITION variable=edge_weights cyclic factor=reshape_factor + } + + void set_weight(unsigned iva, edge_weight_t const &weight) { + //#pragma HLS INLINE + edge_weights[iva] = weight; + } +}; + +template struct OutputBiasNormalizer; + +template +struct OutputBiasNormalizer::type> { + typedef typename CONFIG_T::output_transform_biases_t biases_t; + + biases_t const (&output_biases)[CONFIG_T::n_out_features]; + + OutputBiasNormalizer(nvtx_T const) : output_biases{CONFIG_T::output_transform_biases} { + //#pragma HLS INLINE + } +}; + +template +struct OutputBiasNormalizer::type> { + typedef typename CONFIG_T::output_transform_biases_t biases_t; + + biases_t output_biases[CONFIG_T::n_out_features]; + + OutputBiasNormalizer(nvtx_T const nvtx) { + //#pragma HLS ARRAY_PARTITION variable=output_biases complete + + // Cannot add a loop label here due to a Vivado HLS bug, apparently + for (unsigned io = 0; io < CONFIG_T::n_out_features; ++io) { + typename CONFIG_T::aggr_t bias = CONFIG_T::output_transform_biases[io]; + bias *= nvtx; + output_biases[io] = normalize_log2(bias, CONFIG_T::n_vertices_width); + } + } +}; + +template struct InputDataGetter { + typedef data_T data_t; + + data_T const *dataref; + + InputDataGetter(data_T const *d) : dataref{d} { + //#pragma HLS INLINE + } + data_T const &get(unsigned iv, unsigned ix) const { + //#pragma HLS INLINE + unsigned const ivx = iv * CONFIG_T::n_in_features + ix; + return dataref[ivx]; + } +}; + +template struct SingleVertexDataGetter { + typedef data_T data_t; + + data_T const (&dataref)[CONFIG_T::n_in_features]; + + SingleVertexDataGetter(data_T const (&d)[CONFIG_T::n_in_features]) : dataref{d} { + //#pragma HLS INLINE + } + data_T const &get(unsigned, unsigned ix) const { + //#pragma HLS INLINE + return dataref[ix]; + } +}; + +template struct OutputResSetter { + typedef res_T res_t; + + res_T *resref; + + OutputResSetter(res_T *r) : resref{r} { + //#pragma HLS INLINE + } + void set(unsigned iv, unsigned io, res_T const &acc) { + //#pragma HLS INLINE + unsigned const ivo = iv * CONFIG_T::n_out_features + io; + resref[ivo] = acc; + } +}; + +template struct SingleVertexResSetter { + typedef res_T res_t; + + res_T (&resref)[CONFIG_T::n_out_features]; + + SingleVertexResSetter(res_T (&r)[CONFIG_T::n_out_features]) : resref{r} { + //#pragma HLS INLINE + } + void set(unsigned, unsigned io, res_T const &acc) { + //#pragma HLS INLINE + resref[io] = acc; + } +}; + +template +inline void compute_weights_aggregates(data_getter_T const &data_getter, unsigned iv, arrays_local_T &arrays_local, + arrays_T &arrays) { + //#pragma HLS INLINE + +Aggregators: + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + typename CONFIG_T::distance_t distance = CONFIG_T::aggregator_distance_biases[ia]; + + InFeatures1: + for (unsigned ix = 0; ix < CONFIG_T::n_in_features; ++ix) { + unsigned const iax = ia * CONFIG_T::n_in_features + ix; + + typename CONFIG_T::distance_t incr = data_getter.get(iv, ix) * CONFIG_T::aggregator_distance_weights[iax]; + + distance += incr; + } + + typename CONFIG_T::edge_weight_t edge_weight = + garnet_utils::compute_edge_weight(distance); + + arrays_local.edge_weight_mean[ia] += edge_weight; + + InFeatures2: + for (unsigned ix = 0; ix < CONFIG_T::n_in_features; ++ix) { + unsigned const iax = ia * CONFIG_T::n_in_features + ix; + + typename data_getter_T::data_t incr = data_getter.get(iv, ix) * edge_weight; + + arrays_local.weighted_feature_mean[iax] += incr; + } + + unsigned const iva = iv * CONFIG_T::n_aggregators + ia; + arrays.set_weight(iva, edge_weight); + } +} + +template +inline typename CONFIG_T::aggr_t compute_output_base_core(arrays_T const &arrays, unsigned io, unsigned ia) { + //#pragma HLS INLINE + + unsigned const ioa = io * CONFIG_T::n_aggregators + ia; + typename CONFIG_T::aggr_t aggr = arrays.edge_weight_mean[ia] * CONFIG_T::input_transform_biases[ioa]; + +InFeatures: + for (unsigned ix = 0; ix < CONFIG_T::n_in_features; ++ix) { + unsigned const ioax = ioa * CONFIG_T::n_in_features + ix; + unsigned const iax = ia * CONFIG_T::n_in_features + ix; + + aggr += arrays.weighted_feature_mean[iax] * CONFIG_T::input_transform_weights[ioax]; + } + + return aggr; +} + +template +inline void compute_output_base(arrays_T const &arrays, + typename CONFIG_T::aggr_t output_base[CONFIG_T::n_out_features * CONFIG_T::n_aggregators]) { + //#pragma HLS INLINE + +OutFeatures: + for (unsigned io = 0; io < CONFIG_T::n_out_features; ++io) { + Aggregators: + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + unsigned const ioa = io * CONFIG_T::n_aggregators + ia; + + output_base[ioa] = compute_output_base_core(arrays, io, ia); + } + } +} + +template +inline void +compute_vertex_output(arrays_T const &arrays, unsigned iv, + typename CONFIG_T::aggr_t const output_base[CONFIG_T::n_out_features * CONFIG_T::n_aggregators], + res_setter_T &res_setter) { + //#pragma HLS INLINE + + typename arrays_T::edge_weight_t edge_weights[CONFIG_T::n_aggregators]; + //#pragma HLS ARRAY_PARTITION variable=edge_weights complete + +Aggregators1: + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + unsigned const iva = iv * CONFIG_T::n_aggregators + ia; + + edge_weights[ia] = arrays.edge_weights[iva]; + } + +OutFeatures: + for (unsigned io = 0; io < CONFIG_T::n_out_features; ++io) { + typename res_setter_T::res_t acc = CONFIG_T::output_transform_biases[io]; + + Aggregators2: + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + unsigned const ioa = io * CONFIG_T::n_aggregators + ia; + + typename res_setter_T::res_t incr = edge_weights[ia] * output_base[ioa]; + acc += incr; + } + + res_setter.set(iv, io, acc); + } +} + +template +void aggregate(data_T const data[CONFIG_T::n_vertices * CONFIG_T::n_in_features], nvtx_T const nvtx, arrays_T &arrays) { + InputDataGetter data_getter(data); + + unsigned const unroll_factor = CONFIG_T::n_vertices >> CONFIG_T::log2_reuse_factor; + + Means means_accum; + +VerticesOuter: + for (unsigned ivv = 0; ivv < CONFIG_T::reuse_factor; ++ivv) { + //#pragma HLS PIPELINE + + if (ivv * unroll_factor >= nvtx) + break; + + Means means_local; + + VerticesInner: + for (unsigned ir = 0; ir < unroll_factor; ++ir) { + unsigned iv = ivv * unroll_factor + ir; + + if (iv == nvtx) + break; + + compute_weights_aggregates(data_getter, iv, means_local, arrays); + } + + means_accum.add_means_normalized(means_local); + } + + arrays.set_means_normalized(nvtx, means_accum); +} + +template +void distribute(nvtx_T const nvtx, arrays_T const &arrays, res_T res[CONFIG_T::n_vertices * CONFIG_T::n_out_features]) { + OutputResSetter res_setter(res); + + typename CONFIG_T::aggr_t output_base[CONFIG_T::n_out_features * CONFIG_T::n_aggregators]; + //#pragma HLS ARRAY_PARTITION variable=output_base complete + + compute_output_base(arrays, output_base); + + unsigned const unroll_factor = CONFIG_T::n_vertices >> CONFIG_T::log2_reuse_factor; + +VerticesOuter: + for (unsigned ivv = 0; ivv < CONFIG_T::reuse_factor; ++ivv) { + //#pragma HLS PIPELINE + + if (ivv * unroll_factor >= nvtx) + break; + + VerticesInner: + for (unsigned ir = 0; ir < unroll_factor; ++ir) { + unsigned iv = ivv * unroll_factor + ir; + + if (iv == nvtx) + break; + + compute_vertex_output(arrays, iv, output_base, res_setter); + } + } +} + +template +void set_output(output_biases_T const &output_transform_biases, arrays_T const &arrays, + res_T res[CONFIG_T::n_out_features]) { + //#pragma HLS PIPELINE + +OutFeatures: + for (unsigned io = 0; io < CONFIG_T::n_out_features; ++io) { + res_T acc = output_transform_biases.output_biases[io]; + + Aggregators: + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + typename CONFIG_T::aggr_t aggr = compute_output_base_core(arrays, io, ia); + + acc += arrays.edge_weight_mean[ia] * aggr; + } + + res[io] = acc; + } +} + +template +void distribute_aggregate(nvtx_T const nvtx, prev_arrays_T const &prev_arrays, current_arrays_T ¤t_arrays) { + typedef typename prev_layer_t::output_t data_T; + + typename prev_layer_t::aggr_t prev_output_base[prev_layer_t::n_out_features * prev_layer_t::n_aggregators]; + //#pragma HLS ARRAY_PARTITION variable=prev_output_base complete + + compute_output_base(prev_arrays, prev_output_base); + + unsigned const unroll_factor = current_layer_t::n_vertices >> current_layer_t::log2_reuse_factor; + + Means means_accum; + +VerticesOuter: + for (unsigned ivv = 0; ivv < current_layer_t::reuse_factor; ++ivv) { + //#pragma HLS PIPELINE + + if (ivv * unroll_factor >= nvtx) + break; + + Means means_local; + + VerticesInner: + for (unsigned ir = 0; ir < unroll_factor; ++ir) { + unsigned iv = ivv * unroll_factor + ir; + + if (iv == nvtx) + break; + + data_T data[prev_layer_t::n_out_features]; + //#pragma HLS ARRAY_PARTITION variable=data complete + + SingleVertexResSetter res_setter(data); + + compute_vertex_output(prev_arrays, iv, prev_output_base, res_setter); + + SingleVertexDataGetter data_getter(data); + + compute_weights_aggregates(data_getter, iv, means_local, current_arrays); + } + + means_accum.add_means_normalized(means_local); + } + + current_arrays.set_means_normalized(nvtx, means_accum); +} + +template +inline typename std::enable_if::value>::type +sublayer(nvtx_T const nvtx, prev_arrays_T const &prev_arrays, last_arrays_T &last_arrays) { + //#pragma HLS INLINE + + distribute_aggregate(nvtx, prev_arrays, last_arrays); +} + +template +inline typename std::enable_if::value>::type +sublayer(nvtx_T const nvtx, prev_arrays_T const &prev_arrays, last_arrays_T &last_arrays) { + //#pragma HLS INLINE + + WeightsAndMeans current_arrays; + + distribute_aggregate(nvtx, prev_arrays, current_arrays); + + sublayer(nvtx, current_arrays, last_arrays); +} +} // namespace garnet_utils + +struct garnet_config { + // Layer specs + static const unsigned n_vertices_width = 8; + static const unsigned n_vertices = (1 << n_vertices_width); + static const unsigned n_in_features = 4; + static const unsigned n_propagate = 4; + static const unsigned n_aggregators = 4; + static const unsigned n_out_features = 4; + static const unsigned distance_width = 12; + + // Internal data type definitions + typedef float input_transform_weights_t; + typedef float input_transform_biases_t; + typedef float output_transform_weights_t; + typedef float output_transform_biases_t; + typedef float aggregator_distance_weights_t; + typedef float aggregator_distance_biases_t; + + typedef float norm_t; + typedef float distance_t; + typedef float edge_weight_t; + typedef float edge_weight_aggr_t; + typedef float aggr_t; + typedef float output_t; + + /* static const input_transform_weights_t (&input_transform_weights)[n_out_features * n_aggregators * n_in_features]; */ + /* static const input_transform_biases_t (&input_transform_biases)[n_out_features * n_aggregators]; */ + /* static const aggregator_distance_weights_t (&aggregator_distance_weights)[n_aggregators * n_in_features]; */ + /* static const aggregator_distance_biases_t (&aggregator_distance_biases)[n_aggregators]; */ + /* static const output_transform_biases_t (&output_transform_biases)[n_out_features]; */ + + enum OutputCollapse { no_collapse, collapse_mean, collapse_max }; + + static const unsigned output_collapse = no_collapse; + + static const bool mean_by_nvert = false; + static const bool is_stack = false; + + // Optimization specs + static const unsigned reuse_factor = 64; + static const unsigned log2_reuse_factor = 6; +}; + +// vertices -> vertices +template +typename std::enable_if::type +garnet(data_T const data[CONFIG_T::n_vertices * CONFIG_T::n_in_features], nvtx_T const nvtx[1], + res_T res[CONFIG_T::n_vertices * CONFIG_T::n_out_features]) { + //#pragma HLS DATAFLOW + + garnet_utils::WeightsAndMeans arrays; + + garnet_utils::aggregate(data, nvtx[0], arrays); + + garnet_utils::distribute(nvtx[0], arrays, res); +} + +// vertices -> out features +template +typename std::enable_if::type +garnet(data_T const data[CONFIG_T::n_vertices * CONFIG_T::n_in_features], nvtx_T const nvtx[1], + res_T res[CONFIG_T::n_out_features]) { + //#pragma HLS DATAFLOW + + garnet_utils::Means arrays; + + garnet_utils::aggregate(data, nvtx[0], arrays); + + garnet_utils::OutputBiasNormalizer normalize_bias(nvtx[0]); + + garnet_utils::set_output(normalize_bias, arrays, res); +} + +// vertices -> vertices +template +typename std::enable_if::type +garnet_stack(data_T const data[CONFIG_T::n_vertices * CONFIG_T::n_in_features], nvtx_T const nvtx[1], + res_T res[CONFIG_T::n_vertices * CONFIG_T::n_out_features]) { + //#pragma HLS DATAFLOW + + typedef typename CONFIG_T::template sublayer_t<0> first_layer_t; + unsigned const ilast = CONFIG_T::n_sublayers - 1; + typedef typename CONFIG_T::template sublayer_t last_layer_t; + + garnet_utils::WeightsAndMeans arrays_first; + garnet_utils::Means arrays_last; + + garnet_utils::aggregate(data, nvtx[0], arrays_first); + + garnet_utils::sublayer(nvtx[0], arrays_first, + arrays_last); + + garnet_utils::distribute(nvtx[0], arrays_last, res); +} + +// vertices -> out features +template +typename std::enable_if::type +garnet_stack(data_T const data[CONFIG_T::n_vertices * CONFIG_T::n_in_features], nvtx_T const nvtx[1], + res_T res[CONFIG_T::n_out_features]) { + //#pragma HLS DATAFLOW + + typedef typename CONFIG_T::template sublayer_t<0> first_layer_t; + unsigned const ilast = CONFIG_T::n_sublayers - 1; + typedef typename CONFIG_T::template sublayer_t last_layer_t; + + garnet_utils::WeightsAndMeans arrays_first; + garnet_utils::Means arrays_last; + + garnet_utils::aggregate(data, nvtx[0], arrays_first); + + garnet_utils::sublayer(nvtx[0], arrays_first, + arrays_last); + + garnet_utils::OutputBiasNormalizer normalize_bias(nvtx[0]); + + garnet_utils::set_output(normalize_bias, arrays_last, res); +} + +/* Reference (dumb) implementation returning (Vertices, Features) */ +template +typename std::enable_if::type +garnet_ref(data_T const data[CONFIG_T::n_vertices * CONFIG_T::n_in_features], nvtx_T const nvtx[1], + res_T res[CONFIG_T::n_vertices * CONFIG_T::n_out_features]) { + typename CONFIG_T::edge_weight_t edge_weights[CONFIG_T::n_vertices * CONFIG_T::n_aggregators]; + typename CONFIG_T::aggr_t propagated_features[CONFIG_T::n_vertices * CONFIG_T::n_propagate]; + + for (unsigned iv = 0; iv < CONFIG_T::n_vertices; ++iv) { + if (iv == nvtx[0]) + break; + + for (unsigned ip = 0; ip < CONFIG_T::n_propagate; ++ip) { + unsigned const ivp = iv * CONFIG_T::n_propagate + ip; + + propagated_features[ivp] = CONFIG_T::input_transform_biases[ip]; + + for (unsigned ix = 0; ix < CONFIG_T::n_in_features; ++ix) { + unsigned const ivx = iv * CONFIG_T::n_in_features + ix; + unsigned const ipx = ip * CONFIG_T::n_in_features + ix; + + propagated_features[ivp] += data[ivx] * CONFIG_T::input_transform_weights[ipx]; + } + } + + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + unsigned const iva = iv * CONFIG_T::n_aggregators + ia; + + typename CONFIG_T::aggr_t distance = CONFIG_T::aggregator_distance_biases[ia]; + + for (unsigned ix = 0; ix < CONFIG_T::n_in_features; ++ix) { + unsigned const ivx = iv * CONFIG_T::n_in_features + ix; + unsigned const iax = ia * CONFIG_T::n_in_features + ix; + + distance += data[ivx] * CONFIG_T::aggregator_distance_weights[iax]; + } + + edge_weights[iva] = garnet_utils::compute_edge_weight(distance); + } + } + + typename CONFIG_T::aggr_t aggregated_features[CONFIG_T::n_aggregators * CONFIG_T::n_propagate]; + + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + for (unsigned ip = 0; ip < CONFIG_T::n_propagate; ++ip) { + unsigned const iap = ia * CONFIG_T::n_propagate + ip; + + aggregated_features[iap] = 0.; + + for (unsigned iv = 0; iv < CONFIG_T::n_vertices; ++iv) { + if (iv == nvtx[0]) + break; + + unsigned const iva = iv * CONFIG_T::n_aggregators + ia; + unsigned const ivp = iv * CONFIG_T::n_propagate + ip; + + aggregated_features[iap] += edge_weights[iva] * propagated_features[ivp]; + } + } + } + + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + for (unsigned ip = 0; ip < CONFIG_T::n_propagate; ++ip) { + unsigned const iap = ia * CONFIG_T::n_propagate + ip; + + if (CONFIG_T::mean_by_nvert) + aggregated_features[iap] /= nvtx[0]; + else { + // Not using right shift in case aggr_t is float or double + aggregated_features[iap] /= CONFIG_T::n_vertices; + } + } + } + + for (unsigned iv = 0; iv < CONFIG_T::n_vertices; ++iv) { + if (iv == nvtx[0]) + break; + + for (unsigned io = 0; io < CONFIG_T::n_out_features; ++io) { + unsigned const ivo = iv * CONFIG_T::n_out_features + io; + + typename CONFIG_T::aggr_t acc = CONFIG_T::output_transform_biases[io]; + + for (unsigned ia = 0; ia < CONFIG_T::n_aggregators; ++ia) { + unsigned const iva = iv * CONFIG_T::n_aggregators + ia; + unsigned const ioa = io * CONFIG_T::n_aggregators + ia; + + typename CONFIG_T::aggr_t aggr = 0.; + + for (unsigned ip = 0; ip < CONFIG_T::n_propagate; ++ip) { + unsigned const iap = ia * CONFIG_T::n_propagate + ip; + unsigned const ioap = ioa * CONFIG_T::n_propagate + ip; + + aggr += CONFIG_T::output_transform_weights[ioap] * aggregated_features[iap]; + } + + acc += edge_weights[iva] * aggr; + } + + res[ivo] = acc; + } + } +} + +/* Reference (dumb) implementation returning (Features) - output averaged over vertices already */ +template +typename std::enable_if::type +garnet_ref(data_T const data[CONFIG_T::n_vertices * CONFIG_T::n_in_features], nvtx_T const nvtx[1], + res_T res[CONFIG_T::n_out_features]) { + typename CONFIG_T::aggr_t vertex_res[CONFIG_T::n_vertices * CONFIG_T::n_out_features]; + + garnet_ref(data, nvtx, vertex_res); + + for (unsigned io = 0; io < CONFIG_T::n_out_features; ++io) { + typename CONFIG_T::aggr_t acc = 0.; + + for (unsigned iv = 0; iv < CONFIG_T::n_vertices; ++iv) { + if (iv == nvtx[0]) + break; + + unsigned const ivo = iv * CONFIG_T::n_out_features + io; + + acc += vertex_res[ivo]; + } + + if (CONFIG_T::mean_by_nvert) + acc /= nvtx[0]; + else { + // Not using right shift in case aggr_t is float or double + acc /= CONFIG_T::n_vertices; + } + + res[io] = acc; + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_helpers.h b/hls4ml/templates/catapult/nnet_utils/nnet_helpers.h new file mode 100644 index 0000000000..ed701e5c59 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_helpers.h @@ -0,0 +1,461 @@ + +#ifndef NNET_HELPERS_H +#define NNET_HELPERS_H + +#include "ac_channel.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern const char *get_weights_dir(); + +namespace nnet { + +#ifndef __SYNTHESIS__ + +#ifndef WEIGHTS_DIR +#define WEIGHTS_DIR get_weights_dir() +#endif + +template void load_weights_from_txt(T *w, const char *fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + + size_t i = 0; + while (std::getline(iss, token, ',')) { + // CATAPULT_PORT + // std::istringstream(token) >> w[i]; + double tmp; + std::istringstream(token) >> tmp; + w[i] = tmp; + i++; + } + + if (SIZE != i) { + std::cerr << "ERROR: Expected " << SIZE << " values"; + std::cerr << " but read only " << i << " values" << std::endl; + } + } +} + +template void load_compressed_weights_from_txt(T *w, const char *fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + std::string extra_chars = "} "; + + size_t i = 0; + while (std::getline(iss, token, '{')) { + if (token.length() == 0) { + continue; + } + for (char c : extra_chars) { + token.erase(std::remove(token.begin(), token.end(), c), token.end()); + } + if (token.back() == ',') { + token.erase(token.end() - 1); + } + + std::replace(token.begin(), token.end(), ',', ' '); + std::istringstream structss(token); + + if (!(structss >> w[i].row_index >> w[i].col_index >> w[i].weight)) { + std::cerr << "ERROR: Unable to parse file " << std::string(fname); + exit(1); + } + i++; + } + + if (SIZE != i) { + std::cerr << "ERROR: Expected " << SIZE << " values"; + std::cerr << " but read only " << i << " values" << std::endl; + } + } +} + +template void load_exponent_weights_from_txt(T *w, const char *fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + std::string extra_chars = "} "; + + size_t i = 0; + while (std::getline(iss, token, '{')) { + if (token.length() == 0) { + continue; + } + for (char c : extra_chars) { + token.erase(std::remove(token.begin(), token.end(), c), token.end()); + } + if (token.back() == ',') { + token.erase(token.end() - 1); + } + + std::replace(token.begin(), token.end(), ',', ' '); + std::istringstream structss(token); + + double sign; + double weight; + if (!(structss >> sign >> weight)) { + std::cerr << "ERROR: Unable to parse file " << std::string(fname); + exit(1); + } + w[i].sign = sign; + w[i].weight = weight; + i++; + } + + if (SIZE != i) { + std::cerr << "ERROR: Expected " << SIZE << " values"; + std::cerr << " but read only " << i << " values" << std::endl; + } + } +} + +template +void convert_single_data(ac_fixed &src, double &dst) { + dst = src.to_double(); +} +template +void convert_single_data(ac_fixed &src, float &dst) { + dst = src.to_double(); +} +template void convert_single_data(srcType &src, dstType &dst) { dst = dstType(src); } +template void convert_data(srcType *src, dstType *dst) { + for (size_t i = 0; i < SIZE; i++) { + convert_single_data(src[i], dst[i]); + } +} + +template void convert_data(srcType *src, ac_channel &dst) { + for (size_t i = 0; i < SIZE / dstType::size; i++) { + dstType ctype; + for (size_t j = 0; j < dstType::size; j++) { + ctype[j] = typename dstType::value_type(src[i * dstType::size + j]); + } + dst.write(ctype); + } +} + +template void convert_data(ac_channel &src, dstType *dst) { + for (size_t i = 0; i < SIZE / srcType::size; i++) { + srcType ctype = src.read(); + for (size_t j = 0; j < srcType::size; j++) { + dst[i * srcType::size + j] = dstType(ctype[j].to_double()); // this may only work for ac_fixed + } + } +} + +extern bool trace_enabled; +extern std::map *trace_outputs; +extern size_t trace_type_size; + +template void save_output_array(data_T *data, save_T *ptr, size_t layer_size) { + for (int i = 0; i < layer_size; i++) { + ptr[i] = static_cast(data[i].to_double()); + } +} + +template void save_output_array(ac_channel &data, save_T *ptr, size_t layer_size) { + for (size_t i = 0; i < layer_size / data_T::size; i++) { + data_T ctype = data.read(); + for (size_t j = 0; j < data_T::size; j++) { + ptr[i * data_T::size + j] = save_T(ctype[j]); + } + data.write(ctype); + } +} + +template void save_output_array(ac_channel &data, float *ptr, size_t layer_size) { + for (size_t i = 0; i < layer_size / data_T::size; i++) { + data_T ctype = data.read(); + for (size_t j = 0; j < data_T::size; j++) { + ptr[i * data_T::size + j] = ctype[j].to_double(); + } + data.write(ctype); + } +} + +template void save_output_array(ac_channel &data, double *ptr, size_t layer_size) { + for (size_t i = 0; i < layer_size / data_T::size; i++) { + data_T ctype = data.read(); + for (size_t j = 0; j < data_T::size; j++) { + ptr[i * data_T::size + j] = ctype[j].to_double(); + } + data.write(ctype); + } +} + +// We don't want to include save_T in this function because it will be inserted into myproject.cpp +// so a workaround with element size is used +template void save_layer_output(data_T *data, const char *layer_name, size_t layer_size) { + if (!trace_enabled) + return; + + if (trace_outputs) { + if (trace_outputs->count(layer_name) > 0) { + if (trace_type_size == 4) { + save_output_array(data, (float *)(*trace_outputs)[layer_name], layer_size); + } else if (trace_type_size == 8) { + save_output_array(data, (double *)(*trace_outputs)[layer_name], layer_size); + } else { + std::cout << "Unknown trace type!" << std::endl; + } + } else { + std::cout << "Layer name: " << layer_name << " not found in debug storage!" << std::endl; + } + } else { + std::ostringstream filename; + filename << "./tb_data/" << layer_name << "_output.log"; // TODO if run as a shared lib, path should be ../tb_data + std::fstream out; + out.open(filename.str(), std::ios::app); + assert(out.is_open()); + for (int i = 0; i < layer_size; i++) { + out << data[i] << " "; // We don't care about precision in text files + } + out << std::endl; + out.close(); + } +} + +template void save_layer_output(ac_channel &data, const char *layer_name, size_t layer_size) { + if (!trace_enabled) + return; + + if (trace_outputs) { + if (trace_outputs->count(layer_name) > 0) { + if (trace_type_size == 4) { + save_output_array(data, (float *)(*trace_outputs)[layer_name], layer_size); + } else if (trace_type_size == 8) { + save_output_array(data, (double *)(*trace_outputs)[layer_name], layer_size); + } else { + std::cout << "Unknown trace type!" << std::endl; + } + } else { + std::cout << "Layer name: " << layer_name << " not found in debug storage!" << std::endl; + } + } else { + std::ostringstream filename; + filename << "./tb_data/" << layer_name << "_output.log"; // TODO if run as a shared lib, path should be ../tb_data + std::fstream out; + out.open(filename.str(), std::ios::app); + assert(out.is_open()); + for (size_t i = 0; i < layer_size / data_T::size; i++) { + data_T ctype = data.read(); + for (size_t j = 0; j < data_T::size; j++) { + out << ctype[j].to_double(); + out << " "; // We don't care about precision in text files + } + data.write(ctype); + } + out << std::endl; + out.close(); + } +} + +#endif + +template void copy_data(std::vector src, dst_T dst[SIZE]) { + typename std::vector::const_iterator in_begin = src.cbegin() + OFFSET; + typename std::vector::const_iterator in_end = in_begin + SIZE; + std::copy(in_begin, in_end, dst); +} + +template +void copy_data(std::vector src, ac_channel &dst) { + typename std::vector::const_iterator in_begin = src.cbegin() + OFFSET; + typename std::vector::const_iterator in_end = in_begin + SIZE; + + size_t i_pack = 0; + dst_T dst_pack; + for (typename std::vector::const_iterator i = in_begin; i != in_end; ++i) { + dst_pack[i_pack++] = typename dst_T::value_type(*i); + if (i_pack == dst_T::size) { + i_pack = 0; + dst.write(dst_pack); + } + } +} + +template void copy_data_axi(std::vector src, dst_T dst[SIZE]) { + for (auto i = 0; i < SIZE; i++) + if (i == SIZE - 1) { + dst[i].data = src[i]; + dst[i].last = 1; + } else { + dst[i].data = src[i]; + dst[i].last = 0; + } +} + +template void print_result(res_T result[SIZE], std::ostream &out, bool keep = false) { + for (unsigned i = 0; i < SIZE; i++) { + out << result[i] << " "; + } + out << std::endl; +} + +template void print_result(ac_channel &result, std::ostream &out, bool keep = false) { + if (!keep) { + while (result.available(1)) { + res_T res_pack = result.read(); + for (unsigned int j = 0; j < res_T::size; j++) { + out << res_pack[j] << " "; + } + } + out << std::endl; + } else { + if (result.debug_size() >= SIZE / res_T::size) { + for (unsigned int i = 0; i < SIZE / res_T::size; i++) { + res_T res_pack = result[i]; // peek + for (unsigned int j = 0; j < res_T::size; j++) { + out << res_pack[j] << " "; + } + } + out << std::endl; + } + } +} + +template void fill_zero(data_T data[SIZE]) { std::fill_n(data, SIZE, 0.); } + +template void fill_zero(ac_channel &data) { + for (unsigned int i = 0; i < SIZE / data_T::size; i++) { + data_T data_pack; + for (unsigned int j = 0; j < data_T::size; j++) { + data_pack[j] = 0.; + } + data.write(data_pack); + } +} + +// Fix for CAT-36531 +template void fill_random(data_T data[SIZE]) { + // std::cout << "Fill_Random SIZE:"<< SIZE << std::endl; + data_T MAX_VALUE; + for (unsigned int i = 0; i < SIZE; i++) { + // Generate a random value (for example, between 0 and 1) + data_T random_value = (data_T)rand() / MAX_VALUE.template set_val(); + data[i] = random_value; + } +} + +template void fill_random(ac_channel &data) { + typedef typename data_T::value_type base_T; + base_T MAX_VALUE; + for (unsigned int i = 0; i < SIZE / data_T::size; i++) { + data_T data_pack; + for (unsigned int j = 0; j < data_T::size; j++) { + // Generate a random value (for example, between 0 and 1) + base_T random_value = (base_T)rand() / MAX_VALUE.template set_val(); + data_pack[j] = random_value; + } + data.write(data_pack); + } + // std::cout << "Fill_Random AC_CHANNEL" << std::endl; +} + +template int read_file_1D(const char *filename, dataType data[nrows]) { + FILE *fp; + fp = fopen(filename, "r"); + if (fp == 0) { + return -1; + } + // Read data from file + float newval; + for (int ii = 0; ii < nrows; ii++) { + if (fscanf(fp, "%f\n", &newval) != 0) { + data[ii] = newval; + } else { + return -2; + } + } + fclose(fp); + return 0; +} + +template +int read_file_2D(const char *filename, dataType data[nrows][ncols]) { + FILE *fp; + fp = fopen(filename, "r"); + if (fp == 0) { + return -1; + } + // Read data from file + float newval; + for (int ii = 0; ii < nrows; ii++) { + for (int jj = 0; jj < ncols; jj++) { + if (fscanf(fp, "%f\n", &newval) != 0) { + data[ii][jj] = newval; + } else { + return -2; + } + } + } + fclose(fp); + return 0; +} + +template void change_type(ac_channel &in, ac_channel &out) { + in_T datareg; + ac_channel input_trunc; + for (int ii = 0; ii < N_IN; ii++) { + out << (out_T)in.read(); + } +} + +template void hls_stream_debug(ac_channel &data, ac_channel &res) { + data_T datareg; + for (int ii = 0; ii < N_IN; ii++) { + datareg = data.read(); + std::cout << "[" << ii << "]: " << datareg << std::endl; + res << datareg; + } +} + +constexpr int ceillog2(int x) { return (x <= 2) ? 1 : 1 + ceillog2((x + 1) / 2); } + +constexpr int floorlog2(int x) { return (x < 2) ? 0 : 1 + floorlog2(x / 2); } + +constexpr int pow2(int x) { return x == 0 ? 1 : 2 * pow2(x - 1); } + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_image.h b/hls4ml/templates/catapult/nnet_utils/nnet_image.h new file mode 100755 index 0000000000..26947fae01 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_image.h @@ -0,0 +1,41 @@ +#ifndef NNET_IMAGE_H_ +#define NNET_IMAGE_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include + +namespace nnet { + +struct resize_config { + static const unsigned height = 10; + static const unsigned width = 10; + static const unsigned n_chan = 10; + static const unsigned new_height = 10; + static const unsigned new_width = 10; +}; + +template +void resize_nearest(data_T image[CONFIG_T::height * CONFIG_T::width * CONFIG_T::n_chan], + data_T resized[CONFIG_T::new_height * CONFIG_T::new_width * CONFIG_T::n_chan]) { + int y_ratio = (int)((CONFIG_T::height << 16) / CONFIG_T::new_height) + 1; + int x_ratio = (int)((CONFIG_T::width << 16) / CONFIG_T::new_width) + 1; + int x2, y2; + + //#pragma HLS PIPELINE + + for (int i = 0; i < CONFIG_T::new_height; i++) { + for (int j = 0; j < CONFIG_T::new_width; j++) { + x2 = ((j * x_ratio) >> 16); + y2 = ((i * y_ratio) >> 16); + for (int k = 0; k < CONFIG_T::n_chan; k++) { + resized[(i * CONFIG_T::new_width * CONFIG_T::n_chan) + j * CONFIG_T::n_chan + k] = + image[(y2 * CONFIG_T::width * CONFIG_T::n_chan) + x2 * CONFIG_T::n_chan + k]; + } + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_image_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_image_stream.h new file mode 100644 index 0000000000..1757f7bfb8 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_image_stream.h @@ -0,0 +1,66 @@ +#ifndef NNET_IMAGE_STREAM_H_ +#define NNET_IMAGE_STREAM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" + +namespace nnet { + +template void resize_nearest(ac_channel &image, ac_channel &resized) { + assert(CONFIG_T::new_height % CONFIG_T::height == 0); + assert(CONFIG_T::new_width % CONFIG_T::width == 0); + constexpr unsigned ratio_height = CONFIG_T::new_height / CONFIG_T::height; + constexpr unsigned ratio_width = CONFIG_T::new_width / CONFIG_T::width; + +ImageHeight: + for (unsigned h = 0; h < CONFIG_T::height; h++) { + //#pragma HLS PIPELINE + + data_T data_in_row[CONFIG_T::width]; + + ImageWidth: + for (unsigned i = 0; i < CONFIG_T::width; i++) { + //#pragma HLS UNROLL + + data_T in_data = image.read(); + + ImageChan: + for (unsigned j = 0; j < CONFIG_T::n_chan; j++) { + //#pragma HLS UNROLL + + data_in_row[i][j] = in_data[j]; + } + } + + ResizeHeight: + for (unsigned i = 0; i < ratio_height; i++) { + //#pragma HLS UNROLL + + ImageWidth2: + for (unsigned l = 0; l < CONFIG_T::width; l++) { + //#pragma HLS UNROLL + + ResizeWidth: + for (unsigned j = 0; j < ratio_width; j++) { + //#pragma HLS UNROLL + + data_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ResizeChan: + for (unsigned k = 0; k < CONFIG_T::n_chan; k++) { + //#pragma HLS UNROLL + + out_data[k] = data_in_row[l][k]; + } + + resized.write(out_data); + } + } + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_math.h b/hls4ml/templates/catapult/nnet_utils/nnet_math.h new file mode 100644 index 0000000000..c25f7187b6 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_math.h @@ -0,0 +1,178 @@ +#ifndef NNET_MATH_H_ +#define NNET_MATH_H_ + +#include "hls_math.h" + +namespace nnet { + +// This header defines the functions that return type different from the input +// For example, hls::sin(x) returns ac_fixed +// By ensuring we return the same type we can avoid casting issues in expressions + +template T sin(T x) { return (T)hls::sin(x); }; + +template T cos(T x) { return (T)hls::cos(x); }; + +template T asin(T x) { return (T)hls::asin(x); }; + +template T acos(T x) { return (T)hls::acos(x); }; + +template T atan(T x) { return (T)hls::atan(x); }; + +template T atan2(T x, T y) { return (T)hls::atan2(x, y); }; + +template void init_sincos_table(T table[1 << (W - I - 3)][2]) { + unsigned int NTE = 1 << (W - I - 3); // No of table entries + double step = M_PI / (4 * NTE); // Interval between angles + double y = 0; + // double scaled_angle = 0; + + for (unsigned int i = 0; i < NTE; i++) { + table[i][0] = std::cos(y); + table[i][1] = std::sin(y); + y += step; + // scaled_angle = y/(2*M_PI); + // printf("cos(%f) = %23.22f, sin(%f) = %23.22f index = %d, scaled angle = %13.12f \n", y, cos(y), y, sin(y), i, + // scaled_angle); + } +} + +template void sincos_lut(const T &input, T output[2]) { + + #pragma HLS INLINE + + // This implementation is based on ac_sincos_lut.h from AC math library + + static bool flag = true; + if (flag && T::width - T::iwidth > 12) { +#if !defined(__SYNTHESIS__) && defined(SINCOS_LUT_DEBUG) + std::cout << "FILE : " << __FILE__ << ", LINE : " << __LINE__ << std::endl; + std::cout << "Warning: The output of sincos_lut will not be accurate" << std::endl; +#endif + flag = false; + } + // Datatype for lookup table entries + typedef ac_fixed luttype; + // Datatype for posinput which is used to handle negative inputs + typedef ac_fixed posinputtype; + + typedef ac_int<9, false> lutindextype; // 9 bits required for indexing into 512 entry table + typedef ac_int<3, false> octanttype; // 3 bits required for octant value range of 0 thru 7 + T outputtemp[2]; + lutindextype luTdex = 0; + posinputtype posinput = input; + + // Initialize the lookup table +#ifdef __SYNTHESIS__ + bool initialized = false; + luttype sincos[512][2]; +#else + static bool initialized = false; + static luttype sincos[512][2]; +#endif + if (!initialized) { + init_sincos_table(sincos); + initialized = true; + } + + // Leaving this commented out makes the table to to BRAM + //#pragma HLS ARRAY_PARTITION variable=sincos complete dim=0 + + typedef ac_int lutindextype1; + // Extracting (MSB-3:LSB) bits of scaled input to determine the lookup table index + lutindextype1 luTdex1 = posinput.range(AP_MAX(T::width - T::iwidth - 3, 1), 0); // Extracting the lookup table index + + if (T::width - T::iwidth >= 4 && T::width - T::iwidth <= 12) { + luTdex(8, 12 - (T::width - T::iwidth)) = luTdex1; // stride + } + // Approximation for the scaled inputs whose number of bits are greater than 12 + else if (T::width - T::iwidth > 12) { + // Lookup table index for the scaled inputs whose number of bits are greater than 12 + luTdex = luTdex1 / (1 << (AP_MAX(T::width - T::iwidth - 12, 0))); + if ((luTdex1 % (1 << (AP_MAX(T::width - T::iwidth - 12, 0)))) > (1 << (AP_MAX(T::width - T::iwidth - 13, 0)))) { + luTdex = luTdex + 1; + } + typedef ac_fixed + datatype; + datatype x = (datatype)luTdex1; + x = x >> AP_MAX(T::width - T::iwidth - 12, 0); + if (x > 511.5) { + luTdex = 511; + } + if (luTdex1 <= 1 << (AP_MAX(T::width - T::iwidth - 13, 0)) && luTdex1 != 0) { + luTdex = 1; + } + } + + if (T::width - T::iwidth >= 3) { + // Getting the octant 0-7 by extracting the first 3 bits from MSB side of scaled input where + // octant 0 corresponds to [0-PI/4), + // octant 1 corresponds to [PI/4-2PI/4), + // octant 2 corresponds to [2PI/4-3PI/4) and so on + // octanttype octant = posinput.template slc<3>(T::width-T::iwidth-3); + octanttype octant = posinput(T::width - T::iwidth - 1, T::width - T::iwidth - 3); + luTdex = (octant[0] == 1) ? (lutindextype)(512 - luTdex) : (lutindextype)(luTdex); + // imaginary part is sine + outputtemp[1] = ((octant == 0) | (octant == 3)) ? (T)sincos[luTdex][1] + : ((octant == 2) | (octant == 1)) ? (T)sincos[luTdex][0] + : ((octant == 7) | (octant == 4)) ? (T)-sincos[luTdex][1] + : (T)-sincos[luTdex][0]; + // real part is cosine + outputtemp[0] = ((octant == 6) | (octant == 1)) ? (T)sincos[luTdex][1] + : ((octant == 3) | (octant == 4)) ? (T)-sincos[luTdex][0] + : ((octant == 2) | (octant == 5)) ? (T)-sincos[luTdex][1] + : (T)sincos[luTdex][0]; + // Below two are the cases when the output corresponds to + or - (0 or 1) for which there is no entry in the lookup + // table + output[1] = ((posinput == 0.125) | (posinput == 0.375)) ? T(0.7071067811865475244008) + : ((posinput == 0.625) | (posinput == 0.875)) ? T(-0.7071067811865475244008) + : outputtemp[1]; + output[0] = ((posinput == 0.125) | (posinput == 0.875)) ? T(0.7071067811865475244008) + : ((posinput == 0.375) | (posinput == 0.625)) ? T(-0.7071067811865475244008) + : outputtemp[0]; + } + + if (T::width - T::iwidth <= 2) { + output[1] = (posinput == 0) ? (T)0 + : (posinput == 0.25) ? (T)1 + : (posinput == 0.5) ? (T)0 + : (posinput == 0.75) ? (T)-1 + : outputtemp[1]; + output[0] = (posinput == 0) ? (T)1 + : (posinput == 0.25) ? (T)0 + : (posinput == 0.5) ? (T)-1 + : (posinput == 0.75) ? (T)0 + : outputtemp[0]; + } + +#if !defined(__SYNTHESIS__) && defined(SINCOS_LUT_DEBUG) + std::cout << "FILE : " << __FILE__ << ", LINE : " << __LINE__ << std::endl; + std::cout << "============AP_FIXED SINCOS======================" << std::endl; + std::cout << "positive input is = " << posinput << std::endl; + std::cout << "lut index is = " << luTdex << std::endl; + std::cout << "sin value is = " << output[1] << std::endl; + std::cout << "cos value is = " << output[0] << std::endl; + std::cout << "=================================================" << std::endl; +#endif +} + +template T sin_lut(const T input) { + #pragma HLS INLINE + T sincos_res[2]; + T scaled_input = input * ac_fixed<16, 0, false>(0.15915494309); // 1/(2*pi) + sincos_lut(scaled_input, sincos_res); + return sincos_res[1]; +} + +template T cos_lut(const T input) { + #pragma HLS INLINE + T sincos_res[2]; + T scaled_input = input * ac_fixed<16, 0, false>(0.15915494309); // 1/(2*pi) + sincos_lut(scaled_input, sincos_res); + return sincos_res[0]; +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_merge.h b/hls4ml/templates/catapult/nnet_utils/nnet_merge.h new file mode 100644 index 0000000000..00c2cf5e12 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_merge.h @@ -0,0 +1,232 @@ + +#ifndef NNET_MERGE_H_ +#define NNET_MERGE_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_mult.h" +#include + +namespace nnet { + +struct merge_config { + static const unsigned n_elem = 10; +}; + +struct dot_config { + static const unsigned n_in = 10; + static const unsigned n_out = 1; + static const unsigned reuse_factor = 1; + typedef float accum_t; + // Product function to use + template using product = nnet::product::mult; +}; + +struct concat_config { + static const unsigned n_elem1_0 = 10; + static const unsigned n_elem1_1 = 10; + static const unsigned n_elem1_2 = 10; + static const unsigned n_elem2_0 = 10; + static const unsigned n_elem2_1 = 10; + static const unsigned n_elem2_2 = 10; + + static const int axis = -1; +}; + +template +void add(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) { + for (int ii = 0; ii < CONFIG_T::n_elem; ii++) { + res[ii] = data1[ii] + data2[ii]; + } +} + +template +void subtract(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) { + for (int ii = 0; ii < CONFIG_T::n_elem; ii++) { + res[ii] = data1[ii] - data2[ii]; + } +} + +template +void multiply(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) { + for (int ii = 0; ii < CONFIG_T::n_elem; ii++) { + res[ii] = data1[ii] * data2[ii]; + } +} + +template +void average(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) { + for (int ii = 0; ii < CONFIG_T::n_elem; ii++) { + res[ii] = (data1[ii] + data2[ii]) / (res_T)2; + } +} + +template +void maximum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) { + for (int ii = 0; ii < CONFIG_T::n_elem; ii++) { + res[ii] = (data1[ii] > data2[ii]) ? data1[ii] : data2[ii]; + } +} + +template +void minimum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) { + for (int ii = 0; ii < CONFIG_T::n_elem; ii++) { + res[ii] = (data1[ii] < data2[ii]) ? data1[ii] : data2[ii]; + } +} + +template +void dot1d(input1_T data1[CONFIG_T::n_in], input2_T data2[CONFIG_T::n_in], res_T res[CONFIG_T::n_out]) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + + constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor); + CONFIG_T::template product::limit(multiplier_limit); + + typename CONFIG_T::accum_t mult[CONFIG_T::n_in]; + //#pragma HLS ARRAY_PARTITION variable=mult complete + typename CONFIG_T::accum_t acc = 0; + +Product: + for (int i_mult = 0; i_mult < CONFIG_T::n_in; i_mult++) { + // #pragma HLS UNROLL + mult[i_mult] = CONFIG_T::template product::product(data1[i_mult], data2[i_mult]); + } + +Accum: + for (int i_acc = 0; i_acc < CONFIG_T::n_in; i_acc++) { + // #pragma HLS UNROLL + acc += mult[i_acc]; + } + + res[0] = cast(acc); +} + +template +void concatenate1d(input1_T data1[CONFIG_T::n_elem1_0], input2_T data2[CONFIG_T::n_elem2_0], + res_T res[CONFIG_T::n_elem1_0 + CONFIG_T::n_elem2_0]) { + for (int ii = 0; ii < CONFIG_T::n_elem1_0; ii++) { + res[ii] = data1[ii]; + } + for (int ii = 0; ii < CONFIG_T::n_elem2_0; ii++) { + res[CONFIG_T::n_elem1_0 + ii] = data2[ii]; + } +} + +template +void concatenate2d_0(input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1], + input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1], + res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1]) { + for (int ii = 0; ii < CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1; ii++) { + res[ii] = data1[ii]; + } + for (int ii = 0; ii < CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1; ii++) { + res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + ii] = data2[ii]; + } +} + +template +void concatenate2d_1(input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1], + input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1], + res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1]) { + for (int ii = 0; ii < CONFIG_T::n_elem1_0; ii++) { + for (int jj = 0; jj < CONFIG_T::n_elem1_1; jj++) { + res[ii * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) + jj] = data1[ii * CONFIG_T::n_elem1_1 + jj]; + } + for (int jj = 0; jj < CONFIG_T::n_elem2_1; jj++) { + res[ii * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) + CONFIG_T::n_elem1_1 + jj] = + data2[ii * CONFIG_T::n_elem2_1 + jj]; + } + } +} + +template +void concatenate2d(input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1], + input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1], + res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1]) { + if (CONFIG_T::axis == 2 || CONFIG_T::axis == -1) { + concatenate2d_1(data1, data2, res); + } else { + concatenate2d_0(data1, data2, res); + } +} + +template +void concatenate3d_0(input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2], + input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2], + res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]) { + for (int ii = 0; ii < CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2; ii++) { + res[ii] = data1[ii]; + } + for (int ii = 0; ii < CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2; ii++) { + res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + ii] = data2[ii]; + } +} + +template +void concatenate3d_1(input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2], + input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2], + res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]) { + for (int ii = 0; ii < CONFIG_T::n_elem1_0; ii++) { + for (int jj = 0; jj < CONFIG_T::n_elem1_1; jj++) { + for (int kk = 0; kk < CONFIG_T::n_elem1_2; kk++) { + int res_idx = + ii * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) * CONFIG_T::n_elem1_2 + jj * CONFIG_T::n_elem1_2 + kk; + int data_idx = ii * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + jj * CONFIG_T::n_elem1_2 + kk; + res[res_idx] = data1[data_idx]; + } + } + for (int jj = 0; jj < CONFIG_T::n_elem2_1; jj++) { + for (int kk = 0; kk < CONFIG_T::n_elem2_2; kk++) { + int res_idx = ii * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) * CONFIG_T::n_elem1_2 + + (jj + CONFIG_T::n_elem1_1) * CONFIG_T::n_elem1_2 + kk; + int data_idx = ii * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2 + jj * CONFIG_T::n_elem2_2 + kk; + res[res_idx] = data2[data_idx]; + } + } + } +} + +template +void concatenate3d_2(input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2], + input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2], + res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]) { + for (int ii = 0; ii < CONFIG_T::n_elem1_0; ii++) { + for (int jj = 0; jj < CONFIG_T::n_elem1_1; jj++) { + for (int kk = 0; kk < CONFIG_T::n_elem1_2; kk++) { + int res_idx = ii * CONFIG_T::n_elem1_1 * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + + jj * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + kk; + int data_idx = ii * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + jj * CONFIG_T::n_elem1_2 + kk; + res[res_idx] = data1[data_idx]; + } + for (int kk = 0; kk < CONFIG_T::n_elem1_2; kk++) { + int res_idx = ii * CONFIG_T::n_elem1_1 * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + + jj * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + kk + CONFIG_T::n_elem1_2; + int data_idx = ii * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2 + jj * CONFIG_T::n_elem2_2 + kk; + res[res_idx] = data2[data_idx]; + } + } + } +} + +template +void concatenate3d(input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2], + input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2], + res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]) { + if (CONFIG_T::axis == 3 || CONFIG_T::axis == -1) { + concatenate3d_2(data1, data2, res); + } else if (CONFIG_T::axis == 2 || CONFIG_T::axis == -2) { + concatenate3d_1(data1, data2, res); + } else { + concatenate3d_0(data1, data2, res); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_merge_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_merge_stream.h new file mode 100644 index 0000000000..ef0d542fc0 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_merge_stream.h @@ -0,0 +1,380 @@ + +#ifndef NNET_MERGE_STREAM_H_ +#define NNET_MERGE_STREAM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include + +namespace nnet { + +template +void add(ac_channel &data1, ac_channel &data2, ac_channel &res) { + assert(input1_T::size == input2_T::size && input1_T::size == res_T::size); + +AddLoop: + for (int i = 0; i < CONFIG_T::n_elem / input1_T::size; i++) { + //#pragma HLS PIPELINE + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + AddPack: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = in_data1[j] + in_data2[j]; + } + + res.write(out_data); + } +} + +template +void subtract(ac_channel &data1, ac_channel &data2, ac_channel &res) { + assert(input1_T::size == input2_T::size && input1_T::size == res_T::size); + +SubtractLoop: + for (int i = 0; i < CONFIG_T::n_elem / input1_T::size; i++) { + //#pragma HLS PIPELINE + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + SubtractPack: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = in_data1[j] - in_data2[j]; + } + + res.write(out_data); + } +} + +template +void multiply(ac_channel &data1, ac_channel &data2, ac_channel &res) { + assert(input1_T::size == input2_T::size && input1_T::size == res_T::size); + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +MultiplyLoop: + for (int i = 0; i < CONFIG_T::n_elem / input1_T::size; i++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + MultiplyPack: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = in_data1[j] * in_data2[j]; + } + + res.write(out_data); + } +} + +template +void average(ac_channel &data1, ac_channel &data2, ac_channel &res) { + assert(input1_T::size == input2_T::size && input1_T::size == res_T::size); + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +AverageLoop: + for (int i = 0; i < CONFIG_T::n_elem / input1_T::size; i++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + AveragePack: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = (in_data1[j] + in_data2[j]) / (typename res_T::value_type)2; + } + + res.write(out_data); + } +} + +template +void maximum(ac_channel &data1, ac_channel &data2, ac_channel &res) { + assert(input1_T::size == input2_T::size && input1_T::size == res_T::size); + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +MaximumLoop: + for (int i = 0; i < CONFIG_T::n_elem / input1_T::size; i++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + MaximumPack: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = (in_data1[j] > in_data2[j]) ? in_data1[j] : in_data2[j]; + } + + res.write(out_data); + } +} + +template +void minimum(ac_channel &data1, ac_channel &data2, ac_channel &res) { + assert(input1_T::size == input2_T::size && input1_T::size == res_T::size); + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +MinimumLoop: + for (int i = 0; i < CONFIG_T::n_elem / input1_T::size; i++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + MinimumPack: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = (in_data1[j] < in_data2[j]) ? in_data1[j] : in_data2[j]; + } + + res.write(out_data); + } +} + +template +void concatenate3d_0(ac_channel &data1, ac_channel &data2, ac_channel &res) { +ConcatLoopHeight1: + for (int i = 0; i < CONFIG_T::n_elem1_0; i++) { + ConcatLoopWidth1: + for (int j = 0; j < CONFIG_T::n_elem1_1; j++) { + //#pragma HLS PIPELINE II=1 + + input1_T in_data1 = data1.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput1: + for (int k = 0; k < input1_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data1[k]; + } + + res.write(out_data); + } + } +ConcatLoopHeight2: + for (int i = 0; i < CONFIG_T::n_elem2_0; i++) { + ConcatLoopWidth2: + for (int j = 0; j < CONFIG_T::n_elem2_1; j++) { + //#pragma HLS PIPELINE II=1 + + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput2: + for (int k = 0; k < input2_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data2[k]; + } + + res.write(out_data); + } + } +} + +template +void concatenate3d_1(ac_channel &data1, ac_channel &data2, ac_channel &res) { +ConcatLoopHeight: + for (int i = 0; i < CONFIG_T::n_elem1_0; i++) { + ConcatLoopWidth1: + for (int j = 0; j < CONFIG_T::n_elem1_1; j++) { + //#pragma HLS PIPELINE II=1 + + input1_T in_data1 = data1.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput1: + for (int k = 0; k < input1_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data1[k]; + } + + res.write(out_data); + } + ConcatLoopWidth2: + for (int j = 0; j < CONFIG_T::n_elem2_1; j++) { + //#pragma HLS PIPELINE II=1 + + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput2: + for (int k = 0; k < input2_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data2[k]; + } + + res.write(out_data); + } + } +} + +template +void concatenate3d_2(ac_channel &data1, ac_channel &data2, ac_channel &res) { +ConcatLoopHeight: + for (int i = 0; i < CONFIG_T::n_elem1_0; i++) { + ConcatLoopWidth: + for (int j = 0; j < CONFIG_T::n_elem1_1; j++) { + //#pragma HLS PIPELINE II=1 + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput1: + for (int k = 0; k < input1_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data1[k]; + } + + ConcatPackInput2: + for (int k = 0; k < input2_T::size; k++) { + // #pragma HLS UNROLL + out_data[input1_T::size + k] = in_data2[k]; + } + + res.write(out_data); + } + } +} + +template +void concatenate3d(ac_channel &data1, ac_channel &data2, ac_channel &res) { + if (CONFIG_T::axis == 3 || CONFIG_T::axis == -1) { + concatenate3d_2(data1, data2, res); + } else if (CONFIG_T::axis == 2 || CONFIG_T::axis == -2) { + concatenate3d_1(data1, data2, res); + } else { + concatenate3d_0(data1, data2, res); + } +} + +template +void concatenate2d_0(ac_channel &data1, ac_channel &data2, ac_channel &res) { +ConcatLoopHeight1: + for (int i = 0; i < CONFIG_T::n_elem1_0; i++) { + // pragma HLS PIPELINE II=1 + + input1_T in_data1 = data1.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput1: + for (int k = 0; k < input1_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data1[k]; + } + + res.write(out_data); + } +ConcatLoopHeight2: + for (int i = 0; i < CONFIG_T::n_elem2_0; i++) { + //#pragma HLS PIPELINE II=1 + + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput2: + for (int k = 0; k < input2_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data2[k]; + } + + res.write(out_data); + } +} + +template +void concatenate2d_1(ac_channel &data1, ac_channel &data2, ac_channel &res) { +ConcatLoopHeight: + for (int i = 0; i < CONFIG_T::n_elem1_0; i++) { + //#pragma HLS PIPELINE II=1 + + input1_T in_data1 = data1.read(); + input2_T in_data2 = data2.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + ConcatPackInput1: + for (int k = 0; k < input1_T::size; k++) { + // #pragma HLS UNROLL + out_data[k] = in_data1[k]; + } + + ConcatPackInput2: + for (int k = 0; k < input2_T::size; k++) { + // #pragma HLS UNROLL + out_data[input1_T::size + k] = in_data2[k]; + } + + res.write(out_data); + } +} + +template +void concatenate2d(ac_channel &data1, ac_channel &data2, ac_channel &res) { + if (CONFIG_T::axis == 2 || CONFIG_T::axis == -1) { + concatenate2d_1(data1, data2, res); + } else { + concatenate2d_0(data1, data2, res); + } +} + +template +void concatenate1d(ac_channel &data1, ac_channel &data2, ac_channel &res) { + res_T out_data; +//#pragma HLS DATA_PACK variable=out_data +ConcatLoop1: + for (int i = 0; i < CONFIG_T::n_elem1_0 / input1_T::size; i++) { + //#pragma HLS PIPELINE + input1_T in_data1 = data1.read(); + ConcatPack1: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = in_data1[j]; + } + res.write(out_data); + } +ConcatLoop2: + for (int i = 0; i < CONFIG_T::n_elem2_0 / input2_T::size; i++) { + //#pragma HLS PIPELINE + input2_T in_data2 = data2.read(); + ConcatPack2: + for (int j = 0; j < res_T::size; j++) { + // #pragma HLS UNROLL + out_data[j] = in_data2[j]; + } + res.write(out_data); + } +} +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_mult.h b/hls4ml/templates/catapult/nnet_utils/nnet_mult.h new file mode 100755 index 0000000000..7379eec489 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_mult.h @@ -0,0 +1,127 @@ +#ifndef NNET_MULT_H_ +#define NNET_MULT_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_helpers.h" +#include +#include + +namespace nnet { + +namespace product { + +/* --- + * different methods to perform the product of input and weight, depending on the + * types of each. + * --- */ + +class Product { + public: + static void limit(unsigned multiplier_limit) {} // Nothing to do here +}; + +template class both_binary : public Product { + public: + static x_T product(x_T a, w_T w) { + // specialisation for 1-bit weights and incoming data + //#pragma HLS INLINE + return a == w; + } +}; + +template class weight_binary : public Product { + public: + static auto product(x_T a, w_T w) -> decltype(-a) { + // Specialisation for 1-bit weights, arbitrary data + //#pragma HLS INLINE + if (w == 0) + return -a; + else + return a; + } +}; + +template class data_binary : public Product { + public: + static auto product(x_T a, w_T w) -> decltype(-w) { + // Specialisation for 1-bit data, arbitrary weight + //#pragma HLS INLINE + if (a == 0) + return -w; + else + return w; + } +}; + +template class weight_ternary : public Product { + public: + static auto product(x_T a, w_T w) -> decltype(-a) { + // Specialisation for 2-bit weights, arbitrary data + //#pragma HLS INLINE + if (w == 0) + return 0; + else if (w == -1) + return -a; + else + return a; // if(w == 1) + } +}; + +template class mult : public Product { + public: + static auto product(x_T a, w_T w) -> decltype(a * w) { + // 'Normal' product + //#pragma HLS INLINE + return a * w; + } + static void limit(unsigned multiplier_limit) { + //#pragma HLS INLINE + //#pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation + } +}; + +template class weight_exponential : public Product { + public: + // Construct the return type from the multiplication equivalent to the largest shifts + // ap_int is the type if the multiplicand equivalent to the largest lshift << + // ap_fixed is the type of the multiplicand equivalent to the largest rshift >> + using r_T = decltype(x_T(0) * (ac_int(1) + + ac_fixed(1))); + static r_T product(x_T a, w_T w) { + // Shift product for exponential weights + //#pragma HLS INLINE + // shift by the exponent. Negative weights shift right + r_T y = static_cast(a) << w.weight; + // negate or not depending on weight sign + return w.sign == 1 ? y : static_cast(-y); + } +}; + +} // namespace product + +template +inline typename std::enable_if>::value && + std::is_same>::value, + ac_int>::type +cast(typename CONFIG_T::accum_t x) { + return (ac_int)(x - CONFIG_T::n_in / 2) * 2; +} + +template +inline typename std::enable_if>::value && + !std::is_same>::value, + res_T>::type +cast(typename CONFIG_T::accum_t x) { + return (res_T)x; +} + +template +inline typename std::enable_if<(!std::is_same>::value), res_T>::type +cast(typename CONFIG_T::accum_t x) { + return (res_T)x; +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_padding.h b/hls4ml/templates/catapult/nnet_utils/nnet_padding.h new file mode 100755 index 0000000000..47986523fb --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_padding.h @@ -0,0 +1,145 @@ +#ifndef NNET_PADDING_H_ +#define NNET_PADDING_H_ + +#include + +namespace nnet { + +struct padding1d_config { + static const unsigned n_chan = 10; + static const unsigned in_width = 10; + static const unsigned out_width = 10; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; +}; + +template +void zeropad1d_cf(data_T data[CONFIG_T::n_chan * CONFIG_T::in_width], data_T res[CONFIG_T::n_chan * CONFIG_T::out_width]) { + //#pragma HLS PIPELINE + + for (int j = 0; j < CONFIG_T::n_chan; j++) { + for (int i = 0; i < CONFIG_T::pad_left; i++) { + *(res++) = 0; + } + + for (int i = 0; i < CONFIG_T::in_width; i++) { + *(res++) = (res_T) * (data++); + } + + for (int i = 0; i < CONFIG_T::pad_right; i++) { + *(res++) = 0; + } + } +} + +template +void zeropad1d_cl(data_T data[CONFIG_T::n_chan * CONFIG_T::in_width], res_T res[CONFIG_T::n_chan * CONFIG_T::out_width]) { + //#pragma HLS PIPELINE + + for (int i = 0; i < CONFIG_T::pad_left; i++) { + for (int j = 0; j < CONFIG_T::n_chan; j++) { + *(res++) = 0; + } + } + + for (int i = 0; i < CONFIG_T::in_width; i++) { + for (int j = 0; j < CONFIG_T::n_chan; j++) { + *(res++) = (res_T) * (data++); + } + } + + for (int i = 0; i < CONFIG_T::pad_right; i++) { + for (int j = 0; j < CONFIG_T::n_chan; j++) { + *(res++) = 0; + } + } +} + +struct padding2d_config { + static const unsigned n_chan = 10; + static const unsigned in_height = 10; + static const unsigned in_width = 10; + static const unsigned out_height = 10; + static const unsigned out_width = 10; + static const unsigned pad_top = 0; + static const unsigned pad_bottom = 0; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; +}; + +template +void zeropad2d_cf(data_T data[CONFIG_T::n_chan * CONFIG_T::in_height * CONFIG_T::in_width], + data_T res[CONFIG_T::n_chan * CONFIG_T::out_height * CONFIG_T::out_width]) { + //#pragma HLS PIPELINE + + for (int k = 0; k < CONFIG_T::n_chan; k++) { + + for (int i = 0; i < CONFIG_T::pad_top; i++) { + for (int j = 0; j < CONFIG_T::out_width; j++) { + *(res++) = 0; + } + } + + for (int i = 0; i < CONFIG_T::in_height; i++) { + for (int j = 0; j < CONFIG_T::pad_left; j++) { + *(res++) = 0; + } + for (int j = 0; j < CONFIG_T::in_width; j++) { + *(res++) = (res_T) * (data++); + } + for (int j = 0; j < CONFIG_T::pad_right; j++) { + *(res++) = 0; + } + } + + for (int i = 0; i < CONFIG_T::pad_bottom; i++) { + for (int j = 0; j < CONFIG_T::out_width; j++) { + *(res++) = 0; + } + } + } +} + +template +void zeropad2d_cl(data_T data[CONFIG_T::n_chan * CONFIG_T::in_height * CONFIG_T::in_width], + res_T res[CONFIG_T::n_chan * CONFIG_T::out_height * CONFIG_T::out_width]) { + //#pragma HLS PIPELINE + + for (int i = 0; i < CONFIG_T::pad_top; i++) { + for (int j = 0; j < CONFIG_T::out_width; j++) { + for (int k = 0; k < CONFIG_T::n_chan; k++) { + *(res++) = 0; + } + } + } + + for (int i = 0; i < CONFIG_T::in_height; i++) { + for (int j = 0; j < CONFIG_T::pad_left; j++) { + for (int k = 0; k < CONFIG_T::n_chan; k++) { + *(res++) = 0; + } + } + for (int j = 0; j < CONFIG_T::in_width; j++) { + for (int k = 0; k < CONFIG_T::n_chan; k++) { + *(res++) = (res_T) * (data++); + } + } + for (int j = 0; j < CONFIG_T::pad_right; j++) { + for (int k = 0; k < CONFIG_T::n_chan; k++) { + *(res++) = 0; + } + } + } + + for (int i = 0; i < CONFIG_T::pad_bottom; i++) { + for (int j = 0; j < CONFIG_T::out_width; j++) { + for (int k = 0; k < CONFIG_T::n_chan; k++) { + *(res++) = 0; + } + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_padding_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_padding_stream.h new file mode 100644 index 0000000000..9c11683746 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_padding_stream.h @@ -0,0 +1,95 @@ +#ifndef NNET_PADDING_STREAM_H_ +#define NNET_PADDING_STREAM_H_ + +#include + +namespace nnet { + +template void fill_zero(ac_channel &res) { + //#pragma HLS INLINE + res_T res_part; + for (unsigned int c = 0; c < CONFIG_T::n_chan; c++) { + //#pragma HLS UNROLL + res_part[c] = 0; + } + res.write(res_part); +} + +template void fill_data(ac_channel &data, ac_channel &res) { + //#pragma HLS INLINE + data_T data_part = data.read(); + res_T res_part; + for (unsigned int c = 0; c < CONFIG_T::n_chan; c++) { + //#pragma HLS UNROLL + res_part[c] = data_part[c]; + } + res.write(res_part); +} + +template void zeropad1d_cl(ac_channel &data, ac_channel &res) { +PadLeft: + for (int i = 0; i < CONFIG_T::pad_left; i++) { + fill_zero(res); + } + +CopyMain: + for (int i = 0; i < CONFIG_T::in_width; i++) { + fill_data(data, res); + } + +PadRight: + for (int i = 0; i < CONFIG_T::pad_right; i++) { + fill_zero(res); + } +} + +// Description: +// apply zero padding to input feature data "data" based on +// padding parameters in CONFIG_T +// +// CONFIG_T::pad_top +// CONFIG_T::pad_left "data" CONFIG_T::pad_right +// CONFIG_T::pad_bottom +// +// Template Params: +// data_T - typically nnet::array< ac_fixed<>, 3*1> (see myproject.cpp -> firmware/defines.h) +// res_T - typically nnet::array< ac_fixed<>, 3*1> + +template void zeropad2d_cl(ac_channel &data, ac_channel &res) { + +PadTop: + for (unsigned i = 0; i < CONFIG_T::pad_top; i++) { + PadTopWidth: + for (unsigned j = 0; j < CONFIG_T::out_width; j++) { + fill_zero(res); + } + } + +PadMain: + for (unsigned i = 0; i < CONFIG_T::in_height; i++) { + PadLeft: + for (unsigned j = 0; j < CONFIG_T::pad_left; j++) { + fill_zero(res); + } + CopyMain: + for (unsigned j = 0; j < CONFIG_T::in_width; j++) { + fill_data(data, res); + } + PadRight: + for (unsigned j = 0; j < CONFIG_T::pad_right; j++) { + fill_zero(res); + } + } + +PadBottom: + for (unsigned i = 0; i < CONFIG_T::pad_bottom; i++) { + PadBottomWidth: + for (unsigned j = 0; j < CONFIG_T::out_width; j++) { + fill_zero(res); + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_pooling.h b/hls4ml/templates/catapult/nnet_utils/nnet_pooling.h new file mode 100644 index 0000000000..82e281023b --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_pooling.h @@ -0,0 +1,362 @@ +#ifndef NNET_POOLING_H_ +#define NNET_POOLING_H_ + +#include "nnet_helpers.h" +#include + +namespace nnet { + +// Return the maximum value from an array +template T max(T x[N]) { + T y = x[0]; + for (int i = 1; i < N; i++) { + y = x[i] > y ? x[i] : y; + } + return y; +} + +template ac_int avg(ac_int (&x)[N]) { + // Use a wider accumulator than the input to avoid overflow + ac_int tmp = 0; + for (int i = 0; i < N; i++) { + tmp += x[i]; + } + tmp /= N; + // Now cast back to original type + ac_int y = tmp; + return tmp; +} + +template ac_fixed avg(ac_fixed (&x)[N]) { + // Use a wider accumulator than the input to avoid overflow + ac_fixed tmp = 0; + for (int i = 0; i < N; i++) { + tmp += x[i]; + } + tmp /= N; + // Now cast back to original type + ac_fixed y = tmp; + return y; +} + +// Return the mean value of an array +template T avg(T (&x)[N]) { + T y = 0; + for (int i = 0; i < N; i++) { + y += x[i]; + } + y /= N; + return y; +} + +// Enumeration for pooling operation (max, avg, l2norm pooling) +enum Pool_Op { Max, Average }; // L2Norm }; +template T pool_op(T (&x)[N]) { + switch (op) { + case Max: + return max(x); + case Average: + return avg(x); + // case L2Norm: return l2norm(x); + } +} + +template T pad_val() { + /*--- + *- In Tensorflow, pooling ignores the value in the padded cells + *- For Avg pooling, return 0 (the divisior is modified to the + *- area overlapping the unpadded image. + *- For max pooling, return the most negative value for the type. + *- TODO this is not really generic, it assumes fixed point or integer T + ---*/ + switch (op) { + case Max: { + T x = 0; + x[x.width - 1] = 1; + return x; + break; + } + case Average: + return 0; + } +} + +struct pooling1d_config { + // IO size + static const unsigned n_in = 10; + static const unsigned pool_width = 2; + static const unsigned stride_width = 2; + static const unsigned n_out = (n_in - pool_width) / stride_width + 1; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const bool count_pad = false; + // Pooling function + static const Pool_Op pool_op = Max; +}; + +template constexpr int pool_op_limit_1d() { + return CONFIG_T::n_in * CONFIG_T::n_filt / CONFIG_T::reuse_factor; +} + +template +void pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONFIG_T::n_out * CONFIG_T::n_filt]) { + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + // TODO partition the arrays according to the reuse factor + const int limit = pool_op_limit_1d(); + #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit + // Add any necessary padding + unsigned padded_width = CONFIG_T::n_in + CONFIG_T::pad_left + CONFIG_T::pad_right; + if (CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { + padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); + } + + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Loop over input image x in steps of stride + for (int ii = 0; ii < padded_width; ii += CONFIG_T::stride_width) { + data_T pool[CONFIG_T::pool_width]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 + // Keep track of number of pixels in image vs padding region + unsigned img_overlap = 0; + // Loop over pool window x + for (int jj = 0; jj < CONFIG_T::stride_width; jj++) { + if (ii + jj < CONFIG_T::pad_left || ii + jj >= (padded_width - CONFIG_T::pad_right)) { + // Add padding + pool[jj] = pad_val(); + if (CONFIG_T::count_pad) { + img_overlap++; + } + } else { + pool[jj] = data[(ii + jj - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff]; + img_overlap++; + } + } + // do the pooling + // TODO in the case of average pooling, need to reduce width to area of pool window + // not overlapping padding region + res[(ii / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] = + pool_op(pool); + // If the pool op is Average, the zero-padding needs to be removed from the results + if (CONFIG_T::pool_op == Average) { + data_T rescale = static_cast(CONFIG_T::pool_width) / img_overlap; + res[(ii / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale; + } + } + } +} + +template +void global_pooling1d_cl(data_T data[CONFIG_T::n_in * CONFIG_T::n_filt], res_T res[CONFIG_T::n_filt]) { + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pool_width == CONFIG_T::stride_width); + + // TODO partition the arrays according to the reuse factor + const int limit = pool_op_limit_1d(); + #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit + + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + data_T pool[CONFIG_T::n_in]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 + for (int jj = 0; jj < CONFIG_T::n_in; jj++) { + pool[jj] = data[jj * CONFIG_T::n_filt + ff]; + } + // do the pooling + res[ff] = pool_op(pool); + } +} + +struct pooling2d_config { + // IO size + static const unsigned in_height = 10; + static const unsigned in_width = 10; + static const unsigned n_filt = 4; + static const unsigned stride_height = 2; + static const unsigned stride_width = 2; + static const unsigned pool_height = 2; + static const unsigned pool_width = 2; + static const unsigned out_height = (in_height - pool_height) / stride_height + 1; + static const unsigned out_width = (in_width - pool_width) / stride_width + 1; + // Padding + static const unsigned pad_top = 0; + static const unsigned pad_bottom = 0; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const bool count_pad = false; + // Pooling function + static const Pool_Op pool_op = Max; + // Reuse factor + static const unsigned reuse_factor = 1; + + // Internal data type definitions + typedef float accum_t; +}; + +template constexpr int pool_op_limit() { + return (CONFIG_T::out_height * CONFIG_T::out_width) * CONFIG_T::n_filt / CONFIG_T::reuse_factor; +} + +template +void pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_filt], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt]) { + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + // TODO partition the arrays according to the reuse factor + const int limit = pool_op_limit(); + #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit + // Add any necessary padding + unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { + padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); + padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); + } + + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Loop over input image y in steps of stride + for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + // Loop over input image x in steps of stride + for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 + // Keep track of number of pixels in image vs padding region + unsigned img_overlap = 0; + // Loop over pool window y + for (int kk = 0; kk < CONFIG_T::stride_height; kk++) { + // Loop over pool window x + for (int ll = 0; ll < CONFIG_T::stride_width; ll++) { + if (ii + kk < CONFIG_T::pad_top || ii + kk >= (padded_height - CONFIG_T::pad_bottom) || + jj + ll < CONFIG_T::pad_left || jj + ll >= (padded_width - CONFIG_T::pad_right)) { + // Add padding + pool[kk * CONFIG_T::stride_width + ll] = pad_val(); + if (CONFIG_T::count_pad) { + img_overlap++; + } + } else { + pool[kk * CONFIG_T::stride_width + ll] = + data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width * CONFIG_T::n_filt + + (jj + ll - CONFIG_T::pad_left) * CONFIG_T::n_filt + ff]; + img_overlap++; + } + } + } + // do the pooling + // TODO in the case of average pooling, need to reduce height * width to area of pool window + // not overlapping padding region + res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt + + (jj / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] = + pool_op(pool); + // If the pool op is Average, the zero-padding needs to be removed from the results + if (CONFIG_T::pool_op == Average) { + data_T rescale = + static_cast(CONFIG_T::pool_height) * static_cast(CONFIG_T::pool_width) / img_overlap; + res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width * CONFIG_T::n_filt + + (jj / CONFIG_T::stride_width) * CONFIG_T::n_filt + ff] *= rescale; + } + } + } + } +} + +template +void pooling2d_cf(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_filt], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt]) { + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + // TODO partition the arrays according to the reuse factor + const int limit = pool_op_limit(); + #pragma HLS ALLOCATION function instances=CONFIG_T::pool_op limit=limit + // Add any necessary padding + unsigned padded_height = CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom; + unsigned padded_width = CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right; + if (CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0) { + padded_height -= padded_height - (padded_height / CONFIG_T::stride_height * CONFIG_T::stride_height); + padded_width -= padded_width - (padded_width / CONFIG_T::stride_width * CONFIG_T::stride_width); + } + + for (int ff = 0; ff < CONFIG_T::n_filt; ff++) { + // Loop over input image y in steps of stride + for (int ii = 0; ii < padded_height; ii += CONFIG_T::stride_height) { + // Loop over input image x in steps of stride + for (int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width) { + data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 + // Keep track of number of pixels in image vs padding region + unsigned img_overlap = 0; + // Loop over pool window y + for (int kk = 0; kk < CONFIG_T::stride_height; kk++) { + // Loop over pool window x + for (int ll = 0; ll < CONFIG_T::stride_width; ll++) { + if (ii + kk < CONFIG_T::pad_top || ii + kk >= (padded_height - CONFIG_T::pad_bottom) || + jj + ll < CONFIG_T::pad_left || jj + ll >= (padded_width - CONFIG_T::pad_right)) { + // Add padding + pool[kk * CONFIG_T::stride_width + ll] = pad_val(); + if (CONFIG_T::count_pad) { + img_overlap++; + } + } else { + pool[kk * CONFIG_T::stride_width + ll] = + data[(ii + kk - CONFIG_T::pad_top) * CONFIG_T::in_width + + ff * CONFIG_T::in_width * CONFIG_T::in_height + ll + jj - CONFIG_T::pad_left]; + img_overlap++; + } + } + } + // do the pooling + // TODO in the case of average pooling, need to reduce height * width to area of pool window + // not overlapping padding region + res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width + (jj / CONFIG_T::stride_width) + + ff * CONFIG_T::out_height * CONFIG_T::out_width] = + pool_op(pool); + // If the pool op is Average, the zero-padding needs to be removed from the results + if (CONFIG_T::pool_op == Average) { + data_T rescale = + static_cast(CONFIG_T::pool_height) * static_cast(CONFIG_T::pool_width) / img_overlap; + res[(ii / CONFIG_T::stride_height) * CONFIG_T::out_width + (jj / CONFIG_T::stride_width) + + ff * CONFIG_T::out_height * CONFIG_T::out_width] *= rescale; + } + } + } + } +} + +template +void global_pooling2d_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_filt], + res_T res[CONFIG_T::n_filt]) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0); + assert(CONFIG_T::pool_width == CONFIG_T::stride_width); + assert(CONFIG_T::pool_height == CONFIG_T::stride_height); + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + const int limit = pool_op_limit(); + #pragma HLS ALLOCATION instances=pool_op limit=limit function + +FiltLoop: + for (int filt = 0; filt < CONFIG_T::n_filt; filt++) { + data_T pool[CONFIG_T::in_height * CONFIG_T::in_width]; + + InputLoop: + for (int i = 0; i < CONFIG_T::in_height * CONFIG_T::in_width; i++) { + pool[i] = data[i * CONFIG_T::n_filt + filt]; + } + + res[filt] = static_cast(pool_op(pool)); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_pooling_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_pooling_stream.h new file mode 100644 index 0000000000..051a27a54b --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_pooling_stream.h @@ -0,0 +1,601 @@ +#ifndef NNET_POOLING_STREAM_H_ +#define NNET_POOLING_STREAM_H_ + +// #include "utils/x_hls_utils.h" +#include "ac_channel.h" +#include "ap_shift_reg.h" +#include "nnet_common.h" +#include "nnet_conv_stream.h" +#include "nnet_pooling.h" + +namespace nnet { + +// ************************************************* +// Max/average pooling +// ************************************************* + +template T reduce_pool(T x[N]) { + //#pragma HLS INLINE + if (CONFIG_T::pool_op == Max) { + Op_max op_max; + return reduce>(x, op_max); + } else { + Op_add op_add; + T sum = reduce>(x, op_add); + return sum / N; + } +} + +template void init_pool_table(unsigned table[TABLE_SIZE]) { + for (unsigned ii = 0; ii < TABLE_SIZE; ii++) { + table[ii] = ii % POOL_SIZE; + } +} + +template +void compute_pool_encoded_2d( + const unsigned h_idx, const unsigned w_idx, const data_T &in_elem, + ac_channel data_window[CONFIG_T::pool_height * CONFIG_T::pool_width * CONFIG_T::n_filt], + ac_channel &res, res_T &res_pack, unsigned &outputs_ready) { + // Nearest H without unused pixels on the right + constexpr unsigned nH = + ((CONFIG_T::in_height - CONFIG_T::pool_height) / CONFIG_T::stride_height) * CONFIG_T::stride_height + + CONFIG_T::pool_height; + // Scaled H that behaves like original H + constexpr unsigned sH = + (DIV_ROUNDUP(CONFIG_T::pool_height, CONFIG_T::stride_height) - 1) * CONFIG_T::stride_height + CONFIG_T::pool_height; + // Nearest W without unused pixels on the right + constexpr unsigned nW = ((CONFIG_T::in_width - CONFIG_T::pool_width) / CONFIG_T::stride_width) * CONFIG_T::stride_width + + CONFIG_T::pool_width; + // Scaled W that behaves like original W + constexpr unsigned sW = + (DIV_ROUNDUP(CONFIG_T::pool_width, CONFIG_T::stride_width) - 1) * CONFIG_T::stride_width + CONFIG_T::pool_width; + +#ifdef __SYNTHESIS__ + bool initialized = false; + unsigned pool_table_height[CONFIG_T::in_height]; + unsigned pool_table_width[CONFIG_T::in_width]; +#else + static bool initialized = false; + static unsigned pool_table_height[CONFIG_T::in_height]; + static unsigned pool_table_width[CONFIG_T::in_width]; +#endif + if (!initialized) { + init_pool_table(pool_table_height); + init_pool_table(pool_table_width); + initialized = true; + } + + //#pragma HLS INLINE + + if (data_T::size / CONFIG_T::n_filt > 1) { + //#pragma HLS ARRAY_PARTITION variable=pool_table_height complete + //#pragma HLS ARRAY_PARTITION variable=pool_table_width complete + } + + typename CONFIG_T::accum_t pool_window[CONFIG_T::pool_height * CONFIG_T::pool_width]; + //#pragma HLS ARRAY_PARTITION variable=pool_window complete + + const unsigned sh_idx = pool_table_height[h_idx] * CONFIG_T::pool_width; + const unsigned wp_idx = w_idx * (data_T::size / CONFIG_T::n_filt); +PixelLoop: + for (unsigned p = 0; p < data_T::size / CONFIG_T::n_filt; p++) { + //#pragma HLS PIPELINE + + ac_int filt_mask = 0; + if ((h_idx < nH) && (wp_idx + p < nW)) { + filt_mask = sh_idx + pool_table_width[wp_idx + p] + 1; + } + CopyDataFilt: + for (unsigned c = 0; c < CONFIG_T::n_filt; c++) { + if (filt_mask > 0) + data_window[c * CONFIG_T::pool_height * CONFIG_T::pool_width + filt_mask.to_uint() - 1].write( + in_elem[p * CONFIG_T::n_filt + c]); + } + + if (filt_mask == CONFIG_T::pool_height * CONFIG_T::pool_width) { + FiltLoop: + for (unsigned c = 0; c < CONFIG_T::n_filt; c++) { + PoolLoop: + for (unsigned f = 0; f < CONFIG_T::pool_height * CONFIG_T::pool_width; f++) { + pool_window[f] = data_window[c * CONFIG_T::pool_height * CONFIG_T::pool_width + f].read(); + } + if (res_T::size / CONFIG_T::n_filt == + 1) { // Saves resources if we don't pack output, compiler will remove the else branch + res_pack[c] = + reduce_pool( + pool_window); + } else { + res_pack[outputs_ready * CONFIG_T::n_filt + c] = + reduce_pool( + pool_window); + } + } + if (res_T::size / CONFIG_T::n_filt == + 1) { // Saves resources if we don't pack output, compiler will remove the else branch + res.write(res_pack); + } else { + if (outputs_ready == (res_T::size / CONFIG_T::n_filt) - 1) { + res.write(res_pack); + outputs_ready = 0; + } else { + outputs_ready++; + } + } + } + } +} + +template +void pooling2d_encoded_cl(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pool_height == CONFIG_T::stride_height && CONFIG_T::pool_width == CONFIG_T::stride_width); + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + unsigned outputs_ready = 0; + + static ac_channel + data_window[CONFIG_T::pool_height * CONFIG_T::pool_width * CONFIG_T::n_filt]; + // constexpr int win_depth = CONFIG_T::pool_height * CONFIG_T::out_width; + // for (unsigned i_out = 0; i_out < CONFIG_T::pool_height * CONFIG_T::pool_width * CONFIG_T::n_filt; i_out++) { + // #pragma HLS STREAM variable=data_window[i_out] depth=win_depth + // } + + constexpr int pack_factor = (data_T::size / CONFIG_T::n_filt) * (res_T::size / CONFIG_T::n_filt == 1); + (void)pack_factor; +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (pack_factor); i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (res_T::size / CONFIG_T::n_filt == 1) { + //#pragma HLS PIPELINE II=pack_factor + } + compute_pool_encoded_2d(i_ih, i_iw, data.read(), data_window, res, res_pack, + outputs_ready); + } + } +} + +// ************************************************* +// Line Buffer Implementation (Phil's) +// ************************************************* +template +void compute_pool_buffer_2d(const data_T &in_elem, + ap_shift_reg + line_buffer[MAX(CONFIG_T::pool_height - 1, 1)][CONFIG_T::n_filt], + ac_channel &res) { + //#pragma HLS INLINE + const static int lShiftX = CONFIG_T::pool_width - 1; + const static int lShiftY = CONFIG_T::pool_height - 1; + static int pX = 0; // pixel X + static int pY = 0; // pixel Y + static int sX = 0; // stride X + static int sY = 0; // stride Y + + typename data_T::value_type pool_window[CONFIG_T::pool_height * CONFIG_T::pool_width]; + //#pragma HLS ARRAY_PARTITION variable=pool_window complete + + static typename data_T::value_type kernel_data[CONFIG_T::pool_height * CONFIG_T::pool_width * CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable = kernel_data complete dim = 0 + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + + // Add pixel into line buffer, return pooling kernels + nnet::shift_line_buffer(in_elem, line_buffer, kernel_data); + + // Can compute pooling output + if ((sX - lShiftX) == 0 && (sY - lShiftY) == 0 && pY > lShiftY - 1 && pX > lShiftX - 1) { + FiltLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + //#pragma HLS PIPELINE + + // Retrieve data for current channel + PoolLoop: + for (unsigned i_ihw = 0; i_ihw < CONFIG_T::pool_height * CONFIG_T::pool_width; i_ihw++) { + pool_window[i_ihw] = kernel_data[i_ihw * CONFIG_T::n_filt + i_ic]; + } + + // Compute Pooling + res_pack[i_ic] = + reduce_pool( + pool_window); + } + + // Write to output + res.write(res_pack); + } + + // Counter Housekeeping + if (pX + 1 == CONFIG_T::in_width) // Includes padding, end of line (padded) + { + pX = 0; + sX = 0; + if (pY + 1 == CONFIG_T::in_height) { // Reached bottom of image + pY = 0; + sY = 0; + } else { // Next line + pY = pY + 1; + // Update stride (threshold) ? subtract stride : increment stride + sY = ((sY - lShiftY) == 0) ? sY - CONFIG_T::stride_height + 1 : sY + 1; + } + } else { + pX = pX + 1; + // Update stride (threshold) ? subtract stride : increment stride + sX = ((sX - lShiftX) == 0) ? sX - CONFIG_T::stride_width + 1 : sX + 1; + } +} + +template +void pooling2d_buffer_cl(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pool_height == CONFIG_T::stride_height && CONFIG_T::pool_width == CONFIG_T::stride_width); + + static ap_shift_reg line_buffer[MAX(CONFIG_T::pool_height - 1, 1)] + [CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 + +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + //#pragma HLS LOOP_FLATTEN + //#pragma HLS PIPELINE + + compute_pool_buffer_2d(data.read(), line_buffer, res); + } + } +} + +template void pooling2d_cl(ac_channel &data, ac_channel &res) { + //#pragma HLS inline region + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + pooling2d_buffer_cl(data, res); + break; + case conv_implementation::encoded: + pooling2d_encoded_cl(data, res); + break; + } +} + +// ************************************************* +// Pooling 1D +// ************************************************* + +template +void compute_pool_encoded_1d(const unsigned w_idx, const data_T &in_elem, + ac_channel data_window[CONFIG_T::pool_width * CONFIG_T::n_filt], + ac_channel &res, res_T &res_pack, unsigned &outputs_ready) { + // Nearest W without unused pixels on the right + constexpr unsigned nW = + ((CONFIG_T::n_in - CONFIG_T::pool_width) / CONFIG_T::stride_width) * CONFIG_T::stride_width + CONFIG_T::pool_width; + // Scaled W that behaves like original W + constexpr unsigned sW = + (DIV_ROUNDUP(CONFIG_T::pool_width, CONFIG_T::stride_width) - 1) * CONFIG_T::stride_width + CONFIG_T::pool_width; + +#ifdef __SYNTHESIS__ + bool initialized = false; + unsigned pool_table_width[CONFIG_T::n_in]; +#else + static bool initialized = false; + static unsigned pool_table_width[CONFIG_T::n_in]; +#endif + if (!initialized) { + init_pool_table(pool_table_width); + initialized = true; + } + + //#pragma HLS INLINE + + if (data_T::size / CONFIG_T::n_filt > 1) { + //#pragma HLS ARRAY_PARTITION variable=pool_table_width complete + } + + typename CONFIG_T::accum_t pool_window[CONFIG_T::pool_width]; + //#pragma HLS ARRAY_PARTITION variable=pool_window complete + + const unsigned wp_idx = w_idx * (data_T::size / CONFIG_T::n_filt); + +PixelLoop: + for (unsigned p = 0; p < data_T::size / CONFIG_T::n_filt; p++) { + //#pragma HLS PIPELINE + + ac_int filt_mask = 0; + if (wp_idx + p < nW) { + filt_mask = pool_table_width[wp_idx + p] + 1; + } + + CopyDataFilt: + for (unsigned c = 0; c < CONFIG_T::n_filt; c++) { + if (filt_mask > 0) + data_window[c * CONFIG_T::pool_width + filt_mask.to_uint() - 1].write(in_elem[p * CONFIG_T::n_filt + c]); + } + + if (filt_mask == CONFIG_T::pool_width) { + FiltLoop: + for (unsigned c = 0; c < CONFIG_T::n_filt; c++) { + PoolLoop: + for (unsigned f = 0; f < CONFIG_T::pool_width; f++) { + pool_window[f] = data_window[c * CONFIG_T::pool_width + f].read(); + } + if (res_T::size / CONFIG_T::n_filt == + 1) { // Saves resources if we don't pack output, compiler will remove the else branch + res_pack[c] = reduce_pool(pool_window); + } else { + res_pack[outputs_ready * CONFIG_T::n_filt + c] = + reduce_pool(pool_window); + } + } + if (res_T::size / CONFIG_T::n_filt == + 1) { // Saves resources if we don't pack output, compiler will remove the else branch + res.write(res_pack); + } else { + if (outputs_ready == (res_T::size / CONFIG_T::n_filt) - 1) { + res.write(res_pack); + outputs_ready = 0; + } else { + outputs_ready++; + } + } + } + } +} + +template +void pooling1d_encoded_cl(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pool_width == CONFIG_T::stride_width); + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + unsigned outputs_ready = 0; + + ac_channel data_window[CONFIG_T::pool_width * CONFIG_T::n_filt]; + // constexpr int win_depth = CONFIG_T::n_out; + // for (unsigned i_out = 0; i_out < CONFIG_T::pool_width * CONFIG_T::n_filt; i_out++) { + // #pragma HLS STREAM variable=data_window[i_out] depth=win_depth + // } + + constexpr int pack_factor = data_T::size / CONFIG_T::n_filt; + +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::n_in / (pack_factor); i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (res_T::size / CONFIG_T::n_filt == 1) { + //#pragma HLS PIPELINE II=pack_factor + } + compute_pool_encoded_1d(i_iw, data.read(), data_window, res, res_pack, outputs_ready); + } +} + +// ************************************************* +// Line Buffer Implementation (Phil's) 1D +// ************************************************* +template +void compute_pool_buffer_1d(const data_T &in_elem, ac_channel &res) { + //#pragma HLS INLINE + const static int lShiftX = CONFIG_T::pool_width - 1; + // Counters + static int pX = 0; + static int sX = 0; + + typename data_T::value_type pool_window[CONFIG_T::pool_width]; + //#pragma HLS ARRAY_PARTITION variable=pool_window complete + + static typename data_T::value_type kernel_data[CONFIG_T::pool_width * CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable = kernel_data complete dim = 0 + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + + // Add pixel into line buffer, return pooling kernels + // 1D case line buffer not necessary. Put directly into the kernel_data buffer + nnet::kernel_shift_1d(in_elem, kernel_data); + + // Can compute pooling output + if ((sX - lShiftX) == 0 && pX > lShiftX - 1) { + FiltLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + //#pragma HLS PIPELINE + + // Retrieve data for current channel + PoolLoop: + for (unsigned i_iw = 0; i_iw < CONFIG_T::pool_width; i_iw++) { + pool_window[i_iw] = kernel_data[i_iw * CONFIG_T::n_filt + i_ic]; + } + + // Compute Pooling + res_pack[i_ic] = reduce_pool(pool_window); + } + + // Write to output + res.write(res_pack); + } + + // Counter Housekeeping + if (pX + 1 == CONFIG_T::n_in) // Includes padding, end of line (padded) + { + pX = 0; + sX = 0; + } else { + pX = pX + 1; + // Update stride (threshold) ? subtract stride : increment stride + sX = ((sX - lShiftX) == 0) ? sX - CONFIG_T::stride_width + 1 : sX + 1; + } +} + +template +void pooling1d_buffer_cl(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::n_in; i_iw++) { + //#pragma HLS LOOP_FLATTEN + //#pragma HLS PIPELINE + compute_pool_buffer_1d(data.read(), res); + } +} + +template void pooling1d_cl(ac_channel &data, ac_channel &res) { + //#pragma HLS inline region + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + pooling1d_buffer_cl(data, res); + break; + case conv_implementation::encoded: + pooling1d_encoded_cl(data, res); + break; + } +} + +// ************************************************* +// Global max/average pooling +// ************************************************* + +template T reduce_global_pool(T x, T y[N]) { + //#pragma HLS INLINE + if (CONFIG_T::pool_op == Max) { + Op_max op_max; + T y_max = reduce>(y, op_max); + return (x > y_max) ? x : y_max; + } else { + Op_add op_add; + T y_sum = reduce>(y, op_add); + return x + y_sum; + } +} + +template +void compute_global_pool(const data_T &in_elem, typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt]) { +PoolFilt: + for (unsigned c = 0; c < CONFIG_T::n_filt; c++) { + + typename CONFIG_T::accum_t data_pack[data_T::size / CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable=data_pack complete dim=0 + + PixelLoop: + for (unsigned p = 0; p < data_T::size / CONFIG_T::n_filt; p++) { + data_pack[p] = in_elem[p * CONFIG_T::n_filt + c]; + } + data_window[c] = reduce_global_pool( + data_window[c], data_pack); + } +} + +template +void global_pooling2d_cl(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pool_height == CONFIG_T::stride_height && CONFIG_T::pool_width == CONFIG_T::stride_width); + + typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable=data_window complete + + typename CONFIG_T::accum_t init = 0; + if (CONFIG_T::pool_op == Max) { + // init = hls::numeric_limits::min(); + init.template set_val(); + } + +PoolInitLoop: + for (unsigned i_init = 0; i_init < CONFIG_T::n_filt; i_init++) { + data_window[i_init] = init; + } + +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_filt); i_iw++) { + //#pragma HLS LOOP_FLATTEN + compute_global_pool(data.read(), data_window); + } + } + + if (CONFIG_T::pool_op == Max) { + MaxPoolRes: + for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) { + //#pragma HLS PIPELINE + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + MaxPoolPack: + for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) { + res_pack[i_pack] = data_window[i_pack]; + } + res.write(res_pack); + } + } else { + AvgPoolRes: + for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) { + //#pragma HLS PIPELINE + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + AvgPoolPack: + for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) { + res_pack[i_pack] = data_window[i_pack] / (CONFIG_T::in_height * CONFIG_T::in_width); + } + res.write(res_pack); + } + } +} + +template +void global_pooling1d_cl(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::pool_width == CONFIG_T::stride_width); + + typename CONFIG_T::accum_t data_window[CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable=data_window complete + + typename CONFIG_T::accum_t init = 0; + if (CONFIG_T::pool_op == Max) { + // init = hls::numeric_limits::min(); + init.template set_val(); + } + +PoolInitLoop: + for (unsigned i_init = 0; i_init < CONFIG_T::n_filt; i_init++) { + data_window[i_init] = init; + } + +ReadInput: + for (unsigned i_iw = 0; i_iw < CONFIG_T::n_in / (data_T::size / CONFIG_T::n_filt); i_iw++) { + //#pragma HLS LOOP_FLATTEN + compute_global_pool(data.read(), data_window); + } + + if (CONFIG_T::pool_op == Max) { + MaxPoolRes: + for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) { + //#pragma HLS PIPELINE + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + MaxPoolPack: + for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) { + res_pack[i_pack] = data_window[i_pack]; + } + res.write(res_pack); + } + } else { + AvgPoolRes: + for (unsigned i_res = 0; i_res < CONFIG_T::n_filt / res_T::size; i_res++) { + //#pragma HLS PIPELINE + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + AvgPoolPack: + for (unsigned i_pack = 0; i_pack < res_T::size; i_pack++) { + res_pack[i_pack] = data_window[i_pack] / CONFIG_T::n_in; + } + res.write(res_pack); + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_recr_activations.h b/hls4ml/templates/catapult/nnet_utils/nnet_recr_activations.h new file mode 100755 index 0000000000..fd2019f3d5 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_recr_activations.h @@ -0,0 +1,56 @@ +#ifndef NNET_RECR_ACTIVATION_H_ +#define NNET_RECR_ACTIVATION_H_ + +#include "ac_channel.h" +#include "nnet_activation.h" +#include "nnet_common.h" +#include "nnet_helpers.h" +#include + +namespace nnet { + +namespace activation { + +template class Activation { + public: + // ************************************************* + // Blank Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {} // Nothing to do here +}; + +template class relu : public Activation { + public: + // ************************************************* + // Relu Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + nnet::relu(data, res); + } +}; + +template class sigmoid : public Activation { + public: + // ************************************************* + // Sigmoid Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + nnet::sigmoid(data, res); + } +}; + +template class tanh : public Activation { + public: + // ************************************************* + // TanH Activation + // ************************************************* + static void activation(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { + nnet::tanh(data, res); + } +}; + +} // namespace activation + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_recurrent.h b/hls4ml/templates/catapult/nnet_utils/nnet_recurrent.h new file mode 100755 index 0000000000..f08d4d1050 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_recurrent.h @@ -0,0 +1,572 @@ + +#ifndef NNET_RECURSIVE_H_ +#define NNET_RECURSIVE_H_ + +#include "ac_channel.h" +#include "nnet_activation.h" +#include "nnet_common.h" +#include "nnet_dense.h" +#include "nnet_recr_activations.h" + +namespace nnet { + +struct lstm_config { + // Internal data type definitions + typedef float weight_t; + typedef float bias_t; + + // Layer Sizes + static const unsigned n_in = 2; + static const unsigned n_parts = 20; + static const unsigned n_out = 2; + static const unsigned n_state = 2; + static const unsigned n_4state = 8; + static const unsigned table_size = 1024; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const unsigned n_zeros = 0; + static const bool store_weights_in_bram = false; + static const bool use_static = true; + + template using activation_recr = nnet::activation::relu; + template using activation = nnet::activation::relu; +}; +// Long Short term Memory NN (LSTM) +// Resources: +// https://github.com/nicodjimenez/lstm/blob/master/lstm.py +// https://github.com/llSourcell/LSTM_Networks/blob/master/LSTM%20Demo.ipynb +// https://en.wikipedia.org/wiki/Long_short-term_memory +// Notes: +// - LSTM naming conventions adopted from the above links +// - s_newstate = activation(U*input + W*state) +// - h_output = activation(U*input + W*state)*activation(s_newstate) +// - If softmax is needed on output, perform *outside* this operations +// Originall had a version allows for the state in each layer to be saved, moved this to above (this requires are LARGE +// dense network at the end) +template +void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state], + res_T s_newstate[CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) { + // Initialize the state variable -- will maintain state between function calls + + typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 4]; + typename CONFIG_T::accum_t tmpres_state[CONFIG_T::n_state * 4]; + typename CONFIG_T::accum_t tmpres_ifo[CONFIG_T::n_state * 3]; // activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_c[CONFIG_T::n_state]; // activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_ifo[CONFIG_T::n_state * 3]; // i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_c[CONFIG_T::n_state]; // c-matrix (keras notation) + typename CONFIG_T::accum_t s_actstate[CONFIG_T::n_state]; + + //#pragma HLS ARRAY_PARTITION variable=h_newstate complete + //#pragma HLS ARRAY_PARTITION variable=s_newstate complete + //#pragma HLS ARRAY_PARTITION variable=tmpres complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_state complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_ifo complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_c complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_ifo complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_c complete + //#pragma HLS ARRAY_PARTITION variable=s_actstate complete + + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_newstate, tmpres_state, param_r, param_br); + + for (int iacc = 0; iacc < (3 * CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc; + if (iacc > 2 * CONFIG_T::n_state - 1) + index = iacc + CONFIG_T::n_state; + inputacc_ifo[iacc] = tmpres[index] + tmpres_state[index]; + } + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state * 2; + inputacc_c[iacc] = tmpres[index] + tmpres_state[index]; + } + + CONFIG_T::template activation_recr::activation( + inputacc_ifo, tmpres_ifo); + + // Now for the confusion matrix + CONFIG_T::template activation::activation( + inputacc_c, tmpres_c); + + // Operation: s=g*i+sold*f (update state with buffer to avoid timing issues) + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + s_newstate[iacc] = tmpres_c[iacc] * tmpres_ifo[iacc] + s_newstate[iacc] * tmpres_ifo[iacc + (CONFIG_T::n_state)]; + } + // Operation: h=act(s)*o + CONFIG_T::template activation::activation( + s_newstate, s_actstate); + + for (int iacc = 0; iacc < CONFIG_T::n_state; iacc++) { + //#pragma HLS UNROLL + h_newstate[iacc] = tmpres_ifo[iacc + 2 * (CONFIG_T::n_state)] * s_actstate[iacc]; + } +} + +template +void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state], + res_T s_newstate[CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) { + static res_T h_state[CONFIG_T::n_state]; + static res_T s_state[CONFIG_T::n_state]; + // Initialize the state variable -- will maintain state between function calls + typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 4]; + typename CONFIG_T::accum_t tmpres_state[CONFIG_T::n_state * 4]; + typename CONFIG_T::accum_t tmpres_ifo[CONFIG_T::n_state * 3]; // activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_c[CONFIG_T::n_state]; // activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_ifo[CONFIG_T::n_state * 3]; // i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_c[CONFIG_T::n_state]; // c-matrix (keras notation) + typename CONFIG_T::accum_t s_actstate[CONFIG_T::n_state]; + + //#pragma HLS ARRAY_PARTITION variable=h_newstate complete + //#pragma HLS ARRAY_PARTITION variable=s_newstate complete + //#pragma HLS ARRAY_PARTITION variable=h_state complete + //#pragma HLS ARRAY_PARTITION variable=s_state complete + //#pragma HLS ARRAY_PARTITION variable=tmpres complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_state complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_ifo complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_c complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_ifo complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_c complete + //#pragma HLS ARRAY_PARTITION variable=s_actstate complete + + if (reset_state) { + for (int i_state = 0; i_state < (CONFIG_T::n_state); i_state++) { + //#pragma HLS UNROLL + s_state[i_state] = 0; + h_state[i_state] = 0; + } + } + + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_state, tmpres_state, param_r, + param_br); + + for (int iacc = 0; iacc < (3 * CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc; + if (iacc > 2 * CONFIG_T::n_state - 1) + index = iacc + CONFIG_T::n_state; + inputacc_ifo[iacc] = tmpres[index] + tmpres_state[index]; + } + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state * 2; + inputacc_c[iacc] = tmpres[index] + tmpres_state[index]; + } + + CONFIG_T::template activation_recr::activation( + inputacc_ifo, tmpres_ifo); + + // Now for the confusion matrix + CONFIG_T::template activation::activation( + inputacc_c, tmpres_c); + + // Operation: s=g*i+sold*f (update state with buffer to avoid timing issues) + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + s_state[iacc] = tmpres_c[iacc] * tmpres_ifo[iacc] + s_state[iacc] * tmpres_ifo[iacc + (CONFIG_T::n_state)]; + s_newstate[iacc] = s_state[iacc]; + } + // Operation: h=act(s)*o + CONFIG_T::template activation::activation( + s_state, s_actstate); + + for (int iacc = 0; iacc < CONFIG_T::n_state; iacc++) { + //#pragma HLS UNROLL + h_state[iacc] = tmpres_ifo[iacc + 2 * (CONFIG_T::n_state)] * s_actstate[iacc]; + h_newstate[iacc] = h_state[iacc]; + } +} + +template +void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) { + + res_T h_newstate[CONFIG_T::n_state]; + res_T s_newstate[CONFIG_T::n_state]; + data_T data_in[CONFIG_T::n_in]; + bool reset_state = true; + + //#pragma HLS ARRAY_PARTITION variable=h_newstate complete + //#pragma HLS ARRAY_PARTITION variable=s_newstate complete + + for (int ii = 0; ii < CONFIG_T::n_state; ii++) { + //#pragma HLS UNROLL + h_newstate[ii] = 0; + s_newstate[ii] = 0; + } + for (int iloop = 0; iloop < CONFIG_T::n_sequence; iloop++) { + for (int j = 0; j < CONFIG_T::n_in; j++) { + //#pragma HLS UNROLL + data_in[j] = data[j + iloop * CONFIG_T::n_in]; + } + if (CONFIG_T::use_static) + nnet::lstm_static(reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, + param_br); + else + nnet::lstm(reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, + param_br); + if (CONFIG_T::n_sequence_out > 1) + for (int i = CONFIG_T::n_state * iloop, j = 0; i < (CONFIG_T::n_state * (iloop + 1)); i++, j++) { + //#pragma HLS UNROLL + res[i] = h_newstate[j]; + } + reset_state = false; + } + if (CONFIG_T::n_sequence_out == 1) + for (int i = 0; i < (CONFIG_T::n_state); i++) { + //#pragma HLS UNROLL + res[i] = h_newstate[i]; + } +} + +template +void lstm_stack(ac_channel &data_stream, ac_channel &res_stream, + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], + typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) { + + typename res_T::value_type h_newstate[CONFIG_T::n_state]; + typename res_T::value_type s_newstate[CONFIG_T::n_state]; + //#pragma HLS ARRAY_PARTITION variable=h_newstate complete + //#pragma HLS ARRAY_PARTITION variable=s_newstate complete + + for (int ii = 0; ii < CONFIG_T::n_state; ii++) { + //#pragma HLS UNROLL + h_newstate[ii] = 0; + s_newstate[ii] = 0; + } + + typename data_T::value_type data_in[CONFIG_T::n_in]; + bool reset_state = true; + +DataPropagation: + for (int i_in = 0; i_in < CONFIG_T::n_sequence * CONFIG_T::n_in / data_T::size; i_in++) { + if (CONFIG_T::n_sequence * CONFIG_T::n_in / data_T::size > 1) { + // //#pragma HLS PIPELINE + } + data_T data_pack = data_stream.read(); + DataPack: + for (int i_pack = 0; i_pack < data_T::size; i_pack++) { + //#pragma HLS UNROLL + data_in[i_pack] = data_pack[i_pack]; + } + if (CONFIG_T::use_static) + nnet::lstm_static( + reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, param_br); + else + nnet::lstm( + reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, param_br); + if (CONFIG_T::n_sequence_out > 1) { + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + ResPack_sequences: + for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + //#pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } + reset_state = false; + } + + if (CONFIG_T::n_sequence_out == 1) { + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + ResPack: + for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + //#pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } +} + +// Struct for the GRU template + +struct gru_config { + // Internal data type definitions + typedef float weight_t; + typedef float bias_t; + typedef float accum_t; + + // Layer Sizes + static const unsigned n_in = 2; + static const unsigned n_out = 2; + static const unsigned n_state = 2; + static const unsigned n_sequence = 2; + static const unsigned n_4state = 8; + static const unsigned table_size = 1024; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const bool use_static = true; + static const unsigned n_zeros = 0; + + template using activation_recr = nnet::activation::relu; + template using activation = nnet::activation::relu; +}; + +template +void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], // TODO - Check the layout of the param + // weights - refer page in copy!! + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) { + // Initialize the state variable -- will maintain state between function calls + typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 3]; + typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state * 3]; + typename CONFIG_T::accum_t tmpres_state_h[CONFIG_T::n_state]; + typename CONFIG_T::accum_t tmpres_zr[CONFIG_T::n_state * 2]; // activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_h[CONFIG_T::n_state]; // activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_zr[CONFIG_T::n_state * 2]; // i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_h[CONFIG_T::n_state]; // c-matrix (keras notation) + + //#pragma HLS ARRAY_PARTITION variable=h_newstate complete + //#pragma HLS ARRAY_PARTITION variable=tmpres complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_state_zr complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_state_h complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_zr complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_h complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_zr complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_h complete + + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_newstate, tmpres_state_zr, param_zr, + param_br); + + // Adding the individual vectors from the multiplication of tmpres = Wx*x(t); tmpres_state_zr = Wh*h(t-1); tmpres + // initialized with biases -- DONE + for (int iacc = 0; iacc < (2 * CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc; + inputacc_zr[iacc] = tmpres[index] + tmpres_state_zr[index]; + } + + // Activation function Sub layer -- START + CONFIG_T::template activation_recr::activation(inputacc_zr, tmpres_zr); + + // Activation function Sub layer -- END + + // Hadamrd product of r(t) = inputacc_zr[2*n_state:n_state] and h(t-1) = h_newstate + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + tmpres_state_h[iacc] = tmpres_zr[iacc + (CONFIG_T::n_state)] * tmpres_state_zr[iacc + (2 * CONFIG_T::n_state)]; + } + + // Assuming reset_after is false + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state * 2; + inputacc_h[iacc] = tmpres[index] + tmpres_state_h[iacc]; + } + + // Now run the activation on this guy + CONFIG_T::template activation::activation(inputacc_h, tmpres_h); + + // Mix the stat with the previous state + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + h_newstate[iacc] = (res_T)(tmpres_h[iacc] * (1 - tmpres_zr[iacc]) + h_newstate[iacc] * tmpres_zr[iacc]); + } +} + +template +void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) { + // Initialize the state variable -- will maintain state between function calls + + static res_T h_state[CONFIG_T::n_state]; + typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 3]; + typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state * 3]; + typename CONFIG_T::accum_t tmpres_state_h[CONFIG_T::n_state]; + typename CONFIG_T::accum_t tmpres_zr[CONFIG_T::n_state * 2]; // activated i,f,o matrices (keras notation) + typename CONFIG_T::accum_t tmpres_h[CONFIG_T::n_state]; // activated c-matrix (keras notation) + typename CONFIG_T::accum_t inputacc_zr[CONFIG_T::n_state * 2]; // i,f,o matrices (keras notation) + typename CONFIG_T::accum_t inputacc_h[CONFIG_T::n_state]; // c-matrix (keras notation) + + //#pragma HLS ARRAY_PARTITION variable=h_state complete + //#pragma HLS ARRAY_PARTITION variable=h_newstate complete + //#pragma HLS ARRAY_PARTITION variable=tmpres complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_state_zr complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_state_h complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_zr complete + //#pragma HLS ARRAY_PARTITION variable=tmpres_h complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_zr complete + //#pragma HLS ARRAY_PARTITION variable=inputacc_h complete + + if (reset_state) { + for (int i_h_state = 0; i_h_state < (CONFIG_T::n_state); i_h_state++) { + //#pragma HLS UNROLL + h_state[i_h_state] = 0; + } + } + + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_state, tmpres_state_zr, param_zr, + param_br); + + // Adding the individual vectors from the multiplication of tmpres = Wx*x(t); tmpres_state_zr = Wh*h(t-1); tmpres + // initialized with biases -- DONE + for (int iacc = 0; iacc < (2 * CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc; + inputacc_zr[iacc] = tmpres[index] + tmpres_state_zr[index]; + } + + // Activation function Sub layer -- START + CONFIG_T::template activation_recr::activation(inputacc_zr, tmpres_zr); + + // Activation function Sub layer -- END + + // Hadamrd product of r(t) = inputacc_zr[2*n_state:n_state] and h(t-1) = h_newstate + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + tmpres_state_h[iacc] = tmpres_zr[iacc + (CONFIG_T::n_state)] * tmpres_state_zr[iacc + (2 * CONFIG_T::n_state)]; + } + + // Assuming reset_after is false + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + int index = iacc + CONFIG_T::n_state * 2; + inputacc_h[iacc] = tmpres[index] + tmpres_state_h[iacc]; + } + + // Now run the activation on this guy + CONFIG_T::template activation::activation(inputacc_h, tmpres_h); + + // Mix the stat with the previous state + for (int iacc = 0; iacc < (CONFIG_T::n_state); iacc++) { + //#pragma HLS UNROLL + h_state[iacc] = (res_T)(tmpres_h[iacc] * (1 - tmpres_zr[iacc]) + h_state[iacc] * tmpres_zr[iacc]); + h_newstate[iacc] = h_state[iacc]; + } +} + +template +void gru_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) { + + res_T h_state[CONFIG_T::n_state]; + data_T data_in[CONFIG_T::n_in]; + bool reset_state = true; + + //#pragma HLS ARRAY_PARTITION variable=h_state complete + //#pragma HLS ARRAY_PARTITION variable=data_in complete + + for (int ii = 0; ii < CONFIG_T::n_state; ii++) { + //#pragma HLS UNROLL + h_state[ii] = 0; + } + for (int iloop = 0; iloop < CONFIG_T::n_sequence; iloop++) { + for (int j = 0; j < CONFIG_T::n_in; j++) { + //#pragma HLS UNROLL + data_in[j] = data[j + iloop * CONFIG_T::n_in]; + } + if (CONFIG_T::use_static) + nnet::gru_static(reset_state, data_in, h_state, param, param_zr, param_b, param_br); + else + nnet::gru(reset_state, data_in, h_state, param, param_zr, param_b, param_br); + if (CONFIG_T::n_sequence_out > 1) + for (int i = CONFIG_T::n_state * iloop, j = 0; i < (CONFIG_T::n_state * (iloop + 1)); i++, j++) { + //#pragma HLS UNROLL + res[i] = h_state[j]; + } + reset_state = false; + } + if (CONFIG_T::n_sequence_out == 1) + for (int i = 0; i < (CONFIG_T::n_state); i++) { + //#pragma HLS UNROLL + res[i] = h_state[i]; + } +} + +template +void gru_stack(ac_channel &data_stream, ac_channel &res_stream, + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], + typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3], + typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) { + + typename res_T::value_type h_newstate[CONFIG_T::n_state]; + //#pragma HLS ARRAY_PARTITION variable=h_newstate complete + for (int ii = 0; ii < CONFIG_T::n_state; ii++) { + //#pragma HLS UNROLL + h_newstate[ii] = 0; + } + + typename data_T::value_type data_in[CONFIG_T::n_in]; + bool reset_state = true; + +DataPropagation: + for (int i_in = 0; i_in < CONFIG_T::n_sequence * CONFIG_T::n_in / data_T::size; i_in++) { + if (CONFIG_T::n_sequence * CONFIG_T::n_in / data_T::size > 1) { + // //#pragma HLS PIPELINE + } + data_T data_pack = data_stream.read(); + DataPack: + for (int i_pack = 0; i_pack < data_T::size; i_pack++) { + //#pragma HLS UNROLL + data_in[i_pack] = data_pack[i_pack]; + } + if (CONFIG_T::use_static) + nnet::gru_static( + reset_state, data_in, h_newstate, param, param_zr, param_b, param_br); + else + nnet::gru(reset_state, data_in, h_newstate, + param, param_zr, param_b, param_br); + if (CONFIG_T::n_sequence_out > 1) { + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + ResPack_sequences: + for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + //#pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } + reset_state = false; + } + + if (CONFIG_T::n_sequence_out == 1) { + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + ResPack: + for (int i_pack = 0; i_pack < res_T::size; i_pack++) { + //#pragma HLS UNROLL + res_pack[i_pack] = h_newstate[i_pack]; + } + res_stream.write(res_pack); + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_sepconv1d_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv1d_stream.h new file mode 100644 index 0000000000..eb5ef9f7db --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv1d_stream.h @@ -0,0 +1,127 @@ +#ifndef NNET_SEPARABLE_CONV1D_STREAM_H_ +#define NNET_SEPARABLE_CONV1D_STREAM_H_ + +#include "ac_channel.h" +#include "nnet_common.h" +#include "nnet_conv1d_stream.h" +#include "nnet_sepconv_stream.h" + +namespace nnet { + +template +void depthwise_conv_1d_encoded_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + ac_channel data_window[CONFIG_T::filt_width * CONFIG_T::n_chan]; + // const int win_depth = CONFIG_T::out_width; + // for (unsigned i_out = 0; i_out < CONFIG_T::filt_width * CONFIG_T::n_chan; i_out++) { + // #pragma HLS STREAM variable=data_window[i_out] depth=win_depth + // } + + //#pragma HLS ARRAY_PARTITION variable=CONFIG_T::pixels complete + + res_T res_pack; + //#pragma HLS DATA_PACK variable=res_pack + unsigned outputs_ready = 0; + + ac_int pixel_idx[data_T::size / CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=pixel_idx complete + + constexpr int ce_reuse_factor = + CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1); + (void)ce_reuse_factor; +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_chan); i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_scaled_indices_1d(i_iw, pixel_idx); + compute_depthwise_output_encoded(data.read(), data_window, res, res_pack, outputs_ready, + weights, biases, pixel_idx); + } +} + +template +void depthwise_conv_1d_buffer_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency); + (void)ce_reuse_factor; +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + //#pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_depthwise_output_buffer_1d(data.read(), res, weights, biases); + } +} + +template +void depthwise_conv_1d_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + #pragma HLS inline recursive + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + depthwise_conv_1d_buffer_cl(data, res, weights, biases); + break; + case conv_implementation::encoded: + depthwise_conv_1d_encoded_cl(data, res, weights, biases); + break; + } +} + +template +void pointwise_conv_1d_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::filt_width == 1); + + //#pragma HLS ARRAY_PARTITION variable=weights complete + //#pragma HLS ARRAY_PARTITION variable=biases complete + + constexpr int ce_reuse_factor = + CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1); + (void)ce_reuse_factor; +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_chan); i_iw++) { + if (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + if (i_iw % CONFIG_T::stride_width == 0) { + pointwise_mult_buffer(data.read(), res, weights, biases); + } else { + data.read(); + } + } +} + +template +void separable_conv_1d_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::depthwise_config::weight_t + depthwise_weights[CONFIG_T::depthwise_config::filt_width * CONFIG_T::depthwise_config::n_chan], + typename CONFIG_T::pointwise_config::weight_t + pointwise_weights[CONFIG_T::pointwise_config::n_chan * CONFIG_T::pointwise_config::n_filt], + typename CONFIG_T::depthwise_config::bias_t depthwise_biases[CONFIG_T::depthwise_config::n_chan], + typename CONFIG_T::pointwise_config::bias_t pointwise_biases[CONFIG_T::pointwise_config::n_filt]) { + //#pragma HLS DATAFLOW + + ac_channel depthwise_res; + unsigned res_depth = CONFIG_T::depthwise_config::out_width; + //#pragma HLS STREAM variable=depthwise_res depth=res_depth + + depthwise_conv_1d_cl(data, depthwise_res, depthwise_weights, + depthwise_biases); + pointwise_conv_1d_cl(depthwise_res, res, pointwise_weights, + pointwise_biases); +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d.h b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d.h new file mode 100644 index 0000000000..d98dd8c315 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d.h @@ -0,0 +1,82 @@ +#ifndef NNET_SEPARABLE_CONV2D_H_ +#define NNET_SEPARABLE_CONV2D_H_ + +#include "nnet_common.h" +#include + +namespace nnet { + +template +void depthwise_conv_2d_cl( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_chan], + typename CONFIG_T::weight_t depthwise_weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t depthwise_biases[CONFIG_T::n_chan]) { + const int in_height = CONFIG_T::in_height; + const int in_width = CONFIG_T::in_width; + const int n_chan = CONFIG_T::n_chan; + const int filt_height = CONFIG_T::filt_height; + const int filt_width = CONFIG_T::filt_width; + const int out_height = CONFIG_T::out_height; + const int out_width = CONFIG_T::out_width; + + // constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; (void)ce_reuse_factor; + + // do { + + //#pragma HLS ARRAY_PARTITION variable=res complete dim=0 + //#pragma HLS ARRAY_PARTITION variable=depthwise_biases complete dim=0 + //#pragma HLS ARRAY_PARTITION variable=depthwise_weights complete dim=0 + for (int h = 0; h < in_height - filt_height + 1; h++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor rewind + for (int w = 0; w < in_width - filt_width + 1; w++) { + //#pragma HLS UNROLL + for (int c = 0; c < n_chan; c++) { + //#pragma HLS UNROLL + res_T sum = depthwise_biases[c]; + + // Apply the filter + for (int i = 0; i < filt_height; i++) { + //#pragma HLS UNROLL + for (int j = 0; j < filt_width; j++) { + //#pragma HLS UNROLL + int data_idx = (h + i) * in_width * n_chan + (w + j) * n_chan + c; + int weight_idx = i * filt_width * n_chan + j * n_chan + c; + sum += data[data_idx] * depthwise_weights[weight_idx]; + } + } + + int res_idx = (h * out_width * n_chan) + w * n_chan + c; + res[res_idx] = sum; + } + } + } + // } while (false); +} + +template +void separable_conv_2d_cl(data_T data[CONFIG_T::depthwise_config::in_height * CONFIG_T::depthwise_config::in_width * + CONFIG_T::depthwise_config::n_chan], + res_T res[CONFIG_T::pointwise_config::out_height * CONFIG_T::pointwise_config::out_width * + CONFIG_T::pointwise_config::n_filt], + typename CONFIG_T::depthwise_config::weight_t + depthwise_weights[CONFIG_T::depthwise_config::filt_height * + CONFIG_T::depthwise_config::filt_width * CONFIG_T::depthwise_config::n_chan], + typename CONFIG_T::pointwise_config::weight_t + pointwise_weights[CONFIG_T::pointwise_config::n_chan * CONFIG_T::pointwise_config::n_filt], + typename CONFIG_T::depthwise_config::bias_t depthwise_biases[CONFIG_T::depthwise_config::n_chan], + typename CONFIG_T::pointwise_config::bias_t pointwise_biases[CONFIG_T::pointwise_config::n_filt]) { + + //#pragma HLS INLINE region + + dw_res_T depthwise_results[CONFIG_T::depthwise_config::out_height * CONFIG_T::depthwise_config::out_width * + CONFIG_T::depthwise_config::n_chan]; + depthwise_conv_2d_cl(data, depthwise_results, depthwise_weights, + depthwise_biases); + pointwise_conv_2d_cl(depthwise_results, res, pointwise_weights, + pointwise_biases); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d_stream.h new file mode 100644 index 0000000000..a4f7d4faa9 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv2d_stream.h @@ -0,0 +1,152 @@ +#ifndef NNET_SEPARABLE_CONV2D_STREAM_H_ +#define NNET_SEPARABLE_CONV2D_STREAM_H_ + +#include "nnet_common.h" +#include "nnet_conv2d_stream.h" +#include "nnet_sepconv_stream.h" +#include "nnet_types.h" +#include + +namespace nnet { + +template +void depthwise_conv_2d_encoded_cl( + ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::filt_height == CONFIG_T::filt_width); + + static ac_channel + data_window[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + // const int win_depth = CONFIG_T::filt_height * CONFIG_T::out_width; + // for (unsigned i_out = 0; i_out < CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan; i_out++) { + // #pragma HLS STREAM variable=data_window[i_out] depth=win_depth + // } + + // #pragma HLS ARRAY_PARTITION variable=CONFIG_T::pixels complete + + res_T res_pack; + // PRAGMA_DATA_PACK(res_pack) + unsigned outputs_ready = 0; + + ac_int pixel_idx[data_T::size / CONFIG_T::n_chan]; + // #pragma HLS ARRAY_PARTITION variable=pixel_idx complete + + constexpr int ce_reuse_factor = + CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1); + (void)ce_reuse_factor; +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_chan); i_iw++) { + // #pragma HLS LOOP_FLATTEN + // if (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1) { + // #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + // } + compute_scaled_indices_2d(i_ih, i_iw, pixel_idx); + compute_depthwise_output_encoded(data.read(), data_window, res, res_pack, outputs_ready, + weights, biases, pixel_idx); + } + } +} + +// Line Buffer Implementation (Phil's) +template +void depthwise_conv_2d_buffer_cl( + ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + + static ap_shift_reg line_buffer[CONFIG_T::filt_height - 1] + [CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency); + (void)ce_reuse_factor; +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + //#pragma HLS LOOP_FLATTEN + // if (CONFIG_T::strategy == nnet::latency) { + // #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + // } + if (CONFIG_T::filt_height > 1) { + compute_depthwise_output_buffer_2d(data.read(), line_buffer, res, weights, biases); + } else { + compute_depthwise_output_buffer_1d(data.read(), res, weights, biases); + } + } + } +} + +template +void depthwise_conv_2d_cl( + ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + // #pragma HLS inline recursive + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + depthwise_conv_2d_buffer_cl(data, res, weights, biases); + break; + case conv_implementation::encoded: + depthwise_conv_2d_encoded_cl(data, res, weights, biases); + break; + } +} + +template +void pointwise_conv_2d_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + assert(CONFIG_T::pad_top == 0 && CONFIG_T::pad_bottom == 0 && CONFIG_T::pad_left == 0 && CONFIG_T::pad_right == 0); + assert(CONFIG_T::filt_height == 1 && CONFIG_T::filt_width == 1); + + // #pragma HLS ARRAY_PARTITION variable=weights complete + // #pragma HLS ARRAY_PARTITION variable=biases complete + + constexpr int ce_reuse_factor = + CONFIG_T::reuse_factor * (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1); + (void)ce_reuse_factor; +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width / (data_T::size / CONFIG_T::n_chan); i_iw++) { + if (CONFIG_T::strategy == nnet::latency && data_T::size / CONFIG_T::n_chan == 1) { + // #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + if (i_ih % CONFIG_T::stride_height == 0 && i_iw % CONFIG_T::stride_width == 0) { + pointwise_mult_buffer(data.read(), res, weights, biases); + } else { + data.read(); + } + } + } +} + +template +void separable_conv_2d_cl(ac_channel &data, ac_channel &res, + typename CONFIG_T::depthwise_config::weight_t + depthwise_weights[CONFIG_T::depthwise_config::filt_height * + CONFIG_T::depthwise_config::filt_width * CONFIG_T::depthwise_config::n_chan], + typename CONFIG_T::pointwise_config::weight_t + pointwise_weights[CONFIG_T::pointwise_config::n_chan * CONFIG_T::pointwise_config::n_filt], + typename CONFIG_T::depthwise_config::bias_t depthwise_biases[CONFIG_T::depthwise_config::n_chan], + typename CONFIG_T::pointwise_config::bias_t pointwise_biases[CONFIG_T::pointwise_config::n_filt]) { + // #pragma HLS DATAFLOW + + static ac_channel depthwise_res; + unsigned res_depth = CONFIG_T::depthwise_config::out_height * CONFIG_T::depthwise_config::out_width; + // #pragma HLS STREAM variable=depthwise_res depth=res_depth + + depthwise_conv_2d_cl(data, depthwise_res, depthwise_weights, + depthwise_biases); + pointwise_conv_2d_cl(depthwise_res, res, pointwise_weights, + pointwise_biases); +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_sepconv_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv_stream.h new file mode 100644 index 0000000000..753d260a77 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_sepconv_stream.h @@ -0,0 +1,315 @@ +#ifndef NNET_SEPARABLE_CONV_STREAM_H_ +#define NNET_SEPARABLE_CONV_STREAM_H_ + +#include "nnet_common.h" +#include "nnet_conv_stream.h" +#include +#include + +namespace nnet { + +template +void depthwise_product(data_T data[CONFIG_T::kernel_size * CONFIG_T::n_chan], res_T res[CONFIG_T::n_chan], + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + // #pragma HLS INLINE + + typename CONFIG_T::accum_t mult[CONFIG_T::kernel_size * CONFIG_T::n_chan]; + typename CONFIG_T::accum_t acc[CONFIG_T::n_chan]; + + // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases + // #pragma HLS function_instantiate variable=weights + + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; + + // Add dummy loop to which the pipeline pragma can be applied + do { + + //#pragma HLS ARRAY_PARTITION variable=mult complete + + //#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::multiplier_limit + + // Do the matrix-multiply + Product: + for (int ii = 0; ii < CONFIG_T::kernel_size * CONFIG_T::n_chan; ii++) { + // #pragma HLS UNROLL + mult[ii] = CONFIG_T::mult_config::template product::product( + data[ii], weights[ii]); + } + + // Initialize accumulator with input biases + ResetAccum: + for (int iacc = 0; iacc < CONFIG_T::n_chan; iacc++) { + //#pragma HLS UNROLL + acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc]; + } + + // Accumulate multiplication result + Accum1: + for (int ii = 0; ii < CONFIG_T::kernel_size; ii++) { + Accum2: + for (int jj = 0; jj < CONFIG_T::n_chan; jj++) { + int index = ii * CONFIG_T::n_chan + jj; + acc[jj] += mult[index]; + } + } + + // Cast to "res_t" type + Result: + for (int ires = 0; ires < CONFIG_T::n_chan; ires++) { + //#pragma HLS UNROLL + res[ires] = cast(acc[ires]); + } + } while (0); +} + +template +void depthwise_mult_buffer(ac_channel data_window[CONFIG_T::kernel_size * CONFIG_T::n_chan], + res_T &res_pack, ac_channel &res_stream, unsigned &outputs_ready, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + //#pragma HLS INLINE + + typename data_T::value_type data[CONFIG_T::kernel_size * CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=data complete + typename res_T::value_type res[CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=res complete + +InitData: + for (int id = 0; id < CONFIG_T::kernel_size * CONFIG_T::n_chan; id++) { + //#pragma HLS UNROLL + data[id] = data_window[id].read(); + } + + //#pragma HLS INLINE recursive + if (CONFIG_T::strategy == nnet::latency) { + depthwise_product(data, res, weights, biases); + } else { + assert("Resource strategy for DepthwiseConv2D is not supported." && false); + } + +CastLoop: + for (unsigned jj = 0; jj < CONFIG_T::n_chan; jj++) { + //#pragma HLS UNROLL + if (res_T::size / CONFIG_T::n_chan == 1) { + res_pack[jj] = res[jj]; + } else { + res_pack[outputs_ready * CONFIG_T::n_chan + jj] = res[jj]; + } + } + + if (res_T::size / CONFIG_T::n_chan == 1) { + res_stream.write(res_pack); + } else { + if (outputs_ready == (res_T::size / CONFIG_T::n_chan) - 1) { + res_stream.write(res_pack); + outputs_ready = 0; + } else { + outputs_ready++; + } + } +} + +template +void compute_depthwise_output_encoded( + const data_T &in_elem, ac_channel data_window[CONFIG_T::kernel_size * CONFIG_T::n_chan], + ac_channel &res, res_T &res_pack, unsigned &outputs_ready, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan], ac_int *pixel_idx) { + //#pragma HLS INLINE + + constexpr int ce_reuse_factor = CONFIG_T::reuse_factor; + (void)ce_reuse_factor; +MultLoop: + for (unsigned p = 0; p < data_T::size / CONFIG_T::n_chan; p++) { + //#pragma HLS PIPELINE II=CONFIG_T::reuse_factor + CopyDataFilt: + for (unsigned f = 0; f < CONFIG_T::kernel_size; f++) { + //#pragma HLS UNROLL + CopyDataChan: + for (unsigned c = 0; c < CONFIG_T::n_chan; c++) { + //#pragma HLS UNROLL + if (pixel_idx[p][f]) + data_window[f * CONFIG_T::n_chan + c].write(in_elem[p * CONFIG_T::n_chan + c]); + } + } + if (pixel_idx[p][CONFIG_T::kernel_size - 1]) { + depthwise_mult_buffer(data_window, res_pack, res, outputs_ready, weights, biases); + } + } +} + +template +void pointwise_mult_buffer(const data_T &data_pack, ac_channel &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + //#pragma HLS INLINE + + typename data_T::value_type data[CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=data complete + + typename res_T::value_type res[CONFIG_T::n_filt]; + //#pragma HLS ARRAY_PARTITION variable=res complete + + res_T res_pack; + // PRAGMA_DATA_PACK(res_pack) + +InitData: + for (int id = 0; id < CONFIG_T::n_chan; id++) { + //#pragma HLS UNROLL + data[id] = data_pack[id]; + } + + //#pragma HLS INLINE recursive + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + data, res, weights, biases); + } else { + dense_resource( + data, res, weights, biases); + } + +CastLoop: + for (unsigned jj = 0; jj < CONFIG_T::n_filt; jj++) { + //#pragma HLS UNROLL + res_pack[jj] = res[jj]; + } + + res_stream.write(res_pack); +} + +// Line Buffer Implementation (Phil's) +template +void compute_depthwise_output_buffer_1d(const data_T &in_elem, ac_channel &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + //#pragma HLS INLINE + + // Thresholds + const static int lShiftX = CONFIG_T::filt_width - 1; + + // Counters + static int pX = 0; + static int sX = 0; + + static typename data_T::value_type kernel_data[CONFIG_T::filt_width * CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + res_T res_pack; + // PRAGMA_DATA_PACK(res_pack) + + // Add pixel to buffer + nnet::kernel_shift_1d(in_elem, kernel_data); + + // Check to see if we have a full kernel + if ((sX - lShiftX) == 0 && pX > lShiftX - 1) { + // Dense multiply + //#pragma HLS INLINE recursive + if (CONFIG_T::strategy == nnet::latency) { + depthwise_product(kernel_data, res_out, + weights, biases); + } else { + assert("Resource strategy for DepthwiseConv1D is not supported." && false); + } + + // Pack output + CastLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + //#pragma HLS UNROLL + res_pack[i_ic] = res_out[i_ic]; + } + + // Write output to stream when output ready + res_stream.write(res_pack); + } + + // Pointer Housekeeping + if (pX + 1 == CONFIG_T::in_width) // Includes padding, end of line (padded) + { + pX = 0; + sX = 0; + } else { + pX = pX + 1; + sX = ((sX - lShiftX) == 0) ? sX - CONFIG_T::stride_width + 1 : sX + 1; + } +} + +template +void compute_depthwise_output_buffer_2d(const data_T &in_elem, + ap_shift_reg + line_buffer[MAX(CONFIG_T::filt_height - 1, 1)][CONFIG_T::n_chan], + ac_channel &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + //#pragma HLS INLINE + + // Thresholds + const static int lShiftX = CONFIG_T::filt_width - 1; + const static int lShiftY = CONFIG_T::filt_height - 1; + + // counters + static int pX = 0; // pixel X + static int pY = 0; // pixel Y + + static int sX = 0; // stride X + static int sY = 0; // stride Y + + static typename data_T::value_type kernel_data[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_chan]; + //#pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + res_T res_pack; + // PRAGMA_DATA_PACK(res_pack) + + // Add pixel to buffer + nnet::shift_line_buffer(in_elem, line_buffer, kernel_data); + + // Check to see if we have a full kernel + if ((sX - lShiftX) == 0 && (sY - lShiftY) == 0 && pY > lShiftY - 1 && pX > lShiftX - 1) { + // Dense multiply + //#pragma HLS INLINE recursive + if (CONFIG_T::strategy == nnet::latency) { + depthwise_product(kernel_data, res_out, + weights, biases); + } else { + assert("Resource strategy for DepthwiseConv2D is not supported." && false); + } + + // Pack output + CastLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + //#pragma HLS UNROLL + res_pack[i_ic] = res_out[i_ic]; + } + + // Write output to stream when output ready + res_stream.write(res_pack); + } + + // Pointer Housekeeping + if (pX + 1 == CONFIG_T::in_width) // Includes padding, end of line (padded) + { + pX = 0; + sX = 0; + if (pY + 1 == CONFIG_T::in_height) { // Reached bottom of image + pY = 0; + sY = 0; + } else { + pY = pY + 1; + sY = ((sY - lShiftY) == 0) ? sY - CONFIG_T::stride_height + 1 : sY + 1; + } + } else { + pX = pX + 1; + sX = ((sX - lShiftX) == 0) ? sX - CONFIG_T::stride_width + 1 : sX + 1; + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_stream.h new file mode 100644 index 0000000000..c76bfba5a6 --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_stream.h @@ -0,0 +1,156 @@ + +#ifndef NNET_STREAM_H +#define NNET_STREAM_H + +#include "ac_channel.h" + +namespace nnet { + +struct broadcast_config { + static const unsigned in_height = 1; + static const unsigned in_width = 1; + static const unsigned in_chan = 3; + static const unsigned out_height = 2; + static const unsigned out_width = 2; + static const unsigned out_chan = 3; +}; + +template +void clone_stream(ac_channel &data, ac_channel &res1, ac_channel &res2) { +// CloneLoop: for (int i = 0; i < N / data_T::size; i++) { +//#pragma HLS PIPELINE +#ifndef __SYNTHESIS__ + while (data.available(1)) +#endif + { + data_T in_data = data.read(); + res_T out_data; + // res_T out_data2; + //#pragma HLS DATA_PACK variable=out_data1 + //#pragma HLS DATA_PACK variable=out_data2 + + ClonePack: + for (int j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + out_data[j] = in_data[j]; + // out_data2[j] = in_data[j]; + } + + res1.write(out_data); + res2.write(out_data); + } +} + +template void repack_stream(ac_channel &data, ac_channel &res) { + if (data_T::size == res_T::size) { + for (int i = 0; i < N / data_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + for (int j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + out_data[j] = in_data[j]; + } + + res.write(out_data); + } + } else if (data_T::size > res_T::size) { + constexpr unsigned pack_diff = data_T::size / res_T::size; + for (int i = 0; i < N / data_T::size; i++) { + if (N / data_T::size > 1) { + //#pragma HLS PIPELINE + } + + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + + for (int j = 0; j < pack_diff; j++) { + //#pragma HLS PIPELINE + + res_T out_data; + for (int k = 0; k < res_T::size; k++) { + //#pragma HLS UNROLL + out_data[k] = in_data[j * res_T::size + k]; + } + res.write(out_data); + } + } + } else { // data_T::size < res_T::size + res_T out_data; + constexpr unsigned pack_diff = res_T::size / data_T::size; + unsigned pack_cnt = 0; + for (int i = 0; i < N / data_T::size; i++) { + //#pragma HLS PIPELINE + + data_T in_data = data.read(); + for (int j = 0; j < data_T::size; j++) { + //#pragma HLS UNROLL + out_data[pack_cnt * data_T::size + j] = in_data[j]; + } + + if (pack_cnt == pack_diff - 1) { + res.write(out_data); + pack_cnt = 0; + } else { + pack_cnt++; + } + } + } +} + +template +void broadcast_stream_1x1xC(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::in_height == 1 && CONFIG_T::in_width == 1 && CONFIG_T::in_chan == CONFIG_T::out_chan); + int n_dupl = (CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::out_chan) / + (CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::in_chan); +BroadcastLoop: + for (int i = 0; i < CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::in_chan / data_T::size; i++) { + //#pragma HLS PIPELINE + data_T in_data = data.read(); + for (int j = 0; j < n_dupl; j++) { + //#pragma HLS PIPELINE + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + for (int k = 0; k < res_T::size; k++) { + //#pragma HLS UNROLL + out_data[k] = in_data[k]; + } + res.write(out_data); + } + } +} + +template +void broadcast_stream_HxWx1(ac_channel &data, ac_channel &res) { + assert(CONFIG_T::in_chan == 1 && CONFIG_T::in_height == CONFIG_T::out_height && + CONFIG_T::in_width == CONFIG_T::out_width); +BroadcastLoop: + for (int i = 0; i < CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::in_chan / data_T::size; i++) { + //#pragma HLS PIPELINE + data_T in_data = data.read(); + res_T out_data; + //#pragma HLS DATA_PACK variable=out_data + for (int k = 0; k < res_T::size; k++) { + //#pragma HLS UNROLL + out_data[k] = in_data[0]; + } + res.write(out_data); + } +} + +template +void broadcast_stream(ac_channel &data, ac_channel &res) { + if (CONFIG_T::in_height == 1 && CONFIG_T::in_width == 1 && CONFIG_T::in_chan == CONFIG_T::out_chan) { + broadcast_stream_1x1xC(data, res); + } else if (CONFIG_T::in_chan == 1 && CONFIG_T::in_height == CONFIG_T::out_height && + CONFIG_T::in_width == CONFIG_T::out_width) { + broadcast_stream_HxWx1(data, res); + } +} +} // namespace nnet + +#endif diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_types.h b/hls4ml/templates/catapult/nnet_utils/nnet_types.h new file mode 100644 index 0000000000..d761891fdc --- /dev/null +++ b/hls4ml/templates/catapult/nnet_utils/nnet_types.h @@ -0,0 +1,64 @@ +#ifndef NNET_TYPES_H_ +#define NNET_TYPES_H_ + +#include +#include +#include + +namespace nnet { + +// Fixed-size array +template struct array { + typedef T value_type; + static const unsigned size = N; + + T data[N]; + + T &operator[](size_t pos) { return data[pos]; } + + const T &operator[](size_t pos) const { return data[pos]; } + + array &operator=(const array &other) { + if (&other == this) + return *this; + + assert(N == other.size && "Array sizes must match."); + + for (unsigned i = 0; i < N; i++) { + //#pragma HLS UNROLL + data[i] = other[i]; + } + return *this; + } +}; + +// Generic lookup-table implementation, for use in approximations of math functions +template class lookup_table { + public: + lookup_table(T from, T to) : range_start(from), range_end(to), base_div(ac_int<16, false>(N) / T(to - from)) { + T step = (range_end - range_start) / ac_int<16, false>(N); + for (size_t i = 0; i < N; i++) { + T num = range_start + ac_int<16, false>(i) * step; + T sample = func(num); + samples[i] = sample; + } + } + + T operator()(T n) const { + int index = (n - range_start) * base_div; + if (index < 0) + index = 0; + else if (index > N - 1) + index = N - 1; + return samples[index]; + } + + private: + T samples[N]; + const T range_start, range_end; + ac_fixed<20, 16, true> base_div; +}; + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado_accelerator/build_lib.sh b/hls4ml/templates/vivado_accelerator/build_lib.sh old mode 100644 new mode 100755 diff --git a/hls4ml/writer/__init__.py b/hls4ml/writer/__init__.py index f4eed945a1..f16cccc9fa 100644 --- a/hls4ml/writer/__init__.py +++ b/hls4ml/writer/__init__.py @@ -1,3 +1,4 @@ +from hls4ml.writer.catapult_writer import CatapultWriter from hls4ml.writer.quartus_writer import QuartusWriter from hls4ml.writer.symbolic_writer import SymbolicExpressionWriter from hls4ml.writer.vitis_writer import VitisWriter @@ -9,4 +10,5 @@ register_writer('VivadoAccelerator', VivadoAcceleratorWriter) register_writer('Vitis', VitisWriter) register_writer('Quartus', QuartusWriter) +register_writer('Catapult', CatapultWriter) register_writer('SymbolicExpression', SymbolicExpressionWriter) diff --git a/hls4ml/writer/catapult_writer.py b/hls4ml/writer/catapult_writer.py new file mode 100755 index 0000000000..48d44e4a59 --- /dev/null +++ b/hls4ml/writer/catapult_writer.py @@ -0,0 +1,929 @@ +import glob +import os +import tarfile +from collections import OrderedDict +from shutil import copyfile, copytree, rmtree + +import numpy as np +import yaml + +from hls4ml.backends import get_backend +from hls4ml.writer.writers import Writer + +config_filename = 'hls4ml_config.yml' + + +class CatapultWriter(Writer): + def print_array_to_cpp(self, var, odir, write_txt_file=True): + """Write a weights array to C++ header files. + + Args: + var (WeightVariable): Weight to write + odir (str): Output directory + write_txt_file (bool, optional): Write txt files in addition to .h files. Defaults to True. + """ + + h_file = open(f"{odir}/firmware/weights/{var.name}.h", "w") + if write_txt_file: + txt_file = open(f"{odir}/firmware/weights/{var.name}.txt", "w") + + # meta data + h_file.write(f"//Numpy array shape {var.shape}\n") + h_file.write(f"//Min {np.min(var.min):.12f}\n") + h_file.write(f"//Max {np.max(var.max):.12f}\n") + h_file.write(f"//Number of zeros {var.nzeros}\n") + h_file.write("\n") + + h_file.write(f"#ifndef {var.name.upper()}_H_\n") + h_file.write(f"#define {var.name.upper()}_H_\n") + h_file.write("\n") + + if write_txt_file: + h_file.write("#ifndef __SYNTHESIS__\n") + h_file.write("// global extern pointer only - actual array allocated in myproject_test.cpp\n") + h_file.write("extern " + var.definition_cpp() + ";\n") + h_file.write("#else\n") + + h_file.write(var.definition_cpp() + " = {") + + # fill c++ array. + # not including internal brackets for multidimensional case + sep = '' + for x in var: + h_file.write(sep + x) + if write_txt_file: + txt_file.write(sep + x) + sep = ", " + h_file.write("};\n") + if write_txt_file: + h_file.write("#endif\n") + txt_file.close() + h_file.write("\n#endif\n") + h_file.close() + + def write_output_dir(self, model): + """Write the base output directory + + Args: + model (ModelGraph): the hls4ml model. + """ + if not os.path.isdir(f"{model.config.get_output_dir()}/firmware/weights"): + os.makedirs(f"{model.config.get_output_dir()}/firmware/weights") + + @staticmethod + def _make_array_pragma(variable, model): + """ + Layers in hls_model.py can specify output array partitioning through the `pragma` attribute. + If `pragma` is a string: options are 'partition', 'reshape', or 'stream'. + If `pragma` is a tuple: (mode, type, factor) where mode is 'partition' or 'reshape', type is + 'complete', 'cyclic', or 'block', and factor is an integer only used when the type is not 'complete'. + """ + + config = variable.pragma + if type(config) is tuple: + mode = config[0] + if mode in ['partition', 'reshape']: + typ = config[1] + if typ != 'complete': + factor = config[2] + elif mode == 'stream': + depth = config[1] + else: + mode = config + typ = 'complete' + factor = 0 + + if mode in ['partition', 'reshape']: + if typ == 'complete': + template = '// #pragma HLS ARRAY_{mode} variable={name} {type} dim={dim}' + else: + template = '// #pragma HLS ARRAY_{mode} variable={name} {type} factor={factor} dim={dim}' + + return template.format(mode=mode.upper(), name=variable.name, type=typ, factor=factor, dim=0) + + elif mode == 'stream': + fifo = model.config.get_config_value("FIFO") + if fifo is not None: + retstr = f'#pragma hls_resource {variable.name}:cns variables="{variable.name}"' + retstr += f' map_to_module="{fifo}" // depth="{depth}"' + return retstr + else: + return '' + else: + return '' + + @staticmethod + def _make_array_fifo_pragma(variable, model): + config = variable.pragma + factor = '' + if type(config) is tuple: + mode = config[0] + if mode in ['partition', 'reshape']: + typ = config[1] + if typ != 'complete': + factor = config[2] + elif mode == 'stream': + depth = config[1] + else: + mode = config + typ = 'complete' + factor = 0 + + if mode == 'stream': + fifo = model.config.get_config_value("FIFO") + if fifo is not None: + return f'// #pragma hls_fifo_depth {depth} {factor}' + else: + return '' + else: + return '' + + def write_project_cpp(self, model): + """Write the main architecture source file (myproject.cpp) + + Args: + model (ModelGraph): the hls4ml model. + """ + + filedir = os.path.dirname(os.path.abspath(__file__)) + + fout = open(f'{model.config.get_output_dir()}/firmware/layer_summary.txt', 'w') + outstr = "" + outstr = outstr + "{}".format("Layer Name").ljust(25) + outstr = outstr + " {}".format("Layer Class").ljust(20) + outstr = outstr + " {}".format("Input Type").ljust(40) + outstr = outstr + " {}".format("Input Shape").ljust(15) + outstr = outstr + " {}".format("Output Type").ljust(40) + outstr = outstr + " {}".format("Output Shape").ljust(15) + # outstr = outstr + " {}".format("Weight Type").ljust(24) + # outstr = outstr + " {}".format("Bias Type").ljust(24) + outstr = outstr + " {}".format("Filter Shape").ljust(15) + outstr = outstr + " {}".format("Stride").ljust(10) + outstr = outstr + " {}".format("IOType").ljust(15) + outstr = outstr + " {}".format("Reuse").ljust(10) + + fout.write(outstr + "\n") + input_shape = "" + input_datatype = "" + for layer in model.get_layers(): + datatype = layer.get_output_variable().type.precision.definition_cpp() + " " + shape = "" + # layer.get_output_variable().type.precision.width + # layer.get_output_variable().type.precision.integer + # layer.get_output_variable().type.precision.sign + for _k, v in layer.get_output_variable().get_shape(): + shape = shape + "[" + str(v) + "]" + + if layer.attributes.layer.class_name != 'Input': + my_class_name = layer.class_name + if layer.attributes.layer.class_name == 'Activation': + my_class_name = layer.get_attr('activation') + + # filter_datatype = "" + # print(layer.weights.__dir__()) + # layer_precision = layer.get_layer_precision() + # for wname, weights in layer.weights.items(): + # print(wname) + # print(weights.type.name) + # print(weights.type.precision.definition_cpp()) + # #print(weights.type.precision.__dir__()) + # print(weights.type.precision.width) + # if 'ACFixed' in weights.type.precision.__class__: + # print(weights.type.precision.integer) + # print(weights.type.precision.signed) + # print(weights.data_length) + + filter = "" + filt_width = layer.get_attr('filt_width') + filt_height = layer.get_attr('filt_height') + if filt_width is not None: + filter = "[" + str(filt_width) + "]" + if filt_height is not None: + filter = filter + "[" + str(filt_height) + "]" + + stride = "" + stride_width = layer.get_attr('stride_width') + if stride_width is not None: + stride = str(stride_width) + + outstr = "" + outstr = outstr + f"{layer.name}".ljust(25) + outstr = outstr + f" {my_class_name}".ljust(20) + outstr = outstr + f" {input_datatype}".ljust(40) + outstr = outstr + f" {input_shape}".ljust(15) + outstr = outstr + f" {datatype}".ljust(40) + outstr = outstr + f" {shape}".ljust(15) + # outstr = outstr + " {}".format("weight type").ljust(24) + # outstr = outstr + " {}".format("bias type").ljust(24) + outstr = outstr + f" {filter}".ljust(15) + outstr = outstr + f" {stride}".ljust(10) + outstr = outstr + " {}".format(layer.model.config.get_config_value('IOType')).ljust(15) + outstr = outstr + f" {str(layer.model.config.get_reuse_factor(layer))}".ljust(10) + fout.write(outstr + "\n") + + input_shape = shape + input_datatype = datatype + + fout.close() + + f = open(os.path.join(filedir, '../templates/catapult/firmware/myproject.cpp')) + fout = open(f'{model.config.get_output_dir()}/firmware/{model.config.get_project_name()}.cpp', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + indent = ' ' + + for line in f.readlines(): + # Add headers to weights and biases + if 'myproject' in line: + newline = line.replace('myproject', model.config.get_project_name()) + elif '// hls-fpga-machine-learning insert header' in line: + inputs_str = ', '.join([i.definition_cpp(as_reference=True) for i in model_inputs]) + outputs_str = ', '.join([o.definition_cpp(as_reference=True) for o in model_outputs]) + brams_str = ', \n'.join([indent + b.definition_cpp(as_reference=False) for b in model_brams]) + + newline = '' + newline += indent + inputs_str + ',\n' + newline += indent + outputs_str + if len(model_brams) > 0: + newline += ',\n' + brams_str + newline += '\n' + + elif '// hls-fpga-machine-learning insert load weights' in line: + newline = line + for layer in model.get_layers(): + for w in layer.get_weights(): + if w.weight_class == 'CompressedWeightVariable': + newline += indent + ' nnet::load_compressed_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.nonzeros, w.name, w.name + ) + elif w.weight_class == 'ExponentWeightVariable': + newline += indent + ' nnet::load_exponent_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.data_length, w.name, w.name + ) + else: + newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.data_length, w.name, w.name + ) + + # Add Interface Synthesis resource pragmas + elif '// hls-fpga-machine-learning insert IFSynPragmas' in line: + newline = line + all_inputs = [i.name for i in model_inputs] + all_outputs = [o.name for o in model_outputs] + all_brams = [b.name for b in model_brams] + io_type = model.config.get_config_value("IOType") + + if io_type == 'io_serial' or io_type == 'io_stream': + # Eventually this will be amba.ccs_axi4stream_in and amba.ccs_axi4stream_out + for dut_input in all_inputs: + newline += f'#pragma hls_resource {dut_input}:rsc variables="{dut_input}"' + newline += ' map_to_module="ccs_ioport.ccs_in_wait"\n' + for dut_output in all_outputs: + newline += f'#pragma hls_resource {dut_output}:rsc variables="{dut_output}"' + newline += ' map_to_module="ccs_ioport.ccs_out_wait"\n' + + # Add input/output type + elif '// hls-fpga-machine-learning insert IO' in line: + newline = line + all_inputs = [i.name for i in model_inputs] + all_outputs = [o.name for o in model_outputs] + all_brams = [b.name for b in model_brams] + io_type = model.config.get_config_value("IOType") + + if io_type == 'io_parallel': + for i in model_inputs: + newline += indent + self._make_array_pragma(i, model) + '\n' + for o in model_outputs: + newline += indent + self._make_array_pragma(o, model) + '\n' + # TODO discussed adding a handle for setting the interface mode for individual input and output arrays + # Probably the handle doesn't need to be exposed to the user but should be just set in hls_model.py + newline += indent + '// #pragma HLS INTERFACE ap_vld port={},{} \n'.format( + ','.join(all_inputs), ','.join(all_outputs) + ) + if model.config.model_strategy.lower() == 'dataflow': + newline += indent + '// #pragma HLS DATAFLOW \n' + else: + newline += indent + '// #pragma HLS PIPELINE \n' + if io_type == 'io_stream': + newline += indent + '// #pragma HLS INTERFACE axis port={},{} \n'.format( + ','.join(all_inputs), ','.join(all_outputs) + ) + if all_brams: + newline += indent + '// #pragma HLS INTERFACE bram port={} \n'.format(','.join(all_brams)) + newline += indent + '// #pragma HLS DATAFLOW \n' + + elif '// hls-fpga-machine-learning insert layers' in line: + io_type = model.config.get_config_value("IOType") + newline = line + '\n' + for layer in model.get_layers(): + vars = layer.get_variables() + for var in vars: + if var not in model_inputs and var not in model_outputs: + def_cpp = var.definition_cpp() + if def_cpp is not None: + if var.pragma: + newline += ' ' + self._make_array_fifo_pragma(var, model) + '\n' + if io_type == 'io_serial' or io_type == 'io_stream': + newline += ' static ' + def_cpp + '; \n' + else: + newline += ' ' + def_cpp + '; \n' + if var.pragma: + newline += ' ' + self._make_array_pragma(var, model) + '\n' + func = layer.get_attr('function_cpp', None) + if func: + if not isinstance(func, (list, set)): + func = [func] + if len(func) == 1: + newline += ' ' + func[0] + ' // ' + layer.name + '\n' + else: + newline += ' // ' + layer.name + '\n' + for line in func: + newline += ' ' + line + '\n' + if model.config.trace_output and layer.get_attr('trace', False): + newline += '#ifndef __SYNTHESIS__\n' + for var in vars: + newline += ' nnet::save_layer_output<{}>({}, "{}", {});\n'.format( + var.type.name, var.name, layer.name, var.size_cpp() + ) + newline += '#endif\n' + newline += '\n' + + # Just copy line + else: + newline = line + + fout.write(newline) + + f.close() + fout.close() + + def write_project_header(self, model): + """Write the main architecture header file (myproject.h) + + Args: + model (ModelGraph): the hls4ml model. + """ + + filedir = os.path.dirname(os.path.abspath(__file__)) + f = open(os.path.join(filedir, '../templates/catapult/firmware/myproject.h')) + fout = open(f'{model.config.get_output_dir()}/firmware/{model.config.get_project_name()}.h', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + indent = ' ' + + for line in f.readlines(): + if 'MYPROJECT' in line: + newline = line.replace('MYPROJECT', format(model.config.get_project_name().upper())) + elif 'myproject' in line: + newline = line.replace('myproject', model.config.get_project_name()) + elif '// hls-fpga-machine-learning insert header' in line: + inputs_str = ', '.join([i.definition_cpp(as_reference=True) for i in model_inputs]) + outputs_str = ', '.join([o.definition_cpp(as_reference=True) for o in model_outputs]) + brams_str = ', \n'.join([indent + b.definition_cpp(as_reference=False) for b in model_brams]) + + newline = '' + newline += indent + inputs_str + ',\n' + newline += indent + outputs_str + if len(model_brams) > 0: + newline += ',\n' + brams_str + newline += '\n' + else: + newline = line + fout.write(newline) + + f.close() + fout.close() + + def write_defines(self, model): + """Write the C++ type definitions file (defines.h) + + Args: + model (ModelGraph): the hls4ml model. + """ + filedir = os.path.dirname(os.path.abspath(__file__)) + f = open(os.path.join(filedir, '../templates/catapult/firmware/defines.h')) + fout = open(f'{model.config.get_output_dir()}/firmware/defines.h', 'w') + + for line in f.readlines(): + # Insert numbers + if '// hls-fpga-machine-learning insert numbers' in line: + newline = line + + defines_list = [] + for layer in model.get_layers(): + defines = '' + for k, v in layer.get_output_variable().get_shape(): + defines += f'#define {k} {v}\n' + + defines_list.append(defines) + + newline += ''.join(defines_list) + + elif '// hls-fpga-machine-learning insert layer-precision' in line: + newline = line + all_precision = OrderedDict() + for layer in model.get_layers(): + layer_precision = layer.get_layer_precision() + for type_name, type_var in layer_precision.items(): + # Ensure that layer's types doesn't override existing types + # This can happen in case of InplaceVariable types + if type_name not in all_precision: + all_precision[type_name] = type_var + for used_type in all_precision.values(): + newline += used_type.definition_cpp() + + else: + newline = line + fout.write(newline) + f.close() + fout.close() + + def write_parameters(self, model): + """Write the C++ layer config file (parameters.h) + + Args: + model (ModelGraph): the hls4ml model. + """ + filedir = os.path.dirname(os.path.abspath(__file__)) + f = open(os.path.join(filedir, '../templates/catapult/firmware/parameters.h')) + fout = open(f'{model.config.get_output_dir()}/firmware/parameters.h', 'w') + + for line in f.readlines(): + if '// hls-fpga-machine-learning insert includes' in line: + newline = line + for include in sorted(set(sum((layer.get_attr('include_header', []) for layer in model.get_layers()), []))): + newline += '#include "%s"\n' % include + + elif '// hls-fpga-machine-learning insert weights' in line: + newline = line + for layer in model.get_layers(): + for w in layer.get_weights(): + if w.storage.lower() != 'bram': + newline += f'#include "weights/{w.name}.h"\n' + + elif "// hls-fpga-machine-learning insert layer-config" in line: + newline = line + for layer in model.get_layers(): + config = layer.get_attr('config_cpp', None) + if config: + newline += '// ' + layer.name + '\n' + newline += config + '\n' + else: + newline = line + fout.write(newline) + f.close() + fout.close() + + def write_weights(self, model): + """Write the weights into header files + + Args: + model (ModelGraph): the hls4ml model. + """ + for layer in model.get_layers(): + for weights in layer.get_weights(): + self.print_array_to_cpp(weights, model.config.get_output_dir()) + + def __make_dat_file(self, original_path, project_path): + """ + Convert other input/output data types into a dat file, which is + a text file with the falttened matrix printed out. Note that ' ' is + assumed to be the delimiter. + """ + + # Take in data from current supported data files + if original_path[-3:] == "npy": + data = np.load(original_path) + else: + raise Exception("Unsupported input/output data files.") + + # Faltten data, just keep first dimension + data = data.reshape(data.shape[0], -1) + + def print_data(f): + for i in range(data.shape[0]): + for j in range(data.shape[1]): + f.write(str(data[i][j]) + " ") + f.write("\n") + + # Print out in dat file + with open(project_path, "w") as f: + print_data(f) + + def write_test_bench(self, model): + """Write the testbench files (myproject_test.cpp and input/output .dat files) + + Args: + model (ModelGraph): the hls4ml model. + """ + + filedir = os.path.dirname(os.path.abspath(__file__)) + + if not os.path.exists(f'{model.config.get_output_dir()}/tb_data/'): + os.mkdir(f'{model.config.get_output_dir()}/tb_data/') + + input_data = model.config.get_config_value('InputData') + output_predictions = model.config.get_config_value('OutputPredictions') + + if input_data: + if input_data[-3:] == "dat": + copyfile(input_data, f'{model.config.get_output_dir()}/tb_data/tb_input_features.dat') + else: + self.__make_dat_file(input_data, f'{model.config.get_output_dir()}/tb_data/tb_input_features.dat') + + if output_predictions: + if output_predictions[-3:] == "dat": + copyfile(output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat') + else: + self.__make_dat_file( + output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat' + ) + + f = open(os.path.join(filedir, '../templates/catapult/myproject_test.cpp')) + fout = open(f'{model.config.get_output_dir()}/{model.config.get_project_name()}_test.cpp', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + for line in f.readlines(): + indent = ' ' * (len(line) - len(line.lstrip(' '))) + + # Insert numbers + if 'myproject' in line: + newline = line.replace('myproject', model.config.get_project_name()) + elif '// hls-fpga-machine-learning insert bram' in line: + newline = line + for bram in model_brams: + newline += f'#include \"firmware/weights/{bram.name}.h\"\n' + + elif '// hls-fpga-machine-learning insert declare weights' in line: + newline = line + for layer in model.get_layers(): + for w in layer.get_weights(): + newline += w.definition_cpp() + ";\n" + + elif '// hls-fpga-machine-learning insert load weights' in line: + newline = line + for layer in model.get_layers(): + for w in layer.get_weights(): + if w.weight_class == 'CompressedWeightVariable': + newline += indent + ' nnet::load_compressed_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.nonzeros, w.name, w.name + ) + elif w.weight_class == 'ExponentWeightVariable': + newline += indent + ' nnet::load_exponent_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.data_length, w.name, w.name + ) + else: + newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.data_length, w.name, w.name + ) + + elif '// hls-fpga-machine-learning insert data' in line: + newline = line + offset = 0 + for inp in model_inputs: + newline += ' ' + inp.definition_cpp() + ';\n' + newline += ' nnet::copy_data(in, {});\n'.format( + inp.type.name, offset, inp.size_cpp(), inp.name + ) + offset += inp.size() + for out in model_outputs: + newline += ' ' + out.definition_cpp() + ';\n' + elif '// hls-fpga-machine-learning insert random' in line: + newline = line + for inp in model_inputs: + newline += ' ' + inp.definition_cpp() + ';\n' + newline += f' nnet::fill_random<{inp.type.name}, {inp.size_cpp()}>({inp.name});\n' + for out in model_outputs: + newline += ' ' + out.definition_cpp() + ';\n' + elif '// hls-fpga-machine-learning insert zero' in line: + newline = line + for inp in model_inputs: + newline += ' ' + inp.definition_cpp() + ';\n' + newline += f' nnet::fill_zero<{inp.type.name}, {inp.size_cpp()}>({inp.name});\n' + for out in model_outputs: + newline += ' ' + out.definition_cpp() + ';\n' + elif '// hls-fpga-machine-learning insert top-level-function' in line: + newline = line + + input_vars = ','.join([i.name for i in model_inputs]) + output_vars = ','.join([o.name for o in model_outputs]) + bram_vars = ','.join([b.name for b in model_brams]) + + # Concatenate the input, output, and bram variables. Filter out empty/null values + all_vars = ','.join(filter(None, [input_vars, output_vars, bram_vars])) + + top_level = indent + f'{model.config.get_project_name()}({all_vars});\n' + + newline += top_level + elif '// hls-fpga-machine-learning insert predictions' in line: + newline = line + for out in model_outputs: + newline += indent + f'for(int i = 0; i < {out.size_cpp()}; i++) {{\n' + newline += indent + ' std::cout << pr[i] << " ";\n' + newline += indent + '}\n' + newline += indent + 'std::cout << std::endl;\n' + elif '// hls-fpga-machine-learning insert tb-output' in line: + newline = line + for out in model_outputs: + newline += indent + 'nnet::print_result<{}, {}>({}, fout);\n'.format( + out.type.name, out.size_cpp(), out.name + ) # TODO enable this + elif ( + '// hls-fpga-machine-learning insert output' in line + or '// hls-fpga-machine-learning insert quantized' in line + ): + newline = line + for out in model_outputs: + newline += indent + 'nnet::print_result<{}, {}>({}, std::cout, true);\n'.format( + out.type.name, out.size_cpp(), out.name + ) + else: + newline = line + fout.write(newline) + f.close() + fout.close() + + def write_bridge(self, model): + """Write the Python-C++ bridge (myproject_bridge.cpp) + + Args: + model (ModelGraph): the hls4ml model. + """ + + filedir = os.path.dirname(os.path.abspath(__file__)) + f = open(os.path.join(filedir, '../templates/catapult/myproject_bridge.cpp')) + fout = open(f'{model.config.get_output_dir()}/{model.config.get_project_name()}_bridge.cpp', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + indent = ' ' + + for line in f.readlines(): + if 'MYPROJECT' in line: + newline = line.replace('MYPROJECT', format(model.config.get_project_name().upper())) + elif 'myproject' in line: + newline = line.replace('myproject', format(model.config.get_project_name())) + elif '// hls-fpga-machine-learning insert bram' in line: + newline = line + for bram in model_brams: + newline += f'#include \"firmware/weights/{bram.name}.h\"\n' + elif '// hls-fpga-machine-learning insert declare weights' in line: + newline = line + for layer in model.get_layers(): + for w in layer.get_weights(): + newline += w.definition_cpp() + ";\n" + elif '// hls-fpga-machine-learning insert header' in line: + dtype = line.split('#', 1)[1].strip() + inputs_str = ', '.join([f'{dtype} {i.name}[{i.size_cpp()}]' for i in model_inputs]) + outputs_str = ', '.join([f'{dtype} {o.name}[{o.size_cpp()}]' for o in model_outputs]) + + newline = '' + newline += indent + inputs_str + ',\n' + newline += indent + outputs_str + '\n' + elif '// hls-fpga-machine-learning insert wrapper' in line: + dtype = line.split('#', 1)[1].strip() + newline = '' + for i in model_inputs: + newline += indent + '{var};\n'.format(var=i.definition_cpp(name_suffix='_ap')) + newline += indent + 'nnet::convert_data<{}, {}, {}>({}, {}_ap);\n'.format( + dtype, i.type.name, i.size_cpp(), i.name, i.name + ) + newline += '\n' + + for o in model_outputs: + newline += indent + '{var};\n'.format(var=o.definition_cpp(name_suffix='_ap')) + + newline += '\n' + + input_vars = ','.join([i.name + '_ap' for i in model_inputs]) + bram_vars = ','.join([b.name for b in model_brams]) + output_vars = ','.join([o.name + '_ap' for o in model_outputs]) + + # Concatenate the input, output, and bram variables. Filter out empty/null values + all_vars = ','.join(filter(None, [input_vars, output_vars, bram_vars])) + + top_level = indent + f'{model.config.get_project_name()}({all_vars});\n' + newline += top_level + + newline += '\n' + + for o in model_outputs: + newline += indent + 'nnet::convert_data<{}, {}, {}>({}_ap, {});\n'.format( + o.type.name, dtype, o.size_cpp(), o.name, o.name + ) + elif '// hls-fpga-machine-learning insert trace_outputs' in line: + newline = '' + for layer in model.get_layers(): + func = layer.get_attr('function_cpp', None) + if func and model.config.trace_output and layer.get_attr('trace', False): + vars = layer.get_variables() + for var in vars: + newline += ( + indent + + 'nnet::trace_outputs->insert(std::pair(' + + f'"{layer.name}", (void *) malloc({var.size_cpp()} * element_size)));\n' + ) + + else: + newline = line + fout.write(newline) + + f.close() + fout.close() + + def write_build_script(self, model): + """Write the TCL/Shell build scripts (build_prj.tcl, build_lib.sh) + + Args: + model (ModelGraph): the hls4ml model. + """ + + filedir = os.path.dirname(os.path.abspath(__file__)) + + # build_prj.tcl + srcpath = os.path.join(filedir, '../templates/catapult/build_prj.tcl') + dstpath = f'{model.config.get_output_dir()}/build_prj.tcl' + # copyfile(srcpath, dstpath) + f = open(srcpath) + fout = open(dstpath, 'w') + for line in f.readlines(): + indent = line[: len(line) - len(line.lstrip())] + line = line.replace('myproject', model.config.get_project_name()) + line = line.replace('CATAPULT_DIR', model.config.get_project_dir()) + if '#hls-fpga-machine-learning insert techlibs' in line: + if model.config.get_config_value('Technology') is None: + if model.config.get_config_value('Part') is not None: + line = indent + 'setup_xilinx_part {{{}}}\n'.format(model.config.get_config_value('Part')) + elif model.config.get_config_value('ASICLibs') is not None: + line = indent + 'setup_asic_libs {{{}}}\n'.format(model.config.get_config_value('ASICLibs')) + else: + if model.config.get_config_value('Technology') == 'asic': + line = indent + 'setup_asic_libs {{{}}}\n'.format(model.config.get_config_value('ASICLibs')) + else: + line = indent + 'setup_xilinx_part {{{}}}\n'.format(model.config.get_config_value('Part')) + elif '#hls-fpga-machine-learning insert invoke_args' in line: + tb_in_file = model.config.get_config_value('InputData') + tb_out_file = model.config.get_config_value('OutputPredictions') + invoke_args = '$sfd/firmware/weights' + if tb_in_file is not None: + invoke_args = invoke_args + f' $sfd/tb_data/{tb_in_file}' + if tb_out_file is not None: + invoke_args = invoke_args + f' $sfd/tb_data/{tb_out_file}' + line = indent + f'flow package option set /SCVerify/INVOKE_ARGS "{invoke_args}"\n' + elif 'set hls_clock_period 5' in line: + line = indent + 'set hls_clock_period {}\n'.format(model.config.get_config_value('ClockPeriod')) + fout.write(line) + f.close() + fout.close() + + # build_lib.sh + f = open(os.path.join(filedir, '../templates/catapult/build_lib.sh')) + fout = open(f'{model.config.get_output_dir()}/build_lib.sh', 'w') + + for line in f.readlines(): + line = line.replace('myproject', model.config.get_project_name()) + line = line.replace('mystamp', model.config.get_config_value('Stamp')) + + fout.write(line) + f.close() + fout.close() + + def write_nnet_utils(self, model): + """Copy the nnet_utils, AP types headers and any custom source to the project output directory + + Args: + model (ModelGraph): the hls4ml model. + """ + + # nnet_utils + filedir = os.path.dirname(os.path.abspath(__file__)) + + srcpath = os.path.join(filedir, '../templates/catapult/nnet_utils/') + dstpath = f'{model.config.get_output_dir()}/firmware/nnet_utils/' + + if not os.path.exists(dstpath): + os.mkdir(dstpath) + + headers = [os.path.basename(h) for h in glob.glob(srcpath + '*.h')] + + if model.config.get_config_value('DontCopyNNET') is not None: + h = 'nnet_code_gen.h' + copyfile(srcpath + h, dstpath + h) + return + + for h in headers: + copyfile(srcpath + h, dstpath + h) + + print("Copying NNET files to local firmware directory") + + filedir = os.path.dirname(os.path.abspath(__file__)) + for pkg in ('ac_types', 'ac_math', 'ac_simutils'): + dstpath = f'{model.config.get_output_dir()}/firmware/{pkg}/' + + # backward compatibility, look in root dir + srcpath = os.path.join(filedir, '../../' + pkg + '/') + if not os.path.exists(srcpath): + # look next in Catapult-specific templates + srcpath = os.path.join(filedir, '../templates/catapult/' + pkg + '/') + + if os.path.exists(srcpath): + if os.path.exists(dstpath): + rmtree(dstpath) + print("... copying AC " + pkg + " headers from " + srcpath) + copytree(srcpath, dstpath) + else: + print("... skipping copy of " + pkg + " headers - assumed to located in Catapult install tree") + + # custom source + filedir = os.path.dirname(os.path.abspath(__file__)) + + custom_source = get_backend('Catapult').get_custom_source() + for dst, srcpath in custom_source.items(): + dstpath = f'{model.config.get_output_dir()}/firmware/{dst}' + copyfile(srcpath, dstpath) + + def write_generated_code(self, model): + """Write the generated code (nnet_code_gen.h) + + Args: + model (ModelGraph): the hls4ml model. + """ + path = f'{model.config.get_output_dir()}/firmware/nnet_utils/nnet_code_gen.h' + f = open(path) + contents = f.readlines() + f.close() + f = open(path, 'w') + + for line in contents: + if '// hls4ml insert code' in line: + newline = line + for layer in model.get_layers(): + for generated_code in layer.code.values(): + newline += str(generated_code) + else: + newline = line + f.write(newline) + f.close() + + def write_yml(self, model): + """Write the config to the YAML file + + Args: + model (ModelGraph): the hls4ml model. + """ + + def keras_model_representer(dumper, keras_model): + model_path = model.config.get_output_dir() + '/keras_model.h5' + keras_model.save(model_path) + return dumper.represent_scalar('!keras_model', model_path) + + try: + from tensorflow.keras import Model as KerasModel + + yaml.add_multi_representer(KerasModel, keras_model_representer) + except Exception: + pass + + with open(model.config.get_output_dir() + '/' + config_filename, 'w') as file: + yaml.dump(model.config.config, file) + + def write_tar(self, model): + """Write the generated project as a .tar.gz archive + + Args: + model (ModelGraph): the hls4ml model. + """ + + if not os.path.exists(model.config.get_output_dir() + '.tar.gz'): + with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive: + archive.add(model.config.get_output_dir(), recursive=True) + else: + print("Project .tar.gz archive already exists") + + def write_hls(self, model): + print('Writing HLS project') + self.write_output_dir(model) + self.write_project_cpp(model) + self.write_project_header(model) + self.write_weights(model) + self.write_defines(model) + self.write_parameters(model) + self.write_test_bench(model) + self.write_bridge(model) + self.write_build_script(model) + self.write_nnet_utils(model) + self.write_generated_code(model) + self.write_yml(model) + self.write_tar(model) + print('Done') diff --git a/test/pytest/ci-template.yml b/test/pytest/ci-template.yml index 5477da933a..50e9f799f6 100644 --- a/test/pytest/ci-template.yml +++ b/test/pytest/ci-template.yml @@ -5,7 +5,8 @@ - k8s-default before_script: - source ~/.bashrc - - if [ $EXAMPLEMODEL == 1 ]; then git submodule init; git submodule update; fi + - git submodule update --init --recursive hls4ml/templates/catapult/ + - if [ $EXAMPLEMODEL == 1 ]; then git submodule update --init example-models; fi - conda activate hls4ml-testing - pip install .[testing,sr,optimization] script: diff --git a/test/pytest/test_activations.py b/test/pytest/test_activations.py index caaaed636a..5ab9481e1a 100644 --- a/test/pytest/test_activations.py +++ b/test/pytest/test_activations.py @@ -12,7 +12,7 @@ # Variable 'name' is simply used as an identifier for the activation -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Catapult', 'Quartus']) @pytest.mark.parametrize('shape, io_type', [((8,), 'io_parallel'), ((8,), 'io_stream'), ((8, 8, 3), 'io_stream')]) @pytest.mark.parametrize( 'activation, name', diff --git a/test/pytest/test_batchnorm.py b/test/pytest/test_batchnorm.py index c0ef0705ae..727d2ee574 100644 --- a/test/pytest/test_batchnorm.py +++ b/test/pytest/test_batchnorm.py @@ -29,7 +29,7 @@ def model(request): @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('model', [True, False], indirect=True) def test_batchnorm(model, data, backend, io_type): default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>' diff --git a/test/pytest/test_batchnorm_pytorch.py b/test/pytest/test_batchnorm_pytorch.py index a7a0c80247..93cda2729c 100644 --- a/test/pytest/test_batchnorm_pytorch.py +++ b/test/pytest/test_batchnorm_pytorch.py @@ -21,7 +21,7 @@ def data(): @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) def test_batchnorm(data, backend, io_type): model = nn.Sequential( nn.BatchNorm1d(in_shape), diff --git a/test/pytest/test_clone_flatten.py b/test/pytest/test_clone_flatten.py index 12f30985bf..5f631d027f 100644 --- a/test/pytest/test_clone_flatten.py +++ b/test/pytest/test_clone_flatten.py @@ -28,7 +28,7 @@ def keras_model(): @pytest.fixture @pytest.mark.parametrize('io_type', ['io_stream']) -@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Catapult']) def hls_model(keras_model, backend, io_type): hls_config = hls4ml.utils.config_from_keras_model( keras_model, diff --git a/test/pytest/test_cnn_mnist.py b/test/pytest/test_cnn_mnist.py index ab3365f228..27b966f51d 100644 --- a/test/pytest/test_cnn_mnist.py +++ b/test/pytest/test_cnn_mnist.py @@ -61,7 +61,7 @@ def keras_model(mnist_data): ('Vitis', 'io_parallel', 'resource'), ('Vitis', 'io_parallel', 'latency'), ('Vitis', 'io_stream', 'latency'), - ('Vitis', 'io_stream', 'resource'), + ('Vitis', 'io_stream', 'latency'), ], ) def test_mnist_cnn(keras_model, mnist_data, backend, io_type, strategy): diff --git a/test/pytest/test_conv1d.py b/test/pytest/test_conv1d.py index 79beb01a2c..48357a42a1 100644 --- a/test/pytest/test_conv1d.py +++ b/test/pytest/test_conv1d.py @@ -41,6 +41,8 @@ def keras_model(): ('Vitis', 'io_parallel', 'latency'), ('Vitis', 'io_stream', 'latency'), ('Vitis', 'io_stream', 'resource'), + ('Catapult', 'io_stream', 'latency'), + ('Catapult', 'io_stream', 'resource'), ], ) def hls_model(keras_model, backend, io_type, strategy): @@ -91,6 +93,8 @@ def hls_model(keras_model, backend, io_type, strategy): ('Vitis', 'io_parallel', 'latency'), ('Vitis', 'io_stream', 'latency'), ('Vitis', 'io_stream', 'resource'), + ('Catapult', 'io_stream', 'latency'), + ('Catapult', 'io_stream', 'resource'), ], ) def test_accuracy(data, keras_model, hls_model): diff --git a/test/pytest/test_embed.py b/test/pytest/test_embed.py index fd8e39cdb9..a27fc45b93 100644 --- a/test/pytest/test_embed.py +++ b/test/pytest/test_embed.py @@ -25,7 +25,7 @@ def keras_model(): @pytest.fixture -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def hls_model(keras_model, backend, io_type): hls_config = hls4ml.utils.config_from_keras_model(keras_model, default_precision='ap_fixed<16,6>', granularity='name') @@ -39,7 +39,7 @@ def hls_model(keras_model, backend, io_type): return hls_model -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_embedding_accuracy(data, keras_model, hls_model): X = data diff --git a/test/pytest/test_globalpooling.py b/test/pytest/test_globalpooling.py index c402a53cdf..b99f0d8212 100644 --- a/test/pytest/test_globalpooling.py +++ b/test/pytest/test_globalpooling.py @@ -32,7 +32,7 @@ def keras_model_1d(request): return model, model_type, keepdims -@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado']) +@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado', 'Catapult']) @pytest.mark.parametrize( 'keras_model_1d', [ @@ -87,7 +87,7 @@ def keras_model_2d(request): return model, model_type, keepdims -@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado']) +@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado', 'Catapult']) @pytest.mark.parametrize( 'keras_model_2d', [ diff --git a/test/pytest/test_keras_h5_loader.py b/test/pytest/test_keras_h5_loader.py index b53bb3a668..0c42adee31 100644 --- a/test/pytest/test_keras_h5_loader.py +++ b/test/pytest/test_keras_h5_loader.py @@ -9,7 +9,7 @@ test_root_path = Path(__file__).parent -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) def test_keras_h5_loader(backend): input_shape = (10,) model = tf.keras.models.Sequential( diff --git a/test/pytest/test_keras_nested_model.py b/test/pytest/test_keras_nested_model.py index 8c4670ad51..66fa81e2f9 100755 --- a/test/pytest/test_keras_nested_model.py +++ b/test/pytest/test_keras_nested_model.py @@ -127,7 +127,7 @@ def randX_20_15(): return randX(20, 15) -@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Catapult']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_nested_model(randX_20_15, backend, io_type): n_in = 15 @@ -150,7 +150,7 @@ def test_nested_model(randX_20_15, backend, io_type): np.testing.assert_allclose(y_keras.ravel(), y_hls4ml.ravel(), rtol=1e-2, atol=0.02) -@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Catapult']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_sub_nested_model(randX_20_15, backend, io_type): n_in = 15 diff --git a/test/pytest/test_pointwiseconv.py b/test/pytest/test_pointwiseconv.py index b7fee0a4ab..060b9877de 100644 --- a/test/pytest/test_pointwiseconv.py +++ b/test/pytest/test_pointwiseconv.py @@ -31,6 +31,8 @@ ('Vivado', 'io_stream', 'resource'), ('Vitis', 'io_stream', 'latency'), ('Vitis', 'io_stream', 'resource'), + ('Catapult', 'io_stream', 'latency'), + ('Catapult', 'io_stream', 'resource'), ], ) def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy): @@ -87,6 +89,8 @@ def test_pointwiseconv1d(chans, padds, strides, backend, io_type, strategy): ('Vivado', 'io_parallel', 'latency'), ('Vivado', 'io_stream', 'latency'), ('Vivado', 'io_stream', 'resource'), + ('Catapult', 'io_stream', 'latency'), + ('Catapult', 'io_stream', 'resource'), ], ) def test_pointwiseconv2d(chans, padds, strides, backend, io_type, strategy): diff --git a/test/pytest/test_pooling.py b/test/pytest/test_pooling.py index 1f958696d8..d7de80a5a7 100644 --- a/test/pytest/test_pooling.py +++ b/test/pytest/test_pooling.py @@ -32,7 +32,7 @@ def keras_model_1d(request): return model, model_type, pads -@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado']) +@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado', 'Catapult']) @pytest.mark.parametrize( 'keras_model_1d', [ @@ -87,7 +87,7 @@ def keras_model_2d(request): return model, model_type, pads -@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado']) +@pytest.mark.parametrize('backend', ['Quartus', 'Vitis', 'Vivado', 'Catapult']) @pytest.mark.parametrize( 'keras_model_2d', [ diff --git a/test/pytest/test_repack_stream.py b/test/pytest/test_repack_stream.py index 12d44a66b7..04cc9867a9 100644 --- a/test/pytest/test_repack_stream.py +++ b/test/pytest/test_repack_stream.py @@ -9,7 +9,7 @@ test_root_path = Path(__file__).parent -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) def test_repack_precision(backend: str): inp = keras.Input(shape=(3, 3), name='inp') out = keras.layers.Reshape((3, 3), name='reshape')(inp) @@ -41,7 +41,7 @@ def test_repack_precision(backend: str): assert repack_precision.signed is True, 'Precision mismatch' -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('strategy', ['Latency', 'Resource']) def test_repack(backend: str, strategy: str): inp1 = keras.Input(shape=(4,), name='inp1') diff --git a/test/pytest/test_reshape.py b/test/pytest/test_reshape.py index 3c421c1474..ac277bb491 100755 --- a/test/pytest/test_reshape.py +++ b/test/pytest/test_reshape.py @@ -21,7 +21,7 @@ def randX_20_10(): return randX(20, 10) -@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Catapult']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_reshape_parallel(randX_20_10, backend, io_type): model = tf.keras.models.Sequential( diff --git a/test/pytest/test_sepconv1d.py b/test/pytest/test_sepconv1d.py index a75d854283..64b72db48a 100644 --- a/test/pytest/test_sepconv1d.py +++ b/test/pytest/test_sepconv1d.py @@ -25,7 +25,7 @@ @pytest.mark.parametrize('kernels', kernel_options) @pytest.mark.parametrize('bias', bias_options) @pytest.mark.parametrize('io_type', io_type_options) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Catapult']) def test_sepconv1d(conv1d, chans, padds, strides, kernels, bias, io_type, backend): model = tf.keras.models.Sequential() input_shape = (28, 3) diff --git a/test/pytest/test_sepconv2d.py b/test/pytest/test_sepconv2d.py index 1ce85c5016..2fa2d94afe 100644 --- a/test/pytest/test_sepconv2d.py +++ b/test/pytest/test_sepconv2d.py @@ -25,7 +25,7 @@ @pytest.mark.parametrize("kernels", kernel_options) @pytest.mark.parametrize("bias", bias_options) @pytest.mark.parametrize("io_type", io_type_options) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Catapult']) def test_sepconv2d(conv2d, chans, padds, strides, kernels, bias, io_type, backend): model = tf.keras.models.Sequential() input_shape = (28, 28, 3) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 3cab00745c..19c9042465 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -19,7 +19,7 @@ def generate_data(input_shape): return np.clip(d, -32, 31) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('strategy', ['stable', 'latency', 'argmax']) @pytest.mark.parametrize( 'input_bits,input_shape,table_bits,io_type', @@ -65,7 +65,7 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl assert acc_hls4ml >= 0.98 -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_softmax_skipped(backend, io_type): X = np.random.rand(100, 10) diff --git a/test/pytest/test_softsign.py b/test/pytest/test_softsign.py index a23e89e7da..217865fe46 100644 --- a/test/pytest/test_softsign.py +++ b/test/pytest/test_softsign.py @@ -10,7 +10,7 @@ test_root_path = Path(__file__).parent -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('input_shape, io_type', [((8,), 'io_parallel'), ((8,), 'io_stream'), ((8, 8, 3), 'io_stream')]) def test_softsign(backend, input_shape, io_type): X = np.random.rand(1000, *input_shape) diff --git a/test/pytest/test_upsampling.py b/test/pytest/test_upsampling.py index 8ec5cabda9..9051d582bd 100644 --- a/test/pytest/test_upsampling.py +++ b/test/pytest/test_upsampling.py @@ -46,7 +46,7 @@ def keras_model_2d(): @pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('model_type', ['1d', '2d']) def test_upsampling(keras_model_1d, keras_model_2d, data_1d, data_2d, model_type, io_type, backend): if model_type == '1d': diff --git a/test/pytest/test_zeropadding.py b/test/pytest/test_zeropadding.py index 962a3334a6..95f7d79a7d 100644 --- a/test/pytest/test_zeropadding.py +++ b/test/pytest/test_zeropadding.py @@ -50,7 +50,7 @@ def keras_model_2d(): @pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel']) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult']) @pytest.mark.parametrize('model_type', ['1d', '2d']) def test_zeropadding(keras_model_1d, keras_model_2d, data_1d, data_2d, model_type, io_type, backend): if model_type == '1d':