diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index bf2cf8b4..0bac74dd 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -33,14 +33,10 @@ class ExtractBiasFromConv(Transformation): - """ - Extracts the (optional) Bias from a Conv(Transpose) node and inserts it behind the - Conv(Transpose) node as an Add node. - """ - def apply(self, model): graph = model.graph node_ind = 0 + nodes_to_remove = [] for n in graph.node: node_ind += 1 if n.op_type in ["Conv", "ConvTranspose"]: @@ -49,9 +45,16 @@ def apply(self, model): # Extract bias bias = model.get_initializer(n.input[2]) if bias is None: - warnings.warn(f"Could not extract bias from node {n}") - continue + producer = model.find_producer(n.input[2]) + bias = model.get_initializer(producer.input[0]) + if bias is None: + warnings.warn(f"Could not extract bias from node") + continue + if producer is not None: + # Mark the producer node for removal + nodes_to_remove.append(producer) + # Insert bias as Add node behind the Conv node out_shape = model.get_tensor_shape(n.output[0]) # Reshape bias tensor @@ -80,6 +83,9 @@ def apply(self, model): # Repoint Conv output and remove bias tensor n.output[0] = act_add_tensor.name n.input.remove(n.input[2]) + + for node_to_remove in nodes_to_remove: + graph.node.remove(node_to_remove) return model, True