This repository has been archived by the owner on Feb 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: The Implementation of Graph Transformations, with the PatternMatch and ReplaceMatch rules. Reviewed By: akyrola Differential Revision: D5404144 fbshipit-source-id: 2bab68e6bff2e841ea9fb64df5d92ea945e704af
- Loading branch information
1 parent
f5bbac6
commit 7df0d66
Showing
7 changed files
with
482 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#include "caffe2/contrib/transform/transform.h" | ||
|
||
#include "caffe2/core/common.h" | ||
#include "caffe2/core/logging.h" | ||
#include "caffe2/core/net.h" | ||
#include "caffe2/proto/caffe2.pb.h" | ||
|
||
namespace caffe2 { | ||
|
||
using transform::Graph; | ||
|
||
CAFFE_DEFINE_REGISTRY(TransformRegistry, Transform); | ||
|
||
std::vector<std::vector<int>> Transform::PatternMatch(const Graph& graph) { | ||
std::vector<std::vector<int>> matches; | ||
|
||
// Consider every possible node as the starting point. | ||
for (int idx = 0; idx < graph.size(); ++idx) { | ||
// The current working subgraph. We will try to add new nodes to this, | ||
// when invoking the PatternRule. | ||
std::vector<int> subgraph; | ||
|
||
// The largest "validated" subgraph found so far. | ||
// This will be mutated by PatternMatchHelper. | ||
std::vector<int> best_subgraph; | ||
|
||
// Only begin to match if the start node is accepted. | ||
if (PatternRule(graph, subgraph, idx)) { | ||
subgraph.push_back(idx); | ||
PatternMatchHelper(graph, &subgraph, &best_subgraph); | ||
subgraph.pop_back(); | ||
} | ||
if (best_subgraph.size() > 0) { // match found | ||
matches.push_back(best_subgraph); | ||
} | ||
} | ||
return matches; | ||
} | ||
|
||
void Transform::TryNeighbors( | ||
const Graph& graph, | ||
const std::map<int, string>& neighbors, | ||
std::vector<int>* subgraph_ptr, | ||
std::vector<int>* best_subgraph_ptr) { | ||
auto& subgraph = *subgraph_ptr; | ||
for (const auto& edge : neighbors) { | ||
int j = edge.first; | ||
if (std::find(subgraph.begin(), subgraph.end(), j) == subgraph.end()) { | ||
if (PatternRule(graph, subgraph, j)) { | ||
subgraph.push_back(j); | ||
PatternMatchHelper(graph, subgraph_ptr, best_subgraph_ptr); | ||
subgraph.pop_back(); | ||
} | ||
} | ||
} | ||
} | ||
|
||
void Transform::PatternMatchHelper( | ||
const Graph& graph, | ||
std::vector<int>* subgraph_ptr, | ||
std::vector<int>* best_subgraph_ptr) { | ||
CHECK(subgraph_ptr); | ||
auto& subgraph = *subgraph_ptr; | ||
CHECK(best_subgraph_ptr); | ||
auto& best_subgraph = *best_subgraph_ptr; | ||
|
||
// If the current subgraph is valid, and the largest we've seen so far, | ||
// make it the best_subgraph. | ||
if (ValidatorRule(graph, subgraph) && | ||
subgraph.size() > best_subgraph.size()) { | ||
best_subgraph = subgraph; | ||
} | ||
|
||
// Try adding each parent and child of every node in the subgraph, | ||
// and see if we can accept it. | ||
for (int i : subgraph) { | ||
TryNeighbors( | ||
graph, graph.node(i).children, subgraph_ptr, best_subgraph_ptr); | ||
TryNeighbors(graph, graph.node(i).parents, subgraph_ptr, best_subgraph_ptr); | ||
} | ||
} | ||
|
||
void Transform::ReplacePattern( | ||
const std::vector<vector<int>>& matches, | ||
Graph* graph) { | ||
// Simply try to apply the replace rule upon every match. | ||
for (const auto& match : matches) { | ||
if (!ReplaceRule(match, graph)) { | ||
CAFFE_THROW("Replace failed!"); | ||
} | ||
} | ||
} | ||
|
||
// The simple interface - performs the transformation upon a NetDef, and returns | ||
// the result. | ||
NetDef Transform::ApplyTo(const NetDef& orig_net) { | ||
Graph g(orig_net); | ||
const auto matches = PatternMatch(g); | ||
ReplacePattern(matches, &g); | ||
return g.GetNetDef(); | ||
} | ||
|
||
} // namespace Caffe2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
#pragma once | ||
|
||
#include "caffe2/contrib/transform/graph.h" | ||
#include "caffe2/core/common.h" | ||
#include "caffe2/proto/caffe2.pb.h" | ||
#include "caffe2/utils/proto_utils.h" | ||
|
||
namespace caffe2 { | ||
|
||
/** | ||
* The Transform Base Object | ||
* | ||
* A Transform is an operation which manipulates a Caffe2 NetDef. | ||
* You can consider it as a function: Transform.ApplyTo(NetDef) -> NetDef | ||
* | ||
* A Transform Operation does 4 things: | ||
* 1) Creates a Graph object from a NetDef, which stores connections. | ||
* 2) Pattern Matches on the Graph, to find subgraphs it wants to change. | ||
* 3) Replaces the subgraphs that it's matched with new operators. | ||
* 4) Creates a NetDef from the changed Graph, and returns it. | ||
* | ||
* The effect of a Transform is defined by its 3 protected virtual functions. | ||
* 1) PatternRule determines for an ordered subgraph and a node, whether to | ||
* consider adding the node to the subgraph. | ||
* 2) ValidatorRule determines, for an ordered subgraph, whether it is a | ||
* match. | ||
* 3) ReplaceRule mutates the graph, based on a matched subgraph. | ||
* | ||
* This is the base class for all derived classes to base off. To create your | ||
* own transform, write your implementations for PatternRule, ValidatorRule, and | ||
* ReplaceRule. | ||
*/ | ||
class Transform { | ||
public: | ||
Transform() {} | ||
|
||
/** | ||
* Apply a Transform onto a NetDef. | ||
* Returns the transformed NetDef. | ||
*/ | ||
NetDef ApplyTo(const NetDef& orig_net_def); | ||
|
||
virtual ~Transform() {} | ||
|
||
/** | ||
* Generates all matches (stored as ordered subgraphs) and returns them. | ||
* | ||
* A match is stored as vector<int>, which is a mapping to OperatorDefs | ||
* in Graph. The order matters. | ||
*/ | ||
std::vector<std::vector<int>> PatternMatch(const transform::Graph& graph); | ||
|
||
/** | ||
* Applies the replace rule onto each of the matches found. | ||
*/ | ||
void ReplacePattern( | ||
const std::vector<std::vector<int>>& matches, | ||
transform::Graph* graph); | ||
|
||
protected: | ||
/** | ||
* The PatternRule essentially answers: | ||
* Given the current subgraph (ordered), should we append the new node at idx? | ||
*/ | ||
virtual bool PatternRule( | ||
const transform::Graph& g, | ||
const std::vector<int>& subgraph, | ||
int idx) { | ||
CAFFE_NOT_IMPLEMENTED; | ||
} | ||
|
||
/** | ||
* The ValidatorRule essentially answers: | ||
* Given a subgraph, can we accept it? | ||
*/ | ||
virtual bool ValidatorRule( | ||
const transform::Graph& g, | ||
const std::vector<int>& subgraph) { | ||
CAFFE_NOT_IMPLEMENTED; | ||
} | ||
|
||
/** | ||
* The ReplaceRule actually mutates the graph, and applies the transformation | ||
* upon the subgraph. | ||
*/ | ||
virtual bool ReplaceRule( | ||
const std::vector<int>& subgraph, | ||
transform::Graph* g_ptr) { | ||
CAFFE_NOT_IMPLEMENTED; | ||
} | ||
|
||
private: | ||
/** | ||
* A helper function for PatternMatch, which keeps track of the best subgraph | ||
* so far. | ||
*/ | ||
void PatternMatchHelper( | ||
const transform::Graph& graph, | ||
std::vector<int>* subgraph_ptr, | ||
std::vector<int>* best_subgraph_ptr); | ||
/** | ||
* Attempts to append each neighbor to the end of the subgraph. | ||
*/ | ||
void TryNeighbors( | ||
const transform::Graph& graph, | ||
const std::map<int, string>& neighbors, | ||
std::vector<int>* subgraph_ptr, | ||
std::vector<int>* best_subgraph_ptr); | ||
}; | ||
|
||
CAFFE_DECLARE_REGISTRY(TransformRegistry, Transform); | ||
|
||
} // namespace |
Oops, something went wrong.