From a36ccdeadf74be201fa4e17b49f5f704fa4d2fa9 Mon Sep 17 00:00:00 2001 From: Francesco Brivio Date: Tue, 7 May 2024 12:22:37 +0200 Subject: [PATCH] implement balanced tree reduce for xilinxhls backend --- .../xilinxhls/firmware/BDT_unrolled.h | 52 +++++++++++++++---- .../hls-template/firmware/BDT_unrolled.cpp | 2 +- conifer/backends/xilinxhls/writer.py | 4 +- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/conifer/backends/xilinxhls/firmware/BDT_unrolled.h b/conifer/backends/xilinxhls/firmware/BDT_unrolled.h index 20b6468..ca32862 100644 --- a/conifer/backends/xilinxhls/firmware/BDT_unrolled.h +++ b/conifer/backends/xilinxhls/firmware/BDT_unrolled.h @@ -5,8 +5,37 @@ namespace BDT{ +/* --- +* Balanced tree reduce implementation. +* Reduces an array of inputs to a single value using the template binary operator 'Op', +* for example summing all elements with OpAdd, or finding the maximum with OpMax +* Use only when the input array is fully unrolled. Or, slice out a fully unrolled section +* before applying and accumulate the result over the rolled dimension. +* Required for emulation to guarantee equality of ordering. +* --- */ +constexpr int floorlog2(int x) { return (x < 2) ? 0 : 1 + floorlog2(x / 2); } + +constexpr int pow2(int x) { return x == 0 ? 1 : 2 * pow2(x - 1); } + +template T reduce(const T *x, Op op) { + static constexpr int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0; + static constexpr int rightN = N - leftN > 0 ? N - leftN : 0; + if (N == 1) { + return x[0]; + } + if (N == 2) { + return op(x[0], x[1]); + } + return op(reduce(x, op), reduce(x + leftN, op)); +} + +template class OpAdd { + public: + T operator()(T a, T b) { return a + b; } +}; + +// Number of trees given number of classes constexpr int fn_classes(int n_classes){ - // Number of trees given number of classes return n_classes == 2 ? 1 : n_classes; } @@ -99,23 +128,24 @@ struct BDT{ public: score_t normalisation; score_t init_predict[fn_classes(n_classes)]; + OpAdd op_add; - void tree_scores(input_t x, score_t scores[n_trees][fn_classes(n_classes)]) const; + void tree_scores(input_t x, score_t scores[fn_classes(n_classes)][n_trees]) const; void decision_function(input_t x, score_t score[fn_classes(n_classes)]) const{ - score_t scores[n_trees][fn_classes(n_classes)]; + score_t scores[fn_classes(n_classes)][n_trees]; #pragma HLS ARRAY_PARTITION variable=scores dim=0 + // Get predictions scores + tree_scores(x, scores); + // Reduce + Reduce: for(int j = 0; j < fn_classes(n_classes); j++){ + // Init predictions score[j] = init_predict[j]; + // Sum predictions from trees via "reduce" method + score[j] += reduce>(scores[j], op_add); } - tree_scores(x, scores); - Trees: - for(int i = 0; i < n_trees; i++){ - Classes: - for(int j = 0; j < fn_classes(n_classes); j++){ - score[j] += scores[i][j]; - } - } + // Normalize predictions for(int j = 0; j < fn_classes(n_classes); j++){ score[j] *= normalisation; } diff --git a/conifer/backends/xilinxhls/hls-template/firmware/BDT_unrolled.cpp b/conifer/backends/xilinxhls/hls-template/firmware/BDT_unrolled.cpp index 8f31d0d..633c3a2 100644 --- a/conifer/backends/xilinxhls/hls-template/firmware/BDT_unrolled.cpp +++ b/conifer/backends/xilinxhls/hls-template/firmware/BDT_unrolled.cpp @@ -2,7 +2,7 @@ #include "parameters.h" template<> -void BDT::BDT::tree_scores(input_arr_t x, score_t scores[n_trees][fn_classes(n_classes)]) const { +void BDT::BDT::tree_scores(input_arr_t x, score_t scores[fn_classes(n_classes)][n_trees]) const { // conifer insert tree_scores } diff --git a/conifer/backends/xilinxhls/writer.py b/conifer/backends/xilinxhls/writer.py index 0a41b57..e9b28f4 100644 --- a/conifer/backends/xilinxhls/writer.py +++ b/conifer/backends/xilinxhls/writer.py @@ -139,7 +139,7 @@ def write_bdt_h(self): newline = '' for it, trees in enumerate(self.trees): for ic, tree in enumerate(trees): - newline += f' scores[{it}][{ic}] = tree_{it}_{ic}.decision_function(x);\n' + newline += f' scores[{ic}][{it}] = tree_{ic}_{it}.decision_function(x);\n' else: newline = line fout.write(newline) @@ -227,7 +227,7 @@ def _write_parameters_h_unrolled(self, fout): for iclass, tree in enumerate(trees): fout.write(f'static const BDT::Tree<{itree*nc+iclass}, {tree.n_nodes()}, {tree.n_leaves()}') fout.write(f', input_arr_t, score_t, threshold_t>') - fout.write(f' tree_{itree}_{iclass} = {{\n') + fout.write(f' tree_{iclass}_{itree} = {{\n') # loop over fields for ifield, field in enumerate(tree_fields): newline = ' {'