diff --git a/src/finn/transformation/streamline/__init__.py b/src/finn/transformation/streamline/__init__.py index 2e68de698..39ef87f81 100644 --- a/src/finn/transformation/streamline/__init__.py +++ b/src/finn/transformation/streamline/__init__.py @@ -76,8 +76,8 @@ def apply(self, model): BatchNormToAffine(), ConvertSignToThres(), MoveMulPastMaxPool(), - MoveScalarLinearPastInvariants(), AbsorbSignBiasIntoMultiThreshold(), + MoveScalarLinearPastInvariants(), MoveAddPastMul(), MoveScalarAddPastMatMul(), MoveAddPastConv(), diff --git a/src/finn/transformation/streamline/absorb.py b/src/finn/transformation/streamline/absorb.py index 4c280d8f2..447fedf01 100644 --- a/src/finn/transformation/streamline/absorb.py +++ b/src/finn/transformation/streamline/absorb.py @@ -30,6 +30,10 @@ import qonnx.core.data_layout as DataLayout import warnings from onnx import helper as oh +# Protobuf onnx graph node type +from onnx import NodeProto # noqa +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.datatype import DataType # QONNX wrapper of ONNX model graphs @@ -261,7 +265,7 @@ def apply(self, model): class Absorb1BitMulIntoMatMul(Transformation): - """Absorb bipolar or binary multiplications into the preciding matrix + """Absorb bipolar or binary multiplications into the preceding matrix multiply.""" def apply(self, model): @@ -270,16 +274,28 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "MatMul": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "MatMul" and not model.is_fork_node(n): matmul_weight_name = n.input[1] W = model.get_initializer(matmul_weight_name) Wdt = model.get_tensor_datatype(matmul_weight_name) - assert W is not None, "Initializer for matmul weights is not set." + # Just skip matmuls with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 if is_1bit: Wnew = A * W @@ -298,7 +314,7 @@ def apply(self, model): class Absorb1BitMulIntoConv(Transformation): - """Absorb bipolar or binary multiplications into the preciding convolution.""" + """Absorb bipolar or binary multiplications into the preceding convolution.""" def apply(self, model): graph = model.graph @@ -306,16 +322,28 @@ def apply(self, model): graph_modified = False for n in graph.node: node_ind += 1 - if n.op_type == "Conv": + # Note: Join-node test is implicitly covered by testing for the + # initializer below + # Note: This cannot handle fork-nodes, as only the first consumer is + # considered below. + # TODO: Fork-nodes could be handled if the muls are the same in all + # branches, but this is not checked nor rewired at all right now. + if n.op_type == "Conv" and not model.is_fork_node(n): conv_weight_name = n.input[1] W = model.get_initializer(conv_weight_name) Wdt = model.get_tensor_datatype(conv_weight_name) - assert W is not None, "Initializer for conv weights is not set." + # Just skip convs with non-existing weight initializers + if W is None: + continue consumer = model.find_consumer(n.output[0]) + # Note: Join-node test is implicitly covered by testing for the + # initializer below if consumer is not None and consumer.op_type == "Mul": mul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) - assert A is not None, "Initializer for mul weights is not set." + # Just skip muls with non-existing scale initializers + if A is None: + continue is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 is_scalar = np.prod(A.shape) == 1 actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape))) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 2c54518ed..77ecf15a2 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -116,58 +116,133 @@ def apply(self, model): return model, graph_modified +# Tests whether a tensor is a scalar, i.e., whether all dimensions are 1 +def is_scalar(tensor): + return tensor is not None and all(x == 1 for x in tensor.shape) + + +# Tests whether a node is a scalar multiplication with a constant scale factor +def is_const_scalar_mul(node, model): + # Only handle existing Mul type nodes + if node is not None and node.op_type == "Mul": + # The constant must be an initializer + # Note: Assumes the constant parameter to always be the second input + scale = model.get_initializer(node.input[1]) + # Test for existence of a constant scale factor + return scale is not None and is_scalar(scale) + # Did not match the operator type + return False + + +# Refactored version of the MoveScalarMulPastMatMul transform capable of +# transforming two-input MatMul, like those being part of the attention operator class MoveScalarMulPastMatMul(Transformation): """Move scalar mul operations past matmul operations. We want to have muls next to each other such that they can be collapsed into a single mul.""" + # Applies the transform to a whole model graph def apply(self, model): + # Get the model graph out of the model wrapper object graph = model.graph - node_ind = 0 + # Keep track of whether the graph has been modified graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n): - consumer = model.find_consumer(n.output[0]) - if ( - consumer is not None - and consumer.op_type == "MatMul" - and not model.is_join_node(consumer) - ): - mul_weight_name = n.input[1] - matmul_weight_name = consumer.input[1] - A = model.get_initializer(mul_weight_name) - W = model.get_initializer(matmul_weight_name) - if (A is None) or (W is None): - warnings.warn("MatMul or Mul params are not constant, skipping") + + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # First pattern matching condition: For the transform to be + # applicable, the node has to be a MatMul operator + if node.op_type == "MatMul": + # Note: When touching the following code, remember to treat both + # branches equivalently! + # TODO: Can this be enforced or at least be made easier by + # extracting common code patterns to a function? + + # Get the left hand side and right hand side inputs + # Note: Assumes the ordering of left to right inputs to match + # indices 0 to 1. However, it does not "hurt" if it is + # reversed as both sides are treated equivalently. + lhs = model.find_producer(node.input[0]) + rhs = model.find_producer(node.input[1]) + + # Give precedence to the left hand side input testing for the + # presence of a scalar multiplication + if is_const_scalar_mul(lhs, model): + # Cannot handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. + if model.is_fork_node(lhs): + # Softly skip this node continue - start_name = n.input[0] - middle_name = n.output[0] - end_name = consumer.output[0] - mm_out_shape = model.get_tensor_shape(end_name) - if all(x == 1 for x in A.shape): - # if the mul is scalar, we can simply swap the order of ops - # make and insert new nodes - new_matmul = oh.make_node( - "MatMul", - [start_name, matmul_weight_name], - [middle_name], - name=consumer.name, - ) - new_mul = oh.make_node( - "Mul", - [middle_name, mul_weight_name], - [end_name], - name=n.name, - ) - graph.node.insert(node_ind, new_matmul) - graph.node.insert(node_ind + 1, new_mul) - model.set_tensor_shape(middle_name, mm_out_shape) - # remove old nodes - graph.node.remove(n) - graph.node.remove(consumer) - graph_modified = True + # Unpack the connection pattern of a scalar mul feeding the + # lhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = lhs.input[0], lhs.input[1], node.input[1] + # Names of the intermediate and the global output + m, o = lhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, c], [m], node.name) + mul = oh.make_node("Mul", [m, b], [o], lhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(lhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + # Cannot further modify the node (i.e., the rhs) as the + # index and state of the nodes changed and need to be + # queried again from the graph.node at the start of the next + # iteration. + continue + + # Next try whether the right hand side matches the pattern of a + # scalar multiplication + if is_const_scalar_mul(rhs, model): + # Cannot handle fork nodes: We would have to distribute the + # Mul into all branches + # TODO: Maybe reconsider this at some point, there is + # probably nothing preventing this in general, it is just + # more difficult and apparently not necessary right now. + if model.is_fork_node(rhs): + # Softly skip this node + continue + # Unpack the connection pattern of a scalar mul feeding the + # rhs input of the matmul + # Names of the three input tensors to the mul-matmul complex + a, b, c = node.input[0], rhs.input[0], rhs.input[1] + # Names of the intermediate and the global output + m, o = rhs.output[0], node.output[0] # noqa: Duplicate code + # Rewire the operator connections locally, swapping mul and + # matmul operator order + matmul = oh.make_node("MatMul", [a, b], [m], node.name) + mul = oh.make_node("Mul", [m, c], [o], rhs.name) + # Insert the rewired nodes into the graph + graph.node.insert(index, matmul) + graph.node.insert(index + 1, mul) + # Adapt the shape of the intermediate tensor as it changed + # according to the output shape of the matmul + model.set_tensor_shape(m, model.get_tensor_shape(o)) + # Remove the old nodes from the graph + graph.node.remove(rhs) + graph.node.remove(node) + # The graph has been modified, this needs to be reported + # back to the caller + graph_modified = True + + # Finalize the transformation by inferring shapes again (as these might + # have changed) model = model.transform(InferShapes()) - return (model, graph_modified) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified class MoveScalarAddPastMatMul(Transformation): @@ -617,6 +692,7 @@ def apply(self, model): graph_modified = True else: continue + # Note: Running shape inference is necessary as shape annotations have # been deleted above model = model.transform(InferShapes()) @@ -634,6 +710,7 @@ class MoveScalarLinearPastInvariants(Transformation): GlobalAveragePool """ + # Op-types of currently supported invariants # Op-types of currently supported invariants SUPPORTED_INVARIANTS = { "GlobalAveragePool", diff --git a/tests/transformation/streamline/test_move_scalar_past_matmul.py b/tests/transformation/streamline/test_move_scalar_past_matmul.py index e4f4357ff..515e9b946 100644 --- a/tests/transformation/streamline/test_move_scalar_past_matmul.py +++ b/tests/transformation/streamline/test_move_scalar_past_matmul.py @@ -72,6 +72,43 @@ def test_move_scalar_mul_past_matmul(): assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] +@pytest.mark.streamline +def test_move_scalar_mul_past_join_matmul(): + top_in1 = oh.make_tensor_value_info("top_in1", TensorProto.FLOAT, [1, 2]) + top_in2 = oh.make_tensor_value_info("top_in2", TensorProto.FLOAT, [2, 1]) + mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1, 1]) + mul2_param = oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1, 1]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 1]) + modelproto = qonnx_make_model( + oh.make_graph( + name="test", + inputs=[top_in1, top_in2], + outputs=[top_out], + value_info=[mul1_param, mul2_param], + nodes=[ + oh.make_node("Mul", ["top_in1", "mul1_param"], ["middle1"]), + oh.make_node("Mul", ["top_in2", "mul2_param"], ["middle2"]), + oh.make_node("MatMul", ["middle1", "middle2"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform(InferShapes()) + model.set_initializer("mul1_param", np.asarray([[3]], dtype=np.float32)) + model.set_initializer("mul2_param", np.asarray([[3]], dtype=np.float32)) + new_model = model.transform(MoveScalarMulPastMatMul()) + inp_dict = { + "top_in1": np.asarray([[-1.0, 1.0]], dtype=np.float32), + "top_in2": np.asarray([[1.0], [-1.0]], dtype=np.float32), + } + assert ox.compare_execution(model, new_model, inp_dict) + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[2].op_type == "Mul" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] + assert new_model.graph.node[1].output[0] == new_model.graph.node[2].input[0] + + @pytest.mark.streamline def test_move_scalar_add_past_matmul(): top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2])