Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamlining of Scaled Dot-Product Attention #12

Merged
merged 20 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fdd89a6
[Streamline] Prefer AbsorbSignBiasIntoMultiThreshold transform
iksnagreb Sep 30, 2023
be33bbc
[Streamline] Refactor MoveScalarMulPastMatMul to handle join-node matmul
iksnagreb Sep 30, 2023
09c1993
Remove misplaced/outdated comment
iksnagreb Sep 30, 2023
9dade0c
[Streamline] Soften initializer tests in Absorb1BitMulIntoMatMul/Conv
iksnagreb Sep 30, 2023
8bae5d7
Address some linting issues
iksnagreb Oct 19, 2023
b22ebe3
[Tests] Add test for MoveScalarMulPastMatMul handling join nodes
iksnagreb Oct 19, 2023
c10fa1d
[Deps] Update qonnx version to include FoldTransposeIntoQuantInit fix
iksnagreb Oct 27, 2023
475a27b
[Streamline] Fix FoldQuantWeights input order and shape annotations
iksnagreb Nov 13, 2023
bd6a8f8
[Streamline] Fix AbsorbAddIntoMultiThreshold assumed input order
iksnagreb Nov 13, 2023
1f7dd4c
[Streamline] Add support for Slice to MoveScalarLinearPastInvariants
iksnagreb Nov 15, 2023
b3e50d7
[Streamline] Absorb1BitMulIntoMatMul/Conv does not handle fork-nodes
iksnagreb Nov 17, 2023
0413368
[Deps] Temporarily switch qonnx to my fork including necessary fixes
iksnagreb Nov 17, 2023
2bf7949
Make quantized activation handlers data layout aware
iksnagreb Nov 20, 2023
8783fd4
[Deps] Update qonnx
iksnagreb Nov 20, 2023
2bf37f1
[Deps] Update qonnx
iksnagreb Dec 13, 2023
a4fc498
[Deps] Update qonnx
iksnagreb Mar 13, 2024
6c56382
Fix some typos
iksnagreb Apr 4, 2024
15a9daa
Merge remote-tracking branch 'xilinx/dev' into feature/attention-stre…
iksnagreb Jan 20, 2025
311ac68
Merge remote-tracking branch 'eki-project/dev' into feature/attention…
iksnagreb Feb 6, 2025
95ed158
Merge branch 'dev' into feature/attention-streamline
iksnagreb Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/finn/transformation/streamline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def apply(self, model):
BatchNormToAffine(),
ConvertSignToThres(),
MoveMulPastMaxPool(),
MoveScalarLinearPastInvariants(),
AbsorbSignBiasIntoMultiThreshold(),
MoveScalarLinearPastInvariants(),
MoveAddPastMul(),
MoveScalarAddPastMatMul(),
MoveAddPastConv(),
Expand Down
44 changes: 36 additions & 8 deletions src/finn/transformation/streamline/absorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import qonnx.core.data_layout as DataLayout
import warnings
from onnx import helper as oh
# Protobuf onnx graph node type
from onnx import NodeProto # noqa
# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType

# QONNX wrapper of ONNX model graphs
Expand Down Expand Up @@ -261,7 +265,7 @@ def apply(self, model):


