diff --git a/src/qonnx/analysis/inference_cost.py b/src/qonnx/analysis/inference_cost.py index 98e03428..847058b7 100644 --- a/src/qonnx/analysis/inference_cost.py +++ b/src/qonnx/analysis/inference_cost.py @@ -201,10 +201,10 @@ def inference_cost_upsample(model, node, discount_sparsity): return ret -def inference_cost(model, discount_sparsity=True): +def inference_cost(model, discount_sparsity=True, cost_breakdown=False): "Ensure all nodes have unique names prior to calling this analysis pass." - node_costs = {} + ret, node_costs, nodes_per_optype = {}, {}, {} zero_cost_ops = [ "MaxPool", "AveragePool", @@ -240,13 +240,24 @@ def inference_cost(model, discount_sparsity=True): if node.op_type in inference_cost_fxn_map.keys(): node_cost = inference_cost_fxn_map[node.op_type](model, node, discount_sparsity) node_costs[node.name] = node_cost + if node.op_type not in nodes_per_optype.keys(): + new_optype = {} + new_optype[node.name] = node_cost + nodes_per_optype[node.op_type] = new_optype + else: + nodes_per_optype[node.op_type][node.name] = node_cost elif node.op_type in zero_cost_ops: continue else: unsupported_ops.add(node.op_type) - - ret = aggregate_dict_keys(node_costs) - ret["unsupported"] = unsupported_ops - ret["discount_sparsity"] = discount_sparsity - + total = aggregate_dict_keys(node_costs) + total["unsupported"] = unsupported_ops + total["discount_sparsity"] = discount_sparsity + ret["total_cost"] = total + if cost_breakdown: + optype_cost = {} + for optype, resources in nodes_per_optype.items(): + optype_cost[optype] = aggregate_dict_keys(resources) + ret["optype_cost"] = optype_cost + ret["node_cost"] = node_costs return ret diff --git a/src/qonnx/analysis/l0_resource_estimates.py b/src/qonnx/analysis/l0_resource_estimates.py new file mode 100644 index 00000000..fdee029e --- /dev/null +++ b/src/qonnx/analysis/l0_resource_estimates.py @@ -0,0 +1,227 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# # All rights reserved. +# # +# # Redistribution and use in source and binary forms, with or without +# # modification, are permitted provided that the following conditions are met: +# # +# # * Redistributions of source code must retain the above copyright notice, this +# # list of conditions and the following disclaimer. +# # +# # * Redistributions in binary form must reproduce the above copyright notice, +# # this list of conditions and the following disclaimer in the documentation +# # and/or other materials provided with the distribution. +# # +# # * Neither the name of qonnx nor the names of its +# # contributors may be used to endorse or promote products derived from +# # this software without specific prior written permission. +# # +# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from qonnx.core.datatype import DataType + +"""DSP Type: a) None: + For Fixed Points and floating point + 1) When dsp_type is None. All operations will be processed using LUTs. + 2) LUTs are calculated using: 1.1*b_width1*b_width2 + 2) Example: + a) op_mac_Int4_Int2: 1.1*4*2 = 8.8 LUTs. + b) op_mac_Int8_INT8: 1.1*8*8 = 70.4 LUTs. + c) op_mac_Int8_FLOAT16: 1.1*8*16 = 140.8 LUTs + d) op_mac_FLOAT16_FLOAT16: 1.1*16*16 = 281.6 LUTs. + + b) DSP48: + For Fixed Points + 1) Everything less than 4 will be promoted to 4. For ex: INT2 will use the same resources as INT4. + 2) INT4: One dsp48 + 200 LUTs can accomodate 4 (4*4) bit mac. + So, no of dsp's from mac's can be calculated as (0.25).mac_count + (200*0.5)*mac_count LUTs. + 3) Everything between 5 and 8 will be promoted to 8, Ex: INT6 will use the same resources as INT8. + 4) INT88: One dsp48 + 200 LUTs can accomodate 2 (8*8) bit mac. So, + no of dsp's from mac's can be calculated as (0.5).mac_count + (200*0.25)*mac_count LUTs. + For Floating Points + 1) FLOAT32: 2 dsp + 700 LUT can accomodate 1 mac count. + 2) FLOAT16: 1 dsp + 400 LUT can accomodate 1 mac count. + c) DSP58: + For Fixed Points + 1) INT8: One dsp58 can accomodate 3 (8*8) bit mac. + So, no of dsp's from mac's can be calculated as (0.33)*mac_count. + 2) INT4: One dsp58 can accomodate 4 (4*4) bit mac. + So, no of dsp's from mac's can be calculated as (0.25)*mac_count. + 3) INT16: 1 mac count requires 1 dsp. + For Floating Points + 1) FLOAT32: 1 mac count requires 1 dsp. + 2) FLOAT16: 1 mac count requires 1 dsp. + Mapping strategy for On-Chip Memory (bits_per_res): + a) 1 "BRAM" can accomodate 36*1024 = 36864 bits. + b) 1 "URAM" can accomodate 288*1024 = 294912 bits. + c) 1 LUT can accomodate 64 bits. +""" +resource_table = { + "FLOAT32": {"NONE": (0, 1100), "DSP48": (2, 700), "DSP58": (1, 0)}, + "FLOAT16": {"NONE": (0, 1100), "DSP48": (1, 400), "DSP58": (1, 0)}, + "INT32": {"NONE": (0, 1100), "DSP48": (1, 0), "DSP58": (1, 0)}, + "INT16": {"NONE": (0, 282), "DSP48": (1, 0), "DSP58": (1, 0)}, + "INT8": {"NONE": (0, 71), "DSP48": (0.5, 100), "DSP58": (0.33, 0)}, + "INT4": {"NONE": (0, 18), "DSP48": (0.25, 50), "DSP58": (0.25, 0)}, +} + +bits_per_res = {"BRAM": 36864, "URAM": 294912, "LUT": 64} + + +def ocm_resources(num_mem_bits, d_factor): + """Provides an estimate about the number of urams and brams required for the + on-chip memory depending upon the distribution factor. + Args: + num_mem_bits (int): Number of memory bits. + d_factor (float): Distribution factor between 0 and 1. + To distribute memory between BRAM and URAM. + Returns: + A dictionary for ocm resources containing memory requirements for luts, brams and urams + """ + if d_factor is None: + luts_req = num_mem_bits / bits_per_res["LUT"] # neither bram nor uram. + ocm_res = {"LUT": luts_req} + elif d_factor == 1: # everything in uram. + uram_req = num_mem_bits / bits_per_res["URAM"] # URAM: 288kbit/URAM + ocm_res = {"URAM": uram_req} + elif d_factor == 0: # everything in bram 36kbit/BRAM + bram_req = num_mem_bits / bits_per_res["BRAM"] + ocm_res = {"BRAM": bram_req} + else: # both bram and uram. + uram_por, bram_por = d_factor, 1 - d_factor + bram_req = (bram_por * num_mem_bits) / bits_per_res["BRAM"] + uram_req = (uram_por * num_mem_bits) / bits_per_res["URAM"] + ocm_res = {"BRAM": bram_req, "URAM": uram_req} + return ocm_res + + +def promoting_datatype(dtype, b_width): + """Datatype promoting criterion. Only used when DSPs are used for processing. + Args: + dtype (str): conatining "INT" or "FLOAT". + b_width (int): precision of the respective datatype. + Returns: + Returns promoted datatype and precision value.""" + + if "INT" in dtype: + promoted_dtype = "INT" + if b_width <= 4: + promoted_bwidth = 4 + elif 4 < b_width <= 8: + promoted_bwidth = 8 + elif 8 < b_width <= 16: + promoted_bwidth = 16 + else: + promoted_bwidth = 32 + elif "FLOAT" in dtype: + promoted_dtype = "FLOAT" + if b_width <= 16: + promoted_bwidth = 16 + else: + promoted_bwidth = 32 + else: + raise Exception("Unsupported data type") + + return promoted_dtype, promoted_bwidth + + +def dtype_casting(dtype1, dtype2, b_width1, b_width2): + """Implementing datatype promotion.""" + + promoted_dtype1, promoted_bwidth1 = promoting_datatype(dtype1, b_width1) # either INT or FLOAT + promoted_dtype2, promoted_bwidth2 = promoting_datatype(dtype2, b_width2) + + if promoted_dtype1 == promoted_dtype2: # same datatype + if promoted_bwidth1 == promoted_bwidth2: # same precision. + dtype = promoted_dtype1 + str(promoted_bwidth1) # can also use dtype_2 + new_bwidth2 + else: # different precision. + if promoted_bwidth1 >= promoted_bwidth2: + dtype = promoted_dtype1 + str(promoted_bwidth1) + else: + dtype = promoted_dtype2 + str(promoted_bwidth2) + else: # dtype_1 != dtype_2 (Different datatype and same/different precision) + if promoted_dtype1 == "FLOAT": # with different datatypes, using float and it's respective precision. + dtype = promoted_dtype1 + str(promoted_bwidth1) + else: + dtype = promoted_dtype2 + str(promoted_bwidth2) + + return dtype + + +def core_resources(inf_cost, dsp_type, bwidth_lower_limit, bwidth_upper_limit): + """Provide estimate resources required for the processing ("CORE"), assuming maximum unfolding. + Args: + inf_cost (dict): Inference cost dict. + dsp_type (str): None OR "DSP48" OR "DSP58". Default to None. + bwidth_lower_limit (int): Default to 8. It indicates bit values less than 8 will be processed using LUTs. + bwidth_upper_limit (int): Default to 32. It indicates bit values less than 32 will be processed using LUTs. + Returns: + A dictionary containing CORE resource estimates.""" + + dsp_res_mac = 0 + lut_res_mac = 0 + for i in inf_cost.keys(): + if "op_mac" in i: + mac_count = inf_cost[i] + detail_list = i.split("_") + dtype1, dtype2 = detail_list[-1], detail_list[-2] + b_width1, b_width2 = DataType[dtype1].bitwidth(), DataType[dtype2].bitwidth() + if dsp_type is None: # Computing everything in LUTs. + lut_res_mac += 1.1 * b_width1 * b_width2 * mac_count + dsp_comp = "DSP" # default name for DSP and dsp_res_mac = 0 + else: # dsp_type == "DSP48" or dsp_type == "DSP58" + if (b_width1 < bwidth_lower_limit or b_width2 < bwidth_lower_limit) or ( + b_width1 > bwidth_upper_limit or b_width2 > bwidth_upper_limit + ): # Computing everything in LUTs. + lut_res_mac += 1.1 * b_width1 * b_width2 * mac_count # dsp_res_mac = 0 + else: + casted_dtype = dtype_casting(dtype1, dtype2, b_width1, b_width2) + casted_bwidth = DataType[casted_dtype].bitwidth() + if casted_bwidth > bwidth_upper_limit: # Computing everything in LUTs. + lut_res_mac += ( + 1.1 * b_width1 * b_width2 * mac_count + ) # original bwidth values are used, since dsp_res_mac = 0. + else: + dsp_res_mac += ( + resource_table[casted_dtype][dsp_type][0] * mac_count + ) # at index zero, we expect to have dsp factor. + lut_res_mac += ( + resource_table[casted_dtype][dsp_type][1] * mac_count + ) # at index one, we expect to have lut factor. + dsp_comp = dsp_type # assigning name as per dsp type. + else: + continue + + core_res = {"LUT": lut_res_mac, dsp_comp: dsp_res_mac} + + return core_res + + +def l0_resource_estimates(inf_cost, dsp_type=None, bwidth_lower_limit=8, bwidth_upper_limit=32, d_factor=None): + """Provide estimate resources required for the processing ("CORE") and memory ("OCM"), assuming maximum unfolding. + Args: + inf_cost (dict): Inference cost dict. + dsp_type (str): None OR "DSP48" OR "DSP58". Default to None. + bram_type (str): Default to "BRAM". It can be BRAM, BRAM36, BRAM_36K, BRAM_18K. + bwidth_lower_limit (int): Default to 8. It indicates bit values less than 8 will be processed using LUTs. + bwidth_upper_limit (int): Default to 32. It indicates bit values less than 32 will be processed using LUTs. + d_factor (float): Default to 1. It can have values between 0 and 1. + Returns: + A dictionary containing CORE and OCM resource estimates.""" + + core_res = core_resources(inf_cost, dsp_type, bwidth_lower_limit, bwidth_upper_limit) + + num_mem_bits = inf_cost["total_mem_w_bits"] + ocm_res = ocm_resources(num_mem_bits, d_factor) + + est_res_req = {"CORE": core_res, "OCM": ocm_res} + + return est_res_req diff --git a/src/qonnx/analysis/l1_resource_estimates.py b/src/qonnx/analysis/l1_resource_estimates.py new file mode 100644 index 00000000..85f30d37 --- /dev/null +++ b/src/qonnx/analysis/l1_resource_estimates.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import math + +from qonnx.analysis.l0_resource_estimates import core_resources +from qonnx.util.l0_performance_estimate import l0_performance_estimate + +resource_map = { + "res_limit": {"LUT": 0.7, "BRAM": 0.80, "URAM": 0.80, "DSP48": 0.80, "DSP58": 0.80}, + "bits_per_res": {"BRAM": 36864, "URAM": 294912, "LUT": 64}, +} + + +def rounding_down(value, decimal_place): + fact = 10**decimal_place + return math.floor(value * fact) / fact + + +def memory_distribution(resource_budget, total_cost, node_cost, unused_resources): + """This function is used for memory resource allocation to each node. In this function, we + are trying to project either one of the following resources URAM, BRAM, LUTRAM to a particular node. + However, for some nodes a combination of resources could be projected. + Parameters:- + resource_budget (dict):- Resource budget of an FPGA in a dictionary format. + Example (ve2802) :-{"LUT": 520704, "BRAM": 600, "URAM": 264, "DSP58": 1312} + total_cost (dict):- Inference cost dictionary of the full model. + per_node_cost (dict):- Inference cost associated with each layer/node in the neural network. + unused_resources (dict):- Dictionary in the format of resource budget to keep record of resources used, + provided internally. + """ + + res_limit, bits_per_res = resource_map["res_limit"], resource_map["bits_per_res"] + + mem_port = ( + node_cost["total_mem_w_bits"] / total_cost["total_mem_w_bits"] + ) # mem portion of this layer in the model (between 0 and 1). + + available_bram = ( + mem_port * (res_limit["BRAM"] * resource_budget["BRAM"]) + unused_resources["BRAM"] + ) # Approximate Bram available for this node. + available_uram = ( + mem_port * (res_limit["URAM"] * resource_budget["URAM"]) + unused_resources["URAM"] + ) # Approximate Uram available for this node. + available_lut = ( + mem_port * (res_limit["LUT"] * resource_budget["LUT"]) + unused_resources["LUT"] + ) # Approximate LUT available for this node. + + req_bram = node_cost["total_mem_w_bits"] / bits_per_res["BRAM"] # BRAM required for this node. + req_uram = node_cost["total_mem_w_bits"] / bits_per_res["URAM"] # URAM required for this node. + req_lut = node_cost["total_mem_w_bits"] / bits_per_res["LUT"] # LUT required for this node. + + if req_bram < 1 and available_lut > req_lut: + per_node_res_bdgt = {"BRAM": 0, "URAM": 0, "LUT": math.ceil(req_lut)} + else: + if available_uram > math.ceil(req_uram) and available_bram > math.ceil(req_bram): + uram_bit_waste = (math.ceil(req_uram) - req_uram) * bits_per_res["URAM"] + bram_bit_waste = (math.ceil(req_bram) - req_bram) * bits_per_res["BRAM"] + min_bit_waste = min(uram_bit_waste, bram_bit_waste) + if min_bit_waste == bram_bit_waste: + per_node_res_bdgt = {"BRAM": math.ceil(req_bram), "URAM": 0, "LUT": 0} + else: + per_node_res_bdgt = {"BRAM": 0, "URAM": math.ceil(req_uram), "LUT": 0} + elif available_uram > math.ceil(req_uram) and available_bram < math.ceil(req_bram): + per_node_res_bdgt = {"BRAM": 0, "URAM": math.ceil(req_uram), "LUT": 0} + elif available_uram < math.ceil(req_uram) and available_bram > math.ceil(req_bram): + per_node_res_bdgt = {"BRAM": math.ceil(req_bram), "URAM": 0, "LUT": 0} + else: + bits_in_bram, bits_in_uram = available_bram * bits_per_res["BRAM"], available_uram * bits_per_res["URAM"] + total_available_bits = bits_in_bram + bits_in_uram + if total_available_bits >= node_cost["total_mem_w_bits"]: + uram_dist_factor = bits_in_uram / total_available_bits + bram_dist_factor = 1 - uram_dist_factor + bram_dist = bram_dist_factor * node_cost["total_mem_w_bits"] / bits_per_res["BRAM"] + uram_dist = uram_dist_factor * node_cost["total_mem_w_bits"] / bits_per_res["URAM"] + per_node_res_bdgt = {"BRAM": math.ceil(bram_dist), "URAM": math.ceil(uram_dist), "LUT": 0} + else: + per_node_res_bdgt = {"BRAM": math.ceil(available_bram), "URAM": math.ceil(available_uram), "LUT": 0} + + unused_resources = { + "BRAM": (available_bram - per_node_res_bdgt["BRAM"]), # -ve value indicates over budget in unused resources. + "URAM": (available_uram - per_node_res_bdgt["URAM"]), + "LUT": (available_lut - per_node_res_bdgt["LUT"]), + } + + return per_node_res_bdgt, unused_resources + + +def l1_resource_estimates( + resource_budget, + total_cost, + per_node_cost, + bwidth_lower_limit=8, + bwidth_upper_limit=32, + clock_freq=300000000, + decimal_place=4, +): + """ + This function provides an estimate about the resources used by each layer of the neural network + when deployed on a particular FPGA. + + Parameters:- + resource_budget (dict):- Resource budget of an FPGA in a dictionary format. + Example (ve2802) :-{"LUT": 520704, "BRAM": 600, "URAM": 264, "DSP58": 1312} + total_cost (dict):- Inference cost dictionary of the full model. + per_node_cost (dict):- Inference cost associated with each layer/node in the neural network. + bwidth_lower_limit (int):- Lower cut off value to use DSPs instead of LUTs. + bwidth_upper_limit (int):- Upper cut off value to use LUTs instead of DSPs. + clock_freq (int):- Default at 300 million (single pump) + decimal_place (int):- used to round upto "decimal_place" values after decimal. + """ + + # 300 MHZ + all_node_res_bdgt = {} + # rounding upto + + dsp_type = next((key for key in resource_budget if "DSP" in key), None) + + unused_resources = {"LUT": 0, "BRAM": 0, "URAM": 0} # passing unused resources to the following layers. + + resource_performance = l0_performance_estimate( + resource_budget, # Inference per sec (slowdown factor * clock cycle) + total_cost, + bwidth_lower_limit, + bwidth_upper_limit, + clock_freq, + ) + + slowdown_factor = resource_performance[1] / clock_freq + + for node_name, node_cost in per_node_cost.items(): + per_node_res_bdgt, unused_resources = memory_distribution(resource_budget, total_cost, node_cost, unused_resources) + + core_res = core_resources(node_cost, dsp_type, bwidth_lower_limit, bwidth_upper_limit) + + for i, j in core_res.items(): + core_amount = slowdown_factor * j + if i in per_node_res_bdgt.keys(): + per_node_res_bdgt[i] += rounding_down(core_amount, decimal_place) + else: + per_node_res_bdgt[i] = rounding_down(core_amount, decimal_place) + all_node_res_bdgt[node_name] = per_node_res_bdgt + return all_node_res_bdgt diff --git a/src/qonnx/data/onnx/l1_resource_estimates/network1.onnx b/src/qonnx/data/onnx/l1_resource_estimates/network1.onnx new file mode 100644 index 00000000..130a4fd6 Binary files /dev/null and b/src/qonnx/data/onnx/l1_resource_estimates/network1.onnx differ diff --git a/src/qonnx/util/inference_cost.py b/src/qonnx/util/inference_cost.py index 86428c76..57d5292d 100644 --- a/src/qonnx/util/inference_cost.py +++ b/src/qonnx/util/inference_cost.py @@ -70,8 +70,24 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"): return total_mem_bits, total_mem_elems +def assign_mem_bits_and_elems(res_dict): + mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(res_dict, "mem_w") + mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(res_dict, "mem_o") + res_dict["total_mem_w_bits"] = mem_w_bits + res_dict["total_mem_w_elems"] = mem_w_elems + res_dict["total_mem_o_bits"] = mem_o_bits + res_dict["total_mem_o_elems"] = mem_o_elems + return res_dict + + def inference_cost( - model_filename_or_wrapper, *, output_json=None, output_onnx=None, preprocess=True, discount_sparsity=True + model_filename_or_wrapper, + *, + output_json=None, + output_onnx=None, + preprocess=True, + discount_sparsity=True, + cost_breakdown=False ): """Return the inference cost estimate metric for given ONNX model. Supports the Quant op for weight/activation quantization. @@ -83,8 +99,9 @@ def inference_cost( :param preprocess: If set, run preprocessing steps such as shape inference, datatype inference and constant folding. Strongly recommended. :param discount_sparsity: If set, will discount op cost of MAC ops with a - constant zero weight, and the mem cost of constant zero weights. - """ + constant zero weight, and the mem cost of constant zero weights.""" + + combined_results = {} if isinstance(model_filename_or_wrapper, ModelWrapper): model = model_filename_or_wrapper else: @@ -104,25 +121,36 @@ def inference_cost( model = model.transform(GiveReadableTensorNames()) if output_onnx is not None: model.save(output_onnx) - ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity)) - bops, macs = compute_bops_and_macs(ret) - mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(ret, "mem_w") - mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(ret, "mem_o") - ret["total_bops"] = bops - ret["total_macs"] = macs - ret["total_mem_w_bits"] = mem_w_bits - ret["total_mem_w_elems"] = mem_w_elems - ret["total_mem_o_bits"] = mem_o_bits - ret["total_mem_o_elems"] = mem_o_elems - - if "unsupported" in ret: - ret["unsupported"] = str(ret["unsupported"]) - - if output_json is not None: - with open(output_json, "w") as f: - json.dump(ret, f, sort_keys=True, indent=2) - - return ret + ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity, cost_breakdown)) + for i, res in ret.items(): + if i == "total_cost": + bops, macs = compute_bops_and_macs(res) + res = assign_mem_bits_and_elems(res) + res["total_bops"] = bops + res["total_macs"] = macs + if "unsupported" in res: + res["unsupported"] = str(res["unsupported"]) + if output_json is not None: + with open(output_json, "w") as f: + json.dump(res, f, sort_keys=True, indent=2) + combined_results[i] = res + elif i == "optype_cost": + per_optype_breakdown = {} + for optype, op_res in res.items(): + bops, macs = compute_bops_and_macs(op_res) + op_res = assign_mem_bits_and_elems(op_res) + op_res["total_bops"] = bops + op_res["total_macs"] = macs + per_optype_breakdown[optype] = op_res + combined_results[i] = per_optype_breakdown + else: + per_node_breakdown = {} + for node_name in res.keys(): + node_res = res[node_name] + node_res = assign_mem_bits_and_elems(node_res) + per_node_breakdown[node_name] = node_res + combined_results[i] = per_node_breakdown + return combined_results def main(): diff --git a/src/qonnx/util/l0_performance_estimate.py b/src/qonnx/util/l0_performance_estimate.py new file mode 100644 index 00000000..a3ba9058 --- /dev/null +++ b/src/qonnx/util/l0_performance_estimate.py @@ -0,0 +1,146 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from qonnx.analysis.l0_resource_estimates import l0_resource_estimates + +""" Calculate the estimate amount of resources required for a model (from inference cost dict). + The estimates will be divided into two parts: + 1) CORE: For processing + 2) OCM: On-Chip Memory + First, a memory check is performed to verify enough memory is availble to accomodate the model on the FPGA. + Then, for the resources required for processing (CORE), inference per second is calculated. + Args: + resource_budget (dict): Representing the resources available in a respective FPGA. + inf_cost (dict): Inference cost dict. + resource_estimates(): dsp_type (str), bram_type (str), bwidth_lower_limit (int), + h_upper_limit (int), d_factor (float) + clock_freq: Default 3MHZ. + Returns: + A dictionary containing CORE and OCM resource estimates. + Examples: + 1) est_res_req: {'CORE': {'LUT': 1198735769600.0, 'DSP48': 3450357760.0}, + 'OCM': {'BRAM': 8798, 'URAM': 672}} + + 2) resource_budget: {'LUT': 4397752190000, 'BRAM': 1182, 'URAM': 0, 'DSP48': 500000} +""" +resource_map = { + "res_limit": { + "LUT": 0.7, + "BRAM": 0.80, + "URAM": 0.80, + "DSP48": 0.80, + "DSP58": 0.80, + }, + "bits_per_res": {"BRAM": 36864, "URAM": 294912, "LUT": 64}, +} + + +def d_fact(resource_budget, bits_per_res, uram_type, bram_type): + "Determining the d_factor for l0_resource_estimates." + if uram_type == bram_type is None: + d_factor = None + elif uram_type is None and bram_type is not None: + d_factor = 0 + elif bram_type is None and uram_type is not None: + d_factor = 1 + else: + available_bits_uram = resource_budget[uram_type] * bits_per_res[uram_type] + available_bits_bram = resource_budget[bram_type] * bits_per_res[bram_type] + d_factor = available_bits_uram / (available_bits_uram + available_bits_bram) + return d_factor + + +def l0_performance_estimate( + resource_budget, + inf_cost, + bwidth_lower_limit=8, + bwidth_upper_limit=32, + clock_freq=300000000, +): + expected_inference = {} + dsp_type = next((key for key in resource_budget if "DSP" in key), None) + bram_type = next((key for key in resource_budget if "BRAM" in key), None) + uram_type = next((key for key in resource_budget if "URAM" in key), None) + res_limit, bits_per_res = resource_map["res_limit"], resource_map["bits_per_res"] + d_factor = d_fact(resource_budget, bits_per_res, uram_type, bram_type) + est_res_req = l0_resource_estimates(inf_cost, dsp_type, bwidth_lower_limit, bwidth_upper_limit, d_factor) + ocm_res_req, core_res_req = est_res_req["OCM"], est_res_req["CORE"] + luts_for_mem = (1 - res_limit["LUT"]) * resource_budget["LUT"] # some amount of LUTs for memory requirement. + for type, res in ocm_res_req.items(): + if type == "LUT": + luts_req = res + resource_tally = luts_for_mem - luts_req + if resource_tally >= 0: + luts_for_mem = luts_for_mem - luts_req + memory_check = True + else: + luts_for_mem = 0 + memory_check = False + break + else: + if type in resource_budget.keys(): + resource_tally = (res_limit[type] * resource_budget[type]) - res + if resource_tally >= 0: # do param fit on ocm. + memory_check = True + else: + luts_req = (bits_per_res[type] / bits_per_res["LUT"]) * abs(resource_tally) + resource_tally = (res_limit["LUT"] * luts_for_mem) - luts_req + if resource_tally >= 0: + print(f"{type} out of budget, using luts") + memory_check = True + luts_for_mem = luts_for_mem - luts_req + else: + luts_for_mem = 0 + memory_check = False + break + else: + luts_req = (bits_per_res[type] / bits_per_res["LUT"]) * res + resource_tally = luts_for_mem - luts_req + if resource_tally >= 0: + print(f"{type} not available in the budget, using luts") + luts_for_mem = luts_for_mem - luts_req + memory_check = True + else: + luts_for_mem = 0 + memory_check = False + break + + if memory_check is True: + for i in core_res_req.keys(): + if core_res_req[i] > 0: + inf_sec = ((res_limit[i] * resource_budget[i]) / core_res_req[i]) * clock_freq + expected_inference[i] = inf_sec + else: + continue + min_infc_res = min(expected_inference, key=expected_inference.get) + min_infc_sec = expected_inference[min_infc_res] + ret = (min_infc_res, min_infc_sec) + else: + ret = "Memory out of budget" + return ret diff --git a/tests/analysis/test_inference_cost_breakdown.py b/tests/analysis/test_inference_cost_breakdown.py new file mode 100644 index 00000000..3bca2758 --- /dev/null +++ b/tests/analysis/test_inference_cost_breakdown.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import os +import urllib.request + +from qonnx.analysis.inference_cost import aggregate_dict_keys +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.cleanup import cleanup +from qonnx.util.inference_cost import inference_cost as infca + +download_url = "https://github.com/onnx/models/raw/main/validated/vision/" +download_url += "classification/resnet/model/resnet18-v1-7.onnx?download=" + +model_details = { + "resnet18-v1-7": { + "description": "Resnet18 Opset version 7.", + "url": download_url, + }, +} + + +def download_model(test_model, do_cleanup=False, return_modelwrapper=False): + qonnx_url = model_details[test_model]["url"] + # download test data + dl_dir = "/tmp" + dl_file = dl_dir + f"/{test_model}.onnx" + ret = dl_file + if not os.path.isfile(dl_file): + urllib.request.urlretrieve(qonnx_url, dl_file) + if do_cleanup: + out_file = dl_dir + f"/{test_model}_clean.onnx" + cleanup(dl_file, out_file=out_file, override_inpsize=1) + ret = out_file + if return_modelwrapper: + ret = ModelWrapper(ret) + return ret + + +@pytest.mark.parametrize("test_model", model_details.keys()) +def test_inference_cost_breakdown(test_model): + model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) + inf_cost = infca(model, discount_sparsity=False, cost_breakdown=True) + t_cost = inf_cost["total_cost"] # total cost + op_cost = aggregate_dict_keys(inf_cost["optype_cost"]) # cost per optype + n_cost = aggregate_dict_keys(inf_cost["node_cost"]) # cost per node. + assert ( + t_cost["op_mac_FLOAT32_FLOAT32"] == op_cost["op_mac_FLOAT32_FLOAT32"] == n_cost["op_mac_FLOAT32_FLOAT32"] + ), "inf discrepancy" + assert t_cost["total_mem_w_bits"] == op_cost["total_mem_w_bits"] == n_cost["total_mem_w_bits"], "inf discrepancy" + assert t_cost["total_mem_w_elems"] == op_cost["total_mem_w_elems"] == n_cost["total_mem_w_elems"], "inf discrepancy" + assert t_cost["total_mem_o_bits"] == op_cost["total_mem_o_bits"] == n_cost["total_mem_o_bits"], "inf discrepancy" + assert t_cost["total_mem_o_elems"] == op_cost["total_mem_o_elems"] == n_cost["total_mem_o_elems"], "inf discrepancy" diff --git a/tests/analysis/test_l0_resource_estimates.py b/tests/analysis/test_l0_resource_estimates.py new file mode 100644 index 00000000..d5a503f5 --- /dev/null +++ b/tests/analysis/test_l0_resource_estimates.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import os +import urllib.request + +from qonnx.analysis.l0_resource_estimates import l0_resource_estimates +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.cleanup import cleanup +from qonnx.util.inference_cost import inference_cost + +download_url = "https://github.com/onnx/models/raw/main/validated/vision/" +download_url += "classification/resnet/model/resnet18-v1-7.onnx?download=" + +model_details = { + "resnet18-v1-7": { + "description": "Resnet18 Opset version 7.", + "url": download_url, + }, +} + + +def download_model(test_model, do_cleanup=False, return_modelwrapper=False): + qonnx_url = model_details[test_model]["url"] + # download test data + dl_dir = "/tmp" + dl_file = dl_dir + f"/{test_model}.onnx" + ret = dl_file + if not os.path.isfile(dl_file): + urllib.request.urlretrieve(qonnx_url, dl_file) + if do_cleanup: + out_file = dl_dir + f"/{test_model}_clean.onnx" + cleanup(dl_file, out_file=out_file, override_inpsize=1) + ret = out_file + if return_modelwrapper: + ret = ModelWrapper(ret) + return ret + + +@pytest.mark.parametrize("test_model", model_details.keys()) +def test_l0_resource_estimates(test_model): + model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) + inf_cost = inference_cost(model, discount_sparsity=False) + infc_dict = inf_cost["total_cost"] + estimates1 = l0_resource_estimates( + infc_dict, + dsp_type="DSP48", + d_factor=0.7, + ) + estimates2 = l0_resource_estimates(infc_dict) # default values for dsp_type, bram, d_fator are None. + estimates3 = l0_resource_estimates( + infc_dict, + dsp_type="DSP58", + d_factor=0.5, + ) + ocm_res1, core_res1 = estimates1["OCM"], estimates1["CORE"] + ocm_res2, core_res2 = estimates2["OCM"], estimates2["CORE"] + ocm_res3, core_res3 = estimates3["OCM"], estimates3["CORE"] + assert infc_dict["total_mem_w_bits"] == (ocm_res1["BRAM"] * 36 + ocm_res1["URAM"] * 288) * 1024, "discrepancy" + assert infc_dict["total_mem_w_bits"] == (ocm_res2["LUT"] * 64), "discrepancy" + assert infc_dict["total_mem_w_bits"] == (ocm_res3["BRAM"] * 36 + ocm_res3["URAM"] * 288) * 1024, "discrepancy" + # 1 mac fp32*fp32 requires 2 DSP48 and 700 LUTs. + assert core_res1["LUT"] == 700 * infc_dict["op_mac_FLOAT32_FLOAT32"], "discrepancy" + assert core_res1["DSP48"] == 2 * infc_dict["op_mac_FLOAT32_FLOAT32"], "discrepancy" + # when dsp_type is None all processing is performed in LUTs. + assert core_res2["LUT"] == 1.1 * 32 * 32 * infc_dict["op_mac_FLOAT32_FLOAT32"], "discrepancy" + # 1 mac fp32*fp32 requires 1 DSP58 and no LUTs. + assert core_res3["DSP58"] == infc_dict["op_mac_FLOAT32_FLOAT32"], "discrepancy" + assert core_res3["LUT"] == 0.0, "discrepancy" diff --git a/tests/analysis/test_l1_resource_estimates.py b/tests/analysis/test_l1_resource_estimates.py new file mode 100644 index 00000000..99a848ca --- /dev/null +++ b/tests/analysis/test_l1_resource_estimates.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from pkgutil import get_data + +from qonnx.analysis.l1_resource_estimates import l1_resource_estimates +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.inference_cost import inference_cost + +details = { + "details": { + "description": "Cleaned mnv1 model.", + "resource_budget": {"LUT": 520704, "BRAM": 600, "URAM": 264, "DSP58": 1312}, # ve2802 + "res_match": {"LUT": 0, "BRAM": 0, "URAM": 0, "DSP58": 0}, + "res_limit": {"LUT": 0.70, "BRAM": 0.80, "URAM": 0.80, "DSP58": 0.80}, + } +} + + +@pytest.mark.parametrize("test_details", details.keys()) +def test_l1_resource_estimates(test_details): + model_path = get_data("qonnx", "data/onnx/l1_resource_estimates/network1.onnx") # cleaned mnv1. + wrapped_model = ModelWrapper(model_path) + inf_cost = inference_cost(wrapped_model, discount_sparsity=False, cost_breakdown=True) + total_cost, per_node_cost = inf_cost["total_cost"], inf_cost["node_cost"] + r_bdgt = details[test_details]["resource_budget"] + res_match = details[test_details]["res_match"] + res_limit = details[test_details]["res_limit"] + all_node_res_bdgt = l1_resource_estimates(r_bdgt, total_cost, per_node_cost) + assert len(all_node_res_bdgt) == 28, "Discrepancy" # Total number of layers/nodes in the network. + for name, res in all_node_res_bdgt.items(): + res_match["LUT"] += res["LUT"] + res_match["BRAM"] += res["BRAM"] + res_match["URAM"] += res["URAM"] + res_match["DSP58"] += res["DSP58"] + + assert ( + res_limit["LUT"] * r_bdgt["LUT"] >= res_match["LUT"] and res_match["LUT"] == 3100.0000 + ), "Over Budget or Discrepancy" + assert ( + res_limit["BRAM"] * r_bdgt["BRAM"] >= res_match["BRAM"] and res_match["BRAM"] == 288.0000 + ), "Over Budget and Discrepancy" + assert ( + res_limit["URAM"] * r_bdgt["URAM"] >= res_match["URAM"] and res_match["URAM"] == 0.0000 + ), "Over Budget or Discrepancy" + assert ( + res_limit["DSP58"] * r_bdgt["DSP58"] >= res_match["DSP58"] and res_match["DSP58"] == 1049.5988 + ), "Over Budget or Discrepancy" diff --git a/tests/util/test_l0_performance_estimate.py b/tests/util/test_l0_performance_estimate.py new file mode 100644 index 00000000..5ecbca5a --- /dev/null +++ b/tests/util/test_l0_performance_estimate.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +import os +import urllib.request + +from qonnx.analysis.l0_resource_estimates import l0_resource_estimates +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.cleanup import cleanup +from qonnx.util.inference_cost import inference_cost +from qonnx.util.l0_performance_estimate import d_fact, l0_performance_estimate + +download_url = "https://github.com/onnx/models/raw/main/validated/vision/" +download_url += "classification/resnet/model/resnet18-v1-7.onnx?download=" + +model_details = { + "resnet18-v1-7": { + "description": "Resnet18 Opset version 7.", + "url": download_url, + "r_bdgt1": {"LUT": 375000, "BRAM": 1152, "URAM": 300, "DSP48": 2800}, + "r_bdgt2": {"LUT": 375000, "BRAM": 11520, "URAM": 1200, "DSP48": 2800}, # created for the test. + "bits_per_res": {"BRAM": 36864, "URAM": 294912, "LUT": 64}, + "res_limit": { + "LUT": 0.7, + "BRAM": 0.80, + "URAM": 0.80, + "DSP48": 0.80, + "DSP58": 0.80, + }, + }, +} + + +def download_model(test_model, do_cleanup=False, return_modelwrapper=False): + qonnx_url = model_details[test_model]["url"] + # download test data + dl_dir = "/tmp" + dl_file = dl_dir + f"/{test_model}.onnx" + ret = dl_file + if not os.path.isfile(dl_file): + urllib.request.urlretrieve(qonnx_url, dl_file) + if do_cleanup: + out_file = dl_dir + f"/{test_model}_clean.onnx" + cleanup(dl_file, out_file=out_file, override_inpsize=1) + ret = out_file + if return_modelwrapper: + ret = ModelWrapper(ret) + return ret + + +def performance_check(test_model, r_bdgt, core_res_req, clock_freq): + expected_inference = {} + res_limit = model_details[test_model]["res_limit"] + for i in core_res_req: + inf_sec = ((res_limit[i] * r_bdgt[i]) / core_res_req[i]) * clock_freq + expected_inference[i] = inf_sec + min_infc_res = min(expected_inference, key=expected_inference.get) + min_infc_sec = expected_inference[min_infc_res] + ret = (min_infc_res, min_infc_sec) + return ret + + +@pytest.mark.parametrize("test_model", model_details.keys()) +def test_l0_performance_estimate(test_model): + model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) + inf_cost = inference_cost(model, discount_sparsity=False) + infc_dict = inf_cost["total_cost"] + test_details = model_details[test_model] + bits_per_res = test_details["bits_per_res"] + r_bdgt1 = test_details["r_bdgt1"] + r_bdgt2 = test_details["r_bdgt2"] + per_est1 = l0_performance_estimate( + r_bdgt1, + infc_dict, + ) + d_factor2 = d_fact(r_bdgt2, bits_per_res, uram_type="URAM", bram_type="BRAM") + res_est_req2 = l0_resource_estimates( + infc_dict, + d_factor=d_factor2, + dsp_type="DSP48", + ) + core_res_req2 = res_est_req2["CORE"] + perf_check2 = performance_check(test_model, r_bdgt2, core_res_req2, clock_freq=300000000) + per_est2 = l0_performance_estimate( + r_bdgt2, + infc_dict, + ) + assert per_est1 == "Memory out of budget", "Discrepancy" # memory check fails + assert per_est2 == perf_check2, "Discrepancy" # memory check will pass.