Skip to content

Commit

Permalink
[Test] fix changes return style for inference cost
Browse files Browse the repository at this point in the history
  • Loading branch information
maltanar committed May 21, 2024
1 parent 0ca12ce commit a4e7e35
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 76 deletions.
2 changes: 1 addition & 1 deletion src/qonnx/util/inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def inference_cost(
if "unsupported" in res:
res["unsupported"] = str(res["unsupported"])
combined_results[i] = res
else:
elif i in ["optype_cost", "node_cost"]:
per_optype_or_node_breakdown = {}
for optype, op_res in res.items():
bops, macs = compute_bops_and_macs(op_res)
Expand Down
152 changes: 82 additions & 70 deletions tests/analysis/test_inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,90 +34,102 @@
model_details_infcost = {
"FINN-CNV_W2A2": {
"expected_sparse": {
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
"mem_w_INT2": 908033.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 35615771.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 163991084.0,
"total_macs": 36961271.0,
"total_mem_w_bits": 1816066.0,
"total_mem_w_elems": 908033.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
"total_cost": {
"op_mac_SCALEDINT<8>_INT2": 1345500.0,
"mem_w_INT2": 908033.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 35615771.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 163991084.0,
"total_macs": 36961271.0,
"total_mem_w_bits": 1816066.0,
"total_mem_w_elems": 908033.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
}
},
"expected_dense": {
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
"mem_w_INT2": 1542848.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 57906176.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 256507904.0,
"total_macs": 59461376.0,
"total_mem_w_bits": 3085696.0,
"total_mem_w_elems": 1542848.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
"total_cost": {
"op_mac_SCALEDINT<8>_INT2": 1555200.0,
"mem_w_INT2": 1542848.0,
"mem_o_SCALEDINT<32>": 57600.0,
"op_mac_INT2_INT2": 57906176.0,
"mem_o_INT32": 85002.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 256507904.0,
"total_macs": 59461376.0,
"total_mem_w_bits": 3085696.0,
"total_mem_w_elems": 1542848.0,
"total_mem_o_bits": 4563264.0,
"total_mem_o_elems": 142602.0,
}
},
},
"FINN-TFC_W2A2": {
"expected_sparse": {
"op_mac_INT2_INT2": 22355.0,
"mem_w_INT2": 22355.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 89420.0,
"total_macs": 22355.0,
"total_mem_w_bits": 44710.0,
"total_mem_w_elems": 22355.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
"total_cost": {
"op_mac_INT2_INT2": 22355.0,
"mem_w_INT2": 22355.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 89420.0,
"total_macs": 22355.0,
"total_mem_w_bits": 44710.0,
"total_mem_w_elems": 22355.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
}
},
"expected_dense": {
"op_mac_INT2_INT2": 59008.0,
"mem_w_INT2": 59008.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 236032.0,
"total_macs": 59008.0,
"total_mem_w_bits": 118016.0,
"total_mem_w_elems": 59008.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
"total_cost": {
"op_mac_INT2_INT2": 59008.0,
"mem_w_INT2": 59008.0,
"mem_o_INT32": 202.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 236032.0,
"total_macs": 59008.0,
"total_mem_w_bits": 118016.0,
"total_mem_w_elems": 59008.0,
"total_mem_o_bits": 6464.0,
"total_mem_o_elems": 202.0,
}
},
},
"RadioML_VGG10": {
"expected_sparse": {
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
"mem_w_SCALEDINT<8>": 155617.0,
"mem_o_SCALEDINT<32>": 130328.0,
"unsupported": "set()",
"discount_sparsity": True,
"total_bops": 807699904.0,
"total_macs": 12620311.0,
"total_mem_w_bits": 1244936.0,
"total_mem_w_elems": 155617.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
"total_cost": {
"unsupported": "set()",
"discount_sparsity": True,
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12620311.0,
"mem_w_SCALEDINT<8>": 155617.0,
"mem_o_SCALEDINT<32>": 130328.0,
"total_bops": 807699904.0,
"total_macs": 12620311.0,
"total_mem_w_bits": 1244936.0,
"total_mem_w_elems": 155617.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
}
},
"expected_dense": {
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
"mem_w_SCALEDINT<8>": 159104.0,
"mem_o_SCALEDINT<32>": 130328.0,
"unsupported": "set()",
"discount_sparsity": False,
"total_bops": 823328768.0,
"total_macs": 12864512.0,
"total_mem_w_bits": 1272832.0,
"total_mem_w_elems": 159104.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
"total_cost": {
"unsupported": "set()",
"discount_sparsity": False,
"op_mac_SCALEDINT<8>_SCALEDINT<8>": 12864512.0,
"mem_w_SCALEDINT<8>": 159104.0,
"mem_o_SCALEDINT<32>": 130328.0,
"total_bops": 823328768.0,
"total_macs": 12864512.0,
"total_mem_w_bits": 1272832.0,
"total_mem_w_elems": 159104.0,
"total_mem_o_bits": 4170496.0,
"total_mem_o_elems": 130328.0,
}
},
},
}
Expand Down
2 changes: 1 addition & 1 deletion tests/analysis/test_matmul_mac_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def test_matmul_mac_cost():
cleaned_model = cleanup_model(model)
# Two Matmul layers with shape (i_shape, w_shape, o_shape),
# L1: ([4, 64, 32], [4, 32, 64], [4, 64, 64]) and L2: ([4, 64, 64], [4, 64, 32], [4, 64, 32])
inf_cost_dict = infc.inference_cost(cleaned_model, discount_sparsity=False)
inf_cost_dict = infc.inference_cost(cleaned_model, discount_sparsity=False)["total_cost"]
mac_cost = inf_cost_dict["op_mac_FLOAT32_FLOAT32"] # Expected mac cost 4*32*64*64 + 4*64*64*32 = 1048576
assert mac_cost == 1048576.0, "Error: discrepancy in mac cost."
4 changes: 2 additions & 2 deletions tests/transformation/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_pruning_mnv1():
# do cleanup including folding quantized weights
model = cleanup_model(model, False)
inp, golden = get_golden_in_and_output("MobileNetv1-w4a4")
cost0 = inference_cost(model, discount_sparsity=False)
cost0 = inference_cost(model, discount_sparsity=False)["total_cost"]
assert cost0["op_mac_SCALEDINT<8>_SCALEDINT<8>"] == 10645344.0
assert cost0["mem_w_SCALEDINT<8>"] == 864.0
assert cost0["op_mac_SCALEDINT<4>_SCALEDINT<4>"] == 556357408.0
Expand All @@ -105,7 +105,7 @@ def test_pruning_mnv1():
}

