-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
12 changed files
with
1,002 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from refactor_graph.onnx import make_compiler | ||
from onnx import load | ||
import argparse | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Run Refactor compiler, export model serialize." | ||
) | ||
parser.add_argument( | ||
"--model", type=str, required=True, help="Path to the model file file." | ||
) | ||
parser.add_argument("--output", type=str, default="./", help="Path to save the output file.") | ||
args = parser.parse_args() | ||
return ( | ||
args.model, | ||
args.output, | ||
) | ||
|
||
def main(): | ||
model_path, output_path = parse_args() | ||
compiler = make_compiler(load(model_path)) | ||
compiler.serialize(output_path) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
|
||
while getopts ":i:o:" opt; do | ||
case $opt in | ||
i) | ||
model_path=$OPTARG | ||
;; | ||
o) | ||
output_path=$OPTARG | ||
;; | ||
\?) | ||
echo "Invalid option: -$OPTARG" | ||
exit 1 | ||
;; | ||
esac | ||
done | ||
if [ -z "$model_path" ] || [ -z "$output_path" ]; then | ||
echo "Model path and output path are required." | ||
exit 1 | ||
fi | ||
|
||
# 确保输出目录存在 | ||
mkdir -p "$output_path" | ||
|
||
# 运行第一个Python文件并保存输出到文件 | ||
python3 make_serialize.py --model "$model_path" --output "$output_path" | ||
|
||
# 运行第二个Python文件并保存输出到文件 | ||
python3 to_onnx.py --input "$output_path" | ||
|
||
# 输出完成信息 | ||
echo "Models have been run successfully. Outputs are saved in $output_path." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
import mmap | ||
import argparse | ||
from onnx import TensorProto, NodeProto, save_model | ||
from onnx.helper import ( | ||
make_model, | ||
make_node, | ||
make_graph, | ||
make_tensor_value_info, | ||
make_tensor, | ||
make_opsetid, | ||
) | ||
from onnx.checker import check_model | ||
class Topo: | ||
def __init__(self, bytes: bytes): | ||
list = bytes.strip().split(b"<-") | ||
self.inputs = [int(s.strip(b"%")) for s in list[1].split()] | ||
self.outputs = [int(s.strip(b"%")) for s in list[0].split()] | ||
def __str__(self) -> str: | ||
return f"{self.inputs} <- {self.outputs}" | ||
|
||
class Tensor: | ||
def __init__(self, bytes_: bytes): | ||
list = bytes_.split(b"\t") | ||
self.name = str(list[1].strip(), "utf-8") | ||
def map_dt(dt: bytes) -> TensorProto.DataType: | ||
match dt: | ||
case b"F32": | ||
return TensorProto.FLOAT | ||
case b"U8": | ||
return TensorProto.UINT8 | ||
case b"I8": | ||
return TensorProto.INT8 | ||
case b"U16": | ||
return TensorProto.UINT16 | ||
case b"I16": | ||
return TensorProto.INT16 | ||
case b"I32": | ||
return TensorProto.INT32 | ||
case b"I64": | ||
return TensorProto.INT64 | ||
case b"String": | ||
return TensorProto.STRING | ||
case b"Bool": | ||
return TensorProto.BOOL | ||
case b"FP16": | ||
return TensorProto.FLOAT16 | ||
case b"F64": | ||
return TensorProto.DOUBLE | ||
case b"U32": | ||
return TensorProto.UINT32 | ||
case b"U64": | ||
return TensorProto.UINT64 | ||
case b"Complex64": | ||
return TensorProto.COMPLEX64 | ||
case b"Complex128": | ||
return TensorProto.COMPLEX128 | ||
case b"BF16": | ||
return TensorProto.BFLOAT16 | ||
case _: | ||
return TensorProto.UNDEFINED | ||
self.dt = map_dt(list[2].strip()) | ||
layout = list[3].strip() | ||
if layout != b"NCHW" and layout != b"ELSE": | ||
raise ValueError("Unsupported layout") | ||
range = list[4].strip().split() | ||
self.offset = int(range[0], 0) | ||
self.size = int(range[1], 0) | ||
self.shape = [int(s) for s in split_array(list[5])] | ||
def __str__(self) -> str: | ||
return f"{self.name} (dt = {self.dt}) {self.shape} {self.offset}..{self.offset + self.size}" | ||
|
||
class Operator: | ||
def __init__(self, bytes: bytes): | ||
list = bytes.split(b"\t") | ||
self.name = str(list[1].strip(), "utf-8") | ||
list = list[2].split(b"(", 1) | ||
self.type = str(list[0].strip(), "utf-8") | ||
list = list[1].rsplit(b")", 1) | ||
self.meta = list[0].strip() | ||
self.topo = Topo(list[1]) | ||
def __str__(self) -> str: | ||
return f"{self.type}: {self.name}, meta = {self.meta}, topo = {self.topo}" | ||
def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: | ||
if self.type == "BatchNormalization": | ||
return ( | ||
make_node( | ||
self.type, | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
epsilon=float(self.meta.split(b"=")[0]), | ||
), | ||
[], | ||
) | ||
if self.type == "Conv": | ||
meta = [int(x) for x in split_array(self.meta)] | ||
rank = int(len(meta) / 4) | ||
return ( | ||
make_node( | ||
self.type, | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
dilations=meta[0:rank], | ||
strides=meta[rank : 2 * rank], | ||
pads=meta[2 * rank : 4 * rank], | ||
), | ||
[], | ||
) | ||
if self.type == "Relu": | ||
return ( | ||
make_node( | ||
self.type, | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
), | ||
[], | ||
) | ||
if self.type == "MaxPool": | ||
meta = self.meta.split(b",") | ||
ceil_mode = ( | ||
1 if meta[0] == b"true" else (0 if meta[0] == b"false" else None) | ||
) | ||
kernel_shape = [int(x) for x in split_array(meta[1])] | ||
meta = [int(x) for x in split_array(meta[2])] | ||
rank = int(len(meta) / 4) | ||
return ( | ||
make_node( | ||
self.type, | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
ceil_mode=ceil_mode, | ||
kernel_shape=kernel_shape, | ||
dilations=meta[0:rank], | ||
strides=meta[rank : 2 * rank], | ||
pads=meta[2 * rank : 4 * rank], | ||
), | ||
[], | ||
) | ||
if self.type == "Add": | ||
return ( | ||
make_node( | ||
self.type, | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
), | ||
[], | ||
) | ||
if self.type == "GlobalAveragePool": | ||
return ( | ||
make_node( | ||
self.type, | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
), | ||
[], | ||
) | ||
if self.type == "MatMul": | ||
meta = self.meta.split(b",") | ||
alpha = float(meta[0].split(b"=")[0].strip()) | ||
beta = float(meta[1].split(b"=")[0].strip()) | ||
transA = 1 if meta[2].strip() == b"AT" else 0 | ||
transB = 1 if meta[3].strip() == b"BT" else 0 | ||
if alpha != 1 or beta != 0 or transA == 1 or transB == 1: | ||
return ( | ||
make_node( | ||
"Gemm", | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
alpha=alpha, | ||
beta=beta, | ||
transA=transA, | ||
transB=transB, | ||
), | ||
[], | ||
) | ||
else: | ||
return ( | ||
make_node( | ||
self.type, | ||
[tensors[i].name for i in self.topo.inputs], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
), | ||
[], | ||
) | ||
if self.type == "Reshape" or self.type == "Identity": | ||
output = tensors[self.topo.outputs[0]] | ||
shape_name = f"{output.name}_shape" | ||
shape = output.shape | ||
shape = make_tensor(shape_name, TensorProto.INT64, [len(shape)], shape) | ||
return ( | ||
make_node( | ||
"Reshape", | ||
[tensors[self.topo.inputs[0]].name, shape_name], | ||
[tensors[i].name for i in self.topo.outputs], | ||
self.name, | ||
), | ||
[shape], | ||
) | ||
raise ValueError(f"Unsupported operator {self.type}") | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Analysis serialize file.") | ||
parser.add_argument( | ||
"--input", | ||
type=str, | ||
default="./", | ||
help="Path to save the serialize output files.", | ||
) | ||
args = parser.parse_args() | ||
return ( | ||
args.input | ||
) | ||
|
||
def split_array(arr: bytes): | ||
return (x for x in arr.strip().strip(b"[").strip(b"]").split()) | ||
|
||
def main(): | ||
path = parse_args() | ||
info_path = path + "/graph.info" | ||
data_path = path + "/graph.data" | ||
outputfile = path + "/model_refactor.onnx" | ||
with open(info_path, "r") as f: | ||
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as m: | ||
operators = [] | ||
for line in iter(m.readline, b""): | ||
if line == b"\n": | ||
break | ||
operators.append(Operator(line)) | ||
graph = Topo(m.readline().strip().strip(b"graph. ")) | ||
_ = m.readline() | ||
tensors = [Tensor(line) for line in iter(m.readline, b"")] | ||
|
||
with open(data_path, "r") as f: | ||
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as m: | ||
nodes = [] | ||
initializer = [ | ||
make_tensor( | ||
t.name, | ||
t.dt, | ||
t.shape, | ||
vals=m[t.offset : t.offset + t.size], | ||
raw=True, | ||
) | ||
for t in tensors | ||
if t.size != 0 | ||
] | ||
for o in operators: | ||
node, init = o.to_node(tensors) | ||
nodes.append(node) | ||
initializer.extend(init) | ||
graph = make_graph( | ||
nodes, | ||
"graph", | ||
[ | ||
make_tensor_value_info(t.name, t.dt, t.shape) | ||
for t in (tensors[i] for i in graph.inputs) | ||
], | ||
[ | ||
make_tensor_value_info(t.name, t.dt, t.shape) | ||
for t in (tensors[i] for i in graph.outputs) | ||
], | ||
initializer, | ||
) | ||
model = make_model(graph, opset_imports=[make_opsetid( | ||
domain="", version=13)]) | ||
check_model(model) | ||
save_model(model, outputfile) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
145 changes: 145 additions & 0 deletions
145
src/05computation/include/computation/pass/conv_to_matmul.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#ifndef COMPUTATION_CONV_TO_MATMUL_H | ||
#define COMPUTATION_CONV_TO_MATMUL_H | ||
|
||
#include "../graph.h" | ||
#include "computation/operators/conv.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include "computation/operators/reshape.h" | ||
#include "computation/operators/transpose.h" | ||
#include "computation/pass/converter.h" | ||
|
||
namespace refactor::computation { | ||
class ConvToMatmul : public Converter { | ||
|
||
public: | ||
/* | ||
* input weight | ||
* | | | ||
* | | | ||
* transpose transpose | ||
* | | | ||
* | | | ||
* reshape reshape | ||
* \ / | ||
* \ / | ||
* matmul | ||
* | | ||
* reshape | ||
* | | ||
* transpose | ||
* | | ||
* output | ||
*/ | ||
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override { | ||
auto nodesList = g->internal().nodes(); | ||
size_t count = 0; | ||
for (auto opMatch : nodesList) { | ||
if (opMatch->info().op == nullptr) { | ||
continue; | ||
} | ||
size_t optype = opMatch->info().op->opTypeId(); | ||
if (optype != Conv::typeId()) { | ||
continue; | ||
} | ||
auto convOp = dynamic_cast<Conv *>(opMatch->info().op.get()); | ||
auto input = opMatch->inputs()[0]->info().tensor; | ||
auto weight = opMatch->inputs()[1]->info().tensor; | ||
auto shape = weight->shape; | ||
// judge conv is 1x1 convolution | ||
if (shape.size() != 4 || shape[2] != 1 || shape[3] != 1) { | ||
continue; | ||
} | ||
auto attr = convOp->attributes; | ||
auto poolAttrRank = attr.rank(); | ||
auto poolAttrDilation = attr.dilations(); | ||
auto poolAttrStride = attr.strides(); | ||
auto poolAttrPad = attr.pads(); | ||
bool flag = false; | ||
for (auto i : range0_(poolAttrRank)) { | ||
if (poolAttrDilation[i] != 1 || poolAttrStride[i] != 1) { | ||
flag = true; | ||
break; | ||
} | ||
if (poolAttrPad[i] != 0 || poolAttrPad[i + poolAttrRank] != 0) { | ||
flag = true; | ||
break; | ||
} | ||
} | ||
if (flag) { continue; } | ||
// create transpose op | ||
absl::InlinedVector<uint32_t, 4> | ||
perm1 = {0, 2, 3, 1}; | ||
Shape shape1 = {input->shape[0], input->shape[2], input->shape[3], input->shape[1]}; | ||
auto newTransposeOp1 = g->internal().pushNode( | ||
{std::make_unique<Transpose>(perm1), fmt::format("ConvToMatmul_transpose1_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape1), fmt::format("ConvToMatmul_transpose1_{}_out", count)})}); | ||
newTransposeOp1->connect(0, opMatch->inputs()[0]); | ||
absl::InlinedVector<uint32_t, 4> perm2 = {1, 0, 2, 3}; | ||
Shape shape2 = {weight->shape[1], weight->shape[0], weight->shape[2], weight->shape[3]}; | ||
auto newTransposeOp2 = g->internal().pushNode( | ||
{std::make_unique<Transpose>(perm2), fmt::format("ConvToMatmul_transpose2_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(weight->dataType, shape2), fmt::format("ConvToMatmul_transpose2_{}_out", count)})}); | ||
newTransposeOp2->connect(0, opMatch->inputs()[1]); | ||
// create reshape op | ||
Shape shape3 = {input->shape[0] * input->shape[2] * input->shape[3], input->shape[1]}; | ||
Shape shape4 = {weight->shape[1], weight->shape[0]}; | ||
int64_t data1[2] = {input->shape[0] * input->shape[2] * input->shape[3], input->shape[1]}; | ||
int64_t data2[2] = {weight->shape[1], weight->shape[0]}; | ||
auto [data1_, ptr1] = refactor::kernel::Blob::share(sizeof(int64_t) * 2); | ||
auto [data2_, ptr2] = refactor::kernel::Blob::share(sizeof(int64_t) * 2); | ||
ptr1 = &data1[0]; | ||
ptr2 = &data2[0]; | ||
auto newReshapeEdge1 = g->internal().shareEdge({Tensor::share(DataType::I64, {2}, LayoutType::Others, data1_), fmt::format("ConvToMatmul_reshape1_shape_{}", count)}); | ||
auto newReshapeEdge2 = g->internal().shareEdge({Tensor::share(DataType::I64, {2}, LayoutType::Others, data2_), fmt::format("ConvToMatmul_reshape2_shape_{}", count)}); | ||
auto newReshapeOp1 = g->internal().pushNode( | ||
{std::make_unique<Reshape>(), fmt::format("ConvToMatmul_reshape1_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape3), fmt::format("ConvToMatmul_reshape1_{}_out", count)})}); | ||
auto newReshapeOp2 = g->internal().pushNode( | ||
{std::make_unique<Reshape>(), fmt::format("ConvToMatmul_reshape2_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(weight->dataType, shape4), fmt::format("ConvToMatmul_reshape2_{}_out", count)})}); | ||
newReshapeOp1->connect(0, newTransposeOp1->outputs()[0]); | ||
newReshapeOp1->connect(1, newReshapeEdge1); | ||
newReshapeOp2->connect(0, newTransposeOp2->outputs()[0]); | ||
newReshapeOp2->connect(1, newReshapeEdge2); | ||
// create matmul op | ||
Shape shape5 = {input->shape[0] * input->shape[2] * input->shape[3], weight->shape[0]}; | ||
auto newMatMulOp = g->internal().pushNode( | ||
{std::make_unique<MatMul>(1.0, 1.0, false, false), fmt::format("ConvToMatmul_matmul_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape5), fmt::format("ConvToMatmul_matmul_{}_out", count)})}); | ||
newMatMulOp->connect(0, newReshapeOp1->outputs()[0]); | ||
newMatMulOp->connect(1, newReshapeOp2->outputs()[0]); | ||
// create reshape op | ||
Shape shape6 = {input->shape[0], input->shape[2], input->shape[3], weight->shape[0]}; | ||
int64_t data3[4] = {input->shape[0], input->shape[2], input->shape[3], weight->shape[0]}; | ||
auto [data3_, ptr3] = refactor::kernel::Blob::share(sizeof(int64_t) * 4); | ||
ptr3 = &data3[0]; | ||
auto newReshapeEdge3 = g->internal().shareEdge({Tensor::share(DataType::I64, {4}, LayoutType::Others, data3_), fmt::format("ConvToMatmul_reshape3_shape_{}", count)}); | ||
auto newReshapeOp3 = g->internal().pushNode( | ||
{std::make_unique<Reshape>(), fmt::format("ConvToMatmul_reshape3_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape6), fmt::format("ConvToMatmul_reshape3_{}_out", count)})}); | ||
newReshapeOp3->connect(0, newMatMulOp->outputs()[0]); | ||
newReshapeOp3->connect(1, newReshapeEdge3); | ||
// create transpose op | ||
absl::InlinedVector<uint32_t, 4> perm3 = {0, 3, 1, 2}; | ||
Shape shape7 = {input->shape[0], weight->shape[0], input->shape[2], input->shape[3]}; | ||
auto newTransposeOp3 = g->internal().pushNode( | ||
{std::make_unique<Transpose>(perm3), fmt::format("ConvToMatmul_transpose3_{}", count)}, | ||
{g->internal().shareEdge({Tensor::share(input->dataType, shape7), fmt::format("ConvToMatmul_transpose3_{}_out", count)})}); | ||
newTransposeOp3->connect(0, newReshapeOp3->outputs()[0]); | ||
if (opMatch->outputs()[0]->targets().size() == 0) {// global output | ||
g->internal().replaceOutput(opMatch->outputs()[0], newTransposeOp3->outputs()[0]); | ||
} else { | ||
for (auto node : opMatch->outputs()[0]->targets()) { | ||
auto it = std::find(node->inputs().begin(), node->inputs().end(), opMatch->outputs()[0]); | ||
node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], newTransposeOp3->outputs()[0]); | ||
} | ||
} | ||
g->internal().eraseNode(opMatch); | ||
count++; | ||
} | ||
return true; | ||
}; | ||
}; | ||
|
||
}// namespace refactor::computation | ||
#endif// COMPUTATION_CONV_TO_MATMUL_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef COMPUTATION_CONVERTER_H | ||
#define COMPUTATION_CONVERTER_H | ||
|
||
#include "../graph.h" | ||
|
||
namespace refactor::computation { | ||
|
||
class Converter { | ||
public: | ||
Converter() = default; | ||
virtual ~Converter() = default; | ||
virtual bool execute(const std::shared_ptr<GraphMutant> &) const = 0; | ||
static Converter *get(std::string_view key) { | ||
//fmt::println("{}", storage().size()); | ||
if (storage().find(key) != storage().end()) { | ||
return storage().at(key).get(); | ||
} | ||
return nullptr; | ||
}; | ||
static void add(std::shared_ptr<Converter> converter, std::string_view key) { | ||
storage().insert(std::make_pair(key, converter)); | ||
}; | ||
static std::unordered_map<std::string_view, std::shared_ptr<Converter>> &storage() { | ||
static std::unordered_map<std::string_view, std::shared_ptr<Converter>> passStorage; | ||
return passStorage; | ||
} | ||
}; | ||
|
||
template<class T> | ||
class ConverterRegister { | ||
public: | ||
ConverterRegister(const char *claim) { | ||
T *instance = new T; | ||
Converter::add(std::shared_ptr<Converter>(instance), claim); | ||
} | ||
}; | ||
|
||
}// namespace refactor::computation | ||
|
||
#endif// COMPUTATION_CONVERTER_H |
94 changes: 94 additions & 0 deletions
94
src/05computation/include/computation/pass/matmul_transpose.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
#ifndef COMPUTATION_MATMUL_TRANSPOSE_H | ||
#define COMPUTATION_MATMUL_TRANSPOSE_H | ||
|
||
#include "../graph.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include "computation/operators/transpose.h" | ||
#include "computation/pass/converter.h" | ||
|
||
namespace refactor::computation { | ||
class MatMulTransposeFuse : public Converter { | ||
public: | ||
virtual bool execute(const std::shared_ptr<GraphMutant> &g) const override { | ||
auto nodesList = g->internal().nodes(); | ||
for (auto opMatch : nodesList) { | ||
if (opMatch->info().op == nullptr) { | ||
continue; | ||
} | ||
size_t optype = opMatch->info().op->opTypeId(); | ||
if (optype != MatMul::typeId()) { | ||
continue; | ||
} | ||
auto matmulOp = dynamic_cast<MatMul *>(opMatch->info().op.get()); | ||
if (opMatch->predecessors().size() != 0) { | ||
for (size_t i = 0; i < opMatch->inputs().size(); ++i) { | ||
if (auto preOp = opMatch->inputs()[i]->source(); | ||
preOp != nullptr && preOp->info().op->opTypeId() == Transpose::typeId()) { | ||
auto transposeOp = dynamic_cast<Transpose *>(preOp->info().op.get()); | ||
auto axis = transposeOp->perm; | ||
bool flag = false; | ||
if (axis[axis.size() - 1] == axis.size() - 2 && axis[axis.size() - 2] == axis.size() - 1) { | ||
flag = true; | ||
} | ||
for (size_t index = 0; index < axis.size() - 2; ++index) { | ||
if (index == axis[index]) { | ||
continue; | ||
} | ||
flag = false; | ||
break; | ||
} | ||
if (flag) { | ||
if (i == 0) { | ||
matmulOp->transA = !matmulOp->transA; | ||
} else { | ||
matmulOp->transB = !matmulOp->transB; | ||
} | ||
opMatch->reconnect(opMatch->inputs()[i], preOp->inputs()[0]); | ||
g->internal().eraseNode(preOp); | ||
} | ||
} | ||
} | ||
} | ||
if (opMatch->successors().size() == 1) { | ||
if (auto postOp = *(opMatch->outputs()[0]->targets().begin()); | ||
postOp != nullptr && postOp->info().op->opTypeId() == Transpose::typeId()) { | ||
auto transposeOp = dynamic_cast<Transpose *>(postOp->info().op.get()); | ||
auto axis = transposeOp->perm; | ||
bool flag = false; | ||
if (axis[axis.size() - 1] == axis.size() - 2 && axis[axis.size() - 2] == axis.size() - 1) { | ||
flag = true; | ||
} | ||
for (size_t index = 0; index < axis.size() - 2; ++index) { | ||
if (index == axis[index]) { | ||
continue; | ||
} | ||
flag = false; | ||
break; | ||
} | ||
if (flag) { | ||
matmulOp->transA = !matmulOp->transA; | ||
matmulOp->transB = !matmulOp->transB; | ||
auto inputsA = opMatch->inputs()[0]; | ||
auto inputsB = opMatch->inputs()[1]; | ||
opMatch->connect(0, inputsB); | ||
opMatch->connect(1, inputsA); | ||
opMatch->outputs()[0]->info().tensor->shape = postOp->outputs()[0]->info().tensor->shape; | ||
if (postOp->outputs()[0]->targets().size() == 0) {// global output | ||
g->internal().replaceOutput(postOp->outputs()[0], opMatch->outputs()[0]); | ||
} else { | ||
for (auto node : postOp->outputs()[0]->targets()) { | ||
auto it = std::find(node->inputs().begin(), node->inputs().end(), postOp->outputs()[0]); | ||
node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], opMatch->outputs()[0]); | ||
} | ||
} | ||
g->internal().eraseNode(postOp); | ||
} | ||
} | ||
} | ||
} | ||
return true; | ||
}; | ||
}; | ||
|
||
}// namespace refactor::computation | ||
#endif// COMPUTATION_MATMUL_TRANSPOSE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef COMPUTATION_PASS_REGISTER_H | ||
#define COMPUTATION_PASS_REGISTER_H | ||
#include "pass/conv_to_matmul.h" | ||
#include "pass/converter.h" | ||
#include "pass/matmul_transpose.h" | ||
|
||
namespace refactor::computation { | ||
|
||
void register_() { | ||
#define REGISTER(PASS, NAME) static ConverterRegister<PASS> NAME("" #NAME); | ||
REGISTER(MatMulTransposeFuse, MatMulTransposeFuse) | ||
REGISTER(ConvToMatmul, ConvToMatmul) | ||
}; | ||
|
||
|
||
}// namespace refactor::computation | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
src/05computation/test/test_pass/test_cont_to_matmul.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#include "computation/graph.h" | ||
#include "computation/operators/conv.h" | ||
#include "computation/operators/simple_unary.h" | ||
#include <gtest/gtest.h> | ||
|
||
namespace refactor::computation { | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestConvToMatMulGraphBuild1() { | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
int64_t dilations[2] = {1, 1}; | ||
int64_t strides[2] = {1, 1}; | ||
int64_t pads[4] = {0, 0, 0, 0}; | ||
nodes[0] = Node{std::make_unique<Conv>(PoolAttributes(2, &dilations[0], &pads[0], &strides[0])), "conv"}; | ||
nodes[1] = Node{std::make_unique<SimpleUnary>(refactor::kernel::SimpleUnaryType::Relu), "relu"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 5, 5}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 1, 1}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0, 1}, {2}}}, | ||
{1, {{2}, {3}}}, | ||
}, | ||
{0, 1},// global inputs | ||
{3}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input"}}, | ||
{1, {tensor1, "weight"}}, | ||
{2, {tensor2, "conv_output"}}, | ||
{3, {tensor3, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
TEST(Graph, ConvToMatMul1) { | ||
auto graphTopo = TestConvToMatMulGraphBuild1().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
ASSERT_EQ(g_.nodes.size(), 8); | ||
ASSERT_EQ(g_.edges.size(), 13); | ||
} | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestConvToMatMulGraphBuild2() { | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
nodes[0] = Node{std::make_unique<Conv>(PoolAttributes(2, nullptr, nullptr, nullptr)), "conv0"}; | ||
nodes[1] = Node{std::make_unique<SimpleUnary>(refactor::kernel::SimpleUnaryType::Relu), "relu0"}; | ||
nodes[2] = Node{std::make_unique<Conv>(PoolAttributes(2, nullptr, nullptr, nullptr)), "conv1"}; | ||
nodes[3] = Node{std::make_unique<SimpleUnary>(refactor::kernel::SimpleUnaryType::Relu), "relu1"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 5, 5}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 1, 1}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); | ||
auto tensor4 = Tensor::share(DataType::F32, {4, 3, 1, 1}, LayoutType::Others); | ||
auto tensor5 = Tensor::share(DataType::F32, {1, 4, 5, 5}, LayoutType::Others); | ||
auto tensor6 = Tensor::share(DataType::F32, {1, 4, 5, 5}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0, 1}, {2}}}, | ||
{1, {{2}, {3}}}, | ||
{2, {{3, 4}, {5}}}, | ||
{3, {{5}, {6}}}, | ||
}, | ||
{0, 1, 4},// global inputs | ||
{6}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input0"}}, | ||
{1, {tensor1, "weight0"}}, | ||
{2, {tensor2, "conv0_output"}}, | ||
{3, {tensor3, "relu0_output"}}, | ||
{4, {tensor4, "weight1"}}, | ||
{5, {tensor5, "conv1_output"}}, | ||
{6, {tensor6, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
TEST(Graph, ConvToMatMul2) { | ||
auto graphTopo = TestConvToMatMulGraphBuild2().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
ASSERT_EQ(g_.nodes.size(), 16); | ||
ASSERT_EQ(g_.edges.size(), 25); | ||
} | ||
}// namespace refactor::computation |
163 changes: 163 additions & 0 deletions
163
src/05computation/test/test_pass/test_matmul_transpose_fuse.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#include "computation/graph.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include "computation/operators/simple_unary.h" | ||
#include "computation/operators/transpose.h" | ||
#include <gtest/gtest.h> | ||
|
||
namespace refactor::computation { | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestMatMulTransposeGraphBuild1() { | ||
absl::InlinedVector<uint32_t, 4> perm = {0, 1, 3, 2}; | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
nodes[0] = Node{std::make_unique<Transpose>(perm), "transpose0"}; | ||
nodes[1] = Node{std::make_unique<Transpose>(perm), "transpose1"}; | ||
nodes[2] = Node{std::make_unique<MatMul>(1.0, 1.0, false, false), "matmul"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 5}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 3}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {1, 3, 5, 3}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {2, 3, 3, 5}, LayoutType::Others); | ||
auto tensor4 = Tensor::share(DataType::F32, {2, 3, 5, 5}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0}, {2}}}, | ||
{1, {{1}, {3}}}, | ||
{2, {{2, 3}, {4}}}, | ||
}, | ||
{0, 1},// global inputs | ||
{4}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input0"}}, | ||
{1, {tensor1, "input1"}}, | ||
{2, {tensor2, "input0_transpose"}}, | ||
{3, {tensor3, "input1_transpose"}}, | ||
{4, {tensor4, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestMatMulTransposeGraphBuild2() { | ||
absl::InlinedVector<uint32_t, 4> perm = {0, 1, 3, 2}; | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
nodes[0] = Node{std::make_unique<MatMul>(1.0, 1.0, false, false), "matmul"}; | ||
nodes[1] = Node{std::make_unique<Transpose>(perm), "transpose1"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 5}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 4}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {2, 3, 3, 4}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {2, 3, 4, 3}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0, 1}, {2}}}, | ||
{1, {{2}, {3}}}, | ||
}, | ||
{0, 1},// global inputs | ||
{3}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input0"}}, | ||
{1, {tensor1, "input1"}}, | ||
{2, {tensor2, "matmul_output"}}, | ||
{3, {tensor3, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestMatMulTransposeGraphBuild3() { | ||
absl::InlinedVector<uint32_t, 4> perm = {0, 1, 3, 2}; | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
nodes[0] = Node{std::make_unique<Transpose>(perm), "transpose0"}; | ||
nodes[1] = Node{std::make_unique<Transpose>(perm), "transpose1"}; | ||
nodes[2] = Node{std::make_unique<MatMul>(1.0, 1.0, false, false), "matmul"}; | ||
nodes[3] = Node{std::make_unique<Transpose>(perm), "transpose3"}; | ||
nodes[4] = Node{std::make_unique<SimpleUnary>(refactor::kernel::SimpleUnaryType::Relu), "relu"}; | ||
|
||
|
||
auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 4}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 3}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {1, 3, 4, 3}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {2, 3, 3, 5}, LayoutType::Others); | ||
auto tensor4 = Tensor::share(DataType::F32, {2, 3, 4, 5}, LayoutType::Others); | ||
auto tensor5 = Tensor::share(DataType::F32, {2, 3, 5, 4}, LayoutType::Others); | ||
auto tensor6 = Tensor::share(DataType::F32, {2, 3, 5, 4}, LayoutType::Others); | ||
|
||
return { | ||
{ | ||
{0, {{0}, {2}}}, | ||
{1, {{1}, {3}}}, | ||
{2, {{2, 3}, {4}}}, | ||
{3, {{4}, {5}}}, | ||
{4, {{5}, {6}}}, | ||
}, | ||
{0, 1},// global inputs | ||
{6}, // global outputs | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input0"}}, | ||
{1, {tensor1, "input1"}}, | ||
{2, {tensor2, "input0_transpose"}}, | ||
{3, {tensor3, "input1_transpose"}}, | ||
{4, {tensor4, "matmul_output"}}, | ||
{5, {tensor5, "transpose_output"}}, | ||
{6, {tensor6, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
TEST(Graph, MatMulTranspose1) { | ||
auto graphTopo = TestMatMulTransposeGraphBuild1().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
} | ||
|
||
TEST(Graph, MatMulTranspose2) { | ||
auto graphTopo = TestMatMulTransposeGraphBuild2().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
} | ||
|
||
TEST(Graph, MatMulTranspose3) { | ||
auto graphTopo = TestMatMulTransposeGraphBuild3().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
g.optimize(); | ||
auto const &g_ = g.internal().contiguous(); | ||
fmt::println("{}", g_.topology.toString()); | ||
fmt::println("Nodes info :"); | ||
for (size_t i = 0; i < g_.nodes.size(); ++i) { | ||
fmt::println("{}. \"{}\"", i, g_.nodes[i].name); | ||
} | ||
fmt::println("\n Edges info :"); | ||
for (size_t i = 0; i < g_.edges.size(); ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, | ||
vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); | ||
} | ||
} | ||
}// namespace refactor::computation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters