Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Commit

Permalink
Implementation for Graph Transforms
Browse files Browse the repository at this point in the history
Summary: The Implementation of Graph Transformations, with the PatternMatch and ReplaceMatch rules.

Reviewed By: akyrola

Differential Revision: D5404144

fbshipit-source-id: 2bab68e6bff2e841ea9fb64df5d92ea945e704af
  • Loading branch information
benzyx authored and facebook-github-bot committed Jul 21, 2017
1 parent f5bbac6 commit 7df0d66
Show file tree
Hide file tree
Showing 7 changed files with 482 additions and 4 deletions.
1 change: 1 addition & 0 deletions caffe2/contrib/transform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
if(USE_TRANSFORMS)
message(STATUS "Include Graph Transformations")
set(Caffe2_CONTRIB_TRANSFORMS_CPU_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/transform.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/graph.cc"
)

Expand Down
16 changes: 16 additions & 0 deletions caffe2/contrib/transform/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ NetDef Graph::GetNetDef() {
return netdef;
}

void Graph::DeactivateSubgraph(std::vector<int> subgraph) {
for (int idx : subgraph) {
// remove all edges connected to inactive node
for (const auto& edge : node(idx).parents) {
int parent = edge.first;
node(parent).children.erase(idx);
}
for (const auto& edge : node(idx).children) {
int child = edge.first;
node(child).parents.erase(idx);
}
// actually mark flags as false
node(idx).active = false;
}
}

} // namespace transform

} // namespace caffe2
18 changes: 17 additions & 1 deletion caffe2/contrib/transform/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,26 @@ namespace transform {
*/
struct Node {
public:
// Empty constructor for resize
Node() {}

// Alternate constructor
Node(
const OperatorDef& op,
bool active,
std::map<int, string> parents,
std::map<int, string> children)
: op(op), active(active), parents(parents), children(children) {}

// The OperatorDef which this node represents.
OperatorDef op;

// Keeps track of if an operator has been deleted through a transformation.
bool active = true;

// Stores a pair (idx, blob), the index of the child, and the blob of edge.
std::map<int, string> children;
std::map<int, string> parents;
std::map<int, string> children;
};

/**
Expand Down Expand Up @@ -83,6 +94,11 @@ struct Graph {
*/
NetDef GetNetDef();

/**
* Deactivate a subgraph, and get rid of all edges into this subgraph.
*/
void DeactivateSubgraph(std::vector<int> subgraph);

const size_t size() const {
return nodes_.size();
}
Expand Down
3 changes: 0 additions & 3 deletions caffe2/contrib/transform/graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,20 @@ class DummyOp final : public OperatorBase {
};

REGISTER_CPU_OPERATOR(DummyOp1, DummyOp);
REGISTER_CUDA_OPERATOR(DummyOp1, DummyOp);

OPERATOR_SCHEMA(DummyOp1)
.NumInputs(0, INT_MAX)
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});

REGISTER_CPU_OPERATOR(DummyOp2, DummyOp);
REGISTER_CUDA_OPERATOR(DummyOp2, DummyOp);

OPERATOR_SCHEMA(DummyOp2)
.NumInputs(0, INT_MAX)
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});

REGISTER_CPU_OPERATOR(DummyOp3, DummyOp);
REGISTER_CUDA_OPERATOR(DummyOp3, DummyOp);

OPERATOR_SCHEMA(DummyOp3)
.NumInputs(0, INT_MAX)
Expand Down
103 changes: 103 additions & 0 deletions caffe2/contrib/transform/transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include "caffe2/contrib/transform/transform.h"

#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/net.h"
#include "caffe2/proto/caffe2.pb.h"