model = model.transform(PruneChannels(prune_spec))
cost1 = inference_cost(model, discount_sparsity=False)
cost1 = inference_cost(model, discount_sparsity=False)["total_cost"]
assert cost1["op_mac_SCALEDINT<8>_SCALEDINT<8>"] == 7318674.0
assert cost1["mem_w_SCALEDINT<8>"] == 594.0
assert cost1["op_mac_SCALEDINT<4>_SCALEDINT<4>"] == 546053216.0
Expand Down
4 changes: 2 additions & 2 deletions tests/transformation/test_quantize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def to_verify(model, test_details):
def test_quantize_graph(test_model):
test_details = model_details[test_model]
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
original_model_inf_cost = inference_cost(model, discount_sparsity=False)
original_model_inf_cost = inference_cost(model, discount_sparsity=False)["total_cost"]
nodes_pos = test_details["test_input"]
model = model.transform(QuantizeGraph(nodes_pos))
quantnodes_added = len(model.get_nodes_by_op_type("Quant"))
assert quantnodes_added == 10 # 10 positions are specified.
verification = to_verify(model, nodes_pos)
assert verification == "Success"
inf_cost = inference_cost(model, discount_sparsity=False)
inf_cost = inference_cost(model, discount_sparsity=False)["total_cost"]
assert (
inf_cost["total_macs"] == original_model_inf_cost["total_macs"]
) # "1814073344.0" must be same as the original model.
Expand Down

0 comments on commit a4e7e35

Please sign in to comment.