class Absorb1BitMulIntoMatMul(Transformation):
"""Absorb bipolar or binary multiplications into the preciding matrix
"""Absorb bipolar or binary multiplications into the preceding matrix
multiply."""

def apply(self, model):
Expand All @@ -270,16 +274,28 @@ def apply(self, model):
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "MatMul":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "MatMul" and not model.is_fork_node(n):
matmul_weight_name = n.input[1]
W = model.get_initializer(matmul_weight_name)
Wdt = model.get_tensor_datatype(matmul_weight_name)
assert W is not None, "Initializer for matmul weights is not set."
# Just skip matmuls with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
if is_1bit:
Wnew = A * W
Expand All @@ -298,24 +314,36 @@ def apply(self, model):


class Absorb1BitMulIntoConv(Transformation):
"""Absorb bipolar or binary multiplications into the preciding convolution."""
"""Absorb bipolar or binary multiplications into the preceding convolution."""

def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Conv":
# Note: Join-node test is implicitly covered by testing for the
# initializer below
# Note: This cannot handle fork-nodes, as only the first consumer is
# considered below.
# TODO: Fork-nodes could be handled if the muls are the same in all
# branches, but this is not checked nor rewired at all right now.
if n.op_type == "Conv" and not model.is_fork_node(n):
conv_weight_name = n.input[1]
W = model.get_initializer(conv_weight_name)
Wdt = model.get_tensor_datatype(conv_weight_name)
assert W is not None, "Initializer for conv weights is not set."
# Just skip convs with non-existing weight initializers
if W is None:
continue
consumer = model.find_consumer(n.output[0])
# Note: Join-node test is implicitly covered by testing for the
# initializer below
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None, "Initializer for mul weights is not set."
# Just skip muls with non-existing scale initializers
if A is None:
continue
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
is_scalar = np.prod(A.shape) == 1
actual_ndims = len(tuple(filter(lambda x: x > 1, A.shape)))
Expand Down
163 changes: 120 additions & 43 deletions src/finn/transformation/streamline/reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,58 +116,133 @@ def apply(self, model):
return model, graph_modified


# Tests whether a tensor is a scalar, i.e., whether all dimensions are 1
def is_scalar(tensor):
return tensor is not None and all(x == 1 for x in tensor.shape)


# Tests whether a node is a scalar multiplication with a constant scale factor
def is_const_scalar_mul(node, model):
# Only handle existing Mul type nodes
if node is not None and node.op_type == "Mul":
# The constant must be an initializer
# Note: Assumes the constant parameter to always be the second input
scale = model.get_initializer(node.input[1])
# Test for existence of a constant scale factor
return scale is not None and is_scalar(scale)
# Did not match the operator type
return False


# Refactored version of the MoveScalarMulPastMatMul transform capable of
# transforming two-input MatMul, like those being part of the attention operator
class MoveScalarMulPastMatMul(Transformation):
"""Move scalar mul operations past matmul operations. We want to have muls
next to each other such that they can be collapsed into a single mul."""

# Applies the transform to a whole model graph
def apply(self, model):
# Get the model graph out of the model wrapper object
graph = model.graph
node_ind = 0
# Keep track of whether the graph has been modified
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "Mul" and not model.is_fork_node(n) and not model.is_join_node(n):
consumer = model.find_consumer(n.output[0])
if (
consumer is not None
and consumer.op_type == "MatMul"
and not model.is_join_node(consumer)
):
mul_weight_name = n.input[1]
matmul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
W = model.get_initializer(matmul_weight_name)
if (A is None) or (W is None):
warnings.warn("MatMul or Mul params are not constant, skipping")

# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# First pattern matching condition: For the transform to be
# applicable, the node has to be a MatMul operator
if node.op_type == "MatMul":
# Note: When touching the following code, remember to treat both
# branches equivalently!
# TODO: Can this be enforced or at least be made easier by
# extracting common code patterns to a function?

# Get the left hand side and right hand side inputs
# Note: Assumes the ordering of left to right inputs to match
# indices 0 to 1. However, it does not "hurt" if it is
# reversed as both sides are treated equivalently.
lhs = model.find_producer(node.input[0])
rhs = model.find_producer(node.input[1])

# Give precedence to the left hand side input testing for the
# presence of a scalar multiplication
if is_const_scalar_mul(lhs, model):
# Cannot handle fork nodes: We would have to distribute the
# Mul into all branches
# TODO: Maybe reconsider this at some point, there is
# probably nothing preventing this in general, it is just
# more difficult and apparently not necessary right now.
if model.is_fork_node(lhs):
# Softly skip this node
continue
start_name = n.input[0]
middle_name = n.output[0]
end_name = consumer.output[0]
mm_out_shape = model.get_tensor_shape(end_name)
if all(x == 1 for x in A.shape):
# if the mul is scalar, we can simply swap the order of ops
# make and insert new nodes
new_matmul = oh.make_node(
"MatMul",
[start_name, matmul_weight_name],
[middle_name],
name=consumer.name,
)
new_mul = oh.make_node(
"Mul",
[middle_name, mul_weight_name],
[end_name],
name=n.name,
)
graph.node.insert(node_ind, new_matmul)
graph.node.insert(node_ind + 1, new_mul)
model.set_tensor_shape(middle_name, mm_out_shape)
# remove old nodes
graph.node.remove(n)
graph.node.remove(consumer)
graph_modified = True
# Unpack the connection pattern of a scalar mul feeding the
# lhs input of the matmul
# Names of the three input tensors to the mul-matmul complex
a, b, c = lhs.input[0], lhs.input[1], node.input[1]
# Names of the intermediate and the global output
m, o = lhs.output[0], node.output[0] # noqa: Duplicate code
# Rewire the operator connections locally, swapping mul and
# matmul operator order
matmul = oh.make_node("MatMul", [a, c], [m], node.name)
mul = oh.make_node("Mul", [m, b], [o], lhs.name)
# Insert the rewired nodes into the graph
graph.node.insert(index, matmul)
graph.node.insert(index + 1, mul)
# Adapt the shape of the intermediate tensor as it changed
# according to the output shape of the matmul
model.set_tensor_shape(m, model.get_tensor_shape(o))
# Remove the old nodes from the graph
graph.node.remove(lhs)
graph.node.remove(node)
# The graph has been modified, this needs to be reported
# back to the caller
graph_modified = True
# Cannot further modify the node (i.e., the rhs) as the
# index and state of the nodes changed and need to be
# queried again from the graph.node at the start of the next
# iteration.
continue

# Next try whether the right hand side matches the pattern of a
# scalar multiplication
if is_const_scalar_mul(rhs, model):
# Cannot handle fork nodes: We would have to distribute the
# Mul into all branches
# TODO: Maybe reconsider this at some point, there is
# probably nothing preventing this in general, it is just
# more difficult and apparently not necessary right now.
if model.is_fork_node(rhs):
# Softly skip this node
continue
# Unpack the connection pattern of a scalar mul feeding the
# rhs input of the matmul
# Names of the three input tensors to the mul-matmul complex
a, b, c = node.input[0], rhs.input[0], rhs.input[1]
# Names of the intermediate and the global output
m, o = rhs.output[0], node.output[0] # noqa: Duplicate code
# Rewire the operator connections locally, swapping mul and
# matmul operator order
matmul = oh.make_node("MatMul", [a, b], [m], node.name)
mul = oh.make_node("Mul", [m, c], [o], rhs.name)
# Insert the rewired nodes into the graph
graph.node.insert(index, matmul)
graph.node.insert(index + 1, mul)
# Adapt the shape of the intermediate tensor as it changed
# according to the output shape of the matmul
model.set_tensor_shape(m, model.get_tensor_shape(o))
# Remove the old nodes from the graph
graph.node.remove(rhs)
graph.node.remove(node)
# The graph has been modified, this needs to be reported
# back to the caller
graph_modified = True

# Finalize the transformation by inferring shapes again (as these might
# have changed)
model = model.transform(InferShapes())
return (model, graph_modified)
# Return the transformed model and indicate whether the graph actually
# has been transformed
return model, graph_modified


class MoveScalarAddPastMatMul(Transformation):
Expand Down Expand Up @@ -617,6 +692,7 @@ def apply(self, model):
graph_modified = True
else:
continue

# Note: Running shape inference is necessary as shape annotations have
# been deleted above
model = model.transform(InferShapes())
Expand All @@ -634,6 +710,7 @@ class MoveScalarLinearPastInvariants(Transformation):
GlobalAveragePool
"""

# Op-types of currently supported invariants
# Op-types of currently supported invariants
SUPPORTED_INVARIANTS = {
"GlobalAveragePool",
Expand Down
37 changes: 37 additions & 0 deletions tests/transformation/streamline/test_move_scalar_past_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,43 @@ def test_move_scalar_mul_past_matmul():
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]


@pytest.mark.streamline
def test_move_scalar_mul_past_join_matmul():
top_in1 = oh.make_tensor_value_info("top_in1", TensorProto.FLOAT, [1, 2])
top_in2 = oh.make_tensor_value_info("top_in2", TensorProto.FLOAT, [2, 1])
mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [1, 1])
mul2_param = oh.make_tensor_value_info("mul2_param", TensorProto.FLOAT, [1, 1])
top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 1])
modelproto = qonnx_make_model(
oh.make_graph(
name="test",
inputs=[top_in1, top_in2],
outputs=[top_out],
value_info=[mul1_param, mul2_param],
nodes=[
oh.make_node("Mul", ["top_in1", "mul1_param"], ["middle1"]),
oh.make_node("Mul", ["top_in2", "mul2_param"], ["middle2"]),
oh.make_node("MatMul", ["middle1", "middle2"], ["top_out"]),
],
)
)
model = ModelWrapper(modelproto)
model = model.transform(InferShapes())
model.set_initializer("mul1_param", np.asarray([[3]], dtype=np.float32))
model.set_initializer("mul2_param", np.asarray([[3]], dtype=np.float32))
new_model = model.transform(MoveScalarMulPastMatMul())
inp_dict = {
"top_in1": np.asarray([[-1.0, 1.0]], dtype=np.float32),
"top_in2": np.asarray([[1.0], [-1.0]], dtype=np.float32),
}
assert ox.compare_execution(model, new_model, inp_dict)
assert new_model.graph.node[0].op_type == "MatMul"
assert new_model.graph.node[1].op_type == "Mul"
assert new_model.graph.node[2].op_type == "Mul"
assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0]
assert new_model.graph.node[1].output[0] == new_model.graph.node[2].input[0]


@pytest.mark.streamline
def test_move_scalar_add_past_matmul():
top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2])
Expand Down
Loading