diff --git a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst
index fdcf44c6d9..3627855cfb 100644
--- a/docs/finn/source_code/finn.custom_op.fpgadataflow.rst
+++ b/docs/finn/source_code/finn.custom_op.fpgadataflow.rst
@@ -203,6 +203,14 @@ finn.custom\_op.fpgadataflow.thresholding\_batch
:undoc-members:
:show-inheritance:
+finn.custom\_op.fpgadataflow.thresholding\_binary\_search
+-----------------------------------------------------------
+
+.. automodule:: finn.custom_op.fpgadataflow.thresholding_binary_search
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
finn.custom\_op.fpgadataflow.tlastmarker
-----------------------------------------------
diff --git a/finn-rtllib/thresholding/component.xml b/finn-rtllib/thresholding/component.xml
new file mode 100644
index 0000000000..e28a3a2c2d
--- /dev/null
+++ b/finn-rtllib/thresholding/component.xml
@@ -0,0 +1,1002 @@
+
+
+ amd.com
+ finn
+ thresholding_axi
+ 1.0
+
+
+ ap_clk
+
+
+
+
+
+
+ CLK
+
+
+ ap_clk
+
+
+
+
+
+ ASSOCIATED_RESET
+ ap_rst_n
+
+
+ ASSOCIATED_BUSIF
+ s_axilite:s_axis:m_axis
+
+
+ FREQ_TOLERANCE_HZ
+ -1
+
+
+
+
+ m_axis
+
+
+
+
+
+
+ TDATA
+
+
+ m_axis_tdata
+
+
+
+
+ TVALID
+
+
+ m_axis_tvalid
+
+
+
+
+ TREADY
+
+
+ m_axis_tready
+
+
+
+
+
+ s_axis
+
+
+
+
+
+
+ TDATA
+
+
+ s_axis_tdata
+
+
+
+
+ TVALID
+
+
+ s_axis_tvalid
+
+
+
+
+ TREADY
+
+
+ s_axis_tready
+
+
+
+
+
+ s_axilite
+
+
+
+
+
+
+
+
+ AWADDR
+
+
+ s_axilite_AWADDR
+
+
+
+
+ AWVALID
+
+
+ s_axilite_AWVALID
+
+
+
+
+ AWREADY
+
+
+ s_axilite_AWREADY
+
+
+
+
+ WDATA
+
+
+ s_axilite_WDATA
+
+
+
+
+ WSTRB
+
+
+ s_axilite_WSTRB
+
+
+
+
+ WVALID
+
+
+ s_axilite_WVALID
+
+
+
+
+ WREADY
+
+
+ s_axilite_WREADY
+
+
+
+
+ BRESP
+
+
+ s_axilite_BRESP
+
+
+
+
+ BVALID
+
+
+ s_axilite_BVALID
+
+
+
+
+ BREADY
+
+
+ s_axilite_BREADY
+
+
+
+
+ ARADDR
+
+
+ s_axilite_ARADDR
+
+
+
+
+ ARVALID
+
+
+ s_axilite_ARVALID
+
+
+
+
+ ARREADY
+
+
+ s_axilite_ARREADY
+
+
+
+
+ RDATA
+
+
+ s_axilite_RDATA
+
+
+
+
+ RRESP
+
+
+ s_axilite_RRESP
+
+
+
+
+ RVALID
+
+
+ s_axilite_RVALID
+
+
+
+
+ RREADY
+
+
+ s_axilite_RREADY
+
+
+
+
+
+ ap_rst_n
+
+
+
+
+
+
+ RST
+
+
+ ap_rst_n
+
+
+
+
+
+ POLARITY
+ ACTIVE_LOW
+
+
+
+
+
+
+ s_axilite
+ s_axilite
+
+ reg0
+ reg0
+ 0x0
+ 4096
+ 32
+ register
+
+
+
+
+
+
+ xilinx_anylanguagesynthesis
+ Synthesis
+ :vivado.xilinx.com:synthesis
+ Verilog
+ thresholding_axi_wrapper
+
+ xilinx_anylanguagesynthesis_view_fileset
+
+
+
+ viewChecksum
+ fd0bd85b
+
+
+
+
+ xilinx_anylanguagebehavioralsimulation
+ Simulation
+ :vivado.xilinx.com:simulation
+ Verilog
+ thresholding_axi_wrapper
+
+ xilinx_anylanguagebehavioralsimulation_view_fileset
+
+
+
+ viewChecksum
+ fd0bd85b
+
+
+
+
+ xilinx_xpgui
+ UI Layout
+ :vivado.xilinx.com:xgui.ui
+
+ xilinx_xpgui_view_fileset
+
+
+
+ viewChecksum
+ fc6b9b63
+
+
+
+
+ xilinx_utilityxitfiles
+ Utility XIT/TTCL
+ :vivado.xilinx.com:xit.util
+
+ xilinx_utilityxitfiles_view_fileset
+
+
+
+ viewChecksum
+ 8b0215cd
+
+
+
+
+
+
+ ap_clk
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ ap_rst_n
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_AWVALID
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_AWREADY
+
+ out
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_AWADDR
+
+ in
+
+ 5
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_WVALID
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_WREADY
+
+ out
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_WDATA
+
+ in
+
+ 31
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_WSTRB
+
+ in
+
+ 3
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 1
+
+
+
+
+ s_axilite_BVALID
+
+ out
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_BREADY
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_BRESP
+
+ out
+
+ 1
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_ARVALID
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_ARREADY
+
+ out
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_ARADDR
+
+ in
+
+ 5
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_RVALID
+
+ out
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_RREADY
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ s_axilite_RDATA
+
+ out
+
+ 31
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axilite_RRESP
+
+ out
+
+ 1
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axis_tready
+
+ out
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axis_tvalid
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ s_axis_tdata
+
+ in
+
+ 15
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 0
+
+
+
+
+ m_axis_tready
+
+ in
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+ 1
+
+
+
+
+ m_axis_tvalid
+
+ out
+
+
+ std_logic
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+ m_axis_tdata
+
+ out
+
+ 7
+ 0
+
+
+
+ std_logic_vector
+ xilinx_anylanguagesynthesis
+ xilinx_anylanguagebehavioralsimulation
+
+
+
+
+
+
+
+ N
+ N
+ 4
+
+
+ K
+ K
+ 16
+
+
+ C
+ C
+ 1
+
+
+ PE
+ Pe
+ 1
+
+
+ SIGNED
+ Signed
+ true
+
+
+ FPARG
+ Fparg
+ false
+
+
+ BIAS
+ Bias
+ 0
+
+
+ CF
+ Cf
+ 1
+
+
+ ADDR_BITS
+ Addr Bits
+ 6
+
+
+ O_BITS
+ O Bits
+ 4
+
+
+
+
+
+ choice_list_9d8b0d81
+ ACTIVE_HIGH
+ ACTIVE_LOW
+
+
+
+
+ xilinx_anylanguagesynthesis_view_fileset
+
+ hdl/thresholding.sv
+ systemVerilogSource
+
+
+ hdl/thresholding_axi.sv
+ systemVerilogSource
+
+
+ hdl/thresholding_axi_wrapper.v
+ verilogSource
+ CHECKSUM_7b8c102d
+
+
+ hdl/axilite_if.v
+ verilogSource
+ CHECKSUM_69d1ba26
+ xil_defaultlib
+
+
+
+ xilinx_anylanguagebehavioralsimulation_view_fileset
+
+ hdl/thresholding.sv
+ systemVerilogSource
+
+
+ hdl/thresholding_axi.sv
+ systemVerilogSource
+
+
+ hdl/thresholding_axi_wrapper.v
+ verilogSource
+
+
+ hdl/axilite_if.v
+ verilogSource
+ USED_IN_ipstatic
+ xil_defaultlib
+
+
+
+ xilinx_xpgui_view_fileset
+
+ xgui/thresholding_axi_v1_0.tcl
+ tclSource
+ CHECKSUM_fc6b9b63
+ XGUI_VERSION_2
+
+
+
+ xilinx_utilityxitfiles_view_fileset
+
+ gui/thresholding_axi_v1_0.gtcl
+ GTCL
+
+
+
+ MultiThreshold
+
+
+ N
+ Output Precision
+ 4
+
+
+ K
+ Input Precision
+ 16
+
+
+ C
+ Channels
+ 1
+
+
+ PE
+ Pe
+ 1
+
+
+ SIGNED
+ Signed Inputs
+ true
+
+
+ FPARG
+ Floating-Point Inputs
+ false
+
+
+ BIAS
+ Bias
+ 0
+
+
+ CF
+ Channel Fold
+ 1
+
+
+
+ false
+
+
+
+
+
+ ADDR_BITS
+ Address Bits
+ 6
+
+
+
+ false
+
+
+
+
+
+ O_BITS
+ Output Value Width
+ 4
+
+
+
+ false
+
+
+
+
+
+ Component_Name
+ thresholding_axi_wrapper_v1_0
+
+
+
+
+
+ virtex7
+ qvirtex7
+ versal
+ kintex7
+ kintex7l
+ qkintex7
+ qkintex7l
+ akintex7
+ artix7
+ artix7l
+ aartix7
+ qartix7
+ zynq
+ qzynq
+ azynq
+ spartan7
+ aspartan7
+ virtexu
+ zynquplus
+ virtexuplus
+ virtexuplusHBM
+ virtexuplus58g
+ kintexuplus
+ artixuplus
+ kintexu
+
+
+ /UserIP
+
+ thresholding_axi
+ level_1
+ package_project
+ 2
+
+ user.org:user:thresholding_axi_wrapper:1.0
+
+ 2023-06-27T05:47:20Z
+
+
+
+
+
+ 2022.2
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/finn-rtllib/thresholding/gui/thresholding_axi_v1_0.gtcl b/finn-rtllib/thresholding/gui/thresholding_axi_v1_0.gtcl
new file mode 100644
index 0000000000..90d73ede7e
--- /dev/null
+++ b/finn-rtllib/thresholding/gui/thresholding_axi_v1_0.gtcl
@@ -0,0 +1,4 @@
+# This file is automatically written. Do not modify.
+proc gen_USERPARAMETER_CF_VALUE {C PE } {expr $C/$PE}
+proc gen_USERPARAMETER_ADDR_BITS_VALUE {C PE N } {expr int(ceil(log($C/$PE)/log(2))+ceil(log($PE)/log(2))+$N+2)}
+proc gen_USERPARAMETER_O_BITS_VALUE {BIAS N } {expr int(ceil($BIAS >= 0? log(pow(2,$N)+$BIAS)/log(2) : 1+log(-$BIAS >= pow(2,$N-1)? -$BIAS : pow(2,$N)+$BIAS)/log(2)))}
diff --git a/finn-rtllib/thresholding/hdl/axilite_if.v b/finn-rtllib/thresholding/hdl/axilite_if.v
new file mode 100644
index 0000000000..bdd4de288e
--- /dev/null
+++ b/finn-rtllib/thresholding/hdl/axilite_if.v
@@ -0,0 +1,210 @@
+/*
+ Copyright (c) 2020, Xilinx
+ All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of FINN nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
+module axi4lite_if
+#(
+ parameter ADDR_WIDTH = 32,
+ parameter DATA_WIDTH = 32,//AXI4 spec requires this to be strictly 32 or 64
+ parameter IP_DATA_WIDTH = 64//can be any power-of-2 multiple of DATA_WIDTH
+)
+(
+//system signals
+input aclk,
+input aresetn,//active low, asynchronous assertion and synchronous deassertion
+
+//Write channels
+//write address
+output reg awready,
+input awvalid,
+input [ADDR_WIDTH-1:0] awaddr,
+input [2:0] awprot,
+//write data
+output reg wready,
+input wvalid,
+input [DATA_WIDTH-1:0] wdata,
+input [(DATA_WIDTH/8)-1:0] wstrb,
+//burst response
+input bready,
+output reg bvalid,
+output reg [1:0] bresp,//NOTE: 00 = OKAY, 10 = SLVERR (write error)
+
+//Read channels
+//read address
+output reg arready,
+input arvalid,
+input [ADDR_WIDTH-1:0] araddr,
+input [2:0] arprot,
+//read data
+input rready,
+output reg rvalid,
+output reg [1:0] rresp,//NOTE: 00 = OKAY, 10 = SLVERR (read error)
+output reg [DATA_WIDTH-1:0] rdata,
+
+//IP-side interface
+output reg ip_en,
+output reg ip_wen,
+output reg [ADDR_WIDTH-1:0] ip_addr,
+output [IP_DATA_WIDTH-1:0] ip_wdata,
+input ip_rack,
+input [IP_DATA_WIDTH-1:0] ip_rdata
+);
+
+localparam RESP_OKAY = 2'b00;
+localparam RESP_SLVERR = 2'b10;
+//get ceil(log2(ceil(IP_DATA_WIDTH/DATA_WIDTH)))
+localparam NFOLDS_LOG = $clog2((IP_DATA_WIDTH + DATA_WIDTH - 1) / DATA_WIDTH);
+
+reg internal_ren;
+reg internal_wen;
+reg internal_wack;
+reg [ADDR_WIDTH-1:0] internal_raddr;
+reg [ADDR_WIDTH-1:0] internal_waddr;
+reg [DATA_WIDTH-1:0] internal_wdata;
+wire [DATA_WIDTH-1:0] internal_rdata;
+reg internal_error = 0;
+
+//check DATA_WIDTH
+initial begin
+ if(DATA_WIDTH != 32 & DATA_WIDTH != 64) begin
+ $display("AXI4Lite DATA_WIDTH must be 32 or 64");
+ $finish;
+ end
+end
+
+//transaction state machine
+localparam STATE_IDLE = 0,
+ STATE_READ = 1,
+ STATE_WRITE = 2;
+
+reg [1:0] state;
+
+always @(posedge aclk or negedge aresetn)
+ if(~aresetn)
+ state <= STATE_IDLE;
+ else case(state)
+ STATE_IDLE:
+ if(awvalid & wvalid)
+ state <= STATE_WRITE;
+ else if(arvalid)
+ state <= STATE_READ;
+ STATE_READ:
+ if(rvalid & rready)
+ state <= STATE_IDLE;
+ STATE_WRITE:
+ if(bvalid & bready)
+ state <= STATE_IDLE;
+ default: state <= STATE_IDLE;
+ endcase
+
+//write-related internal signals
+always @(*) begin
+ internal_waddr = awaddr >> $clog2(DATA_WIDTH/8);
+ internal_wdata = wdata;
+ internal_wen = (state == STATE_IDLE) & awvalid & wvalid;
+end
+
+always @(posedge aclk) begin
+ awready <= internal_wen;
+ wready <= internal_wen;
+end
+
+//read-related internal signals
+always @(*) begin
+ internal_raddr = araddr >> $clog2(DATA_WIDTH/8);
+ internal_ren = (state == STATE_IDLE) & ~internal_wen & arvalid;
+end
+
+always @(posedge aclk)
+ arready <= internal_ren;
+
+wire write_to_last_fold;
+
+always @(posedge aclk) begin
+ ip_wen <= write_to_last_fold;
+ ip_en <= internal_ren | write_to_last_fold;
+ if(internal_ren | write_to_last_fold)
+ ip_addr <= internal_ren ? (internal_raddr >> NFOLDS_LOG) : (internal_waddr >> NFOLDS_LOG);
+ internal_wack <= internal_wen;
+end
+
+genvar i;
+reg [(1<> (internal_rfold*DATA_WIDTH);
+ always @(posedge aclk)
+ if(internal_ren)
+ internal_rfold <= internal_raddr[NFOLDS_LOG-1:0];
+ for(i=0; i<(1<
+ *
+ * @description
+ * Produces the N-bit count of those among 2^N-1 thresholds that are not
+ * larger than the corresponding input:
+ * y = Σ(T_i <= x)
+ * The result is computed by binary search. The runtime-configurable
+ * thresholds must be written in ascending order:
+ * i < j => T_i < T_j
+ * The design supports channel folding allowing each input to be processed
+ * with respect to a selectable set of thresholds. The corresponding
+ * threshold configuration relies on a channel address prefix. Inputs are
+ * accompanied by a channel selector.
+ *
+ * Parameter Layout as seen on AXI-Lite (row by row):
+ * | Base \ Offs | 0 1 2 ... 2^N-2 2^N-1
+ * ---------+--------------------------------+------------------------------------
+ * Chnl #0 | 0 | T_0 T_1 T_2 ... T_{2^N-2} 'x
+ * Chnl #1 | 2^N | T_0 T_1 T_2 ... T_{2^N-2} 'x
+ * Chnl #c | ((c/PE)*$clog2(PE) + c%PE)*2^N | T_0 T_1 T_2 ... T_{2^N-2} 'x
+ *
+ *****************************************************************************/
+module thresholding #(
+ int unsigned N, // output precision
+ int unsigned K, // input/threshold precision
+ int unsigned C, // number of channels
+ int unsigned PE, // parallel processing elements
+
+ bit SIGNED = 1, // signed inputs
+ bit FPARG = 0, // floating-point inputs: [sign] | exponent | mantissa
+ int BIAS = 0, // offsetting the output [0, 2^N-1] -> [BIAS, 2^N-1 + BIAS]
+
+ // Initial Thresholds
+ parameter THRESHOLDS_PATH = "",
+ bit USE_CONFIG = 1,
+
+ // Force Use of On-Chip Memory Blocks
+ int unsigned DEPTH_TRIGGER_URAM = 0, // if non-zero, local mems of this depth or more go into URAM (prio)
+ int unsigned DEPTH_TRIGGER_BRAM = 0, // if non-zero, local mems of this depth or more go into BRAM
+ bit DEEP_PIPELINE = 0,
+
+ localparam int unsigned CF = C/PE, // Channel fold
+ localparam int unsigned O_BITS = BIAS >= 0?
+ /* unsigned */ $clog2(2**N+BIAS) :
+ /* signed */ 1+$clog2(-BIAS >= 2**(N-1)? -BIAS : 2**N+BIAS)
+)(
+ // Global Control
+ input logic clk,
+ input logic rst,
+
+ // Threshold Configuration
+ input logic cfg_en,
+ input logic cfg_we,
+ input logic [$clog2(CF)+$clog2(PE)+N-1:0] cfg_a,
+ input logic [K-1:0] cfg_d,
+ output logic cfg_rack,
+ output logic [K-1:0] cfg_q,
+
+ // Input Stream
+ output logic irdy,
+ input logic ivld,
+ input logic [PE-1:0][K-1:0] idat,
+
+ // Output Stream
+ input logic ordy,
+ output logic ovld,
+ output logic [PE-1:0][O_BITS-1:0] odat
+);
+
+ // Parameter Constraints Checking
+ initial begin
+ if(CF*PE != C) begin
+ $error("Parallelism PE=%0d is not a multiple of channel count C=%0d.", PE, C);
+ $finish;
+ end
+ end
+
+ // Operations within Pipeline
+ typedef enum logic [1:0] {
+ NOP = 2'b00, // No operation
+ TH = 2'b01, // Thresholding
+ WR = 2'b11, // Write (initialization)
+ RB = 2'b10, // Readback (validation)
+ CFG = 2'b1x // Config op (pointer-preserving)
+ } op_e;
+
+ // Pipeline Link Type
+ typedef logic [$clog2(CF)+N-1:0] ptr_t;
+ typedef logic [K -1:0] val_t;
+ typedef struct packed {
+ op_e op;
+ ptr_t ptr; // WR/RB: address; TH: result
+ val_t val; // WR/RB: threshold value; TH: input value
+ } pipe_t;
+
+ //-----------------------------------------------------------------------
+ // Pipeline Feed
+ // - configuration always takes precedence
+ // - number of pending thresholding ops capped to N+3
+ // across pipeline and output FIFO: pipe:N + A:1 + B:1 + 1
+ localparam int unsigned MAX_PENDING = (DEEP_PIPELINE+1)*N + 3;
+ pipe_t pipe[PE][N+1];
+ if(1) begin : blkFeed
+
+ // Thresholding Input Guard ensuring Output FIFO is never overrun
+ logic signed [$clog2(MAX_PENDING):0] GuardSem = MAX_PENDING-1; // MAX_PENDING-1, ..., 0, -1
+ uwire th_full = GuardSem[$left(GuardSem)];
+ always_ff @(posedge clk) begin
+ if(rst) GuardSem <= MAX_PENDING-1;
+ else begin
+ automatic logic dec = !(USE_CONFIG && cfg_en) && !th_full && ivld;
+ automatic logic inc = ovld && ordy;
+ GuardSem <= GuardSem + (inc == dec? 0 : inc? 1 : -1);
+ end
+ end
+
+ // PE Configuration Address Decoding
+ uwire cfg_sel[PE];
+ if(PE == 1) assign cfg_sel[0] = 1;
+ else begin
+ for(genvar pe = 0; pe < PE; pe++) begin
+ assign cfg_sel[pe] = USE_CONFIG && cfg_en && (cfg_a[N+:$clog2(PE)] == pe);
+ end
+ end
+
+ uwire ptr_t iptr;
+ assign iptr[0+:N] = cfg_a[0+:N];
+ if(CF > 1) begin
+ // Channel Fold Rotation
+ logic [$clog2(CF)-1:0] CnlCnt = 0;
+ logic CnlLst = 0;
+ always_ff @(posedge clk) begin
+ if(rst) begin
+ CnlCnt <= 0;
+ CnlLst <= 0;
+ end
+ else if(!(USE_CONFIG && cfg_en) && !th_full && ivld) begin
+ CnlCnt <= CnlCnt + (CnlLst? 1-CF : 1);
+ CnlLst <= CnlCnt == CF-2;
+ end
+ end
+
+ assign iptr[N+:$clog2(CF)] = USE_CONFIG && cfg_en? cfg_a[N+$clog2(PE)+:$clog2(CF)] : CnlCnt;
+ end
+
+ for(genvar pe = 0; pe < PE; pe++) begin
+ assign pipe[pe][0] = '{
+ op: USE_CONFIG && cfg_en?
+ (!cfg_sel[pe]? NOP : cfg_we? WR : RB) :
+ (ivld && !th_full? TH : NOP),
+ ptr: iptr,
+ val: !(USE_CONFIG && cfg_en)? idat[pe] : cfg_we? cfg_d : 0
+ };
+ end
+
+ assign irdy = !(USE_CONFIG && cfg_en) && !th_full;
+ end : blkFeed
+
+ //-----------------------------------------------------------------------
+ // Free-Running Thresholding Pipeline
+ for(genvar stage = 0; stage < N; stage++) begin : genStages
+
+ localparam int unsigned SN = N-1-stage;
+ for(genvar pe = 0; pe < PE; pe++) begin : genPE
+ uwire pipe_t p = pipe[pe][stage];
+ uwire cs = (p.ptr[SN:0] == 2**SN-1);
+
+ // Threshold Memory
+ val_t Thresh; // Read-out register
+ if(1) begin : blkThresh
+ localparam int unsigned DEPTH = CF * 2**stage;
+ localparam RAM_STYLE =
+ DEPTH_TRIGGER_URAM && (DEPTH >= DEPTH_TRIGGER_URAM)? "ultra" :
+ DEPTH_TRIGGER_BRAM && (DEPTH >= DEPTH_TRIGGER_BRAM)? "block" :
+ // If BRAM trigger defined, force distributed memory below if Vivado may be tempted to use BRAM nonetheless.
+ DEPTH_TRIGGER_BRAM && (DEPTH >= 64)? "distributed" : "auto";
+
+ (* RAM_STYLE = RAM_STYLE *)
+ val_t Threshs[DEPTH];
+ if(THRESHOLDS_PATH != "") begin
+ localparam FILE = $sformatf("%s/threshs_%0d_%0d.dat", THRESHOLDS_PATH, pe, stage);
+ initial $readmemh(FILE, Threshs);
+ end
+
+ if(USE_CONFIG) begin : genThreshMem
+ uwire we = (p.op ==? WR) && cs;
+ if((CF == 1) && (stage == 0)) begin
+ always @(posedge clk) begin
+ if(we) Threshs[0] <= p.val;
+ end
+ end
+ else begin
+ uwire [$clog2(CF)+stage-1:0] addr = p.ptr[$clog2(CF)+N-1:SN+1];
+ always @(posedge clk) begin
+ if(we) Threshs[addr] <= p.val;
+ end
+ end
+ end : genThreshMem
+
+ if((CF == 1) && (stage == 0)) begin
+ assign Thresh = Threshs[0];
+ end
+ else begin
+ uwire [$clog2(CF)+stage-1:0] addr = p.ptr[$clog2(CF)+N-1:SN+1];
+ always_ff @(posedge clk) begin
+ Thresh <= Threshs[addr];
+ end
+ end
+
+ end : blkThresh
+
+ // Pipeline State
+ pipe_t P = '{ op: NOP, default: 'x };
+ logic Reval = 0;
+ always_ff @(posedge clk) begin
+ if(rst) begin
+ P <= '{ op: NOP, default: 'x };
+ Reval <= 0;
+ end
+ else begin
+ P <= p;
+ Reval <= (p.op ==? RB) && cs;
+ end
+ end
+
+ logic cmp;
+ if(!SIGNED) assign cmp = $unsigned(Thresh) <= $unsigned(P.val);
+ else if(!FPARG) assign cmp = $signed(Thresh) <= $signed(P.val);
+ else begin : blkSignedFloat
+ uwire mag_eq = Thresh[K-2:0] == P.val[K-2:0];
+ uwire mag_le = Thresh[K-2:0] <= P.val[K-2:0];
+ always_comb begin
+ unique case({Thresh[K-1], P.val[K-1]})
+ 2'b00: cmp = mag_le;
+ 2'b01: cmp = 0;
+ 2'b10: cmp = 1;
+ 2'b11: cmp = !mag_le || mag_eq;
+ default: cmp = 'x;
+ endcase
+ end
+ end : blkSignedFloat
+
+ // Pipeline State Update
+ pipe_t pp;
+ always_comb begin
+ pp = P;
+ if(P.op !=? CFG) pp.ptr[SN] = cmp;
+ if(Reval) pp.val = Thresh;
+ end
+
+ // Pipeline State Forward (potentially additional register)
+ pipe_t pf;
+ if(!DEEP_PIPELINE) assign pf = pp;
+ else begin
+ pipe_t Pf = '{ op: NOP, default: 'x };
+ always_ff @(posedge clk) begin
+ if(rst) Pf <= '{ op: NOP, default: 'x };
+ else Pf <= pp;
+ end
+ assign pf = Pf;
+ end
+
+ assign pipe[pe][stage+1] = pf;
+
+ end : genPE
+ end : genStages
+
+ //-----------------------------------------------------------------------
+ // Configuration Readback
+ always_comb begin
+ cfg_rack = 0;
+ cfg_q = 0;
+ foreach(pipe[pe]) begin
+ automatic pipe_t p = pipe[pe][N];
+ cfg_rack |= p.op ==? RB;
+ cfg_q |= p.val;
+ end
+ end
+
+ //-----------------------------------------------------------------------
+ // Stream Output through FIFO
+ // - Depth of N + Output Reg to allow pipe to drain entirely under backpressure
+ // - Typically mapped to an SRL shift register
+ if(1) begin : blkStreamOutput
+ localparam int unsigned A_DEPTH = MAX_PENDING - 1;
+ logic [PE-1 : 0][N-1 : 0] ADat[A_DEPTH];
+ logic signed [$clog2(A_DEPTH):0] APtr = '1; // -1, 0, 1, ..., A_DEPTH-1
+ uwire avld = !APtr[$left(APtr)];
+
+ logic [PE-1:0][N-1:0] BDat = 'x;
+ logic BVld = 0;
+
+ uwire aload = pipe[0][N].op ==? TH;
+ uwire bload = !BVld || ordy;
+
+ always_ff @(posedge clk) begin
+ if(aload) begin
+ assert(APtr < $signed(A_DEPTH-1)) else begin
+ $error("Overrun after failing stream guard.");
+ $stop;
+ end
+ foreach(pipe[pe]) ADat[0][pe] <= pipe[pe][N].ptr;
+ for(int unsigned i = 1; i < A_DEPTH; i++) ADat[i] <= ADat[i-1];
+ end
+ end
+ always_ff @(posedge clk) begin
+ if(rst) APtr <= '1;
+ else APtr <= APtr + (aload == (avld && bload)? 0 : aload? 1 : -1);
+ end
+ always_ff @(posedge clk) begin
+ if(rst) begin
+ BDat <= 'x;
+ BVld <= 0;
+ end
+ else if(bload) begin
+ BDat <= ADat[APtr];
+ BVld <= avld;
+ end
+ end
+
+ assign ovld = BVld;
+ for(genvar pe = 0; pe < PE; pe++) begin
+ assign odat[pe] = BDat[pe] + BIAS;
+ end
+ end : blkStreamOutput
+
+endmodule : thresholding
diff --git a/finn-rtllib/thresholding/hdl/thresholding_axi.sv b/finn-rtllib/thresholding/hdl/thresholding_axi.sv
new file mode 100644
index 0000000000..1f235b9486
--- /dev/null
+++ b/finn-rtllib/thresholding/hdl/thresholding_axi.sv
@@ -0,0 +1,164 @@
+/******************************************************************************
+ * Copyright (C) 2022, Advanced Micro Devices, Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+ * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+ * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+ * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * @brief All-AXI interface adapter for thresholding module.
+ * @author Thomas B. Preußer
+ *
+ * @description
+ * This AXI adapter fits the core thresholding functionality:
+ * - with AXI stream data interfaces with flow control
+ * - with implicit round-robin channel rotation as used by FINN, and
+ * - performs aligned byte address to parameter word address translation.
+ *****************************************************************************/
+
+module thresholding_axi #(
+ int unsigned N, // output precision
+ int unsigned K, // input/threshold precision
+ int unsigned C = 1, // Channels
+ int unsigned PE = 1, // Processing Parallelism, requires C = k*PE
+
+ bit SIGNED = 1, // signed inputs
+ bit FPARG = 0, // floating-point inputs: [sign] | exponent | mantissa
+ int BIAS = 0, // offsetting the output [0, 2^N-1] -> [BIAS, 2^N-1 + BIAS]
+
+ // Initial Thresholds
+ parameter THRESHOLDS_PATH = "",
+
+ bit USE_AXILITE, // Implement AXI-Lite for threshold read/write
+
+ // Force Use of On-Chip Memory Blocks
+ int unsigned DEPTH_TRIGGER_URAM = 0, // if non-zero, local mems of this depth or more go into URAM (prio)
+ int unsigned DEPTH_TRIGGER_BRAM = 0, // if non-zero, local mems of this depth or more go into BRAM
+ bit DEEP_PIPELINE = 0,
+
+ localparam int unsigned CF = C/PE, // Channel Fold
+ localparam int unsigned ADDR_BITS = $clog2(CF) + $clog2(PE) + N + 2,
+ localparam int unsigned O_BITS = BIAS >= 0?
+ /* unsigned */ $clog2(2**N+BIAS) :
+ /* signed */ 1+$clog2(-BIAS >= 2**(N-1)? -BIAS : 2**N+BIAS)
+)(
+ //- Global Control ------------------
+ input logic ap_clk,
+ input logic ap_rst_n,
+
+ //- AXI Lite ------------------------
+ // Writing
+ input logic s_axilite_AWVALID,
+ output logic s_axilite_AWREADY,
+ input logic [ADDR_BITS-1:0] s_axilite_AWADDR, // lowest 2 bits (byte selectors) are ignored
+
+ input logic s_axilite_WVALID,
+ output logic s_axilite_WREADY,
+ input logic [31:0] s_axilite_WDATA,
+ input logic [ 3:0] s_axilite_WSTRB,
+
+ output logic s_axilite_BVALID,
+ input logic s_axilite_BREADY,
+ output logic [1:0] s_axilite_BRESP,
+
+ // Reading
+ input logic s_axilite_ARVALID,
+ output logic s_axilite_ARREADY,
+ input logic [ADDR_BITS-1:0] s_axilite_ARADDR,
+
+ output logic s_axilite_RVALID,
+ input logic s_axilite_RREADY,
+ output logic [31:0] s_axilite_RDATA,
+ output logic [ 1:0] s_axilite_RRESP,
+
+ //- AXI Stream - Input --------------
+ output logic s_axis_tready,
+ input logic s_axis_tvalid,
+ input logic [((PE*K+7)/8)*8-1:0] s_axis_tdata,
+
+ //- AXI Stream - Output -------------
+ input logic m_axis_tready,
+ output logic m_axis_tvalid,
+ output logic [((PE*O_BITS+7)/8)*8-1:0] m_axis_tdata
+);
+
+ //-----------------------------------------------------------------------
+ // AXI-lite Configuration Interface
+ uwire cfg_en;
+ uwire cfg_we;
+ uwire [ADDR_BITS-3:0] cfg_a;
+ uwire [K -1:0] cfg_d;
+ uwire cfg_rack;
+ uwire [K -1:0] cfg_q;
+
+ if(USE_AXILITE) begin
+ uwire [ADDR_BITS-1:0] cfg_a0;
+ axi4lite_if #(.ADDR_WIDTH(ADDR_BITS), .DATA_WIDTH(32), .IP_DATA_WIDTH(K)) axi (
+ .aclk(ap_clk), .aresetn(ap_rst_n),
+
+ .awready(s_axilite_AWREADY), .awvalid(s_axilite_AWVALID), .awaddr(s_axilite_AWADDR), .awprot('x),
+ .wready(s_axilite_WREADY), .wvalid(s_axilite_WVALID), .wdata(s_axilite_WDATA), .wstrb(s_axilite_WSTRB),
+ .bready(s_axilite_BREADY), .bvalid(s_axilite_BVALID), .bresp(s_axilite_BRESP),
+
+ .arready(s_axilite_ARREADY), .arvalid(s_axilite_ARVALID), .araddr(s_axilite_ARADDR), .arprot('x),
+ .rready(s_axilite_RREADY), .rvalid(s_axilite_RVALID), .rresp(s_axilite_RRESP), .rdata(s_axilite_RDATA),
+
+ .ip_en(cfg_en), .ip_wen(cfg_we), .ip_addr(cfg_a0), .ip_wdata(cfg_d),
+ .ip_rack(cfg_rack), .ip_rdata(cfg_q)
+ );
+ assign cfg_a = cfg_a0[ADDR_BITS-3:0];
+ always_ff @(posedge ap_clk) begin
+ assert(!ap_rst_n || !cfg_en || (cfg_a0[ADDR_BITS-2+:2] === 3'h0)) else begin
+ $error("%m: Spurious high address bits.");
+ $stop;
+ end
+ end
+ end
+ else begin
+ assign cfg_en = 0;
+ assign cfg_we = 'x;
+ assign cfg_a = 'x;
+ assign cfg_d = 'x;
+ end
+
+ //-----------------------------------------------------------------------
+ // Kernel Implementation
+ thresholding #(
+ .N(N), .K(K), .C(C), .PE(PE),
+ .SIGNED(SIGNED), .FPARG(FPARG), .BIAS(BIAS),
+ .THRESHOLDS_PATH(THRESHOLDS_PATH), .USE_CONFIG(USE_AXILITE),
+ .DEPTH_TRIGGER_URAM(DEPTH_TRIGGER_URAM), .DEPTH_TRIGGER_BRAM(DEPTH_TRIGGER_BRAM),
+ .DEEP_PIPELINE(DEEP_PIPELINE)
+ ) impl (
+ .clk(ap_clk), .rst(!ap_rst_n),
+
+ .cfg_en, .cfg_we, .cfg_a, .cfg_d,
+ .cfg_rack, .cfg_q,
+
+ .irdy(s_axis_tready), .ivld(s_axis_tvalid), .idat(s_axis_tdata),
+ .ordy(m_axis_tready), .ovld(m_axis_tvalid), .odat(m_axis_tdata)
+ );
+
+endmodule : thresholding_axi
diff --git a/finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v b/finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v
new file mode 100644
index 0000000000..3f0b012ef1
--- /dev/null
+++ b/finn-rtllib/thresholding/hdl/thresholding_template_wrapper.v
@@ -0,0 +1,120 @@
+/**
+ * Copyright (c) 2023, Xilinx
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * * Neither the name of FINN nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * @author Thomas B. Preußer
+ * @brief Verilog wrapper for IP packaging.
+ */
+
+module thresholding_template_wrapper #(
+ parameter N = $N$, // output precision
+ parameter K = $M$, // input/threshold precision
+ parameter C = $C$, // Channels
+ parameter PE = $PE$, // Processing Parallelism, requires C = k*PE
+
+ parameter SIGNED = $SIGNED$, // signed inputs
+ parameter FPARG = 0, // floating-point inputs: [sign] | exponent | mantissa
+ parameter BIAS = $BIAS$, // offsetting the output [0, 2^N-1] -> [BIAS, 2^N-1 + BIAS]
+
+ parameter THRESHOLDS_PATH = $THRESHOLDS_PATH$, // Directory with initial threshold data
+ parameter USE_AXILITE = $USE_AXILITE$, // Implement AXI-Lite for threshold read/write
+
+ // Force Use of On-Chip Memory Blocks
+ parameter DEPTH_TRIGGER_URAM = $DEPTH_TRIGGER_URAM$, // if non-zero, local mems of this depth or more go into URAM (prio)
+ parameter DEPTH_TRIGGER_BRAM = $DEPTH_TRIGGER_BRAM$, // if non-zero, local mems of this depth or more go into BRAM
+ parameter DEEP_PIPELINE = $DEEP_PIPELINE$, // [bit] extra pipeline stages for easier timing closure
+
+ parameter O_BITS = $O_BITS$
+)(
+ // Global Control
+ (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF s_axilite:in0_V:out_V, ASSOCIATED_RESET ap_rst_n" *)
+ (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *)
+ input ap_clk,
+ (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *)
+ input ap_rst_n,
+
+ //- AXI Lite ------------------------
+ // Writing
+ input s_axilite_AWVALID,
+ output s_axilite_AWREADY,
+ input [$clog2(C/PE) + $clog2(PE) + N + 1:0] s_axilite_AWADDR, // lowest 2 bits (byte selectors) are ignored
+
+ input s_axilite_WVALID,
+ output s_axilite_WREADY,
+ input [31:0] s_axilite_WDATA,
+ input [ 3:0] s_axilite_WSTRB,
+
+ output s_axilite_BVALID,
+ input s_axilite_BREADY,
+ output [1:0] s_axilite_BRESP,
+
+ // Reading
+ input s_axilite_ARVALID,
+ output s_axilite_ARREADY,
+ input [$clog2(C/PE) + $clog2(PE) + N + 1:0] s_axilite_ARADDR,
+
+ output s_axilite_RVALID,
+ input s_axilite_RREADY,
+ output [31:0] s_axilite_RDATA,
+ output [ 1:0] s_axilite_RRESP,
+
+ //- AXI Stream - Input --------------
+ output in0_V_tready,
+ input in0_V_tvalid,
+ input [((PE*K+7)/8)*8-1:0] in0_V_tdata,
+
+ //- AXI Stream - Output -------------
+ input out_V_tready,
+ output out_V_tvalid,
+ output [((PE*O_BITS+7)/8)*8-1:0] out_V_tdata
+);
+
+ thresholding_axi #(
+ .N(N), .K(K), .C(C), .PE(PE),
+ .SIGNED(SIGNED),
+ .FPARG(FPARG),
+ .BIAS(BIAS),
+ .THRESHOLDS_PATH(THRESHOLDS_PATH),
+ .USE_AXILITE(USE_AXILITE),
+ .DEPTH_TRIGGER_URAM(DEPTH_TRIGGER_URAM),
+ .DEPTH_TRIGGER_BRAM(DEPTH_TRIGGER_BRAM),
+ .DEEP_PIPELINE(DEEP_PIPELINE)
+ ) core (
+ .ap_clk(ap_clk), .ap_rst_n(ap_rst_n),
+
+ .s_axilite_AWVALID(s_axilite_AWVALID), .s_axilite_AWREADY(s_axilite_AWREADY), .s_axilite_AWADDR(s_axilite_AWADDR),
+ .s_axilite_WVALID(s_axilite_WVALID), .s_axilite_WREADY(s_axilite_WREADY), .s_axilite_WDATA(s_axilite_WDATA), .s_axilite_WSTRB(s_axilite_WSTRB),
+ .s_axilite_BVALID(s_axilite_BVALID), .s_axilite_BREADY(s_axilite_BREADY), .s_axilite_BRESP(s_axilite_BRESP),
+
+ .s_axilite_ARVALID(s_axilite_ARVALID), .s_axilite_ARREADY(s_axilite_ARREADY), .s_axilite_ARADDR(s_axilite_ARADDR),
+ .s_axilite_RVALID(s_axilite_RVALID), .s_axilite_RREADY(s_axilite_RREADY), .s_axilite_RDATA(s_axilite_RDATA), .s_axilite_RRESP(s_axilite_RRESP),
+ .s_axis_tready(in0_V_tready), .s_axis_tvalid(in0_V_tvalid), .s_axis_tdata(in0_V_tdata),
+ .m_axis_tready(out_V_tready), .m_axis_tvalid(out_V_tvalid), .m_axis_tdata(out_V_tdata)
+ );
+
+endmodule // thresholding_template_wrapper
diff --git a/finn-rtllib/thresholding/sim/thresh_gen.sv b/finn-rtllib/thresholding/sim/thresh_gen.sv
new file mode 100644
index 0000000000..a8a18be691
--- /dev/null
+++ b/finn-rtllib/thresholding/sim/thresh_gen.sv
@@ -0,0 +1,45 @@
+module thresh_gen;
+ localparam int unsigned K = 9;
+ localparam int unsigned N = 4;
+ localparam int unsigned C = 6;
+
+ typedef logic [K-1:0] thresh_t;
+ localparam thresh_t THRESHOLDS[C][2**N-1] = '{
+ '{ 'h00, 'h01, 'h02, 'h03, 'h04, 'h05, 'h06, 'h07, 'h08, 'h09, 'h0a, 'h0b, 'h0c, 'h0d, 'h0e },
+ '{ 'h10, 'h11, 'h12, 'h13, 'h14, 'h15, 'h16, 'h17, 'h18, 'h19, 'h1a, 'h1b, 'h1c, 'h1d, 'h1e },
+ '{ 'h20, 'h21, 'h22, 'h23, 'h24, 'h25, 'h26, 'h27, 'h28, 'h29, 'h2a, 'h2b, 'h2c, 'h2d, 'h2e },
+ '{ 'h30, 'h31, 'h32, 'h33, 'h34, 'h35, 'h36, 'h37, 'h38, 'h39, 'h3a, 'h3b, 'h3c, 'h3d, 'h3e },
+ '{ 'h40, 'h41, 'h42, 'h43, 'h44, 'h45, 'h46, 'h47, 'h48, 'h49, 'h4a, 'h4b, 'h4c, 'h4d, 'h4e },
+ '{ 'h50, 'h51, 'h52, 'h53, 'h54, 'h55, 'h56, 'h57, 'h58, 'h59, 'h5a, 'h5b, 'h5c, 'h5d, 'h5e }
+ };
+ localparam THRESHOLDS_PATH = ".";
+
+ localparam int unsigned PE = 2;
+ localparam int unsigned CF = C/PE;
+
+ for(genvar stage = 0; stage < N; stage++) begin
+ localparam int unsigned SN = N-1-stage;
+ for(genvar pe = 0; pe < PE; pe++) begin
+ initial begin
+ automatic string file = $sformatf("%s/threshs_%0d_%0d.dat", THRESHOLDS_PATH, pe, stage);
+
+ automatic thresh_t threshs[CF * 2**stage];
+ for(int unsigned c = 0; c < CF; c++) begin
+ for(int unsigned i = 0; i < 2**stage; i++) begin
+ threshs[(c << stage) + i] = THRESHOLDS[c*PE + pe][(i<<(N-stage)) + 2**SN-1];
+ end
+ end
+
+ $writememh(file, threshs);
+ end
+ end
+ end
+
+ // Quit after running all initializers
+ initial begin
+ #1ns;
+ $display("Generation done.");
+ $finish;
+ end
+
+endmodule : thresh_gen
diff --git a/finn-rtllib/thresholding/sim/thresholding.tcl b/finn-rtllib/thresholding/sim/thresholding.tcl
new file mode 100644
index 0000000000..82dc59deb1
--- /dev/null
+++ b/finn-rtllib/thresholding/sim/thresholding.tcl
@@ -0,0 +1,17 @@
+create_project -force thresholding thresholding.vivado -part xcvc1902-vsva2197-2MP-e-S
+set_property board_part xilinx.com:vck190:part0:2.2 [current_project]
+
+read_verilog hdl/axilite_if.v
+read_verilog -sv { hdl/thresholding.sv hdl/thresholding_axi.sv }
+
+set simset [current_fileset -simset]
+set_property -name xsim.simulate.log_all_signals -value true -objects $simset
+set_property -name xsim.simulate.runtime -value all -objects $simset
+add_files -fileset $simset { sim/thresholding_tb.sv sim/thresholding_axi_tb.sv }
+
+foreach top { thresholding_tb thresholding_axi_tb } {
+ set_property top $top $simset
+
+ launch_simulation
+ close_sim
+}
diff --git a/finn-rtllib/thresholding/sim/thresholding_axi_tb.sv b/finn-rtllib/thresholding/sim/thresholding_axi_tb.sv
new file mode 100644
index 0000000000..918f539d15
--- /dev/null
+++ b/finn-rtllib/thresholding/sim/thresholding_axi_tb.sv
@@ -0,0 +1,314 @@
+/******************************************************************************
+ * Copyright (C) 2022, Advanced Micro Devices, Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+ * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+ * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+ * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * @brief Testbench for thresholding_axi.
+ * @author Monica Chiosa
+ *
+ */
+
+module thresholding_axi_tb #(
+ int unsigned N = 4, // output precision
+ int unsigned C = 6, // number of channels
+ int unsigned PE = 2,
+ real M0 = 7.3, // slope of the uniform thresholding line
+ real B0 = 3.1, // offset of the uniform thresholding line
+ bit THROTTLED = 1,
+
+ localparam int unsigned CF = C/PE, // Channel Fold
+ localparam int unsigned ADDR_BITS = $clog2(CF) + $clog2(PE) + N + 2
+);
+
+ //-----------------------------------------------------------------------
+ // Design Geometry
+
+ // For each channel = [0,channel):
+ // M_channel = M0 + CX*channel
+ // B_channel = B0 + CX*channel
+ // Input/threshold precision computed according with the maximum posible value
+ localparam real CX = 1.375;
+ localparam int unsigned K = $clog2((2**N-1)*(M0+C*CX) + (B0+C*CX)); // unused sign + magnitude
+ localparam int unsigned C_BITS = C < 2? 1 : $clog2(C);
+
+ localparam int unsigned MST_STRM_WROUNDS = 503;
+
+ typedef int unsigned threshs_t[C][2**N-1];
+ function threshs_t init_thresholds();
+ automatic threshs_t res;
+ for(int unsigned c = 0; c < C; c++) begin
+ automatic real m = M0 + c*CX;
+ automatic real b = B0 + c*CX;
+ foreach(res[c][i]) begin
+ res[c][i] = int'($ceil(m*i + b));
+ end
+ end
+ return res;
+ endfunction : init_thresholds
+ localparam threshs_t THRESHS = init_thresholds();
+
+ //-----------------------------------------------------------------------
+ // Clock and Reset Control
+ logic clk = 0;
+ always #5ns clk = !clk;
+ logic rst = 1;
+ initial begin
+ #10ns;
+ @(posedge clk);
+ rst <= 0;
+ end
+
+ //-----------------------------------------------------------------------
+ // DUT
+ logic s_axilite_AWVALID;
+ uwire s_axilite_AWREADY;
+ logic [ADDR_BITS-1:0] s_axilite_AWADDR; // lowest 2 bits (byte selectors) are ignored
+ logic s_axilite_WVALID;
+ uwire s_axilite_WREADY;
+ logic [ 31:0] s_axilite_WDATA;
+ uwire s_axilite_BVALID;
+ logic s_axilite_BREADY;
+ uwire [ 1:0] s_axilite_BRESP;
+ logic s_axilite_ARVALID;
+ uwire s_axilite_ARREADY;
+ logic [ADDR_BITS-1:0] s_axilite_ARADDR;
+ uwire s_axilite_RVALID;
+ uwire s_axilite_RREADY = 1;
+ uwire [ 31:0] s_axilite_RDATA;
+ uwire [ 1:0] s_axilite_RRESP;
+
+ uwire irdy;
+ logic ivld;
+ logic [PE-1:0][K-1:0] idat;
+
+ logic ordy = 0;
+ uwire ovld;
+ uwire [PE-1:0][N-1:0] odat;
+
+ thresholding_axi #(.N(N), .K(K), .C(C), .PE(PE), .SIGNED(0), .USE_AXILITE(1)) dut (
+ .ap_clk(clk), .ap_rst_n(!rst),
+
+ // Configuration
+ .s_axilite_AWVALID, .s_axilite_AWREADY, .s_axilite_AWADDR,
+ .s_axilite_WVALID, .s_axilite_WREADY, .s_axilite_WDATA, .s_axilite_WSTRB('1),
+ .s_axilite_BVALID, .s_axilite_BREADY, .s_axilite_BRESP,
+ .s_axilite_ARVALID, .s_axilite_ARREADY, .s_axilite_ARADDR,
+ .s_axilite_RVALID, .s_axilite_RREADY, .s_axilite_RDATA, .s_axilite_RRESP,
+
+ // Stream Processing
+ .s_axis_tready(irdy), .s_axis_tvalid(ivld), .s_axis_tdata(idat),
+ .m_axis_tready(ordy), .m_axis_tvalid(ovld), .m_axis_tdata(odat)
+ );
+
+ //-----------------------------------------------------------------------
+ // Input Stimuli
+ typedef logic [PE-1:0][K-1:0] input_t;
+ typedef logic [$clog2(CF)+$clog2(PE)+N-1:0] addr_t;
+ input_t QW[$]; // Input Feed Tracing
+ addr_t QC[$];
+
+ int unsigned error_cnt = 0;
+ bit done = 0;
+ initial begin
+ // Report testbench details
+ $display("Testbench - tresholding K=%0d -> N=%0d", K, N);
+ for(int unsigned c = 0; c < C; c++) begin
+ $write("Channel #%0d: Thresholds = {", c);
+ for(int unsigned i = 0; i < 2**N-1; i++) $write(" %0d", THRESHS[c][i]);
+ $display(" }");
+ end
+
+ // Config
+ s_axilite_AWVALID = 0;
+ s_axilite_AWADDR = 'x;
+ s_axilite_WVALID = 0;
+ s_axilite_WDATA = 'x;
+ s_axilite_BREADY = 0;
+ s_axilite_ARVALID = 0;
+ s_axilite_ARADDR = 'x;
+
+ // Stream Input
+ ivld = 0;
+ idat = 'x;
+
+ @(posedge clk iff !rst);
+
+ // Threshold Configuration
+ for(int unsigned c = 0; c < C; c+=PE) begin
+ automatic addr_t addr = 0;
+ if(CF > 1) addr[N+$clog2(PE)+:$clog2(CF)] = c/PE;
+ for(int unsigned pe = 0; pe < PE; pe++) begin
+ if(PE > 1) addr[N+:$clog2(PE)] = pe;
+ for(int unsigned t = 0; t < 2**N-1; t++) begin
+ addr[0+:N] = t;
+ fork
+ begin
+ s_axilite_AWVALID <= 1;
+ s_axilite_AWADDR <= { addr, 2'b00 };
+ @(posedge clk iff s_axilite_AWREADY);
+ s_axilite_AWVALID <= 0;
+ s_axilite_AWADDR <= 'x;
+ end
+ begin
+ s_axilite_WVALID <= 1;
+ s_axilite_WDATA <= THRESHS[c+pe][t];
+ @(posedge clk iff s_axilite_WREADY);
+ s_axilite_WVALID <= 0;
+ s_axilite_WDATA <= 'x;
+ end
+ begin
+ s_axilite_BREADY <= 1;
+ @(posedge clk iff s_axilite_BVALID);
+ assert(s_axilite_BRESP == '0) else begin
+ $error("Error on parameter write.");
+ $stop;
+ end
+ s_axilite_BREADY <= 0;
+ end
+ join
+ end
+ end
+ end
+
+ fork
+ // Intermittent configuration readback
+ while(!done) begin
+ if(($urandom()%37) != 0) begin
+ s_axilite_ARVALID <= 0;
+ s_axilite_ARADDR <= 'x;
+ @(posedge clk);
+ end
+ else begin
+ automatic addr_t addr = $urandom()%(N-1);
+ if(PE > 1) addr[N+:$clog2(PE)] = $urandom()%PE;
+ if(CF > 1) addr[N+$clog2(PE)+:$clog2(CF)] = $urandom()%CF;
+
+ s_axilite_ARVALID <= 1;
+ s_axilite_ARADDR <= { addr, 2'b00 };
+ @(posedge clk iff s_axilite_ARREADY);
+
+ QC.push_back(addr);
+ end
+ end
+
+ // AXI4Stream MST Writes input values
+ repeat(MST_STRM_WROUNDS) begin
+ automatic input_t dat;
+
+ while(THROTTLED && ($urandom()%7 == 0)) @(posedge clk);
+
+ std::randomize(dat);
+ ivld <= 1;
+ idat <= dat;
+ @(posedge clk iff irdy);
+ ivld <= 0;
+ idat <= 'x;
+ QW.push_back(dat);
+ end
+ join_any
+ done <= 1;
+ repeat(N+6) @(posedge clk);
+
+ assert(QW.size() == 0) else begin
+ $error("Missing %0d outputs.", QW.size());
+ $stop;
+ end
+ assert(QC.size() == 0) else begin
+ $error("Missing %0d readback replies.", QC.size());
+ $stop;
+ end
+
+ $display("Test completed: %0d errors in %0d tests.", error_cnt, MST_STRM_WROUNDS);
+ $display("=========================================");
+ $finish;
+ end
+
+ // Output Checker -------------------------------------------------------
+
+ // Configuration Readback
+ always_ff @(posedge clk iff s_axilite_RVALID) begin
+ assert(s_axilite_RRESP == '0) else begin
+ $error("Read back error.");
+ $stop;
+ end
+ assert(QC.size()) begin
+ automatic addr_t addr = QC.pop_front();
+ automatic int unsigned cnl =
+ (CF == 1? 0 : addr[N+$clog2(PE)+:$clog2(CF)] * PE) +
+ (PE == 1? 0 : addr[N+:$clog2(PE)]);
+ automatic logic [K-1:0] exp = THRESHS[cnl][addr[0+:N]];
+ assert(s_axilite_RDATA == exp) else begin
+ $error("Readback mismatch on #%0d.%0d: %0d instead of %0d", cnl, addr[0+:N], s_axilite_RDATA, exp);
+ $stop;
+ end
+ end
+ else begin
+ $error("Spurious readback output.");
+ $stop;
+ end
+ end
+
+ // Stream Output
+ int unsigned OCnl = 0;
+ always @(posedge clk) begin
+ if(rst) begin
+ OCnl <= 0;
+ ordy <= 1'b0;
+ end
+ else begin
+ if(!ordy || ovld) ordy <= ($urandom()%5 != 0) || !THROTTLED;
+
+ if(ordy && ovld) begin
+ assert(QW.size()) begin
+ automatic input_t x = QW.pop_front();
+
+ for(int unsigned pe = 0; pe < PE; pe++) begin
+ automatic int unsigned cnl = OCnl + pe;
+
+ $display("Mapped CNL=%0d DAT=%3d -> #%2d", cnl, x[pe], odat[pe]);
+ assert(
+ ((odat[pe] == 0) || (THRESHS[cnl][odat[pe]-1] <= x[pe])) &&
+ ((odat[pe] == 2**N-1) || (x[pe] < THRESHS[cnl][odat[pe]]))
+ ) else begin
+ $error("Output error on presumed input CNL=%0d DAT=0x%0x -> #%0d", cnl, x[pe], odat[pe]);
+ error_cnt++;
+ $stop;
+ end
+ end
+ end
+ else begin
+ $error("Spurious output.");
+ $stop;
+ end
+
+ OCnl <= (OCnl + PE)%C;
+ end
+ end
+ end
+
+endmodule: thresholding_axi_tb
diff --git a/finn-rtllib/thresholding/sim/thresholding_tb.sv b/finn-rtllib/thresholding/sim/thresholding_tb.sv
new file mode 100644
index 0000000000..e42145f10e
--- /dev/null
+++ b/finn-rtllib/thresholding/sim/thresholding_tb.sv
@@ -0,0 +1,274 @@
+/******************************************************************************
+ * Copyright (C) 2022, Advanced Micro Devices, Inc.
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+ * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+ * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+ * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * @brief Testbench for thresholding_axi.
+ * @author Monica Chiosa
+ *
+ */
+
+module thresholding_tb #(
+ int unsigned K = 10, // input precision
+ int unsigned N = 4, // output precision
+ int unsigned C = 6, // number of channels
+ int unsigned PE = 2,
+
+ localparam int unsigned CF = C/PE // Channel Fold
+);
+ localparam bit DEEP_PIPELINE = 1;
+
+ localparam int unsigned MST_STRM_WROUNDS = 507;
+ localparam bit THROTTLED = 1;
+
+ //-----------------------------------------------------------------------
+ // Clock and Reset Control
+ logic clk = 0;
+ always #5ns clk = !clk;
+ logic rst = 1;
+ initial begin
+ #10ns;
+ @(posedge clk);
+ rst <= 0;
+ end
+
+ //-----------------------------------------------------------------------
+ // Parallel Instances differing in Data Type
+ typedef logic [K -1:0] val_t;
+ typedef val_t threshs_t[C][2**N-1];
+ typedef val_t [PE-1:0] input_t;
+ typedef logic [$clog2(CF)+$clog2(PE)+N-1:0] addr_t;
+ logic [0:2] term = '0;
+ always_comb begin
+ if(&term) $finish;
+ end
+ for(genvar i = 0; i < 3; i++) begin : genTypes
+ localparam bit SIGNED = i>0;
+ localparam bit FPARG = i>1;
+
+ //- DUT -------------------------
+ logic cfg_en;
+ logic cfg_we;
+ logic [$clog2(C)+N-1:0] cfg_a;
+ logic [K-1:0] cfg_d;
+ uwire cfg_rack;
+ uwire [K-1:0] cfg_q;
+
+ uwire irdy;
+ logic ivld;
+ logic [PE-1:0][K-1:0] idat;
+
+ logic ordy = 0;
+ uwire ovld;
+ uwire [PE-1:0][N-1:0] odat;
+
+ thresholding #(.N(N), .K(K), .C(C), .PE(PE), .SIGNED(SIGNED), .FPARG(FPARG), .USE_CONFIG(1), .DEEP_PIPELINE(DEEP_PIPELINE)) dut (
+ .clk, .rst,
+
+ // Configuration
+ .cfg_en, .cfg_we, .cfg_a, .cfg_d,
+ .cfg_rack, .cfg_q,
+
+ // Stream Processing
+ .irdy, .ivld, .idat,
+ .ordy, .ovld, .odat
+ );
+
+ //- Stimulus Driver -------------
+ threshs_t THRESHS;
+ function val_t sigord(input val_t x);
+ automatic val_t res = x;
+ if(SIGNED) begin
+ if(FPARG && x[K-1]) res[K-2:0] = ~x[K-2:0];
+ res[K-1] = !x[K-1];
+ end
+ return res;
+ endfunction : sigord
+
+ input_t QW[$]; // Input tracing
+ addr_t QC[$]; // Readback tracking
+ int unsigned error_cnt = 0;
+ bit done = 0;
+ initial begin
+
+ // Generate thresholds
+ std::randomize(THRESHS);
+ foreach(THRESHS[c]) begin
+ val_t row[2**N-1] = THRESHS[c];
+ row.sort with (sigord(item));
+ THRESHS[c] = row;
+ end
+
+ // Report test case details
+ $display("[%0d] Thresholding %s%s%0d -> uint%0d", i, SIGNED? "s" : "u", FPARG? "fp" : "int", K, N);
+ for(int unsigned c = 0; c < C; c++) begin
+ $write("[%0d] Channel #%0d: Thresholds = {", i, c);
+ for(int unsigned i = 0; i < 2**N-1; i++) $write(" %0X", THRESHS[c][i]);
+ $display(" }");
+ end
+
+ // Config
+ cfg_en = 0;
+ cfg_we = 'x;
+ cfg_a = 'x;
+ cfg_d = 'x;
+
+ // Stream Input
+ ivld = 0;
+ idat = 'x;
+
+ @(posedge clk iff !rst);
+
+ // Threshold Configuratin
+ cfg_en <= 1;
+ cfg_we <= 1;
+ for(int unsigned c = 0; c < C; c+=PE) begin
+ if(CF > 1) cfg_a[N+$clog2(PE)+:$clog2(CF)] <= c/PE;
+ for(int unsigned pe = 0; pe < PE; pe++) begin
+ if(PE > 1) cfg_a[N+:$clog2(PE)] = pe;
+ for(int unsigned t = 0; t < 2**N-1; t++) begin
+ cfg_a[0+:N] <= t;
+ cfg_d <= THRESHS[c+pe][t];
+ @(posedge clk);
+ end
+ end
+ end
+ cfg_d <= 'x;
+
+ fork
+ // Intermittent configuration readback
+ while(!done) begin
+ cfg_en <= 0;
+ cfg_we <= 'x;
+ cfg_a <= 'x;
+ @(posedge clk);
+ if(($urandom()%41) == 0) begin
+ automatic addr_t addr = $urandom()%(N-1);
+ if(PE > 1) addr[N+:$clog2(PE)] = $urandom()%PE;
+ if(CF > 1) addr[N+$clog2(PE)+:$clog2(CF)] = $urandom()%CF;
+
+ cfg_en <= 1;
+ cfg_we <= 0;
+ cfg_a <= addr;
+ @(posedge clk);
+ QC.push_back(addr);
+ end
+ end
+
+ // AXI4Stream MST Writes input values
+ repeat(MST_STRM_WROUNDS) begin
+ automatic input_t dat;
+
+ while(THROTTLED && ($urandom()%7 == 0)) @(posedge clk);
+
+ std::randomize(dat);
+ ivld <= 1;
+ idat <= dat;
+ @(posedge clk iff irdy);
+ ivld <= 0;
+ idat <= 'x;
+ QW.push_back(dat);
+ end
+ join_any
+ done <= 1;
+ repeat((DEEP_PIPELINE+1)*N+6) @(posedge clk);
+
+ assert(QW.size() == 0) else begin
+ $error("[%0d] Missing %0d outputs.", i, QW.size());
+ $stop;
+ end
+ assert(QC.size() == 0) else begin
+ $error("[%0d] Missing %0d readback replies.", i, QC.size());
+ $stop;
+ end
+
+ $display("[%0d] Test completed: %0d errors in %0d tests.", i, error_cnt, MST_STRM_WROUNDS);
+ $display("=============================================");
+ term[i] <= 1;
+ end
+
+ //- Readback Checker --------------
+ always_ff @(posedge clk iff cfg_rack) begin
+ assert(QC.size()) begin
+ automatic addr_t addr = QC.pop_front();
+ automatic int unsigned cnl =
+ (CF == 1? 0 : addr[N+$clog2(PE)+:$clog2(CF)] * PE) +
+ (PE == 1? 0 : addr[N+:$clog2(PE)]);
+ automatic logic [K-1:0] exp = THRESHS[cnl][addr[0+:N]];
+ assert(cfg_q == exp) else begin
+ $error("[%0d] Readback mismatch on #%0d.%0d: %0d instead of %0d", i, cnl, addr[0+:N], cfg_q, exp);
+ $stop;
+ end
+ end
+ else begin
+ $error("[%0d] Spurious readback output.", i);
+ $stop;
+ end
+ end
+
+ // Output Checker
+ int unsigned OCnl = 0;
+ always @(posedge clk) begin
+ if(rst) begin
+ OCnl <= 0;
+ ordy <= 1'b0;
+ end
+ else begin
+ if(!ordy || ovld) ordy <= ($urandom()%5 != 0) || !THROTTLED;
+
+ if(ordy && ovld) begin
+ assert(QW.size()) begin
+ automatic input_t x = QW.pop_front();
+
+ for(int unsigned pe = 0; pe < PE; pe++) begin
+ automatic int unsigned cnl = OCnl + pe;
+
+ $display("[%0d] Mapped CNL=%0d DAT=%3x -> #%2d", i, cnl, x[pe], odat[pe]);
+ assert(
+ ((odat[pe] == 0) || (sigord(THRESHS[cnl][odat[pe]-1]) <= sigord(x[pe]))) &&
+ ((odat[pe] == 2**N-1) || (sigord(x[pe]) < sigord(THRESHS[cnl][odat[pe]])))
+ ) else begin
+ $error("[%0d] Output error on presumed input CNL=%0d DAT=0x%0x -> #%0d", i, cnl, x[pe], odat[pe]);
+ error_cnt++;
+ $stop;
+ end
+ end
+ end
+ else begin
+ $error("[%0d] Spurious output.", i);
+ $stop;
+ end
+
+ OCnl <= (OCnl + PE)%C;
+ end
+ end
+ end
+
+ end : genTypes
+
+endmodule: thresholding_tb
diff --git a/finn-rtllib/thresholding/xgui/thresholding_axi_v1_0.tcl b/finn-rtllib/thresholding/xgui/thresholding_axi_v1_0.tcl
new file mode 100644
index 0000000000..338304fa40
--- /dev/null
+++ b/finn-rtllib/thresholding/xgui/thresholding_axi_v1_0.tcl
@@ -0,0 +1,187 @@
+
+# Loading additional proc with user specified bodies to compute parameter values.
+source [file join [file dirname [file dirname [info script]]] gui/thresholding_axi_v1_0.gtcl]
+
+# Definitional proc to organize widgets for parameters.
+proc init_gui { IPINST } {
+ ipgui::add_param $IPINST -name "Component_Name"
+ #Adding Page
+ set Page_0 [ipgui::add_page $IPINST -name "Page 0"]
+ ipgui::add_param $IPINST -name "ADDR_BITS" -parent ${Page_0}
+ ipgui::add_param $IPINST -name "BIAS" -parent ${Page_0}
+ ipgui::add_param $IPINST -name "C" -parent ${Page_0}
+ ipgui::add_param $IPINST -name "CF" -parent ${Page_0}
+ ipgui::add_param $IPINST -name "FPARG" -parent ${Page_0}
+ ipgui::add_param $IPINST -name "K" -parent ${Page_0}
+ ipgui::add_param $IPINST -name "N" -parent ${Page_0}
+ ipgui::add_param $IPINST -name "O_BITS" -parent ${Page_0}
+ set PE [ipgui::add_param $IPINST -name "PE" -parent ${Page_0}]
+ set_property tooltip {PE Count} ${PE}
+ ipgui::add_param $IPINST -name "SIGNED" -parent ${Page_0}
+
+
+}
+
+proc update_PARAM_VALUE.ADDR_BITS { PARAM_VALUE.ADDR_BITS PARAM_VALUE.C PARAM_VALUE.PE PARAM_VALUE.N } {
+ # Procedure called to update ADDR_BITS when any of the dependent parameters in the arguments change
+
+ set ADDR_BITS ${PARAM_VALUE.ADDR_BITS}
+ set C ${PARAM_VALUE.C}
+ set PE ${PARAM_VALUE.PE}
+ set N ${PARAM_VALUE.N}
+ set values(C) [get_property value $C]
+ set values(PE) [get_property value $PE]
+ set values(N) [get_property value $N]
+ set_property value [gen_USERPARAMETER_ADDR_BITS_VALUE $values(C) $values(PE) $values(N)] $ADDR_BITS
+}
+
+proc validate_PARAM_VALUE.ADDR_BITS { PARAM_VALUE.ADDR_BITS } {
+ # Procedure called to validate ADDR_BITS
+ return true
+}
+
+proc update_PARAM_VALUE.CF { PARAM_VALUE.CF PARAM_VALUE.C PARAM_VALUE.PE } {
+ # Procedure called to update CF when any of the dependent parameters in the arguments change
+
+ set CF ${PARAM_VALUE.CF}
+ set C ${PARAM_VALUE.C}
+ set PE ${PARAM_VALUE.PE}
+ set values(C) [get_property value $C]
+ set values(PE) [get_property value $PE]
+ set_property value [gen_USERPARAMETER_CF_VALUE $values(C) $values(PE)] $CF
+}
+
+proc validate_PARAM_VALUE.CF { PARAM_VALUE.CF } {
+ # Procedure called to validate CF
+ return true
+}
+
+proc update_PARAM_VALUE.O_BITS { PARAM_VALUE.O_BITS PARAM_VALUE.BIAS PARAM_VALUE.N } {
+ # Procedure called to update O_BITS when any of the dependent parameters in the arguments change
+
+ set O_BITS ${PARAM_VALUE.O_BITS}
+ set BIAS ${PARAM_VALUE.BIAS}
+ set N ${PARAM_VALUE.N}
+ set values(BIAS) [get_property value $BIAS]
+ set values(N) [get_property value $N]
+ set_property value [gen_USERPARAMETER_O_BITS_VALUE $values(BIAS) $values(N)] $O_BITS
+}
+
+proc validate_PARAM_VALUE.O_BITS { PARAM_VALUE.O_BITS } {
+ # Procedure called to validate O_BITS
+ return true
+}
+
+proc update_PARAM_VALUE.BIAS { PARAM_VALUE.BIAS } {
+ # Procedure called to update BIAS when any of the dependent parameters in the arguments change
+}
+
+proc validate_PARAM_VALUE.BIAS { PARAM_VALUE.BIAS } {
+ # Procedure called to validate BIAS
+ return true
+}
+
+proc update_PARAM_VALUE.C { PARAM_VALUE.C } {
+ # Procedure called to update C when any of the dependent parameters in the arguments change
+}
+
+proc validate_PARAM_VALUE.C { PARAM_VALUE.C } {
+ # Procedure called to validate C
+ return true
+}
+
+proc update_PARAM_VALUE.FPARG { PARAM_VALUE.FPARG } {
+ # Procedure called to update FPARG when any of the dependent parameters in the arguments change
+}
+
+proc validate_PARAM_VALUE.FPARG { PARAM_VALUE.FPARG } {
+ # Procedure called to validate FPARG
+ return true
+}
+
+proc update_PARAM_VALUE.K { PARAM_VALUE.K } {
+ # Procedure called to update K when any of the dependent parameters in the arguments change
+}
+
+proc validate_PARAM_VALUE.K { PARAM_VALUE.K } {
+ # Procedure called to validate K
+ return true
+}
+
+proc update_PARAM_VALUE.N { PARAM_VALUE.N } {
+ # Procedure called to update N when any of the dependent parameters in the arguments change
+}
+
+proc validate_PARAM_VALUE.N { PARAM_VALUE.N } {
+ # Procedure called to validate N
+ return true
+}
+
+proc update_PARAM_VALUE.PE { PARAM_VALUE.PE } {
+ # Procedure called to update PE when any of the dependent parameters in the arguments change
+}
+
+proc validate_PARAM_VALUE.PE { PARAM_VALUE.PE } {
+ # Procedure called to validate PE
+ return true
+}
+
+proc update_PARAM_VALUE.SIGNED { PARAM_VALUE.SIGNED } {
+ # Procedure called to update SIGNED when any of the dependent parameters in the arguments change
+}
+
+proc validate_PARAM_VALUE.SIGNED { PARAM_VALUE.SIGNED } {
+ # Procedure called to validate SIGNED
+ return true
+}
+
+
+proc update_MODELPARAM_VALUE.N { MODELPARAM_VALUE.N PARAM_VALUE.N } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.N}] ${MODELPARAM_VALUE.N}
+}
+
+proc update_MODELPARAM_VALUE.K { MODELPARAM_VALUE.K PARAM_VALUE.K } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.K}] ${MODELPARAM_VALUE.K}
+}
+
+proc update_MODELPARAM_VALUE.C { MODELPARAM_VALUE.C PARAM_VALUE.C } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.C}] ${MODELPARAM_VALUE.C}
+}
+
+proc update_MODELPARAM_VALUE.PE { MODELPARAM_VALUE.PE PARAM_VALUE.PE } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.PE}] ${MODELPARAM_VALUE.PE}
+}
+
+proc update_MODELPARAM_VALUE.SIGNED { MODELPARAM_VALUE.SIGNED PARAM_VALUE.SIGNED } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.SIGNED}] ${MODELPARAM_VALUE.SIGNED}
+}
+
+proc update_MODELPARAM_VALUE.FPARG { MODELPARAM_VALUE.FPARG PARAM_VALUE.FPARG } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.FPARG}] ${MODELPARAM_VALUE.FPARG}
+}
+
+proc update_MODELPARAM_VALUE.BIAS { MODELPARAM_VALUE.BIAS PARAM_VALUE.BIAS } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.BIAS}] ${MODELPARAM_VALUE.BIAS}
+}
+
+proc update_MODELPARAM_VALUE.CF { MODELPARAM_VALUE.CF PARAM_VALUE.CF } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.CF}] ${MODELPARAM_VALUE.CF}
+}
+
+proc update_MODELPARAM_VALUE.ADDR_BITS { MODELPARAM_VALUE.ADDR_BITS PARAM_VALUE.ADDR_BITS } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.ADDR_BITS}] ${MODELPARAM_VALUE.ADDR_BITS}
+}
+
+proc update_MODELPARAM_VALUE.O_BITS { MODELPARAM_VALUE.O_BITS PARAM_VALUE.O_BITS } {
+ # Procedure called to set VHDL generic/Verilog parameter value(s) based on TCL parameter value
+ set_property value [get_property value ${PARAM_VALUE.O_BITS}] ${MODELPARAM_VALUE.O_BITS}
+}
diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py
index 56d4230a3a..425b9bf4f6 100644
--- a/src/finn/custom_op/fpgadataflow/__init__.py
+++ b/src/finn/custom_op/fpgadataflow/__init__.py
@@ -59,6 +59,9 @@
from finn.custom_op.fpgadataflow.streamingfifo import StreamingFIFO
from finn.custom_op.fpgadataflow.streamingmaxpool_batch import StreamingMaxPool_Batch
from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
+from finn.custom_op.fpgadataflow.thresholding_binary_search import (
+ Thresholding_Binary_Search,
+)
from finn.custom_op.fpgadataflow.tlastmarker import TLastMarker
from finn.custom_op.fpgadataflow.upsampler import UpsampleNearestNeighbour_Batch
from finn.custom_op.fpgadataflow.vectorvectoractivation import VectorVectorActivation
@@ -80,6 +83,7 @@
custom_op["Pool_Batch"] = Pool_Batch
custom_op["FMPadding_Batch"] = FMPadding_Batch
custom_op["Thresholding_Batch"] = Thresholding_Batch
+custom_op["Thresholding_Binary_Search"] = Thresholding_Binary_Search
custom_op["AddStreams_Batch"] = AddStreams_Batch
custom_op["LabelSelect_Batch"] = LabelSelect_Batch
custom_op["DuplicateStreams_Batch"] = DuplicateStreams_Batch
diff --git a/src/finn/custom_op/fpgadataflow/thresholding_binary_search.py b/src/finn/custom_op/fpgadataflow/thresholding_binary_search.py
new file mode 100755
index 0000000000..d02b778823
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/thresholding_binary_search.py
@@ -0,0 +1,579 @@
+# Copyright (C) 2022, Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+import os
+import warnings
+from qonnx.core.datatype import DataType
+from qonnx.util.basic import interleave_matrix_outer_dim_from_partitions
+
+from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
+from finn.util.basic import find_next_power_of_2, get_rtlsim_trace_depth, make_build_dir
+from finn.util.data_packing import (
+ npy_to_rtlsim_input,
+ pack_innermost_dim_as_hex_string,
+ rtlsim_output_to_npy,
+)
+
+try:
+ from pyverilator import PyVerilator
+except ModuleNotFoundError:
+ PyVerilator = None
+
+"""@package thresholding_binary_search
+- ONNX i/o tensor shape assumptions for Thresholding:
+- input 0 is the input tensor, shape (..., NumChannels)
+- input 1 is the threshold tensor, shape (NumChannels, n_thres)
+- output 0 is the output tensor, shape (..., NumChannels) - same as input
+- the '...' here can be any shape (representing groups of vectors)
+
+This module creates an RTL IP, HLS is not supported. See 'thresholding_batch'
+for a HLS equivalent.
+"""
+
+
+class Thresholding_Binary_Search(HLSCustomOp):
+ """Class that corresponds to finn-rtllib 'thresholding' function."""
+
+ def __init__(self, onnx_node, **kwargs):
+ super().__init__(onnx_node, **kwargs)
+
+ def get_nodeattr_types(self):
+ my_attrs = {
+ # parallelization; channels thresholded per cycle
+ "PE": ("i", True, 0),
+ # number of channels (each may have different thresholds)
+ "NumChannels": ("i", True, 0),
+ # number of steps in thresholding function. Used only in decoupled mode
+ "numSteps": ("i", True, 1),
+ # FINN DataTypes for inputs, outputs
+ "inputDataType": ("s", True, ""),
+ "weightDataType": ("s", True, ""),
+ "outputDataType": ("s", True, ""),
+ # number of input vectors, examples:
+ # [1] is a single vector (like a FC layer with batch=1)
+ # [4] is four vectors (like a FC layer with batch=4)
+ # [1, 4, 4] is four * four vectors (like a conv layer with batch=1)
+ "numInputVectors": ("ints", False, [1]),
+ # name of the top module in verilog template. Used by PyVerilator
+ # and IPI generation
+ "gen_top_module": ("s", False, ""),
+ # bias to be applied to outputs of the node
+ "activation_bias": ("i", False, 0),
+ }
+ my_attrs.update(super().get_nodeattr_types())
+ return my_attrs
+
+ def calc_tmem(self):
+ """Calculates and returns TMEM."""
+ num_channels = self.get_nodeattr("NumChannels")
+ pe = self.get_nodeattr("PE")
+ return num_channels // pe
+
+ def make_shape_compatible_op(self, model):
+ oshape = self.get_normal_output_shape()
+ return super().make_const_shape_op(oshape)
+
+ def infer_node_datatype(self, model):
+ """Used for FINN DataType inference: set the output tensors' datatypes
+ accordingly for this node"""
+ node = self.onnx_node
+ idt = model.get_tensor_datatype(node.input[0])
+ if idt != self.get_input_datatype():
+ warn_str = "inputDataType changing for %s: %s -> %s " % (
+ node.name,
+ str(self.get_input_datatype().name),
+ str(idt.name),
+ )
+ warnings.warn(warn_str)
+ self.set_nodeattr("inputDataType", idt.name)
+ # set output datatype from property
+ odt = self.get_output_datatype()
+ model.set_tensor_datatype(node.output[0], odt)
+
+ def verify_node(self):
+ """Required by the FINN nalysis module. Checks if custom ops in graph
+ are correctly built, with all attributes and inputs."""
+ return []
+
+ def bram_estimation(self):
+ return 0
+
+ def lut_estimation(self):
+ return 0
+
+ def get_input_datatype(self, ind=0):
+ return DataType[self.get_nodeattr("inputDataType")]
+
+ def get_output_datatype(self, ind=0):
+ return DataType[self.get_nodeattr("outputDataType")]
+
+ def get_weight_datatype(self):
+ """The term 'weights' and 'thresholds' are used interchangably in this class."""
+ return DataType[self.get_nodeattr("weightDataType")]
+
+ def minimize_accumulator_width(self, model):
+ "Minimize threshold width ('accumulator width' here due to convention)"
+ thresholds = model.get_initializer(self.onnx_node.input[1])
+ threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
+ min_threshold = thresholds.min()
+ max_threshold = thresholds.max()
+ min_input = self.get_input_datatype().min()
+ max_input = self.get_input_datatype().max()
+ # get range required by threshold values
+ tdt_min = min(min_input, min_threshold)
+ tdt_max = max(max_input, max_threshold)
+ if tdt_min < 0:
+ if abs(tdt_min) > tdt_max:
+ tdt = DataType.get_smallest_possible(tdt_min)
+ else:
+ tdt = DataType.get_smallest_possible(-tdt_max - 1)
+ else:
+ tdt = DataType.get_smallest_possible(tdt_max)
+ assert np.vectorize(tdt.allowed)(
+ threshold_tensor
+ ).all(), "Thresholds can't be expressed with type %s" % str(tdt)
+ self.set_nodeattr("weightDataType", tdt.name)
+ return DataType[self.get_nodeattr("weightDataType")]
+
+ def get_instream_width(self, ind=0):
+ i_bits = self.get_input_datatype().bitwidth()
+ return i_bits * self.get_nodeattr("PE")
+
+ def get_outstream_width(self, ind=0):
+ o_bits = self.get_output_datatype().bitwidth()
+ return o_bits * self.get_nodeattr("PE")
+
+ def get_weightstream_width(self):
+ """Returns weight stream width"""
+ pe = self.get_nodeattr("PE")
+ wp = self.get_weight_datatype().bitwidth()
+ n_thres_steps = self.get_nodeattr("numSteps")
+ w_width = pe * wp * n_thres_steps
+ return w_width
+
+ def get_folded_input_shape(self, ind=0):
+ fold = self.calc_tmem()
+ pe = self.get_nodeattr("PE")
+ vecs = list(self.get_nodeattr("numInputVectors"))
+ folded_input_shape = tuple(vecs + [fold, pe])
+ return folded_input_shape
+
+ def get_folded_output_shape(self, ind=0):
+ # same shape as input
+ return self.get_folded_input_shape()
+
+ def get_normal_input_shape(self, ind=0):
+ num_channels = self.get_nodeattr("NumChannels")
+ vecs = list(self.get_nodeattr("numInputVectors"))
+ normal_input_shape = tuple(vecs + [num_channels])
+ return normal_input_shape
+
+ def get_normal_output_shape(self, ind=0):
+ # same shape as input
+ return self.get_normal_input_shape()
+
+ def get_number_output_values(self):
+ return 0
+
+ def get_exp_cycles(self):
+ return 0
+
+ def get_hls_compatible_threshold_tensor(self, orig_thres_matrix):
+ """Convert the original numpy weight matrix orig_weight_matrix into
+ a form suitable for passing to the hlslib call:
+ * ensure MH % PE == 0
+ * for unsigned inputs, ensure thresholds are positive
+ * interleave rows between PEs
+ * reshape into (PE, TMEM, n_thres_steps) and return
+ """
+ mh = self.get_nodeattr("NumChannels")
+ pe = self.get_nodeattr("PE")
+ tmem = mh // pe
+ assert mh % pe == 0, "Requirement NumChannels divisable by PE is violated."
+ assert (
+ orig_thres_matrix.ndim == 2
+ ), """Threshold matrix dimension is
+ not as expected (2)."""
+ n_thres_steps = orig_thres_matrix.shape[1]
+ assert n_thres_steps == self.get_nodeattr(
+ "numSteps"
+ ), "Mismatch in threshold steps"
+ if not self.get_input_datatype().signed():
+ # ensure all thresholds are nonnegative
+ assert (orig_thres_matrix >= 0).all()
+ # ensure all thresholds are integer
+ assert np.equal(
+ np.mod(orig_thres_matrix, 1), 0
+ ).all(), "Need int threshold tensor"
+ ret = orig_thres_matrix
+ # ensure channels = mh , duplicating if necessary
+ if ret.shape[0] == 1:
+ ret = np.tile(ret, (mh, 1))
+ assert (
+ ret.shape[0] == mh
+ ), "Channels of threshold matrix are not as expected (mh)"
+ # distribute rows between PEs
+ ret = interleave_matrix_outer_dim_from_partitions(ret, pe)
+ assert (
+ ret.shape[0] == pe
+ ), """First dimension after distribution of the
+ rows between PEs is not as expected (pe)"""
+ assert (
+ ret.shape[1] == tmem
+ ), """Second dimension after distribution of the
+ rows between PEs is not as expected (tmem)"""
+ assert (
+ ret.shape[2] == n_thres_steps
+ ), """Third dimension after distribution of the
+ rows between PEs is not as expected (n_thres_steps)"""
+ return ret.reshape(1, pe, tmem, n_thres_steps)
+
+ def prepare_codegen_rtl_values(self):
+ """All dictionary values produced in this function are to replace
+ their key value(s) in the RTL template files"""
+ code_gen_dict = {}
+
+ # Identify the module name
+ code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"] = [
+ self.get_verilog_top_module_name() + "_axi_wrapper"
+ ]
+ # Set the top module name - AXI wrapper
+ code_gen_dict["$TOP_MODULE$"] = code_gen_dict["$MODULE_NAME_AXI_WRAPPER$"]
+
+ # Identify the module variables
+ output_data_type = self.get_nodeattr("outputDataType") # output precision
+ input_data_type = self.get_nodeattr(
+ "inputDataType"
+ ) # input/threshold precision
+ num_channels = self.get_nodeattr("NumChannels") # number of channels
+ bias = self.get_nodeattr("activation_bias") # activation bias value
+ pe = self.get_nodeattr("PE")
+
+ code_gen_dict["$N$"] = [
+ str(DataType[output_data_type].bitwidth())
+ ] # output precision - convert bitwidth to string
+ code_gen_dict["$M$"] = [
+ str(DataType[input_data_type].bitwidth())
+ ] # input/threshold precision - convert bitwidth to string
+ code_gen_dict["$C$"] = [str(num_channels)] # number of channels
+ code_gen_dict["$BIAS$"] = [str(bias)] # activation bias value
+ code_gen_dict["$PE$"] = [str(pe)] # requires C = M*PE
+
+ # Is the input datatype signed or unsigned?
+ # The thresholding core needs to know this when comparing weights to inputs
+ if self.get_input_datatype().signed():
+ code_gen_dict["$SIGNED$"] = [str(1)]
+ else:
+ code_gen_dict["$SIGNED$"] = [str(0)]
+
+ return code_gen_dict
+
+ def get_rtl_file_list(self):
+ """Thresholding binary search RTL file list"""
+ return ["thresholding.sv", "thresholding_axi.sv", "thresholding_axi_wrapper.v"]
+
+ def get_rtl_file_paths(self):
+ """Get full path of all RTL files"""
+ rtl_root_dir = os.environ["FINN_ROOT"] + "/finn-rtllib/thresholding/hdl/"
+ rtl_file_list = self.get_rtl_file_list()
+ rtl_file_paths = [rtl_root_dir + file for file in rtl_file_list]
+ return rtl_file_paths
+
+ def get_rtl_template_data(self, path):
+ """Return RTL file contents as a template"""
+ with open(path, "r") as f:
+ template = f.read()
+ return template
+
+ def fill_in_rtl_template_data(self, replace_dict, template_data):
+ """Use attribute values to finn in RTL template placeholders"""
+ template_data_cp = template_data
+ for key in replace_dict:
+ replacement_line = "\n".join(replace_dict[key])
+ template_data_cp = template_data_cp.replace(key, replacement_line)
+ return template_data_cp
+
+ def dump_rtl_data(self, dest_dir, filename, data):
+ """Dump filled-in-template RTL files for future synthesis step"""
+ with open(os.path.join(dest_dir, filename), "w") as f:
+ f.write(data)
+ return
+
+ def generate_hdl(self):
+ """Prepare HDL files from templates for synthesis"""
+ # Generate a dictionary of values to put in RTL template
+ code_gen_dict = self.prepare_codegen_rtl_values()
+
+ # Retrieve the destination directory for the final RTL files
+ code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+
+ for rtl_file_path in self.get_rtl_file_paths():
+ # read in original RTL template file
+ template_data = self.get_rtl_template_data(rtl_file_path)
+ # apply code generation to templates
+ data = self.fill_in_rtl_template_data(code_gen_dict, template_data)
+ # dump filled-in template to destination directory for compilation
+ file_only_path = rtl_file_path.split("/")[-1]
+ self.dump_rtl_data(code_gen_dir, file_only_path, data)
+
+ # Before we return - set the 'gen_top_module' attribute for use later
+ # by PyVerilator and IPI generation
+ self.set_nodeattr("gen_top_module", code_gen_dict["$TOP_MODULE$"][0])
+ return
+
+ def code_generation_ipgen(self, model, fpgapart, clk):
+ self.generate_hdl()
+
+ # set ipgen_path and ip_path so that HLS-Synth transformation
+ # and stich_ip transformation do not complain
+ # i.e. during the HLSSynthIP() transformation
+ code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+ self.set_nodeattr("ipgen_path", code_gen_dir)
+ self.set_nodeattr("ip_path", code_gen_dir)
+ return
+
+ def prepare_rtlsim(self):
+ """Creates a Verilator emulation library for the RTL code generated
+ for this node, sets the rtlsim_so attribute to its path and returns
+ a PyVerilator wrapper around it."""
+
+ if PyVerilator is None:
+ raise ImportError("Installation of PyVerilator is required.")
+
+ code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+ verilog_paths = [code_gen_dir]
+ verilog_files = self.get_rtl_file_list()
+
+ # build the Verilator emulation library
+ sim = PyVerilator.build(
+ verilog_files,
+ build_dir=make_build_dir("pyverilator_" + self.onnx_node.name + "_"),
+ verilog_path=verilog_paths,
+ trace_depth=get_rtlsim_trace_depth(),
+ top_module_name=self.get_nodeattr("gen_top_module"),
+ )
+
+ # save generated lib filename in attribute
+ self.set_nodeattr("rtlsim_so", sim.lib._name)
+ return sim
+
+ def execute_node(self, context, graph):
+ # Perform input checks
+ if self.get_nodeattr("exec_mode") != "rtlsim":
+ raise Exception(
+ "Invalid exec_mode value: {}; exec_mode must be set to '{}'".format(
+ self.get_nodeattr("exec_mode"), "rtlsim"
+ )
+ )
+
+ node = self.onnx_node
+ code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+
+ # create a npy file fore each input of the node (in_ind is input index)
+ in_ind = 0
+ for inputs in node.input:
+ # it is assumed that the first input of the node is the data input
+ # the second input are the weights
+ # the third input are the thresholds
+ if in_ind == 0:
+ assert (
+ str(context[inputs].dtype) == "float32"
+ ), """Input datatype is
+ not float32 as expected."""
+ expected_inp_shape = self.get_folded_input_shape()
+ reshaped_input = context[inputs].reshape(expected_inp_shape)
+
+ if self.get_input_datatype() == DataType["BIPOLAR"]:
+ # store bipolar activations as binary
+ reshaped_input = (reshaped_input + 1) / 2
+ export_idt = DataType["BINARY"]
+ else:
+ export_idt = self.get_input_datatype()
+
+ # make copy before saving the array
+ reshaped_input = reshaped_input.copy()
+ np.save(
+ os.path.join(code_gen_dir, "input_{}.npy".format(in_ind)),
+ reshaped_input,
+ )
+ elif in_ind > 2:
+ raise Exception("Unexpected input found for Thresholding_Binary_Search")
+ in_ind += 1
+
+ # Create a PyVerilator wrapper of the RTLSim .so
+ sim = self.get_rtlsim()
+ nbits = self.get_instream_width()
+ inp = npy_to_rtlsim_input(
+ "{}/input_0.npy".format(code_gen_dir), export_idt, nbits
+ )
+
+ super().reset_rtlsim(sim)
+ super().toggle_clk(sim)
+
+ wnbits = self.get_weightstream_width()
+ export_wdt = self.get_weight_datatype()
+ wei = npy_to_rtlsim_input(
+ "{}/thresholds.npy".format(code_gen_dir), export_wdt, wnbits
+ )
+ num_w_reps = np.prod(self.get_nodeattr("numInputVectors"))
+ io_dict = {
+ "inputs": {"in0": inp, "weights": wei * num_w_reps},
+ "outputs": {"s_axis": []},
+ }
+ self.rtlsim_multi_io(sim, io_dict)
+ output = io_dict["outputs"]["out"]
+
+ # Manage output data
+ odt = self.get_output_datatype()
+ target_bits = odt.bitwidth()
+ packed_bits = self.get_outstream_width()
+ out_npy_path = "{}/output.npy".format(code_gen_dir)
+ out_shape = self.get_folded_output_shape()
+
+ rtlsim_output_to_npy(
+ output, out_npy_path, odt, out_shape, packed_bits, target_bits
+ )
+
+ # load and reshape output
+ output = np.load(out_npy_path)
+ oshape = self.get_normal_output_shape()
+ output = np.asarray([output], dtype=np.float32).reshape(*oshape)
+ context[node.output[0]] = output
+ return
+
+ def code_generation_ipi(self):
+ """Constructs and returns the TCL commands for node instantiation as an RTL
+ block."""
+ cmd = []
+ rtl_file_list = self.get_rtl_file_list()
+ code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
+
+ for rtl_file in rtl_file_list:
+ cmd.append(
+ "add_files -norecurse %s" % (os.path.join(code_gen_dir, rtl_file))
+ )
+
+ # Create an RTL block, not an IP core (-type ip)
+ cmd.append(
+ "create_bd_cell -type module -reference %s %s"
+ % (self.get_nodeattr("gen_top_module"), self.onnx_node.name)
+ )
+
+ return cmd
+
+ def get_verilog_top_module_intf_names(self):
+ """Return a dict of names of input and output interfaces.
+ The keys reflect the protocols each interface implements:
+ 'clk', 'rst', 'm_axis', 's_axis', 'aximm', 'axilite'.
+ Values are lists of tuples (axis, aximm) or names (axilite):
+ 'axis' tuples correspond to the list of node inputs in order,
+ each tuple is (interface_name, interface_width_bits).
+ axilite always assumed to be 32 bits and is not tuple (name only).
+ Each block must have at most one aximm and one axilite."""
+
+ intf_names = super().get_verilog_top_module_intf_names()
+ intf_names["axilite"] = ["s_axilite"]
+ return intf_names
+
+ def get_dynamic_config(self, model, address_stride=1):
+ """Returns a configuration dictionary containing axilite write commands
+ in order to program the thresholds into the RTL core during runtime.
+ The default address stride for the weights is 1 byte."""
+
+ thresholds = model.get_initializer(self.onnx_node.input[1])
+ num_channels, num_weights_per_channel = thresholds.shape
+
+ weight_addr_boundary = find_next_power_of_2(num_weights_per_channel)
+ # Make sure that the next power of 2 (output) is greater than the input
+ assert weight_addr_boundary >= num_weights_per_channel
+
+ config = {}
+ channel_cntr = 0
+ for channel in thresholds:
+ channel_start_addr = channel_cntr * weight_addr_boundary * address_stride
+ weight_cntr = 0
+ addr = 0
+ for weight in channel:
+ key_name = "{}_{}{}_{}{}".format(
+ "axilite", "ch", str(channel_cntr), "w", str(weight_cntr)
+ )
+ config[key_name] = (
+ channel_start_addr + addr,
+ int(
+ str(
+ pack_innermost_dim_as_hex_string(
+ [weight],
+ self.get_weight_datatype(),
+ self.get_weight_datatype().bitwidth(),
+ )
+ ),
+ 0,
+ ),
+ )
+
+ weight_cntr += 1
+ addr += address_stride
+
+ channel_cntr += 1
+
+ return config
+
+ def ipgen_singlenode_code(self):
+ """Normally: Builds the bash script for IP generation."""
+ """This is needed for the HLSSynthIP() transformation.
+ This is an IP, not a HLS node, so therefore provide an empty hook
+ to prevent any HLS synthesis."""
+ pass
+
+ def global_includes(self):
+ pass
+
+ def defines(self, var):
+ pass
+
+ def read_npy_data(self):
+ pass
+
+ def strm_decl(self):
+ pass
+
+ def docompute(self):
+ pass
+
+ def dataoutstrm(self):
+ pass
+
+ def save_as_npy(self):
+ pass
+
+ def blackboxfunction(self):
+ pass
+
+ def pragmas(self):
+ pass
diff --git a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
index ef02453498..a50cbbaed1 100644
--- a/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
+++ b/src/finn/transformation/fpgadataflow/convert_to_hls_layers.py
@@ -1019,9 +1019,10 @@ def apply(self, model):
class InferThresholdingLayer(Transformation):
"""Convert any MultiThreshold into a standalone thresholding HLS layer."""
- def __init__(self, mem_mode="const"):
+ def __init__(self, mem_mode="const", use_rtl_variant=False):
super().__init__()
self.mem_mode = mem_mode
+ self.use_rtl_variant = use_rtl_variant
def apply(self, model):
graph = model.graph
@@ -1073,27 +1074,65 @@ def apply(self, model):
)
actval = int(actval)
assert (not odt.signed()) or (actval < 0), (
- node.name + ": Signed output requres actval < 0"
- )
- # create and insert new Thresholding_Batch node
- new_node = helper.make_node(
- "Thresholding_Batch",
- [thl_input, thl_threshold],
- [thl_output],
- domain="finn.custom_op.fpgadataflow",
- backend="fpgadataflow",
- NumChannels=ifc,
- PE=pe,
- numSteps=thl_thres_shape[1],
- inputDataType=idt.name,
- # weightDataType can be tightened by MinimizeAccumulatorWidth
- weightDataType=idt.name,
- outputDataType=odt.name,
- numInputVectors=list(thl_in_shape[:-1]),
- ActVal=actval,
- mem_mode=self.mem_mode,
- name="Thresholding_Batch_" + node.name,
+ node.name + ": Signed output requires actval < 0"
)
+
+ # Ensure that RTL variant is not inserted for unsupported configuration
+ is_rtl_variant_compatible = True
+
+ # Perform checks for RTL variant if chosen
+ if self.use_rtl_variant:
+ assert self.mem_mode == "decoupled", (
+ """%s : RTL Thresholding only supports 'decoupled' memory
+ mode."""
+ % node.name
+ )
+
+ if self.use_rtl_variant and is_rtl_variant_compatible:
+ new_node = helper.make_node(
+ "Thresholding_Binary_Search",
+ [thl_input, thl_threshold],
+ [thl_output],
+ domain="finn.custom_op.fpgadataflow",
+ backend="fpgadataflow",
+ NumChannels=ifc,
+ PE=pe,
+ numSteps=thl_thres_shape[1],
+ inputDataType=idt.name,
+ weightDataType=idt.name,
+ outputDataType=odt.name,
+ numInputVectors=list(thl_in_shape[:-1]),
+ activation_bias=actval,
+ mem_mode=self.mem_mode,
+ name="Thresholding_Binary_Search_" + node.name,
+ )
+ else:
+ if self.use_rtl_variant:
+ warnings.warn(
+ """%s : RTL Thresholding requested for unsupported
+ configuration. Falling back to HLS implementation."""
+ % node.name
+ )
+
+ # create and insert new Thresholding_Batch node
+ new_node = helper.make_node(
+ "Thresholding_Batch",
+ [thl_input, thl_threshold],
+ [thl_output],
+ domain="finn.custom_op.fpgadataflow",
+ backend="fpgadataflow",
+ NumChannels=ifc,
+ PE=pe,
+ numSteps=thl_thres_shape[1],
+ inputDataType=idt.name,
+ weightDataType=idt.name,
+ outputDataType=odt.name,
+ numInputVectors=list(thl_in_shape[:-1]),
+ ActVal=actval,
+ mem_mode=self.mem_mode,
+ name="Thresholding_Batch_" + node.name,
+ )
+
graph.node.insert(insert_point, new_node)
# remove old node
graph.node.remove(node)
diff --git a/src/finn/util/basic.py b/src/finn/util/basic.py
index 1796738c58..5252422dcf 100644
--- a/src/finn/util/basic.py
+++ b/src/finn/util/basic.py
@@ -228,3 +228,22 @@ def is_exe(fpath):
return exe_file
return None
+
+
+def find_next_power_of_2(n):
+ """For any integer 'n', find the next greatest power of 2"""
+ # Negative values will loop infinitely below - return 0
+ if n <= 0:
+ return 0
+ # If '1' is requested, output will be '0' in the loop below, avoid this now.
+ elif n == 1:
+ return 2 # i.e. 2**1
+
+ # decrement 'n' (to handle cases when `n` itself is a power of 2)
+ n = n - 1
+
+ # loop until only one bit is left
+ while n & n - 1:
+ # unset rightmost bit
+ n = n & n - 1
+ return n << 1
diff --git a/tests/fpgadataflow/test_convert_to_hls_thresholding.py b/tests/fpgadataflow/test_convert_to_hls_thresholding.py
new file mode 100755
index 0000000000..9c233bdd06
--- /dev/null
+++ b/tests/fpgadataflow/test_convert_to_hls_thresholding.py
@@ -0,0 +1,276 @@
+# Copyright (C) 2023, Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+from pyverilator.util.axi_utils import axilite_write, reset_rtlsim
+from qonnx.core.datatype import DataType
+from qonnx.core.modelwrapper import ModelWrapper
+from qonnx.custom_op.general.multithreshold import multithreshold
+from qonnx.custom_op.registry import getCustomOp
+from qonnx.transformation.general import GiveUniqueNodeNames
+from qonnx.transformation.infer_datatypes import InferDataTypes
+from qonnx.transformation.infer_shapes import InferShapes
+from qonnx.util.basic import gen_finn_dt_tensor
+from test_fpgadataflow_thresholding_binary_search import (
+ make_single_thresholding_binary_search_modelwrapper,
+)
+
+import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
+from finn.core.rtlsim_exec import rtlsim_exec
+from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+
+test_fpga_part = "xczu3eg-sbva484-1-e"
+target_clk_ns = 5
+
+
+# Helper functions
+def sort_thresholds_increasing(thresholds):
+ return np.sort(thresholds, axis=1)
+
+
+def generate_random_threshold_values(input_data_type, num_input_channels, num_steps):
+ return np.random.randint(
+ input_data_type.min(),
+ input_data_type.max() + 1,
+ (num_input_channels, num_steps),
+ ).astype(np.float32)
+
+
+def generate_pe_value(fold, num_input_channels):
+ if fold == -1:
+ fold = num_input_channels
+ pe = num_input_channels // fold
+ assert num_input_channels % pe == 0
+ return pe
+
+
+# n = batch, c = channel, h = height, w = width of feature map
+# Standard = NCHW; FINN = NHWC
+# Convert from NCHW to NHWC
+def convert_np_array_to_finn_data_layout(data):
+ return np.transpose(data, (0, 2, 3, 1))
+
+
+# n = batch, c = channel, h = height, w = width of feature map
+# Standard = NCHW; FINN = NHWC
+# Convert from NHWC to NCHW
+def convert_np_array_to_standard_data_layout(data):
+ return np.transpose(data, (0, 3, 1, 2))
+
+
+def make_single_multithresholding_modelwrapper(
+ thresholds,
+ pe,
+ input_data_type,
+ output_data_type,
+ activation_bias,
+ num_input_vecs,
+):
+ NumChannels = thresholds.shape[0]
+
+ inp = helper.make_tensor_value_info(
+ "inp", TensorProto.FLOAT, num_input_vecs + [NumChannels]
+ )
+ outp = helper.make_tensor_value_info(
+ "outp", TensorProto.FLOAT, num_input_vecs + [NumChannels]
+ )
+
+ node_inp_list = ["inp", "thresh"]
+
+ Multithresholding_node = helper.make_node(
+ "MultiThreshold",
+ node_inp_list,
+ ["outp"],
+ domain="qonnx.custom_op.general",
+ out_dtype=output_data_type.name,
+ out_bias=float(activation_bias),
+ out_scale=1.0,
+ )
+
+ graph = helper.make_graph(
+ nodes=[Multithresholding_node],
+ name="multithresholding_graph",
+ inputs=[inp],
+ outputs=[outp],
+ )
+
+ model = helper.make_model(graph, producer_name="multithresholding-model")
+ model = ModelWrapper(model)
+ model = model.transform(InferShapes())
+ model = model.transform(InferDataTypes())
+ model = model.transform(GiveUniqueNodeNames())
+
+ model.set_tensor_datatype("inp", input_data_type)
+ model.set_tensor_datatype("outp", output_data_type)
+
+ model.set_tensor_datatype("thresh", input_data_type)
+ model.set_initializer("thresh", thresholds)
+ return model
+
+
+# N.B. Fold values where C % PE != 0 fail
+@pytest.mark.parametrize("activation", [DataType["INT4"], DataType["BIPOLAR"]])
+@pytest.mark.parametrize("input_data_type", [DataType["INT16"], DataType["UINT16"]])
+@pytest.mark.parametrize("fold", [-1, 1, 2, 4, 6])
+@pytest.mark.parametrize("num_input_channels", [16])
+@pytest.mark.fpgadataflow
+@pytest.mark.vivado
+def test_convert_to_hls_tbs_rtl_variant(
+ activation,
+ input_data_type,
+ fold,
+ num_input_channels,
+):
+ # Handle inputs to the test
+ pe = generate_pe_value(fold, num_input_channels)
+ num_steps = activation.get_num_possible_values() - 1
+
+ # See convert_to_hls_layers::InferThresholdingLayer:
+ # assert (not odt.signed()) or (actval < 0)
+ # This implies that it expects a negative activation, BIPOLAR does not provide that
+ if activation == DataType["BIPOLAR"]:
+ pytest.skip(
+ "Only negative activations are supported for "
+ "RTL Thresholding Binary Search node"
+ )
+
+ # Other non-input parameters
+ num_input_vecs = [1, 2, 2]
+ output_data_type = activation
+ if output_data_type == DataType["BIPOLAR"]:
+ activation_bias = 0
+ else:
+ activation_bias = output_data_type.min()
+
+ # generate random input data
+ tensor_shape = tuple(num_input_vecs + [num_input_channels])
+ x = gen_finn_dt_tensor(input_data_type, tensor_shape)
+
+ # Generate random thresholds and sort in ascending order
+ thresholds = generate_random_threshold_values(
+ input_data_type, num_input_channels, num_steps
+ )
+
+ # provide non-decreasing/ascending thresholds
+ thresholds = sort_thresholds_increasing(thresholds)
+
+ x_nhwc = convert_np_array_to_standard_data_layout(x)
+ y = multithreshold(x_nhwc, thresholds)
+
+ # convert back to NHWC for comparison to hw outputs
+ y = convert_np_array_to_finn_data_layout(y)
+ if activation == DataType["BIPOLAR"]:
+ # binary to bipolar
+ y = 2 * y - 1
+ else:
+ # signed offset
+ y += activation.min()
+
+ # Generate model from input parameters to the test
+ model = make_single_thresholding_binary_search_modelwrapper(
+ thresholds,
+ pe,
+ input_data_type,
+ output_data_type,
+ activation_bias,
+ num_input_vecs,
+ )
+
+ model = model.transform(InsertFIFO(True))
+ model = model.transform(GiveUniqueNodeNames())
+ model = model.transform(PrepareIP(test_fpga_part, target_clk_ns))
+ model = model.transform(HLSSynthIP())
+ model = model.transform(CreateStitchedIP(test_fpga_part, target_clk_ns))
+
+ # Retrieve the axilite programming sequence for weights - for decoupled mode only
+ tbs_node = model.get_nodes_by_op_type("Thresholding_Binary_Search")[0]
+ tbs_inst = getCustomOp(tbs_node)
+ config = tbs_inst.get_dynamic_config(model, 4)
+
+ # Reshape generated data (not from model)
+ oshape = model.get_tensor_shape("outp")
+ y_expected = y.reshape(oshape)
+
+ # Helper function that delivers the hook to program the thresholds via AXI-Lite
+ def config_hook(config):
+ if config is None:
+ return None
+
+ def write_thresh_config(sim):
+ # axi_name = "s_axilite_0_" # works
+ axi_name = getCustomOp(
+ model.get_nodes_by_op_type("Thresholding_Binary_Search")[0]
+ ).get_verilog_top_module_intf_names()["axilite"][0]
+ axi_name += "_0_"
+
+ # Write config registers to the Threshold memory.
+ # The dictionary defines (addr, value) tuples.
+ for config_entry in config.values():
+ addr = config_entry[0]
+ val = config_entry[1]
+ axilite_write(sim, addr, val, basename=axi_name)
+
+ reset_rtlsim(sim)
+
+ return write_thresh_config
+
+ input_dict = {"inp": x}
+ rtlsim_exec(model, input_dict, pre_hook=config_hook(config))
+ y_produced = input_dict["outp"]
+ assert (y_produced == y_expected).all()
+
+ # Make a Multithreshold graph and convert to thresholding binary search node
+ new_model = make_single_multithresholding_modelwrapper(
+ thresholds,
+ pe,
+ input_data_type,
+ output_data_type,
+ activation_bias,
+ num_input_vecs,
+ )
+
+ # Recreate the model using the ConvertToHLS transform
+ new_model = new_model.transform(
+ to_hls.InferThresholdingLayer(mem_mode="decoupled", use_rtl_variant=True)
+ )
+ new_model = new_model.transform(InsertFIFO(True))
+ new_model = new_model.transform(GiveUniqueNodeNames())
+ new_model = new_model.transform(PrepareIP(test_fpga_part, target_clk_ns))
+ new_model = new_model.transform(HLSSynthIP())
+ new_model = new_model.transform(CreateStitchedIP(test_fpga_part, target_clk_ns))
+
+ input_dict = {"inp": x}
+ rtlsim_exec(new_model, input_dict, pre_hook=config_hook(config))
+ y_produced_new = input_dict["outp"]
+ assert (y_produced_new == y_expected).all()
diff --git a/tests/fpgadataflow/test_fpgadataflow_thresholding_binary_search.py b/tests/fpgadataflow/test_fpgadataflow_thresholding_binary_search.py
new file mode 100755
index 0000000000..24b60f5ea5
--- /dev/null
+++ b/tests/fpgadataflow/test_fpgadataflow_thresholding_binary_search.py
@@ -0,0 +1,287 @@
+# Copyright (C) 2022, Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+import numpy as np
+from onnx import TensorProto, helper
+from pyverilator.util.axi_utils import axilite_write, reset_rtlsim
+from qonnx.core.datatype import DataType
+from qonnx.core.modelwrapper import ModelWrapper
+from qonnx.custom_op.general.multithreshold import multithreshold
+from qonnx.custom_op.registry import getCustomOp
+from qonnx.transformation.general import GiveUniqueNodeNames
+from qonnx.util.basic import gen_finn_dt_tensor
+
+from finn.core.rtlsim_exec import rtlsim_exec
+from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP
+from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
+from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
+from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
+from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
+from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+
+test_fpga_part = "xczu3eg-sbva484-1-e"
+target_clk_ns = 5
+
+
+# Helper functions
+def sort_thresholds_increasing(thresholds):
+ return np.sort(thresholds, axis=1)
+
+
+def generate_random_threshold_values(input_data_type, num_input_channels, num_steps):
+ return np.random.randint(
+ input_data_type.min(),
+ input_data_type.max() + 1,
+ (num_input_channels, num_steps),
+ ).astype(np.float32)
+
+
+def generate_pe_value(fold, num_input_channels):
+ if fold == -1:
+ fold = num_input_channels
+ pe = num_input_channels // fold
+ assert num_input_channels % pe == 0
+ return pe
+
+
+# n = batch, c = channel, h = height, w = width of feature map
+# Standard = NCHW; FINN = NHWC
+# Convert from NCHW to NHWC
+def convert_np_array_to_finn_data_layout(data):
+ return np.transpose(data, (0, 2, 3, 1))
+
+
+# n = batch, c = channel, h = height, w = width of feature map
+# Standard = NCHW; FINN = NHWC
+# Convert from NHWC to NCHW
+def convert_np_array_to_standard_data_layout(data):
+ return np.transpose(data, (0, 3, 1, 2))
+
+
+def make_single_thresholding_binary_search_modelwrapper(
+ thresholds,
+ pe,
+ input_data_type,
+ output_data_type,
+ activation_bias,
+ num_input_vecs,
+):
+
+ NumChannels = thresholds.shape[0]
+
+ inp = helper.make_tensor_value_info(
+ "inp", TensorProto.FLOAT, num_input_vecs + [NumChannels]
+ )
+ outp = helper.make_tensor_value_info(
+ "outp", TensorProto.FLOAT, num_input_vecs + [NumChannels]
+ )
+
+ node_inp_list = ["inp", "thresh"]
+
+ Thresholding_node = helper.make_node(
+ "Thresholding_Binary_Search",
+ node_inp_list,
+ ["outp"],
+ domain="finn.custom_op.fpgadataflow",
+ backend="fpgadataflow",
+ NumChannels=NumChannels,
+ PE=pe,
+ numSteps=thresholds.shape[1],
+ inputDataType=input_data_type.name,
+ weightDataType=input_data_type.name,
+ outputDataType=output_data_type.name,
+ activation_bias=activation_bias,
+ numInputVectors=num_input_vecs,
+ )
+ graph = helper.make_graph(
+ nodes=[Thresholding_node],
+ name="thresholding_graph",
+ inputs=[inp],
+ outputs=[outp],
+ )
+
+ model = helper.make_model(graph, producer_name="thresholding-model")
+ model = ModelWrapper(model)
+
+ model.set_tensor_datatype("inp", input_data_type)
+ model.set_tensor_datatype("outp", output_data_type)
+
+ model.set_tensor_datatype("thresh", input_data_type)
+ model.set_initializer("thresh", thresholds)
+ return model
+
+
+# Test brief: Test that PrepareRTLSim() runs successfully. This function is not
+# tested in test_fpgadataflow_thresholding_binary_search()
+@pytest.mark.fpgadataflow
+@pytest.mark.vivado
+def test_fpgadataflow_thresholding_binary_search_prepare_rtlsim():
+ input_data_type = DataType["INT16"]
+ act = DataType["INT4"]
+ fold = -1
+ num_input_channels = 16
+
+ # Handle inputs to the test
+ pe = generate_pe_value(fold, num_input_channels)
+ num_steps = act.get_num_possible_values() - 1
+
+ # Generate random, non-decreasing thresholds
+ thresholds = generate_random_threshold_values(
+ input_data_type, num_input_channels, num_steps
+ )
+ thresholds = sort_thresholds_increasing(thresholds)
+
+ # Other non-input parameters
+ num_input_vecs = [1, 2, 2]
+ output_data_type = act
+ if output_data_type == DataType["BIPOLAR"]:
+ activation_bias = 0
+ else:
+ activation_bias = output_data_type.min()
+
+ # Generate model from input parameters to the test
+ model = make_single_thresholding_binary_search_modelwrapper(
+ thresholds,
+ pe,
+ input_data_type,
+ output_data_type,
+ activation_bias,
+ num_input_vecs,
+ )
+
+ model = model.transform(SetExecMode("rtlsim"))
+ model = model.transform(GiveUniqueNodeNames())
+ model = model.transform(PrepareIP(test_fpga_part, target_clk_ns))
+ model = model.transform(HLSSynthIP())
+ model = model.transform(PrepareRTLSim())
+ return
+
+
+# Test brief: Create a Thresholding binary search layer using various parameters
+# and test against a SW generated & simulated dataset
+# N.B. Fold values where C % PE != 0 fail
+@pytest.mark.parametrize("activation", [DataType["INT4"], DataType["BIPOLAR"]])
+@pytest.mark.parametrize("input_data_type", [DataType["INT16"], DataType["UINT16"]])
+@pytest.mark.parametrize("fold", [-1, 1, 2, 4, 6])
+@pytest.mark.parametrize("num_input_channels", [16])
+@pytest.mark.fpgadataflow
+@pytest.mark.vivado
+@pytest.mark.slow
+def test_fpgadataflow_thresholding_binary_search(
+ activation, input_data_type, fold, num_input_channels
+):
+ # Handle inputs to the test
+ pe = generate_pe_value(fold, num_input_channels)
+ num_steps = activation.get_num_possible_values() - 1
+
+ # Other non-input parameters
+ num_input_vecs = [1, 2, 2]
+ output_data_type = activation
+ if output_data_type == DataType["BIPOLAR"]:
+ activation_bias = 0
+ else:
+ activation_bias = output_data_type.min()
+
+ # generate random input data
+ tensor_shape = tuple(num_input_vecs + [num_input_channels])
+ x = gen_finn_dt_tensor(input_data_type, tensor_shape)
+
+ # Generate random thresholds and sort in ascending order
+ thresholds = generate_random_threshold_values(
+ input_data_type, num_input_channels, num_steps
+ )
+
+ # provide non-decreasing/ascending thresholds
+ thresholds = sort_thresholds_increasing(thresholds)
+
+ x_nhwc = convert_np_array_to_standard_data_layout(x)
+ y = multithreshold(x_nhwc, thresholds)
+
+ # convert back to NHWC for comparison to hw outputs
+ y = convert_np_array_to_finn_data_layout(y)
+ if activation == DataType["BIPOLAR"]:
+ # binary to bipolar
+ y = 2 * y - 1
+ else:
+ # signed offset
+ y += activation.min()
+
+ # Generate model from input parameters to the test
+ model = make_single_thresholding_binary_search_modelwrapper(
+ thresholds,
+ pe,
+ input_data_type,
+ output_data_type,
+ activation_bias,
+ num_input_vecs,
+ )
+
+ model = model.transform(InsertFIFO(True))
+ model = model.transform(GiveUniqueNodeNames())
+ model = model.transform(PrepareIP(test_fpga_part, target_clk_ns))
+ model = model.transform(HLSSynthIP())
+ model = model.transform(CreateStitchedIP(test_fpga_part, target_clk_ns))
+
+ # Retrieve the axilite programming sequence for weights - for decoupled mode only
+ tbs_node = model.get_nodes_by_op_type("Thresholding_Binary_Search")[0]
+ tbs_inst = getCustomOp(tbs_node)
+ config = tbs_inst.get_dynamic_config(model, 4)
+
+ # Reshape generated data (not from model)
+ oshape = model.get_tensor_shape("outp")
+ y_expected = y.reshape(oshape)
+
+ # Helper function that delivers the hook to program the thresholds via AXI-Lite
+ def config_hook(config):
+ if config is None:
+ return None
+
+ def write_thresh_config(sim):
+ # axi_name = "s_axilite_0_" # works
+ axi_name = getCustomOp(
+ model.get_nodes_by_op_type("Thresholding_Binary_Search")[0]
+ ).get_verilog_top_module_intf_names()["axilite"][0]
+ axi_name += "_0_"
+
+ # Write config registers to the Threshold memory.
+ # The dictionary defines (addr, value) tuples.
+ for config_entry in config.values():
+ addr = config_entry[0]
+ val = config_entry[1]
+ axilite_write(sim, addr, val, basename=axi_name)
+
+ reset_rtlsim(sim)
+
+ return write_thresh_config
+
+ input_dict = {"inp": x}
+ rtlsim_exec(model, input_dict, pre_hook=config_hook(config))
+ y_produced = input_dict["outp"]
+ assert (y_produced == y_expected).all()
diff --git a/tests/util/test_basic.py b/tests/util/test_basic.py
new file mode 100755
index 0000000000..97a8c50261
--- /dev/null
+++ b/tests/util/test_basic.py
@@ -0,0 +1,60 @@
+# Copyright (C) 2023, Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of FINN nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import pytest
+
+import finn.util.basic as basic
+
+
+@pytest.mark.util
+def test_next_power_of_2():
+ test_vector = [
+ {"input": -2, "expected_result": 0},
+ {"input": -1, "expected_result": 0},
+ {"input": 0, "expected_result": 0},
+ {"input": 1, "expected_result": 2},
+ {"input": 2, "expected_result": 2},
+ {"input": 3, "expected_result": 4},
+ {"input": 4, "expected_result": 4},
+ {"input": 7, "expected_result": 8},
+ {"input": 8, "expected_result": 8},
+ {"input": 11, "expected_result": 16},
+ {"input": 15, "expected_result": 16},
+ {"input": 16, "expected_result": 16},
+ {"input": 18, "expected_result": 32},
+ {"input": 27, "expected_result": 32},
+ {"input": 31, "expected_result": 32},
+ {"input": 32, "expected_result": 32},
+ {"input": 42, "expected_result": 64},
+ {"input": 65, "expected_result": 128},
+ ]
+
+ for test_dict in test_vector:
+ output = basic.find_next_power_of_2(test_dict["input"])
+ assert output >= test_dict["input"]
+ assert output == test_dict["expected_result"]