-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
13 changed files
with
401 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#ifndef COMPUTATION_MUTANT_GENERATOR_H | ||
#define COMPUTATION_MUTANT_GENERATOR_H | ||
|
||
#include "graph.h" | ||
#include "operator.h" | ||
|
||
namespace refactor::computation { | ||
|
||
using OpVec = std::vector<Arc<MyOperator>>; | ||
using TensorVec = std::vector<Rc<refactor::graph_topo::LinkedGraph<Node, Edge>::Edge>>; | ||
|
||
inline uint64_t hashAppend(uint64_t a, uint64_t b) { | ||
return (a * 10000019 + b * 10000079) % 2147483647; | ||
} | ||
|
||
template<typename T> inline uint64_t hashVector(const std::vector<T> &vec) { | ||
uint64_t ret = 0; | ||
for (auto v : vec) | ||
ret = hashAppend(ret, v); | ||
return ret; | ||
} | ||
|
||
class MutantGenerator { | ||
float equalThreshold; | ||
size_t maxDepth; | ||
size_t numValidTensors = 0; | ||
OpVec opList; | ||
std::vector<OpVec> opStorage; | ||
OpVec opFinger; | ||
TensorVec validTensors; | ||
std::set<uint64_t> opHashMaps; | ||
|
||
public: | ||
void init(float, size_t, OpVec) noexcept; | ||
void run(Graph const &, std::vector<Graph> &) noexcept; | ||
void dfs(size_t, Graph const &, Graph &, std::vector<Graph> &) noexcept; | ||
bool is_mutant(Graph &, Graph const &) noexcept; | ||
bool approx_equal(const Tensor &, const Tensor &) const noexcept; | ||
bool have_same_op(Arc<MyOperator> const &, size_t, size_t) noexcept; | ||
void delete_hash_op(Arc<MyOperator> const &, size_t, size_t) noexcept; | ||
}; | ||
}// namespace refactor::computation | ||
|
||
#endif// COMPUTATION_MUTANT_GENERATOR_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#include "computation/mutant_generator.h" | ||
#define MAX_SIZE 1024x1024 | ||
|
||
namespace refactor::computation { | ||
using K = MutantGenerator; | ||
|
||
void K::init(float equalThreshold_, size_t maxDepth_, OpVec opList_) noexcept { | ||
equalThreshold = equalThreshold_; | ||
maxDepth = maxDepth_; | ||
opList = opList_; | ||
opFinger.clear(); | ||
opStorage.clear(); | ||
validTensors.clear(); | ||
opHashMaps.clear(); | ||
for (size_t i = 0; i < maxDepth; ++i) { | ||
opStorage.push_back(opList); | ||
} | ||
} | ||
|
||
void K::run(Graph const &inGraph, std::vector<Graph> &outGraphs) noexcept { | ||
using namespace refactor::graph_topo; | ||
|
||
// init global inputs | ||
std::unordered_map<size_t, Edge> edges; | ||
auto edgeIndex = std::vector<size_t>{}; | ||
auto inputs = inGraph.internal().linked().inputs(); | ||
auto outputs = inGraph.internal().linked().outputs(); | ||
ASSERT(outputs.size() == 1, "Do not support more than one output."); | ||
numValidTensors = inputs.size(); | ||
for (size_t i = 0; i < numValidTensors; ++i) { | ||
edgeIndex.emplace_back(i); | ||
edges.insert({i, inputs[i]->info()}); | ||
} | ||
// init graph | ||
Builder<size_t, Node, size_t, Edge> | ||
builder = {{}, edgeIndex, {}, {}, edges}; | ||
Graph curGraph(std::move(builder.build())); | ||
for (size_t i = 0; i < numValidTensors; ++i) { | ||
validTensors.emplace_back(curGraph.internal().linked().inputs()[i]); | ||
} | ||
dfs(0, inGraph, curGraph, outGraphs); | ||
} | ||
|
||
void K::dfs(size_t depth, Graph const &inGraph, Graph &curGraph, std::vector<Graph> &outGraphs) noexcept { | ||
if (is_mutant(curGraph, inGraph)) { | ||
outGraphs.emplace_back(curGraph); | ||
return; | ||
} | ||
if (depth >= maxDepth) { | ||
return; | ||
} | ||
//auto g_ = curGraph.internal().linked(); | ||
for (size_t index = 0; index < opStorage[depth].size(); ++index) { | ||
auto op = opStorage[depth][index]; | ||
if (op->numInputs == 2) { | ||
for (size_t i = 0; i < numValidTensors; ++i) { | ||
for (size_t j = 0; j < numValidTensors; ++j) { | ||
if (i == j) { | ||
continue; | ||
} | ||
auto x = validTensors[i]->info().tensor; | ||
auto y = validTensors[j]->info().tensor; | ||
//fmt::println("{},{}, {}, {}", i, j, reinterpret_cast<void *>(x.get()), reinterpret_cast<void *>(y.get())); | ||
auto out = Tensor::share(x->dataType, {1024, 1024}, LayoutType::Others); | ||
out->malloc(); | ||
if (!op->compute(*x, *y, *out) || have_same_op(op, i, j)) { | ||
out->free(); | ||
continue; | ||
} | ||
numValidTensors++; | ||
opFinger.push_back(op); | ||
auto name = fmt::format("{}", depth); | ||
auto newEdge = curGraph.internal().linked().shareEdge({out, "tensor_" + name}); | ||
auto newNode = curGraph.internal().linked().pushNode({op->clone(), "op_" + name}, | ||
{newEdge}); | ||
newNode->connect(0, validTensors[i]); | ||
newNode->connect(1, validTensors[j]); | ||
validTensors.push_back(newEdge); | ||
//fmt::println("{}", curGraph.internal().linked().toString([](Node const &o) -> std::string { return std::string(o.op->name()); })); | ||
//fmt::println("{}", reinterpret_cast<void *>(validTensors[j]->info().tensor.get())); | ||
dfs(depth + 1, inGraph, curGraph, outGraphs); | ||
curGraph.internal().linked().eraseNode(newNode); | ||
//curGraph.internal().linked(); | ||
validTensors.pop_back(); | ||
opFinger.pop_back(); | ||
delete_hash_op(op, i, j); | ||
numValidTensors--; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
bool K::is_mutant(Graph &curGraph, Graph const &inGraph) noexcept { | ||
fmt::println("=======================output graph ================="); | ||
fmt::println("{}", curGraph.internal().linked().toString([](Node const &o) -> std::string { return std::string(o.op->name()); })); | ||
fmt::println("Edges info :"); | ||
for (size_t i = 0; i < numValidTensors; ++i) { | ||
fmt::println("{}. \"{}\" Shape is {}", i, validTensors[i]->info().name, | ||
vec2str(validTensors[i]->info().tensor->shape)); | ||
} | ||
auto inputs = inGraph.internal().linked().inputs(); | ||
auto outputs = inGraph.internal().linked().outputs(); | ||
std::vector<refactor::Rc<refactor::graph_topo::LinkedGraph<Node, Edge>::Edge>> outEdges; | ||
for (auto output : outputs) { | ||
int found = -1; | ||
auto &tensor = *output->info().tensor; | ||
for (size_t i = inputs.size(); i < validTensors.size(); ++i) { | ||
if (approx_equal(tensor, *(validTensors[i]->info().tensor))) { | ||
found = i; | ||
break; | ||
} | ||
} | ||
if (found == -1) { | ||
fmt::println("!!!!!!!compare false "); | ||
return false; | ||
} | ||
outEdges.emplace_back(validTensors[found]); | ||
} | ||
curGraph.internal().linked().setOutputs(outEdges); | ||
fmt::println("=======================compare true ================="); | ||
return true; | ||
} | ||
|
||
bool K::approx_equal(const Tensor &a, const Tensor &b) const noexcept { | ||
if (a.shape != b.shape) { | ||
return false; | ||
} | ||
size_t equal = 0, total = 0; | ||
auto dataA = a.data->get<float>(); | ||
auto dataB = b.data->get<float>(); | ||
for (size_t i = 0; i < a.elementsSize(); ++i) { | ||
if (dataA[i] == dataB[i]) { | ||
equal++; | ||
} | ||
total++; | ||
} | ||
if (float(equal) / total >= equalThreshold) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
bool K::have_same_op(Arc<MyOperator> const &op, size_t a, size_t b) noexcept { | ||
//fmt::println("{}", reinterpret_cast<void *>(op->base.get())); | ||
std::vector<size_t> hashInfo = {op->base->opTypeId(), a, b}; | ||
auto res = hashVector(hashInfo); | ||
if (opHashMaps.find(res) != opHashMaps.end()) { | ||
return true; | ||
} | ||
opHashMaps.insert(std::move(res)); | ||
return false; | ||
} | ||
|
||
void K::delete_hash_op(Arc<MyOperator> const &op, size_t a, size_t b) noexcept { | ||
std::vector<size_t> hashInfo = {op->base->opTypeId(), a, b}; | ||
auto res = hashVector(hashInfo); | ||
auto it = opHashMaps.find(res); | ||
if (auto it = opHashMaps.find(res); it != opHashMaps.end()) { | ||
opHashMaps.erase(it); | ||
} | ||
} | ||
}// namespace refactor::computation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include "computation/graph.h" | ||
#include "computation/mutant_generator.h" | ||
#include "computation/operators/concat.h" | ||
#include "computation/operators/mat_mul.h" | ||
#include <gtest/gtest.h> | ||
#include <numeric> | ||
|
||
namespace refactor::computation { | ||
|
||
refactor::graph_topo::Builder<size_t, Node, size_t, Edge> TestInGraphBuild() { | ||
auto nodes = std::unordered_map<size_t, Node>{}; | ||
nodes[0] = Node{std::make_unique<MatMul>(1.0, 1.0, false, false), "matmul_1"}; | ||
nodes[1] = Node{std::make_unique<MatMul>(1.0, 1.0, false, false), "matmul_2"}; | ||
nodes[2] = Node{std::make_unique<Concat>(1, 2), "concat"}; | ||
|
||
auto tensor0 = Tensor::share(DataType::F32, {5, 6}, LayoutType::Others); | ||
auto tensor1 = Tensor::share(DataType::F32, {4, 5}, LayoutType::Others); | ||
auto tensor2 = Tensor::share(DataType::F32, {5, 7}, LayoutType::Others); | ||
auto tensor3 = Tensor::share(DataType::F32, {4, 6}, LayoutType::Others); | ||
auto tensor4 = Tensor::share(DataType::F32, {4, 7}, LayoutType::Others); | ||
auto tensor5 = Tensor::share(DataType::F32, {4, 13}, LayoutType::Others); | ||
// initialize inputs data | ||
auto data0 = reinterpret_cast<float *>(tensor0->malloc()); | ||
auto data1 = reinterpret_cast<float *>(tensor1->malloc()); | ||
auto data2 = reinterpret_cast<float *>(tensor2->malloc()); | ||
std::iota(data0, data0 + tensor0->elementsSize(), 1.0); | ||
std::iota(data1, data1 + tensor1->elementsSize(), 1.0); | ||
std::iota(data2, data2 + tensor2->elementsSize(), 1.0); | ||
// initialize outputs data | ||
float outputData[]{255.0, 270.0, 285.0, 300.0, 315.0, 330.0, 295.0, 310.0, 325.0, 340.0, 355.0, 370.0, 385.0, 580.0, 620.0, 660.0, 700.0, | ||
740.0, 780.0, 670.0, 710.0, 750.0, 790.0, 830.0, 870.0, 910.0, 905.0, 970.0, 1035.0, 1100.0, 1165.0, 1230.0, 1045.0, 1110.0, 1175.0, 1240.0, | ||
1305.0, 1370.0, 1435.0, 1230.0, 1320.0, 1410.0, 1500.0, 1590.0, 1680.0, 1420.0, 1510.0, 1600.0, 1690.0, 1780.0, 1870.0, 1960.}; | ||
std::memcpy(tensor5->malloc(), outputData, tensor5->bytesSize()); | ||
|
||
return { | ||
{{0, {{1, 0}, {3}}}, | ||
{1, {{1, 2}, {4}}}, | ||
{2, {{3, 4}, {5}}}}, | ||
{0, 1, 2}, | ||
{5}, | ||
std::move(nodes), | ||
{ | ||
{0, {tensor0, "input_tensor_0"}}, | ||
{1, {tensor1, "input_tensor_1"}}, | ||
{2, {tensor2, "input_tensor_2"}}, | ||
{3, {tensor3, "matmul0_output"}}, | ||
{4, {tensor4, "matmul1_output"}}, | ||
{5, {tensor5, "output"}}, | ||
}, | ||
}; | ||
} | ||
|
||
TEST(Graph, MutantGenerator) { | ||
auto graphTopo = TestInGraphBuild().build(); | ||
fmt::println("{}", graphTopo.topology.toString()); | ||
Graph g(std::move(graphTopo)); | ||
// create mutant generator | ||
MutantGenerator mutant; | ||
OpVec oplist = {std::make_shared<MatMulBox>(), std::make_shared<ConcatBox>()}; | ||
mutant.init(1.0, 3, oplist); | ||
std::vector<Graph> outGraph = {}; | ||
mutant.run(std::move(g), outGraph); | ||
for (size_t i = 0; i < outGraph.size(); ++i) { | ||
fmt::println("{}", outGraph[i].internal().linked().toString()); | ||
} | ||
} | ||
}// namespace refactor::computation |