From 0032b8ed4602c3d6ec8bbf4c1f0d0fceab10ce8c Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Mon, 13 Nov 2023 16:49:22 +0800 Subject: [PATCH] feat: support auto optimize demo --- .../include/graph_topo/linked_graph.hpp | 68 ++++++- src/02mem_manager/include/mem_manager/blob.hh | 1 + src/02mem_manager/src/blob.cc | 1 + src/05computation/include/computation/graph.h | 1 + .../include/computation/graph_mutant.h | 38 ++++ .../include/computation/mutant_generator.h | 44 +++++ .../include/computation/operator.h | 16 ++ .../include/computation/operators/concat.h | 11 ++ .../include/computation/operators/mat_mul.h | 12 ++ src/05computation/src/graph.cc | 1 + src/05computation/src/graph_mutant.cc | 28 +++ src/05computation/src/mutant_generator.cc | 185 ++++++++++++++++++ src/05computation/src/operators/concat.cc | 34 ++++ src/05computation/src/operators/mat_mul.cc | 35 ++++ .../test/test_mutant_generator.cpp | 67 +++++++ 15 files changed, 541 insertions(+), 1 deletion(-) create mode 100644 src/05computation/include/computation/graph_mutant.h create mode 100644 src/05computation/include/computation/mutant_generator.h create mode 100644 src/05computation/src/graph_mutant.cc create mode 100644 src/05computation/src/mutant_generator.cc create mode 100644 src/05computation/test/test_mutant_generator.cpp diff --git a/src/01graph_topo/include/graph_topo/linked_graph.hpp b/src/01graph_topo/include/graph_topo/linked_graph.hpp index 1978731b..db8383f4 100644 --- a/src/01graph_topo/include/graph_topo/linked_graph.hpp +++ b/src/01graph_topo/include/graph_topo/linked_graph.hpp @@ -1,8 +1,8 @@ #ifndef GRAPH_TOPO_LINKED_GRAPH_H #define GRAPH_TOPO_LINKED_GRAPH_H -#include "container.h" #include "common.h" +#include "container.h" #include #include #include @@ -30,6 +30,7 @@ namespace refactor::graph_topo { static auto shareEdge(TE) -> Rc; std::string toString() const; + std::string toString(std::string func(TN const &)) const; Graph intoGraph() const; std::vector> const &nodes() const; std::vector> const &inputs() const; @@ -43,6 +44,7 @@ namespace refactor::graph_topo { void eraseNode(Rc); size_t cleanup(bool useless(TE const &) = nullptr); bool sort(); + LinkedGraph clone(TN cloneNode(TN const &), TE cloneEdge(TE const &)) const; }; template @@ -127,6 +129,34 @@ namespace refactor::graph_topo { return ss.str(); } + LINKED_GRAPH_FN toString(std::string func(TN const &)) const->std::string { + std::unordered_map indices; + std::stringstream ss; + auto f = [&indices, &ss](Rc const &e) { + if (e) { + auto [it, ok] = indices.try_emplace(e.get(), indices.size()); + ss << it->second << ' '; + } else { + ss << "? "; + } + }; + ss << "*. -> ( "; + for (auto const &e : _inputs) { f(e); } + ss << ')' << std::endl; + for (auto i : range0_(_nodes.size())) { + auto n = _nodes[i]; + ss << i << ". ( "; + for (auto const &e : n->_inputs) { f(e); } + ss << ") -> ( "; + for (auto const &e : n->_outputs) { f(e); } + ss << ')' << func(n->_info) << std::endl; + } + ss << "*. <- ( "; + for (auto const &e : _outputs) { f(e); } + ss << ')' << std::endl; + return ss.str(); + } + LINKED_GRAPH_FN nodes() const->std::vector> const & { return _nodes; } @@ -236,6 +266,42 @@ namespace refactor::graph_topo { return true; } + LINKED_GRAPH_FN clone( + TN cloneNode(TN const &), + TE cloneEdge(TE const &)) + const->LinkedGraph { + LinkedGraph ans; + ans._inputs.reserve(_inputs.size()); + ans._nodes.reserve(_nodes.size()); + ans._outputs.reserve(_outputs.size()); + std::unordered_map> edges; + auto mapEdge = [&](Rc const &e) { + return edges.try_emplace(e.get(), shareEdge(cloneEdge(e->_info))).first->second; + }; + for (auto const &e : _inputs) { + ans._inputs.emplace_back(mapEdge(e)); + } + for (auto const &n : _nodes) { + std::vector> outputs; + outputs.reserve(n->_outputs.size()); + for (auto const &e : n->_outputs) { + outputs.emplace_back(mapEdge(e)); + } + auto n_ = ans.pushNode(cloneNode(n->_info), std::move(outputs)); + for (auto i : range0_(n->_inputs.size())) { + if (auto it = edges.find(n->_inputs[i].get()); it != edges.end()) { + n_->connect(i, it->second); + } else { + n_->connect(i, mapEdge(n->_inputs[i])); + } + } + } + for (auto const &e : _outputs) { + ans._outputs.emplace_back(std::move(edges.at(e.get()))); + } + return ans; + } + LINKED_GRAPH_FN Node::share(TN info, std::vector> outputs) ->Rc { auto ans = Rc(new Node(std::move(info), std::move(outputs))); diff --git a/src/02mem_manager/include/mem_manager/blob.hh b/src/02mem_manager/include/mem_manager/blob.hh index 3108ae91..7e29a3c9 100644 --- a/src/02mem_manager/include/mem_manager/blob.hh +++ b/src/02mem_manager/include/mem_manager/blob.hh @@ -19,6 +19,7 @@ namespace refactor::mem_manager { static std::pair, void *> share(size_t); operator void const *() const noexcept; + operator void *() noexcept; template T const *get() const noexcept { return reinterpret_cast(_ptr); } diff --git a/src/02mem_manager/src/blob.cc b/src/02mem_manager/src/blob.cc index 2435a82e..b307a4ed 100644 --- a/src/02mem_manager/src/blob.cc +++ b/src/02mem_manager/src/blob.cc @@ -14,5 +14,6 @@ namespace refactor::mem_manager { return {std::move(blob), ptr}; } Blob::operator void const *() const noexcept { return _ptr; } + Blob::operator void *() noexcept { return _ptr; } }// namespace refactor::mem_manager diff --git a/src/05computation/include/computation/graph.h b/src/05computation/include/computation/graph.h index 67dbc9af..5d428c62 100644 --- a/src/05computation/include/computation/graph.h +++ b/src/05computation/include/computation/graph.h @@ -30,6 +30,7 @@ namespace refactor::computation { kernel::Graph lower(Target) const; auto internal() const -> decltype(_internal) const &; + auto internal() -> decltype(_internal) &; }; }// namespace refactor::computation diff --git a/src/05computation/include/computation/graph_mutant.h b/src/05computation/include/computation/graph_mutant.h new file mode 100644 index 00000000..d9aac26f --- /dev/null +++ b/src/05computation/include/computation/graph_mutant.h @@ -0,0 +1,38 @@ +#ifndef COMPUTATION_GRAPH_MUTANT_H +#define COMPUTATION_GRAPH_MUTANT_H + +#include "graph_topo.h" +#include "kernel/graph.h" +#include "operator.h" + +namespace refactor::computation { + using kernel::Shape; + using kernel::Tensor; + + struct Node { + Arc op; + std::string name; + }; + + struct Edge { + Arc tensor; + std::string name; + }; + + class GraphMutant { + graph_topo::LinkedGraph _internal; + + public: + explicit GraphMutant(graph_topo::Graph) noexcept; + GraphMutant(graph_topo::GraphTopo, std::vector, std::vector) noexcept; + GraphMutant(graph_topo::LinkedGraph) noexcept; + + GraphMutant clone() noexcept; + + auto internal() const -> decltype(_internal) const &; + auto internal() -> decltype(_internal) &; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_GRAPH_MUTANT_H diff --git a/src/05computation/include/computation/mutant_generator.h b/src/05computation/include/computation/mutant_generator.h new file mode 100644 index 00000000..06e8212c --- /dev/null +++ b/src/05computation/include/computation/mutant_generator.h @@ -0,0 +1,44 @@ +#ifndef COMPUTATION_MUTANT_GENERATOR_H +#define COMPUTATION_MUTANT_GENERATOR_H + +#include "graph_mutant.h" +#include "operator.h" + +namespace refactor::computation { + + using OpVec = std::vector>; + using TensorVec = std::vector::Edge>>; + + inline uint64_t hashAppend(uint64_t a, uint64_t b) { + return (a * 10000019 + b * 10000079) % 2147483647; + } + + template inline uint64_t hashVector(const std::vector &vec) { + uint64_t ret = 0; + for (auto v : vec) + ret = hashAppend(ret, v); + return ret; + } + + class MutantGenerator { + float equalThreshold; + size_t maxDepth; + size_t numValidTensors = 0; + OpVec opList; + std::vector opStorage; + OpVec opFinger; + TensorVec validTensors; + std::set opHashMaps; + + public: + void init(float, size_t, OpVec) noexcept; + void run(GraphMutant const &, std::vector &) noexcept; + void dfs(size_t, GraphMutant const &, GraphMutant &, std::vector &) noexcept; + bool is_mutant(GraphMutant &, GraphMutant const &) noexcept; + bool approx_equal(const Tensor &, const Tensor &) const noexcept; + bool have_same_op(Arc const &, size_t, size_t) noexcept; + void delete_hash_op(Arc const &, size_t, size_t) noexcept; + }; +}// namespace refactor::computation + +#endif// COMPUTATION_MUTANT_GENERATOR_H \ No newline at end of file diff --git a/src/05computation/include/computation/operator.h b/src/05computation/include/computation/operator.h index 02a14fbe..99ab5edf 100644 --- a/src/05computation/include/computation/operator.h +++ b/src/05computation/include/computation/operator.h @@ -6,7 +6,9 @@ namespace refactor::computation { using kernel::LayoutType; + using kernel::Shape; using kernel::Target; + using kernel::Tensor; class Operator { public: @@ -42,6 +44,20 @@ namespace refactor::computation { using OpBox = std::unique_ptr; + struct MyOperator { + size_t numInputs = 2; + Arc base; + + MyOperator() : numInputs(2) {} + MyOperator(size_t num) : numInputs(num) {} + //MyOperator(const MyOperator &other) : numInputs(other.numInputs) {} + //MyOperator(MyOperator &&other) : numInputs(other.numInputs) {} + // virtual std::unique_ptr create() const = 0; + virtual std::unique_ptr clone() const = 0; + virtual bool compute(Tensor const &, Tensor const &, Tensor &) const = 0; + virtual Shape verify(Tensor const &, Tensor const &) const = 0; + }; + }// namespace refactor::computation #endif// COMPUTATION_OPERATOR_H diff --git a/src/05computation/include/computation/operators/concat.h b/src/05computation/include/computation/operators/concat.h index 1ab1602d..ff9c5511 100644 --- a/src/05computation/include/computation/operators/concat.h +++ b/src/05computation/include/computation/operators/concat.h @@ -15,6 +15,17 @@ namespace refactor::computation { kernel::CollectorBox candidateKernels(Target) const noexcept final; }; + using refactor::kernel::Tensor; + struct ConcatBox final : public MyOperator { + // Arc base; + + ConcatBox() noexcept : MyOperator() { + base = std::make_shared(1, 2); + } + std::unique_ptr clone() const final; + bool compute(Tensor const &, Tensor const &, Tensor &) const noexcept final; + Shape verify(Tensor const &, Tensor const &) const noexcept final; + }; }// namespace refactor::computation #endif// COMPUTATION_CONCAT_H diff --git a/src/05computation/include/computation/operators/mat_mul.h b/src/05computation/include/computation/operators/mat_mul.h index 6d2efbad..1f6aec9c 100644 --- a/src/05computation/include/computation/operators/mat_mul.h +++ b/src/05computation/include/computation/operators/mat_mul.h @@ -22,6 +22,18 @@ namespace refactor::computation { kernel::CollectorBox candidateKernels(Target) const noexcept final; }; + using refactor::kernel::Tensor; + struct MatMulBox final : public MyOperator { + // Arc base; + + MatMulBox() noexcept : MyOperator() { + base = std::make_shared(1.0, 1.0, false, false); + } + std::unique_ptr clone() const final; + bool compute(Tensor const &, Tensor const &, Tensor &) const noexcept final; + Shape verify(Tensor const &, Tensor const &) const noexcept final; + }; + }// namespace refactor::computation #endif// COMPUTATION_MAT_MUL_H diff --git a/src/05computation/src/graph.cc b/src/05computation/src/graph.cc index d43d49c3..1b394f99 100644 --- a/src/05computation/src/graph.cc +++ b/src/05computation/src/graph.cc @@ -67,5 +67,6 @@ namespace refactor::computation { } auto Graph::internal() const -> decltype(_internal) const & { return _internal; } + auto Graph::internal() -> decltype(_internal) & { return _internal; } }// namespace refactor::computation diff --git a/src/05computation/src/graph_mutant.cc b/src/05computation/src/graph_mutant.cc new file mode 100644 index 00000000..e01d7857 --- /dev/null +++ b/src/05computation/src/graph_mutant.cc @@ -0,0 +1,28 @@ +#include "computation/graph_mutant.h" + +namespace refactor::computation { + + GraphMutant::GraphMutant(graph_topo::Graph internal) noexcept + : _internal(std::move(internal)) {} + GraphMutant::GraphMutant(graph_topo::GraphTopo topology, + std::vector nodes, + std::vector edges) noexcept + : GraphMutant(graph_topo::Graph{ + std::move(topology), + std::move(nodes), + std::move(edges), + }) {} + + GraphMutant::GraphMutant(graph_topo::LinkedGraph internal) noexcept + : _internal(std::move(internal)) {} + GraphMutant GraphMutant::clone() noexcept { + auto internal = this->_internal.clone([](Node const &o) -> Node { return o; }, + [](Edge const &e) -> Edge { return e; }); + GraphMutant newGraph(std::move(internal)); + return newGraph; + } + + auto GraphMutant::internal() const -> decltype(_internal) const & { return _internal; } + auto GraphMutant::internal() -> decltype(_internal) & { return _internal; } + +}// namespace refactor::computation diff --git a/src/05computation/src/mutant_generator.cc b/src/05computation/src/mutant_generator.cc new file mode 100644 index 00000000..fb1f639c --- /dev/null +++ b/src/05computation/src/mutant_generator.cc @@ -0,0 +1,185 @@ +#include "computation/mutant_generator.h" +#define MAX_SIZE 1024x1024 + +namespace refactor::computation { + using K = MutantGenerator; + + void K::init(float equalThreshold_, size_t maxDepth_, OpVec opList_) noexcept { + equalThreshold = equalThreshold_; + maxDepth = maxDepth_; + opList = opList_; + opFinger.clear(); + opStorage.clear(); + validTensors.clear(); + opHashMaps.clear(); + for (size_t i = 0; i < maxDepth; ++i) { + opStorage.push_back(opList); + } + } + + void K::run(GraphMutant const &inGraph, std::vector &outGraphs) noexcept { + using namespace refactor::graph_topo; + + // init global inputs + std::unordered_map edges; + auto edgeIndex = std::vector{}; + auto inputs = inGraph.internal().inputs(); + auto outputs = inGraph.internal().outputs(); + ASSERT(outputs.size() == 1, "Do not support more than one output."); + numValidTensors = inputs.size(); + for (size_t i = 0; i < numValidTensors; ++i) { + edgeIndex.emplace_back(i); + edges.insert({i, inputs[i]->info()}); + } + // init graph + Builder + builder = {{}, edgeIndex, {}, {}, edges}; + GraphMutant curGraph(std::move(builder.build())); + for (size_t i = 0; i < numValidTensors; ++i) { + validTensors.emplace_back(curGraph.internal().inputs()[i]); + } + dfs(0, inGraph, curGraph, outGraphs); + } + + void K::dfs(size_t depth, GraphMutant const &inGraph, GraphMutant &curGraph, std::vector &outGraphs) noexcept { + if (is_mutant(curGraph, inGraph)) { + //存在非全局输出的张量无后继结点,则此图为冗余图 + int count = 0; + for (size_t i = 0; i < numValidTensors; ++i) { + if (validTensors[i]->targets().size() == 0) { + count++; + } + } + if (count > 1) { + curGraph.internal().cleanup(); + return; + } + auto g = curGraph.clone(); + fmt::println("=======zyz======ok======"); + fmt::println("{}", curGraph.internal().toString([](Node const &o) -> std::string { return std::string(o.op->base->name()); })); + for (size_t i = 0; i < numValidTensors; ++i) { + fmt::println("{}. \"{}\" Shape is {}", i, validTensors[i]->info().name, + vec2str(validTensors[i]->info().tensor->shape)); + } + outGraphs.emplace_back(std::move(g)); + curGraph.internal().setOutputs({}); + return; + } + if (depth >= maxDepth) { + return; + } + //auto g_ = curGraph.internal(); + for (size_t index = 0; index < opStorage[depth].size(); ++index) { + auto op = opStorage[depth][index]; + if (op->numInputs == 2) { + for (size_t i = 0; i < numValidTensors; ++i) { + for (size_t j = 0; j < numValidTensors; ++j) { + if (i == j) { + continue; + } + auto x = validTensors[i]->info().tensor; + auto y = validTensors[j]->info().tensor; + auto ans = op->verify(*x, *y); + if (ans.size() == 0) { + continue; + } + //fmt::println("{},{}, {}, {}", i, j, reinterpret_cast(x.get()), reinterpret_cast(y.get())); + auto out = Tensor::share(x->dataType, ans, LayoutType::Others); + out->malloc(); + if (!op->compute(*x, *y, *out) || have_same_op(op, i, j)) { + out->free(); + continue; + } + numValidTensors++; + opFinger.push_back(op); + auto name = fmt::format("{}", depth); + auto newEdge = curGraph.internal().shareEdge({out, "tensor_" + name}); + auto newNode = curGraph.internal().pushNode({op, "op_" + name}, + {newEdge}); + newNode->connect(0, validTensors[i]); + newNode->connect(1, validTensors[j]); + validTensors.push_back(newEdge); + //fmt::println("{}", curGraph.internal().toString([](Node const &o) -> std::string { return std::string(o.op->name()); })); + //fmt::println("{}", reinterpret_cast(validTensors[j]->info().tensor.get())); + dfs(depth + 1, inGraph, curGraph, outGraphs); + curGraph.internal().eraseNode(newNode); + validTensors.pop_back(); + opFinger.pop_back(); + delete_hash_op(op, i, j); + numValidTensors--; + } + } + } + } + } + + bool K::is_mutant(GraphMutant &curGraph, GraphMutant const &inGraph) noexcept { + // fmt::println("=======================output graph ================="); + // fmt::println("{}", curGraph.internal().toString([](Node const &o) -> std::string { return std::string(o.op->base->name()); })); + // fmt::println("Edges info :"); + // for (size_t i = 0; i < numValidTensors; ++i) { + // fmt::println("{}. \"{}\" Shape is {}", i, validTensors[i]->info().name, + // vec2str(validTensors[i]->info().tensor->shape)); + // } + auto inputs = inGraph.internal().inputs(); + auto outputs = inGraph.internal().outputs(); + std::vector::Edge>> outEdges; + for (auto output : outputs) { + int found = -1; + auto &tensor = *output->info().tensor; + for (size_t i = inputs.size(); i < validTensors.size(); ++i) { + if (approx_equal(tensor, *(validTensors[i]->info().tensor))) { + found = i; + break; + } + } + if (found == -1) { + // fmt::println("!!!!!!!compare false "); + return false; + } + outEdges.emplace_back(validTensors[found]); + } + curGraph.internal().setOutputs(outEdges); + // fmt::println("=======================compare true ================="); + return true; + } + + bool K::approx_equal(const Tensor &a, const Tensor &b) const noexcept { + if (a.shape != b.shape) { + return false; + } + size_t equal = 0, total = 0; + auto dataA = a.data->get(); + auto dataB = b.data->get(); + for (size_t i = 0; i < a.elementsSize(); ++i) { + if (dataA[i] == dataB[i]) { + equal++; + } + total++; + } + if (float(equal) / total >= equalThreshold) { + return true; + } + return false; + } + + bool K::have_same_op(Arc const &op, size_t a, size_t b) noexcept { + //fmt::println("{}", reinterpret_cast(op->base.get())); + std::vector hashInfo = {op->base->opTypeId(), a, b}; + auto res = hashVector(hashInfo); + if (opHashMaps.find(res) != opHashMaps.end()) { + return true; + } + opHashMaps.insert(std::move(res)); + return false; + } + + void K::delete_hash_op(Arc const &op, size_t a, size_t b) noexcept { + std::vector hashInfo = {op->base->opTypeId(), a, b}; + auto res = hashVector(hashInfo); + auto it = opHashMaps.find(res); + if (auto it = opHashMaps.find(res); it != opHashMaps.end()) { + opHashMaps.erase(it); + } + } +}// namespace refactor::computation \ No newline at end of file diff --git a/src/05computation/src/operators/concat.cc b/src/05computation/src/operators/concat.cc index d7ccbe2e..214aca8b 100644 --- a/src/05computation/src/operators/concat.cc +++ b/src/05computation/src/operators/concat.cc @@ -14,4 +14,38 @@ namespace refactor::computation { using Collector_ = kernel::ConcatCollector; return std::make_unique(target, axis); } + + Shape ConcatBox::verify(Tensor const &a, Tensor const &b) const noexcept { + Shape ans = {}; + if (a.rank() != 2 || b.rank() != 2) { + return ans; + } + if (a.shape[0] != b.shape[0]) { + return ans; + } + if (a.dataType != b.dataType) { + return ans; + } + ans = {a.shape[0], a.shape[1] + b.shape[1]}; + return ans; + } + bool ConcatBox::compute(Tensor const &a, Tensor const &b, Tensor &out) const noexcept { + + if (a.data == nullptr || b.data == nullptr) { + return false; + } + //compute + auto kernels = this->base->candidateKernels(Target::Cpu)->filter({a, b}, {out}); + ASSERT(kernels.size() != 0, "do not supposrt this kernel"); + runtime::Resources res; + auto rou = kernels[0]->lower(res); + void const *inputs[]{*a.data, *b.data}; + void *outputs[]{*out.data}; + rou(res, inputs, outputs); + return true; + } + + std::unique_ptr ConcatBox::clone() const { + return std::make_unique(*dynamic_cast(base.get())); + } }// namespace refactor::computation diff --git a/src/05computation/src/operators/mat_mul.cc b/src/05computation/src/operators/mat_mul.cc index f260cbed..de9ea7a5 100644 --- a/src/05computation/src/operators/mat_mul.cc +++ b/src/05computation/src/operators/mat_mul.cc @@ -1,5 +1,6 @@ #include "computation/operators/mat_mul.h" #include "kernel/collectors/mat_mul.h" +#include "runtime/resource.h" namespace refactor::computation { using Op = MatMul; @@ -14,4 +15,38 @@ namespace refactor::computation { return std::make_unique(target, alpha, beta, transA, transB); } + Shape MatMulBox::verify(Tensor const &a, Tensor const &b) const noexcept { + Shape ans = {}; + if (a.rank() != 2 || b.rank() != 2) { + return ans; + } + if (a.shape[1] != b.shape[0]) { + return ans; + } + if (a.dataType != b.dataType) { + return ans; + } + ans = {a.shape[0], + b.shape[1]}; + return ans; + } + bool MatMulBox::compute(Tensor const &a, Tensor const &b, Tensor &out) const noexcept { + + if (a.data == nullptr || b.data == nullptr) { + return false; + } + //compute + auto kernels = this->base->candidateKernels(Target::Cpu)->filter({a, b}, {out}); + ASSERT(kernels.size() != 0, "do not supposrt this kernel"); + runtime::Resources res; + auto rou = kernels[0]->lower(res); + void const *inputs[]{*a.data, *b.data}; + void *outputs[]{*out.data}; + rou(res, inputs, outputs); + return true; + } + + std::unique_ptr MatMulBox::clone() const { + return std::make_unique(*dynamic_cast(base.get())); + } }// namespace refactor::computation diff --git a/src/05computation/test/test_mutant_generator.cpp b/src/05computation/test/test_mutant_generator.cpp new file mode 100644 index 00000000..32206d67 --- /dev/null +++ b/src/05computation/test/test_mutant_generator.cpp @@ -0,0 +1,67 @@ +#include "computation/graph_mutant.h" +#include "computation/mutant_generator.h" +#include "computation/operators/concat.h" +#include "computation/operators/mat_mul.h" +#include +#include + +namespace refactor::computation { + + refactor::graph_topo::Builder TestInGraphBuild() { + auto nodes = std::unordered_map{}; + nodes[0] = Node{std::make_shared(), "matmul_1"}; + nodes[1] = Node{std::make_shared(), "matmul_2"}; + nodes[2] = Node{std::make_shared(), "concat"}; + + auto tensor0 = Tensor::share(DataType::F32, {5, 6}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {4, 5}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {5, 7}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {4, 6}, LayoutType::Others); + auto tensor4 = Tensor::share(DataType::F32, {4, 7}, LayoutType::Others); + auto tensor5 = Tensor::share(DataType::F32, {4, 13}, LayoutType::Others); + // initialize inputs data + auto data0 = reinterpret_cast(tensor0->malloc()); + auto data1 = reinterpret_cast(tensor1->malloc()); + auto data2 = reinterpret_cast(tensor2->malloc()); + std::iota(data0, data0 + tensor0->elementsSize(), 1.0); + std::iota(data1, data1 + tensor1->elementsSize(), 1.0); + std::iota(data2, data2 + tensor2->elementsSize(), 1.0); + // initialize outputs data + float outputData[]{255.0, 270.0, 285.0, 300.0, 315.0, 330.0, 295.0, 310.0, 325.0, 340.0, 355.0, 370.0, 385.0, 580.0, 620.0, 660.0, 700.0, + 740.0, 780.0, 670.0, 710.0, 750.0, 790.0, 830.0, 870.0, 910.0, 905.0, 970.0, 1035.0, 1100.0, 1165.0, 1230.0, 1045.0, 1110.0, 1175.0, 1240.0, + 1305.0, 1370.0, 1435.0, 1230.0, 1320.0, 1410.0, 1500.0, 1590.0, 1680.0, 1420.0, 1510.0, 1600.0, 1690.0, 1780.0, 1870.0, 1960.}; + std::memcpy(tensor5->malloc(), outputData, tensor5->bytesSize()); + + return { + {{0, {{1, 0}, {3}}}, + {1, {{1, 2}, {4}}}, + {2, {{3, 4}, {5}}}}, + {0, 1, 2}, + {5}, + std::move(nodes), + { + {0, {tensor0, "input_tensor_0"}}, + {1, {tensor1, "input_tensor_1"}}, + {2, {tensor2, "input_tensor_2"}}, + {3, {tensor3, "matmul0_output"}}, + {4, {tensor4, "matmul1_output"}}, + {5, {tensor5, "output"}}, + }, + }; + } + + TEST(Graph, MutantGenerator) { + auto graphTopo = TestInGraphBuild().build(); + fmt::println("{}", graphTopo.topology.toString()); + GraphMutant g(std::move(graphTopo)); + // create mutant generator + MutantGenerator mutant; + OpVec oplist = {std::make_shared(), std::make_shared()}; + mutant.init(1.0, 3, oplist); + std::vector outGraph = {}; + mutant.run(std::move(g), outGraph); + // for (size_t i = 0; i < outGraph.size(); ++i) { + // fmt::println("{}", outGraph[i].internal().toString()); + // } + } +}// namespace refactor::computation \ No newline at end of file