From 5ce18b427df888f223cf199ef99e7077c30527c3 Mon Sep 17 00:00:00 2001 From: joannapng Date: Mon, 24 Jun 2024 12:56:41 +0300 Subject: [PATCH 1/2] find bias quant initializer and delete remaining bias quant nodes --- src/qonnx/transformation/extract_conv_bias.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index bf2cf8b4..cfaed135 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -33,14 +33,10 @@ class ExtractBiasFromConv(Transformation): - """ - Extracts the (optional) Bias from a Conv(Transpose) node and inserts it behind the - Conv(Transpose) node as an Add node. - """ - def apply(self, model): graph = model.graph node_ind = 0 + nodes_to_remove = [] for n in graph.node: node_ind += 1 if n.op_type in ["Conv", "ConvTranspose"]: @@ -49,9 +45,25 @@ def apply(self, model): # Extract bias bias = model.get_initializer(n.input[2]) if bias is None: - warnings.warn(f"Could not extract bias from node {n}") - continue + name = n.input[2] + name = name.replace("out", "param") + bias = model.get_initializer(name) + if bias is None: + warnings.warn(f"Could not extract bias from node") + continue + # Find the node that provides this input + bias_input_name = n.input[2] + producer_node = None + for pn in graph.node: + if bias_input_name in pn.output: + producer_node = pn + break + + if producer_node is not None: + # Mark the producer node for removal + nodes_to_remove.append(producer_node) + # Insert bias as Add node behind the Conv node out_shape = model.get_tensor_shape(n.output[0]) # Reshape bias tensor @@ -80,6 +92,9 @@ def apply(self, model): # Repoint Conv output and remove bias tensor n.output[0] = act_add_tensor.name n.input.remove(n.input[2]) + + for node_to_remove in nodes_to_remove: + graph.node.remove(node_to_remove) return model, True From 34ea9c319861681547777be4628652459df41139 Mon Sep 17 00:00:00 2001 From: joannapng Date: Mon, 24 Jun 2024 13:10:34 +0300 Subject: [PATCH 2/2] use producer node instead of name convention --- src/qonnx/transformation/extract_conv_bias.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index cfaed135..0bac74dd 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -45,24 +45,15 @@ def apply(self, model): # Extract bias bias = model.get_initializer(n.input[2]) if bias is None: - name = n.input[2] - name = name.replace("out", "param") - bias = model.get_initializer(name) + producer = model.find_producer(n.input[2]) + bias = model.get_initializer(producer.input[0]) if bias is None: warnings.warn(f"Could not extract bias from node") continue - # Find the node that provides this input - bias_input_name = n.input[2] - producer_node = None - for pn in graph.node: - if bias_input_name in pn.output: - producer_node = pn - break - - if producer_node is not None: + if producer is not None: # Mark the producer node for removal - nodes_to_remove.append(producer_node) + nodes_to_remove.append(producer) # Insert bias as Add node behind the Conv node out_shape = model.get_tensor_shape(n.output[0])