namespace caffe2 {

using transform::Graph;

CAFFE_DEFINE_REGISTRY(TransformRegistry, Transform);

std::vector<std::vector<int>> Transform::PatternMatch(const Graph& graph) {
std::vector<std::vector<int>> matches;

// Consider every possible node as the starting point.
for (int idx = 0; idx < graph.size(); ++idx) {
// The current working subgraph. We will try to add new nodes to this,
// when invoking the PatternRule.
std::vector<int> subgraph;

// The largest "validated" subgraph found so far.
// This will be mutated by PatternMatchHelper.
std::vector<int> best_subgraph;

// Only begin to match if the start node is accepted.
if (PatternRule(graph, subgraph, idx)) {
subgraph.push_back(idx);
PatternMatchHelper(graph, &subgraph, &best_subgraph);
subgraph.pop_back();
}
if (best_subgraph.size() > 0) { // match found
matches.push_back(best_subgraph);
}
}
return matches;
}

void Transform::TryNeighbors(
const Graph& graph,
const std::map<int, string>& neighbors,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr) {
auto& subgraph = *subgraph_ptr;
for (const auto& edge : neighbors) {
int j = edge.first;
if (std::find(subgraph.begin(), subgraph.end(), j) == subgraph.end()) {
if (PatternRule(graph, subgraph, j)) {
subgraph.push_back(j);
PatternMatchHelper(graph, subgraph_ptr, best_subgraph_ptr);
subgraph.pop_back();
}
}
}
}

void Transform::PatternMatchHelper(
const Graph& graph,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr) {
CHECK(subgraph_ptr);
auto& subgraph = *subgraph_ptr;
CHECK(best_subgraph_ptr);
auto& best_subgraph = *best_subgraph_ptr;

// If the current subgraph is valid, and the largest we've seen so far,
// make it the best_subgraph.
if (ValidatorRule(graph, subgraph) &&
subgraph.size() > best_subgraph.size()) {
best_subgraph = subgraph;
}

// Try adding each parent and child of every node in the subgraph,
// and see if we can accept it.
for (int i : subgraph) {
TryNeighbors(
graph, graph.node(i).children, subgraph_ptr, best_subgraph_ptr);
TryNeighbors(graph, graph.node(i).parents, subgraph_ptr, best_subgraph_ptr);
}
}

void Transform::ReplacePattern(
const std::vector<vector<int>>& matches,
Graph* graph) {
// Simply try to apply the replace rule upon every match.
for (const auto& match : matches) {
if (!ReplaceRule(match, graph)) {
CAFFE_THROW("Replace failed!");
}
}
}

// The simple interface - performs the transformation upon a NetDef, and returns
// the result.
NetDef Transform::ApplyTo(const NetDef& orig_net) {
Graph g(orig_net);
const auto matches = PatternMatch(g);
ReplacePattern(matches, &g);
return g.GetNetDef();
}

} // namespace Caffe2
113 changes: 113 additions & 0 deletions caffe2/contrib/transform/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#pragma once

#include "caffe2/contrib/transform/graph.h"
#include "caffe2/core/common.h"
#include "caffe2/proto/caffe2.pb.h"
#include "caffe2/utils/proto_utils.h"

namespace caffe2 {

/**
* The Transform Base Object
*
* A Transform is an operation which manipulates a Caffe2 NetDef.
* You can consider it as a function: Transform.ApplyTo(NetDef) -> NetDef
*
* A Transform Operation does 4 things:
* 1) Creates a Graph object from a NetDef, which stores connections.
* 2) Pattern Matches on the Graph, to find subgraphs it wants to change.
* 3) Replaces the subgraphs that it's matched with new operators.
* 4) Creates a NetDef from the changed Graph, and returns it.
*
* The effect of a Transform is defined by its 3 protected virtual functions.
* 1) PatternRule determines for an ordered subgraph and a node, whether to
* consider adding the node to the subgraph.
* 2) ValidatorRule determines, for an ordered subgraph, whether it is a
* match.
* 3) ReplaceRule mutates the graph, based on a matched subgraph.
*
* This is the base class for all derived classes to base off. To create your
* own transform, write your implementations for PatternRule, ValidatorRule, and
* ReplaceRule.
*/
class Transform {
public:
Transform() {}

/**
* Apply a Transform onto a NetDef.
* Returns the transformed NetDef.
*/
NetDef ApplyTo(const NetDef& orig_net_def);

virtual ~Transform() {}

/**
* Generates all matches (stored as ordered subgraphs) and returns them.
*
* A match is stored as vector<int>, which is a mapping to OperatorDefs
* in Graph. The order matters.
*/
std::vector<std::vector<int>> PatternMatch(const transform::Graph& graph);

/**
* Applies the replace rule onto each of the matches found.
*/
void ReplacePattern(
const std::vector<std::vector<int>>& matches,
transform::Graph* graph);

protected:
/**
* The PatternRule essentially answers:
* Given the current subgraph (ordered), should we append the new node at idx?
*/
virtual bool PatternRule(
const transform::Graph& g,
const std::vector<int>& subgraph,
int idx) {
CAFFE_NOT_IMPLEMENTED;
}

/**
* The ValidatorRule essentially answers:
* Given a subgraph, can we accept it?
*/
virtual bool ValidatorRule(
const transform::Graph& g,
const std::vector<int>& subgraph) {
CAFFE_NOT_IMPLEMENTED;
}

/**
* The ReplaceRule actually mutates the graph, and applies the transformation
* upon the subgraph.
*/
virtual bool ReplaceRule(
const std::vector<int>& subgraph,
transform::Graph* g_ptr) {
CAFFE_NOT_IMPLEMENTED;
}

private:
/**
* A helper function for PatternMatch, which keeps track of the best subgraph
* so far.
*/
void PatternMatchHelper(
const transform::Graph& graph,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr);
/**
* Attempts to append each neighbor to the end of the subgraph.
*/
void TryNeighbors(
const transform::Graph& graph,
const std::map<int, string>& neighbors,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr);
};

CAFFE_DECLARE_REGISTRY(TransformRegistry, Transform);

} // namespace
Loading

0 comments on commit 7df0d66

Please sign in to comment.