Skip to content

Commit

Permalink
feat: support auto optimize demo
Browse files Browse the repository at this point in the history
bitzyz committed Nov 14, 2023
1 parent 65f9c5e commit 0032b8e
Showing 15 changed files with 541 additions and 1 deletion.
68 changes: 67 additions & 1 deletion src/01graph_topo/include/graph_topo/linked_graph.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef GRAPH_TOPO_LINKED_GRAPH_H
#define GRAPH_TOPO_LINKED_GRAPH_H

#include "container.h"
#include "common.h"
#include "container.h"
#include <algorithm>
#include <sstream>
#include <unordered_map>
@@ -30,6 +30,7 @@ namespace refactor::graph_topo {
static auto shareEdge(TE) -> Rc<Edge>;

std::string toString() const;
std::string toString(std::string func(TN const &)) const;
Graph<TN, TE> intoGraph() const;
std::vector<Rc<Node>> const &nodes() const;
std::vector<Rc<Edge>> const &inputs() const;
@@ -43,6 +44,7 @@ namespace refactor::graph_topo {
void eraseNode(Rc<Node>);
size_t cleanup(bool useless(TE const &) = nullptr);
bool sort();
LinkedGraph clone(TN cloneNode(TN const &), TE cloneEdge(TE const &)) const;
};

template<class TN, class TE>
@@ -127,6 +129,34 @@ namespace refactor::graph_topo {
return ss.str();
}

LINKED_GRAPH_FN toString(std::string func(TN const &)) const->std::string {
std::unordered_map<void *, size_t> indices;
std::stringstream ss;
auto f = [&indices, &ss](Rc<Edge> const &e) {
if (e) {
auto [it, ok] = indices.try_emplace(e.get(), indices.size());
ss << it->second << ' ';
} else {
ss << "? ";
}
};
ss << "*. -> ( ";
for (auto const &e : _inputs) { f(e); }
ss << ')' << std::endl;
for (auto i : range0_(_nodes.size())) {
auto n = _nodes[i];
ss << i << ". ( ";
for (auto const &e : n->_inputs) { f(e); }
ss << ") -> ( ";
for (auto const &e : n->_outputs) { f(e); }
ss << ')' << func(n->_info) << std::endl;
}
ss << "*. <- ( ";
for (auto const &e : _outputs) { f(e); }
ss << ')' << std::endl;
return ss.str();
}

LINKED_GRAPH_FN nodes() const->std::vector<Rc<Node>> const & {
return _nodes;
}
@@ -236,6 +266,42 @@ namespace refactor::graph_topo {
return true;
}

LINKED_GRAPH_FN clone(
TN cloneNode(TN const &),
TE cloneEdge(TE const &))
const->LinkedGraph {
LinkedGraph ans;
ans._inputs.reserve(_inputs.size());
ans._nodes.reserve(_nodes.size());
ans._outputs.reserve(_outputs.size());
std::unordered_map<void *, Rc<Edge>> edges;
auto mapEdge = [&](Rc<Edge> const &e) {
return edges.try_emplace(e.get(), shareEdge(cloneEdge(e->_info))).first->second;
};
for (auto const &e : _inputs) {
ans._inputs.emplace_back(mapEdge(e));
}
for (auto const &n : _nodes) {
std::vector<Rc<Edge>> outputs;
outputs.reserve(n->_outputs.size());
for (auto const &e : n->_outputs) {
outputs.emplace_back(mapEdge(e));
}
auto n_ = ans.pushNode(cloneNode(n->_info), std::move(outputs));
for (auto i : range0_(n->_inputs.size())) {
if (auto it = edges.find(n->_inputs[i].get()); it != edges.end()) {
n_->connect(i, it->second);
} else {
n_->connect(i, mapEdge(n->_inputs[i]));
}
}
}
for (auto const &e : _outputs) {
ans._outputs.emplace_back(std::move(edges.at(e.get())));
}
return ans;
}

LINKED_GRAPH_FN Node::share(TN info, std::vector<Rc<Edge>> outputs)
->Rc<Node> {
auto ans = Rc<Node>(new Node(std::move(info), std::move(outputs)));
1 change: 1 addition & 0 deletions src/02mem_manager/include/mem_manager/blob.hh
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ namespace refactor::mem_manager {

static std::pair<std::shared_ptr<Blob>, void *> share(size_t);
operator void const *() const noexcept;
operator void *() noexcept;
template<class T> T const *get() const noexcept {
return reinterpret_cast<T const *>(_ptr);
}
1 change: 1 addition & 0 deletions src/02mem_manager/src/blob.cc
Original file line number Diff line number Diff line change
@@ -14,5 +14,6 @@ namespace refactor::mem_manager {
return {std::move(blob), ptr};
}
Blob::operator void const *() const noexcept { return _ptr; }
Blob::operator void *() noexcept { return _ptr; }

}// namespace refactor::mem_manager
1 change: 1 addition & 0 deletions src/05computation/include/computation/graph.h
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ namespace refactor::computation {

kernel::Graph lower(Target) const;
auto internal() const -> decltype(_internal) const &;
auto internal() -> decltype(_internal) &;
};

}// namespace refactor::computation
38 changes: 38 additions & 0 deletions src/05computation/include/computation/graph_mutant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef COMPUTATION_GRAPH_MUTANT_H
#define COMPUTATION_GRAPH_MUTANT_H

#include "graph_topo.h"
#include "kernel/graph.h"
#include "operator.h"

namespace refactor::computation {
using kernel::Shape;
using kernel::Tensor;

struct Node {
Arc<MyOperator> op;
std::string name;
};

struct Edge {
Arc<Tensor> tensor;
std::string name;
};

class GraphMutant {
graph_topo::LinkedGraph<Node, Edge> _internal;

public:
explicit GraphMutant(graph_topo::Graph<Node, Edge>) noexcept;
GraphMutant(graph_topo::GraphTopo, std::vector<Node>, std::vector<Edge>) noexcept;
GraphMutant(graph_topo::LinkedGraph<Node, Edge>) noexcept;

GraphMutant clone() noexcept;

auto internal() const -> decltype(_internal) const &;
auto internal() -> decltype(_internal) &;
};

}// namespace refactor::computation

#endif// COMPUTATION_GRAPH_MUTANT_H
44 changes: 44 additions & 0 deletions src/05computation/include/computation/mutant_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef COMPUTATION_MUTANT_GENERATOR_H
#define COMPUTATION_MUTANT_GENERATOR_H

#include "graph_mutant.h"
#include "operator.h"

namespace refactor::computation {

using OpVec = std::vector<Arc<MyOperator>>;
using TensorVec = std::vector<Rc<refactor::graph_topo::LinkedGraph<Node, Edge>::Edge>>;

inline uint64_t hashAppend(uint64_t a, uint64_t b) {
return (a * 10000019 + b * 10000079) % 2147483647;
}

template<typename T> inline uint64_t hashVector(const std::vector<T> &vec) {
uint64_t ret = 0;
for (auto v : vec)
ret = hashAppend(ret, v);
return ret;
}

class MutantGenerator {
float equalThreshold;
size_t maxDepth;
size_t numValidTensors = 0;
OpVec opList;
std::vector<OpVec> opStorage;
OpVec opFinger;
TensorVec validTensors;
std::set<uint64_t> opHashMaps;

public:
void init(float, size_t, OpVec) noexcept;
void run(GraphMutant const &, std::vector<GraphMutant> &) noexcept;
void dfs(size_t, GraphMutant const &, GraphMutant &, std::vector<GraphMutant> &) noexcept;
bool is_mutant(GraphMutant &, GraphMutant const &) noexcept;
bool approx_equal(const Tensor &, const Tensor &) const noexcept;
bool have_same_op(Arc<MyOperator> const &, size_t, size_t) noexcept;
void delete_hash_op(Arc<MyOperator> const &, size_t, size_t) noexcept;
};
}// namespace refactor::computation

#endif// COMPUTATION_MUTANT_GENERATOR_H
16 changes: 16 additions & 0 deletions src/05computation/include/computation/operator.h
Original file line number Diff line number Diff line change
@@ -6,7 +6,9 @@

namespace refactor::computation {
using kernel::LayoutType;
using kernel::Shape;
using kernel::Target;
using kernel::Tensor;

class Operator {
public:
@@ -42,6 +44,20 @@ namespace refactor::computation {

using OpBox = std::unique_ptr<Operator>;

struct MyOperator {
size_t numInputs = 2;
Arc<Operator> base;

MyOperator() : numInputs(2) {}
MyOperator(size_t num) : numInputs(num) {}
//MyOperator(const MyOperator &other) : numInputs(other.numInputs) {}
//MyOperator(MyOperator &&other) : numInputs(other.numInputs) {}
// virtual std::unique_ptr<MyOperator> create() const = 0;
virtual std::unique_ptr<Operator> clone() const = 0;
virtual bool compute(Tensor const &, Tensor const &, Tensor &) const = 0;
virtual Shape verify(Tensor const &, Tensor const &) const = 0;
};

}// namespace refactor::computation

#endif// COMPUTATION_OPERATOR_H
11 changes: 11 additions & 0 deletions src/05computation/include/computation/operators/concat.h
Original file line number Diff line number Diff line change
@@ -15,6 +15,17 @@ namespace refactor::computation {
kernel::CollectorBox candidateKernels(Target) const noexcept final;
};

using refactor::kernel::Tensor;
struct ConcatBox final : public MyOperator {
// Arc<Concat> base;

ConcatBox() noexcept : MyOperator() {
base = std::make_shared<Concat>(1, 2);
}
std::unique_ptr<Operator> clone() const final;
bool compute(Tensor const &, Tensor const &, Tensor &) const noexcept final;
Shape verify(Tensor const &, Tensor const &) const noexcept final;
};
}// namespace refactor::computation

#endif// COMPUTATION_CONCAT_H
12 changes: 12 additions & 0 deletions src/05computation/include/computation/operators/mat_mul.h
Original file line number Diff line number Diff line change
@@ -22,6 +22,18 @@ namespace refactor::computation {
kernel::CollectorBox candidateKernels(Target) const noexcept final;
};

using refactor::kernel::Tensor;
struct MatMulBox final : public MyOperator {
// Arc<MatMul> base;

MatMulBox() noexcept : MyOperator() {
base = std::make_shared<MatMul>(1.0, 1.0, false, false);
}
std::unique_ptr<Operator> clone() const final;
bool compute(Tensor const &, Tensor const &, Tensor &) const noexcept final;
Shape verify(Tensor const &, Tensor const &) const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_MAT_MUL_H
1 change: 1 addition & 0 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
@@ -67,5 +67,6 @@ namespace refactor::computation {
}

auto Graph::internal() const -> decltype(_internal) const & { return _internal; }
auto Graph::internal() -> decltype(_internal) & { return _internal; }

}// namespace refactor::computation
28 changes: 28 additions & 0 deletions src/05computation/src/graph_mutant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "computation/graph_mutant.h"

namespace refactor::computation {

GraphMutant::GraphMutant(graph_topo::Graph<Node, Edge> internal) noexcept
: _internal(std::move(internal)) {}
GraphMutant::GraphMutant(graph_topo::GraphTopo topology,
std::vector<Node> nodes,
std::vector<Edge> edges) noexcept
: GraphMutant(graph_topo::Graph<Node, Edge>{
std::move(topology),
std::move(nodes),
std::move(edges),
}) {}

GraphMutant::GraphMutant(graph_topo::LinkedGraph<Node, Edge> internal) noexcept
: _internal(std::move(internal)) {}
GraphMutant GraphMutant::clone() noexcept {
auto internal = this->_internal.clone([](Node const &o) -> Node { return o; },
[](Edge const &e) -> Edge { return e; });
GraphMutant newGraph(std::move(internal));
return newGraph;
}

auto GraphMutant::internal() const -> decltype(_internal) const & { return _internal; }
auto GraphMutant::internal() -> decltype(_internal) & { return _internal; }

}// namespace refactor::computation
185 changes: 185 additions & 0 deletions src/05computation/src/mutant_generator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#include "computation/mutant_generator.h"
#define MAX_SIZE 1024x1024

namespace refactor::computation {
using K = MutantGenerator;

void K::init(float equalThreshold_, size_t maxDepth_, OpVec opList_) noexcept {
equalThreshold = equalThreshold_;
maxDepth = maxDepth_;
opList = opList_;
opFinger.clear();
opStorage.clear();
validTensors.clear();
opHashMaps.clear();
for (size_t i = 0; i < maxDepth; ++i) {
opStorage.push_back(opList);
}
}

void K::run(GraphMutant const &inGraph, std::vector<GraphMutant> &outGraphs) noexcept {
using namespace refactor::graph_topo;

// init global inputs
std::unordered_map<size_t, Edge> edges;
auto edgeIndex = std::vector<size_t>{};
auto inputs = inGraph.internal().inputs();
auto outputs = inGraph.internal().outputs();
ASSERT(outputs.size() == 1, "Do not support more than one output.");
numValidTensors = inputs.size();
for (size_t i = 0; i < numValidTensors; ++i) {
edgeIndex.emplace_back(i);
edges.insert({i, inputs[i]->info()});
}
// init graph
Builder<size_t, Node, size_t, Edge>
builder = {{}, edgeIndex, {}, {}, edges};
GraphMutant curGraph(std::move(builder.build()));
for (size_t i = 0; i < numValidTensors; ++i) {
validTensors.emplace_back(curGraph.internal().inputs()[i]);
}
dfs(0, inGraph, curGraph, outGraphs);
}

void K::dfs(size_t depth, GraphMutant const &inGraph, GraphMutant &curGraph, std::vector<GraphMutant> &outGraphs) noexcept {
if (is_mutant(curGraph, inGraph)) {
//存在非全局输出的张量无后继结点,则此图为冗余图
int count = 0;
for (size_t i = 0; i < numValidTensors; ++i) {
if (validTensors[i]->targets().size() == 0) {
count++;
}
}
if (count > 1) {
curGraph.internal().cleanup();
return;
}
auto g = curGraph.clone();
fmt::println("=======zyz======ok======");
fmt::println("{}", curGraph.internal().toString([](Node const &o) -> std::string { return std::string(o.op->base->name()); }));
for (size_t i = 0; i < numValidTensors; ++i) {
fmt::println("{}. \"{}\" Shape is {}", i, validTensors[i]->info().name,
vec2str(validTensors[i]->info().tensor->shape));
}
outGraphs.emplace_back(std::move(g));
curGraph.internal().setOutputs({});
return;
}
if (depth >= maxDepth) {
return;
}
//auto g_ = curGraph.internal();
for (size_t index = 0; index < opStorage[depth].size(); ++index) {
auto op = opStorage[depth][index];
if (op->numInputs == 2) {
for (size_t i = 0; i < numValidTensors; ++i) {
for (size_t j = 0; j < numValidTensors; ++j) {
if (i == j) {
continue;
}
auto x = validTensors[i]->info().tensor;
auto y = validTensors[j]->info().tensor;
auto ans = op->verify(*x, *y);
if (ans.size() == 0) {
continue;
}
//fmt::println("{},{}, {}, {}", i, j, reinterpret_cast<void *>(x.get()), reinterpret_cast<void *>(y.get()));
auto out = Tensor::share(x->dataType, ans, LayoutType::Others);
out->malloc();
if (!op->compute(*x, *y, *out) || have_same_op(op, i, j)) {
out->free();
continue;
}
numValidTensors++;
opFinger.push_back(op);
auto name = fmt::format("{}", depth);
auto newEdge = curGraph.internal().shareEdge({out, "tensor_" + name});
auto newNode = curGraph.internal().pushNode({op, "op_" + name},
{newEdge});
newNode->connect(0, validTensors[i]);
newNode->connect(1, validTensors[j]);
validTensors.push_back(newEdge);
//fmt::println("{}", curGraph.internal().toString([](Node const &o) -> std::string { return std::string(o.op->name()); }));
//fmt::println("{}", reinterpret_cast<void *>(validTensors[j]->info().tensor.get()));
dfs(depth + 1, inGraph, curGraph, outGraphs);
curGraph.internal().eraseNode(newNode);
validTensors.pop_back();
opFinger.pop_back();
delete_hash_op(op, i, j);
numValidTensors--;
}
}
}
}
}

bool K::is_mutant(GraphMutant &curGraph, GraphMutant const &inGraph) noexcept {
// fmt::println("=======================output graph =================");
// fmt::println("{}", curGraph.internal().toString([](Node const &o) -> std::string { return std::string(o.op->base->name()); }));
// fmt::println("Edges info :");
// for (size_t i = 0; i < numValidTensors; ++i) {
// fmt::println("{}. \"{}\" Shape is {}", i, validTensors[i]->info().name,
// vec2str(validTensors[i]->info().tensor->shape));
// }
auto inputs = inGraph.internal().inputs();
auto outputs = inGraph.internal().outputs();
std::vector<refactor::Rc<refactor::graph_topo::LinkedGraph<Node, Edge>::Edge>> outEdges;
for (auto output : outputs) {
int found = -1;
auto &tensor = *output->info().tensor;
for (size_t i = inputs.size(); i < validTensors.size(); ++i) {
if (approx_equal(tensor, *(validTensors[i]->info().tensor))) {
found = i;
break;
}
}
if (found == -1) {
// fmt::println("!!!!!!!compare false ");
return false;
}
outEdges.emplace_back(validTensors[found]);
}
curGraph.internal().setOutputs(outEdges);
// fmt::println("=======================compare true =================");
return true;
}

bool K::approx_equal(const Tensor &a, const Tensor &b) const noexcept {
if (a.shape != b.shape) {
return false;
}
size_t equal = 0, total = 0;
auto dataA = a.data->get<float>();
auto dataB = b.data->get<float>();
for (size_t i = 0; i < a.elementsSize(); ++i) {
if (dataA[i] == dataB[i]) {
equal++;
}
total++;
}
if (float(equal) / total >= equalThreshold) {
return true;
}
return false;
}

bool K::have_same_op(Arc<MyOperator> const &op, size_t a, size_t b) noexcept {
//fmt::println("{}", reinterpret_cast<void *>(op->base.get()));
std::vector<size_t> hashInfo = {op->base->opTypeId(), a, b};
auto res = hashVector(hashInfo);
if (opHashMaps.find(res) != opHashMaps.end()) {
return true;
}
opHashMaps.insert(std::move(res));
return false;
}

void K::delete_hash_op(Arc<MyOperator> const &op, size_t a, size_t b) noexcept {
std::vector<size_t> hashInfo = {op->base->opTypeId(), a, b};
auto res = hashVector(hashInfo);
auto it = opHashMaps.find(res);
if (auto it = opHashMaps.find(res); it != opHashMaps.end()) {
opHashMaps.erase(it);
}
}
}// namespace refactor::computation
34 changes: 34 additions & 0 deletions src/05computation/src/operators/concat.cc
Original file line number Diff line number Diff line change
@@ -14,4 +14,38 @@ namespace refactor::computation {
using Collector_ = kernel::ConcatCollector;
return std::make_unique<Collector_>(target, axis);
}

Shape ConcatBox::verify(Tensor const &a, Tensor const &b) const noexcept {
Shape ans = {};
if (a.rank() != 2 || b.rank() != 2) {
return ans;
}
if (a.shape[0] != b.shape[0]) {
return ans;
}
if (a.dataType != b.dataType) {
return ans;
}
ans = {a.shape[0], a.shape[1] + b.shape[1]};
return ans;
}
bool ConcatBox::compute(Tensor const &a, Tensor const &b, Tensor &out) const noexcept {

if (a.data == nullptr || b.data == nullptr) {
return false;
}
//compute
auto kernels = this->base->candidateKernels(Target::Cpu)->filter({a, b}, {out});
ASSERT(kernels.size() != 0, "do not supposrt this kernel");
runtime::Resources res;
auto rou = kernels[0]->lower(res);
void const *inputs[]{*a.data, *b.data};
void *outputs[]{*out.data};
rou(res, inputs, outputs);
return true;
}

std::unique_ptr<Operator> ConcatBox::clone() const {
return std::make_unique<Concat>(*dynamic_cast<Concat const *>(base.get()));
}
}// namespace refactor::computation
35 changes: 35 additions & 0 deletions src/05computation/src/operators/mat_mul.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "computation/operators/mat_mul.h"
#include "kernel/collectors/mat_mul.h"
#include "runtime/resource.h"

namespace refactor::computation {
using Op = MatMul;
@@ -14,4 +15,38 @@ namespace refactor::computation {
return std::make_unique<kernel::MatMulCollector>(target, alpha, beta, transA, transB);
}

Shape MatMulBox::verify(Tensor const &a, Tensor const &b) const noexcept {
Shape ans = {};
if (a.rank() != 2 || b.rank() != 2) {
return ans;
}
if (a.shape[1] != b.shape[0]) {
return ans;
}
if (a.dataType != b.dataType) {
return ans;
}
ans = {a.shape[0],
b.shape[1]};
return ans;
}
bool MatMulBox::compute(Tensor const &a, Tensor const &b, Tensor &out) const noexcept {

if (a.data == nullptr || b.data == nullptr) {
return false;
}
//compute
auto kernels = this->base->candidateKernels(Target::Cpu)->filter({a, b}, {out});
ASSERT(kernels.size() != 0, "do not supposrt this kernel");
runtime::Resources res;
auto rou = kernels[0]->lower(res);
void const *inputs[]{*a.data, *b.data};
void *outputs[]{*out.data};
rou(res, inputs, outputs);
return true;
}

std::unique_ptr<Operator> MatMulBox::clone() const {
return std::make_unique<MatMul>(*dynamic_cast<MatMul const *>(base.get()));
}
}// namespace refactor::computation
67 changes: 67 additions & 0 deletions src/05computation/test/test_mutant_generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "computation/graph_mutant.h"
#include "computation/mutant_generator.h"
#include "computation/operators/concat.h"
#include "computation/operators/mat_mul.h"
#include <gtest/gtest.h>
#include <numeric>

namespace refactor::computation {

refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestInGraphBuild() {
auto nodes = std::unordered_map<size_t, Node>{};
nodes[0] = Node{std::make_shared<MatMulBox>(), "matmul_1"};
nodes[1] = Node{std::make_shared<MatMulBox>(), "matmul_2"};
nodes[2] = Node{std::make_shared<ConcatBox>(), "concat"};

auto tensor0 = Tensor::share(DataType::F32, {5, 6}, LayoutType::Others);
auto tensor1 = Tensor::share(DataType::F32, {4, 5}, LayoutType::Others);
auto tensor2 = Tensor::share(DataType::F32, {5, 7}, LayoutType::Others);
auto tensor3 = Tensor::share(DataType::F32, {4, 6}, LayoutType::Others);
auto tensor4 = Tensor::share(DataType::F32, {4, 7}, LayoutType::Others);
auto tensor5 = Tensor::share(DataType::F32, {4, 13}, LayoutType::Others);
// initialize inputs data
auto data0 = reinterpret_cast<float *>(tensor0->malloc());
auto data1 = reinterpret_cast<float *>(tensor1->malloc());
auto data2 = reinterpret_cast<float *>(tensor2->malloc());
std::iota(data0, data0 + tensor0->elementsSize(), 1.0);
std::iota(data1, data1 + tensor1->elementsSize(), 1.0);
std::iota(data2, data2 + tensor2->elementsSize(), 1.0);
// initialize outputs data
float outputData[]{255.0, 270.0, 285.0, 300.0, 315.0, 330.0, 295.0, 310.0, 325.0, 340.0, 355.0, 370.0, 385.0, 580.0, 620.0, 660.0, 700.0,
740.0, 780.0, 670.0, 710.0, 750.0, 790.0, 830.0, 870.0, 910.0, 905.0, 970.0, 1035.0, 1100.0, 1165.0, 1230.0, 1045.0, 1110.0, 1175.0, 1240.0,
1305.0, 1370.0, 1435.0, 1230.0, 1320.0, 1410.0, 1500.0, 1590.0, 1680.0, 1420.0, 1510.0, 1600.0, 1690.0, 1780.0, 1870.0, 1960.};
std::memcpy(tensor5->malloc(), outputData, tensor5->bytesSize());

return {
{{0, {{1, 0}, {3}}},
{1, {{1, 2}, {4}}},
{2, {{3, 4}, {5}}}},
{0, 1, 2},
{5},
std::move(nodes),
{
{0, {tensor0, "input_tensor_0"}},
{1, {tensor1, "input_tensor_1"}},
{2, {tensor2, "input_tensor_2"}},
{3, {tensor3, "matmul0_output"}},
{4, {tensor4, "matmul1_output"}},
{5, {tensor5, "output"}},
},
};
}

TEST(Graph, MutantGenerator) {
auto graphTopo = TestInGraphBuild().build();
fmt::println("{}", graphTopo.topology.toString());
GraphMutant g(std::move(graphTopo));
// create mutant generator
MutantGenerator mutant;
OpVec oplist = {std::make_shared<MatMulBox>(), std::make_shared<ConcatBox>()};
mutant.init(1.0, 3, oplist);
std::vector<GraphMutant> outGraph = {};
mutant.run(std::move(g), outGraph);
// for (size_t i = 0; i < outGraph.size(); ++i) {
// fmt::println("{}", outGraph[i].internal().toString());
// }
}
}// namespace refactor::computation

0 comments on commit 0032b8e

Please sign in to comment.