Skip to content

Commit

Permalink
Merge pull request #12 from eki-project/feature/attention-streamline
Browse files Browse the repository at this point in the history
Streamlining of Scaled Dot-Product Attention
  • Loading branch information
iksnagreb authored Feb 6, 2025
2 parents a0b9007 + 95ed158 commit 1e3085f
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/finn/transformation/streamline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def apply(self, model):
BatchNormToAffine(),
ConvertSignToThres(),
MoveMulPastMaxPool(),
MoveScalarLinearPastInvariants(),
AbsorbSignBiasIntoMultiThreshold(),
MoveScalarLinearPastInvariants(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
MoveAddPastConv(),
Expand Down
44 changes: 36 additions & 8 deletions src/finn/transformation/streamline/absorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import qonnx.core.data_layout as DataLayout
import warnings
from onnx import helper as oh
# Protobuf onnx graph node type
from onnx import NodeProto # noqa
# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType

# QONNX wrapper of ONNX model graphs
Expand Down Expand Up @@ -261,7 +265,7 @@ def apply(self, model):


class Absorb1BitMulIntoMatMul(Transformation):
"""Absorb bipolar or binary multiplications into the preciding matrix
"""Absorb bipolar or binary multiplications into the preceding matrix
multiply."""

def apply(self, model):
Expand All @@ -270,16 +274,28 @@ def apply(self, model):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "MatMul":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "MatMul" and not model.is_fork_node(n):
matmul_weight_name = n.input[1]
W = model.get_initializer(matmul_weight_name)
Wdt = model.get_tensor_datatype(matmul_weight_name)
assert W is not None, "Initializer for matmul weights is not set."
# Just skip matmuls with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
if is_1bit:
Wnew = A * W
Expand All @@ -298,24 +314,36 @@ def apply(self, model):


class Absorb1BitMulIntoConv(Transformation):
"""Absorb bipolar or binary multiplications into the preciding convolution."""
"""Absorb bipolar or binary multiplications into the preceding convolution."""

def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Conv":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "Conv" and not model.is_fork_node(n):
conv_weight_name = n.input[1]
W = model.get_initializer(conv_weight_name)
Wdt = model.get_tensor_datatype(conv_weight_name)
assert W is not None, "Initializer for conv weights is not set."
# Just skip convs with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
is_scalar = np.prod(A.shape) == 1
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
Expand Down
163 changes: 120 additions & 43 deletions src/finn/transformation/streamline/reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,58 +116,133 @@ def apply(self, model):
return model, graph_modified


# Tests whether a tensor is a scalar, i.e., whether all dimensions are 1
def is_scalar(tensor):
return tensor is not None and all(x == 1 for x in tensor.shape)


# Tests whether a node is a scalar multiplication with a constant scale factor
def is_const_scalar_mul(node, model):
# Only handle existing Mul type nodes
if node is not None and node.op_type == "Mul":
# The constant must be an initializer
# Note: Assumes the constant parameter to always be the second input
scale = model.get_initializer(node.input[1])
# Test for existence of a constant scale factor
return scale is not None and is_scalar(scale)
# Did not match the operator type
return False


# Refactored version of the MoveScalarMulPastMatMul transform capable of
# transforming two-input MatMul, like those being part of the attention operator
class MoveScalarMulPastMatMul(Transformation):
"""Move scalar mul operations past matmul operations. We want to have muls
next to each other such that they can be collapsed into a single mul."""

# Applies the transform to a whole model graph
def apply(self, model):
# Get the model graph out of the model wrapper object
graph = model.graph
node_ind = 0
# Keep track of whether the graph has been modified
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and consumer.op_type == "MatMul"
and not model.is_join_node(consumer)
):
mul_weight_name = n.input[1]
matmul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
W = model.get_initializer(matmul_weight_name)
if (A is None) or (W is None):
warnings.warn("MatMul or Mul params are not constant, skipping")

# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# First pattern matching condition: For the transform to be
# applicable, the node has to be a MatMul operator
if node.op_type == "MatMul":
# Note: When touching the following code, remember to treat both
# branches equivalently!
# TODO: Can this be enforced or at least be made easier by
# extracting common code patterns to a function?

# Get the left hand side and right hand side inputs
# Note: Assumes the ordering of left to right inputs to match
# indices 0 to 1. However, it does not "hurt" if it is
# reversed as both sides are treated equivalently.
lhs = model.find_producer(node.input[0])
rhs = model.find_producer(node.input[1])

# Give precedence to the left hand side input testing for the
# presence of a scalar multiplication
if is_const_scalar_mul(lhs, model):
# Cannot handle fork nodes: We would have to distribute the
# Mul into all branches
# TODO: Maybe reconsider this at some point, there is
# probably nothing preventing this in general, it is just
# more difficult and apparently not necessary right now.
if model.is_fork_node(lhs):
# Softly skip this node
continue
start_name = n.input[0]
middle_name = n.output[0]
end_name = consumer.output[0]
mm_out_shape = model.get_tensor_shape(end_name)
if all(x == 1 for x in A.shape):
# if the mul is scalar, we can simply swap the order of ops
# make and insert new nodes
new_matmul = oh.make_node(
"MatMul",
[start_name, matmul_weight_name],
[middle_name],
name=consumer.name,
)
new_mul = oh.make_node(
"Mul",
[middle_name, mul_weight_name],
[end_name],
name=n.name,
)
graph.node.insert(node_ind, new_matmul)
graph.node.insert(node_ind + 1, new_mul)
model.set_tensor_shape(middle_name, mm_out_shape)
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
graph_modified = True
# Unpack the connection pattern of a scalar mul feeding the
# lhs input of the matmul
# Names of the three input tensors to the mul-matmul complex
a, b, c = lhs.input[0], lhs.input[1], node.input[1]
# Names of the intermediate and the global output
m, o = lhs.output[0], node.output[0] # noqa: Duplicate code
# Rewire the operator connections locally, swapping mul and
# matmul operator order
matmul = oh.make_node("MatMul", [a, c], [m], node.name)
mul = oh.make_node("Mul", [m, b], [o], lhs.name)
# Insert the rewired nodes into the graph
graph.node.insert(index, matmul)
graph.node.insert(index + 1, mul)
# Adapt the shape of the intermediate tensor as it changed
# according to the output shape of the matmul
model.set_tensor_shape(m, model.get_tensor_shape(o))
# Remove the old nodes from the graph
graph.node.remove(lhs)
graph.node.remove(node)
# The graph has been modified, this needs to be reported
# back to the caller
graph_modified = True
# Cannot further modify the node (i.e., the rhs) as the
# index and state of the nodes changed and need to be
# queried again from the graph.node at the start of the next
# iteration.
continue

# Next try whether the right hand side matches the pattern of a
# scalar multiplication
if is_const_scalar_mul(rhs, model):
# Cannot handle fork nodes: We would have to distribute the
# Mul into all branches
# TODO: Maybe reconsider this at some point, there is
# probably nothing preventing this in general, it is just
# more difficult and apparently not necessary right now.
if model.is_fork_node(rhs):
# Softly skip this node
continue
# Unpack the connection pattern of a scalar mul feeding the
# rhs input of the matmul
# Names of the three input tensors to the mul-matmul complex
a, b, c = node.input[0], rhs.input[0], rhs.input[1]
# Names of the intermediate and the global output
m, o = rhs.output[0], node.output[0] # noqa: Duplicate code
# Rewire the operator connections locally, swapping mul and
# matmul operator order
matmul = oh.make_node("MatMul", [a, b], [m], node.name)
mul = oh.make_node("Mul", [m, c], [o], rhs.name)
# Insert the rewired nodes into the graph
graph.node.insert(index, matmul)
graph.node.insert(index + 1, mul)
# Adapt the shape of the intermediate tensor as it changed
# according to the output shape of the matmul
model.set_tensor_shape(m, model.get_tensor_shape(o))
# Remove the old nodes from the graph
graph.node.remove(rhs)
graph.node.remove(node)
# The graph has been modified, this needs to be reported
# back to the caller
graph_modified = True

# Finalize the transformation by inferring shapes again (as these might
# have changed)
model = model.transform(InferShapes())
return (model, graph_modified)
# Return the transformed model and indicate whether the graph actually
# has been transformed
return model, graph_modified


class MoveScalarAddPastMatMul(Transformation):
Expand Down Expand Up @@ -617,6 +692,7 @@ def apply(self, model):
graph_modified = True
else:
continue

# Note: Running shape inference is necessary as shape annotations have
# been deleted above
model = model.transform(InferShapes())
Expand All @@ -634,6 +710,7 @@ class MoveScalarLinearPastInvariants(Transformation):
GlobalAveragePool
"""

# Op-types of currently supported invariants
# Op-types of currently supported invariants
SUPPORTED_INVARIANTS = {
"GlobalAveragePool",
Expand Down
37 changes: 37 additions & 0 deletions tests/transformation/streamline/test_move_scalar_past_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,43 @@ def test_move_scalar_mul_past_matmul():
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]


@pytest.mark.streamline
def test_move_scalar_mul_past_join_matmul():
top_in1 = oh.make_tensor_value_info("top_in1", TensorProto.FLOAT, [1, 2])
top_in2 = oh.make_tensor_value_info("top_in2", TensorProto.FLOAT, [2, 1])
mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1, 1])
mul2_param = oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1, 1])
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 1])
modelproto = qonnx_make_model(
oh.make_graph(
name="test",
inputs=[top_in1, top_in2],
outputs=[top_out],
value_info=[mul1_param, mul2_param],
nodes=[
oh.make_node("Mul", ["top_in1", "mul1_param"], ["middle1"]),
oh.make_node("Mul", ["top_in2", "mul2_param"], ["middle2"]),
oh.make_node("MatMul", ["middle1", "middle2"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
model.set_initializer("mul1_param", np.asarray([[3]], dtype=np.float32))
model.set_initializer("mul2_param", np.asarray([[3]], dtype=np.float32))
new_model = model.transform(MoveScalarMulPastMatMul())
inp_dict = {
"top_in1": np.asarray([[-1.0, 1.0]], dtype=np.float32),
"top_in2": np.asarray([[1.0], [-1.0]], dtype=np.float32),
}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "MatMul"
assert new_model.graph.node[1].op_type == "Mul"
assert new_model.graph.node[2].op_type == "Mul"
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
assert new_model.graph.node[1].output[0] == new_model.graph.node[2].input[0]


@pytest.mark.streamline
def test_move_scalar_add_past_matmul():
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2])
Expand Down

0 comments on commit 1e3085f

Please sign in to comment.