diff --git a/caffe2/contrib/transform/CMakeLists.txt b/caffe2/contrib/transform/CMakeLists.txt index 61621f89566..31081f3aab2 100644 --- a/caffe2/contrib/transform/CMakeLists.txt +++ b/caffe2/contrib/transform/CMakeLists.txt @@ -1,6 +1,7 @@ if(USE_TRANSFORMS) message(STATUS "Include Graph Transformations") set(Caffe2_CONTRIB_TRANSFORMS_CPU_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/transform.cc" "${CMAKE_CURRENT_SOURCE_DIR}/graph.cc" ) diff --git a/caffe2/contrib/transform/graph.cc b/caffe2/contrib/transform/graph.cc index 42e5e1bbc4f..7ee77535ae3 100644 --- a/caffe2/contrib/transform/graph.cc +++ b/caffe2/contrib/transform/graph.cc @@ -172,6 +172,22 @@ NetDef Graph::GetNetDef() { return netdef; } +void Graph::DeactivateSubgraph(std::vector subgraph) { + for (int idx : subgraph) { + // remove all edges connected to inactive node + for (const auto& edge : node(idx).parents) { + int parent = edge.first; + node(parent).children.erase(idx); + } + for (const auto& edge : node(idx).children) { + int child = edge.first; + node(child).parents.erase(idx); + } + // actually mark flags as false + node(idx).active = false; + } +} + } // namespace transform } // namespace caffe2 diff --git a/caffe2/contrib/transform/graph.h b/caffe2/contrib/transform/graph.h index 01e48ec3477..75c0ed713fe 100644 --- a/caffe2/contrib/transform/graph.h +++ b/caffe2/contrib/transform/graph.h @@ -17,6 +17,17 @@ namespace transform { */ struct Node { public: + // Empty constructor for resize + Node() {} + + // Alternate constructor + Node( + const OperatorDef& op, + bool active, + std::map parents, + std::map children) + : op(op), active(active), parents(parents), children(children) {} + // The OperatorDef which this node represents. OperatorDef op; @@ -24,8 +35,8 @@ struct Node { bool active = true; // Stores a pair (idx, blob), the index of the child, and the blob of edge. - std::map children; std::map parents; + std::map children; }; /** @@ -83,6 +94,11 @@ struct Graph { */ NetDef GetNetDef(); + /** + * Deactivate a subgraph, and get rid of all edges into this subgraph. + */ + void DeactivateSubgraph(std::vector subgraph); + const size_t size() const { return nodes_.size(); } diff --git a/caffe2/contrib/transform/graph_test.cc b/caffe2/contrib/transform/graph_test.cc index c2dd8c8aecb..67d937b0ebd 100644 --- a/caffe2/contrib/transform/graph_test.cc +++ b/caffe2/contrib/transform/graph_test.cc @@ -22,7 +22,6 @@ class DummyOp final : public OperatorBase { }; REGISTER_CPU_OPERATOR(DummyOp1, DummyOp); -REGISTER_CUDA_OPERATOR(DummyOp1, DummyOp); OPERATOR_SCHEMA(DummyOp1) .NumInputs(0, INT_MAX) @@ -30,7 +29,6 @@ OPERATOR_SCHEMA(DummyOp1) .AllowInplace({{0, 0}, {1, 1}}); REGISTER_CPU_OPERATOR(DummyOp2, DummyOp); -REGISTER_CUDA_OPERATOR(DummyOp2, DummyOp); OPERATOR_SCHEMA(DummyOp2) .NumInputs(0, INT_MAX) @@ -38,7 +36,6 @@ OPERATOR_SCHEMA(DummyOp2) .AllowInplace({{0, 0}, {1, 1}}); REGISTER_CPU_OPERATOR(DummyOp3, DummyOp); -REGISTER_CUDA_OPERATOR(DummyOp3, DummyOp); OPERATOR_SCHEMA(DummyOp3) .NumInputs(0, INT_MAX) diff --git a/caffe2/contrib/transform/transform.cc b/caffe2/contrib/transform/transform.cc new file mode 100644 index 00000000000..606803596f7 --- /dev/null +++ b/caffe2/contrib/transform/transform.cc @@ -0,0 +1,103 @@ +#include "caffe2/contrib/transform/transform.h" + +#include "caffe2/core/common.h" +#include "caffe2/core/logging.h" +#include "caffe2/core/net.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +using transform::Graph; + +CAFFE_DEFINE_REGISTRY(TransformRegistry, Transform); + +std::vector> Transform::PatternMatch(const Graph& graph) { + std::vector> matches; + + // Consider every possible node as the starting point. + for (int idx = 0; idx < graph.size(); ++idx) { + // The current working subgraph. We will try to add new nodes to this, + // when invoking the PatternRule. + std::vector subgraph; + + // The largest "validated" subgraph found so far. + // This will be mutated by PatternMatchHelper. + std::vector best_subgraph; + + // Only begin to match if the start node is accepted. + if (PatternRule(graph, subgraph, idx)) { + subgraph.push_back(idx); + PatternMatchHelper(graph, &subgraph, &best_subgraph); + subgraph.pop_back(); + } + if (best_subgraph.size() > 0) { // match found + matches.push_back(best_subgraph); + } + } + return matches; +} + +void Transform::TryNeighbors( + const Graph& graph, + const std::map& neighbors, + std::vector* subgraph_ptr, + std::vector* best_subgraph_ptr) { + auto& subgraph = *subgraph_ptr; + for (const auto& edge : neighbors) { + int j = edge.first; + if (std::find(subgraph.begin(), subgraph.end(), j) == subgraph.end()) { + if (PatternRule(graph, subgraph, j)) { + subgraph.push_back(j); + PatternMatchHelper(graph, subgraph_ptr, best_subgraph_ptr); + subgraph.pop_back(); + } + } + } +} + +void Transform::PatternMatchHelper( + const Graph& graph, + std::vector* subgraph_ptr, + std::vector* best_subgraph_ptr) { + CHECK(subgraph_ptr); + auto& subgraph = *subgraph_ptr; + CHECK(best_subgraph_ptr); + auto& best_subgraph = *best_subgraph_ptr; + + // If the current subgraph is valid, and the largest we've seen so far, + // make it the best_subgraph. + if (ValidatorRule(graph, subgraph) && + subgraph.size() > best_subgraph.size()) { + best_subgraph = subgraph; + } + + // Try adding each parent and child of every node in the subgraph, + // and see if we can accept it. + for (int i : subgraph) { + TryNeighbors( + graph, graph.node(i).children, subgraph_ptr, best_subgraph_ptr); + TryNeighbors(graph, graph.node(i).parents, subgraph_ptr, best_subgraph_ptr); + } +} + +void Transform::ReplacePattern( + const std::vector>& matches, + Graph* graph) { + // Simply try to apply the replace rule upon every match. + for (const auto& match : matches) { + if (!ReplaceRule(match, graph)) { + CAFFE_THROW("Replace failed!"); + } + } +} + +// The simple interface - performs the transformation upon a NetDef, and returns +// the result. +NetDef Transform::ApplyTo(const NetDef& orig_net) { + Graph g(orig_net); + const auto matches = PatternMatch(g); + ReplacePattern(matches, &g); + return g.GetNetDef(); +} + +} // namespace Caffe2 diff --git a/caffe2/contrib/transform/transform.h b/caffe2/contrib/transform/transform.h new file mode 100644 index 00000000000..b59e9714893 --- /dev/null +++ b/caffe2/contrib/transform/transform.h @@ -0,0 +1,113 @@ +#pragma once + +#include "caffe2/contrib/transform/graph.h" +#include "caffe2/core/common.h" +#include "caffe2/proto/caffe2.pb.h" +#include "caffe2/utils/proto_utils.h" + +namespace caffe2 { + +/** + * The Transform Base Object + * + * A Transform is an operation which manipulates a Caffe2 NetDef. + * You can consider it as a function: Transform.ApplyTo(NetDef) -> NetDef + * + * A Transform Operation does 4 things: + * 1) Creates a Graph object from a NetDef, which stores connections. + * 2) Pattern Matches on the Graph, to find subgraphs it wants to change. + * 3) Replaces the subgraphs that it's matched with new operators. + * 4) Creates a NetDef from the changed Graph, and returns it. + * + * The effect of a Transform is defined by its 3 protected virtual functions. + * 1) PatternRule determines for an ordered subgraph and a node, whether to + * consider adding the node to the subgraph. + * 2) ValidatorRule determines, for an ordered subgraph, whether it is a + * match. + * 3) ReplaceRule mutates the graph, based on a matched subgraph. + * + * This is the base class for all derived classes to base off. To create your + * own transform, write your implementations for PatternRule, ValidatorRule, and + * ReplaceRule. + */ +class Transform { + public: + Transform() {} + + /** + * Apply a Transform onto a NetDef. + * Returns the transformed NetDef. + */ + NetDef ApplyTo(const NetDef& orig_net_def); + + virtual ~Transform() {} + + /** + * Generates all matches (stored as ordered subgraphs) and returns them. + * + * A match is stored as vector, which is a mapping to OperatorDefs + * in Graph. The order matters. + */ + std::vector> PatternMatch(const transform::Graph& graph); + + /** + * Applies the replace rule onto each of the matches found. + */ + void ReplacePattern( + const std::vector>& matches, + transform::Graph* graph); + + protected: + /** + * The PatternRule essentially answers: + * Given the current subgraph (ordered), should we append the new node at idx? + */ + virtual bool PatternRule( + const transform::Graph& g, + const std::vector& subgraph, + int idx) { + CAFFE_NOT_IMPLEMENTED; + } + + /** + * The ValidatorRule essentially answers: + * Given a subgraph, can we accept it? + */ + virtual bool ValidatorRule( + const transform::Graph& g, + const std::vector& subgraph) { + CAFFE_NOT_IMPLEMENTED; + } + + /** + * The ReplaceRule actually mutates the graph, and applies the transformation + * upon the subgraph. + */ + virtual bool ReplaceRule( + const std::vector& subgraph, + transform::Graph* g_ptr) { + CAFFE_NOT_IMPLEMENTED; + } + + private: + /** + * A helper function for PatternMatch, which keeps track of the best subgraph + * so far. + */ + void PatternMatchHelper( + const transform::Graph& graph, + std::vector* subgraph_ptr, + std::vector* best_subgraph_ptr); + /** + * Attempts to append each neighbor to the end of the subgraph. + */ + void TryNeighbors( + const transform::Graph& graph, + const std::map& neighbors, + std::vector* subgraph_ptr, + std::vector* best_subgraph_ptr); +}; + +CAFFE_DECLARE_REGISTRY(TransformRegistry, Transform); + +} // namespace diff --git a/caffe2/contrib/transform/transform_test.cc b/caffe2/contrib/transform/transform_test.cc new file mode 100644 index 00000000000..99b1cf5715b --- /dev/null +++ b/caffe2/contrib/transform/transform_test.cc @@ -0,0 +1,232 @@ +#include +#include +#include "caffe2/contrib/transform/transform.h" +#include "caffe2/core/net.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +namespace { + +using transform::Graph; + +static std::atomic counter; + +class DummyOp final : public OperatorBase { + public: + using OperatorBase::OperatorBase; + bool Run(int /* unused */) override { + counter.fetch_add(1); + return true; + } +}; + +REGISTER_CPU_OPERATOR(TDummyOp1, DummyOp); + +OPERATOR_SCHEMA(TDummyOp1) + .NumInputs(0, INT_MAX) + .NumOutputs(0, INT_MAX) + .AllowInplace({{0, 0}, {1, 1}}); + +REGISTER_CPU_OPERATOR(TDummyOp2, DummyOp); + +OPERATOR_SCHEMA(TDummyOp2) + .NumInputs(0, INT_MAX) + .NumOutputs(0, INT_MAX) + .AllowInplace({{0, 0}, {1, 1}}); + +REGISTER_CPU_OPERATOR(TDummyOp3, DummyOp); + +OPERATOR_SCHEMA(TDummyOp3) + .NumInputs(0, INT_MAX) + .NumOutputs(0, INT_MAX) + .AllowInplace({{0, 0}, {1, 1}}); + +/** + * This dummy transform will find all subgraphs of shape (TDummyOp1 -> + * TDummyOp2) and replaces them with (TDummyOp3). Simple unit test. + */ +class DummyTransform : public Transform { + public: + // Finds all patterns of the form (TDummyOp1 -> TDummyOp2) + bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) + override { + if (subgraph.size() >= pattern_chain.size()) { + return false; + } + // which index are we trying to append the new node to? + int pattern_idx = subgraph.size(); + // type doesn't match + if (g.node(idx).op.type() != pattern_chain[pattern_idx]) { + return false; + } + // not that head, and doesn't have exactly 1 parent + if (pattern_idx > 0 && g.node(idx).parents.size() != 1) { + return false; + } + // not that tail, and doesn't have exactly 1 child + if (pattern_idx < pattern_chain.size() - 1 && + g.node(idx).children.size() != 1) { + return false; + } + + return true; + } + + // Checks if the subgraph matched is (TDummyOp1 -> TDummyOp2) + bool ValidatorRule(const Graph& g, const std::vector& subgraph) + override { + if (subgraph.size() == 2) { + if (g.node(subgraph[0]).op.type() == "TDummyOp1" && + g.node(subgraph[1]).op.type() == "TDummyOp2") { + return true; + } + } + return false; + } + + // Replaces a match of (TDummyOp1 -> TDummyOp2) with (TDummyOp3) + bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { + CHECK(g_ptr); + auto& g = *g_ptr; + OperatorDef new_op; + new_op.set_type("TDummyOp3"); + int new_idx = g.size(); + + std::map new_op_children; + std::map new_op_parents; + + // for each node parent in the head of the match, connect it to our new node + for (const auto& edge : g.node(match[0]).parents) { + int parent = edge.first; + string blob = edge.second; + g.node(parent).children[new_idx] = blob; + new_op_parents[parent] = blob; + } + for (const string& blob : g.node(match[0]).op.input()) { + new_op.add_input(blob); + } + + // for each child in the tail of the match, connect it to our new node + for (const auto& edge : g.node(match[1]).children) { + int child = edge.first; + string blob = edge.second; + g.node(child).parents[new_idx] = blob; + new_op_children[child] = blob; + } + for (const string& blob : g.node(match[1]).op.output()) { + new_op.add_output(blob); + } + + g.DeactivateSubgraph(match); + + g.push_node(transform::Node(new_op, true, new_op_parents, new_op_children)); + return true; + } + + private: + const std::vector pattern_chain = {"TDummyOp1", "TDummyOp2"}; +}; + +// Adds an operator def to a netdef. +// Returns the ptr, if you want to add anything extra (such as device_option) +OperatorDef* AddOp( + NetDef* netdef_ptr, + string op_type, + std::vector inputs, + std::vector outputs) { + CHECK(netdef_ptr); + auto& netdef = *netdef_ptr; + auto op_ptr = netdef.add_op(); + auto& op = *op_ptr; + op.set_type(op_type); + for (const string& inp : inputs) { + op.add_input(inp); + } + for (const string& outp : outputs) { + op.add_output(outp); + } + return op_ptr; +} + +TEST(TransformTest, TestPatternMatch) { + Workspace ws; + ws.CreateBlob("in"); + NetDef netdef; + + AddOp(&netdef, "TDummyOp1", {"in"}, {"mid1"}); + AddOp(&netdef, "TDummyOp2", {"mid1"}, {"mid2"}); + AddOp(&netdef, "TDummyOp1", {"mid2"}, {"mid3"}); + AddOp(&netdef, "TDummyOp2", {"mid3"}, {"out"}); + + DummyTransform t; + Graph g(netdef); + auto matches = t.PatternMatch(g); + + EXPECT_EQ(matches.size(), 2); + EXPECT_EQ(matches[0][0], 0); + EXPECT_EQ(matches[0][1], 1); + EXPECT_EQ(matches[1][0], 2); + EXPECT_EQ(matches[1][1], 3); +} + +TEST(TransformTest, TestReplacePattern) { + Workspace ws; + ws.CreateBlob("in"); + NetDef netdef; + + AddOp(&netdef, "TDummyOp1", {"in"}, {"mid1"}); + AddOp(&netdef, "TDummyOp2", {"mid1"}, {"mid2"}); + AddOp(&netdef, "TDummyOp1", {"mid2"}, {"mid3"}); + AddOp(&netdef, "TDummyOp2", {"mid3"}, {"out"}); + + DummyTransform t; + Graph g(netdef); + std::vector> matches = {{0, 1}, {2, 3}}; + t.ReplacePattern(matches, &g); + + EXPECT_EQ(g.size(), 6); + EXPECT_FALSE(g.is_node_active(0)); + EXPECT_FALSE(g.is_node_active(1)); + EXPECT_FALSE(g.is_node_active(2)); + EXPECT_FALSE(g.is_node_active(3)); + EXPECT_TRUE(g.is_node_active(4)); + EXPECT_TRUE(g.is_node_active(5)); + + EXPECT_EQ(g.node(4).children.size(), 1); + EXPECT_EQ(g.node(4).parents.size(), 0); + EXPECT_TRUE(g.node(4).children.count(5)); + + NetDef replaced_netdef = g.GetNetDef(); + + EXPECT_EQ(replaced_netdef.op().size(), 2); + EXPECT_EQ(replaced_netdef.op(0).type(), "TDummyOp3"); + EXPECT_EQ(replaced_netdef.op(0).input(0), "in"); + EXPECT_EQ(replaced_netdef.op(1).type(), "TDummyOp3"); + EXPECT_EQ(replaced_netdef.op(1).output(0), "out"); +} + +TEST(TransformTest, TestTransformApply) { + Workspace ws; + ws.CreateBlob("in"); + NetDef netdef; + + AddOp(&netdef, "TDummyOp1", {"in"}, {"mid1"}); + AddOp(&netdef, "TDummyOp2", {"mid1"}, {"mid2"}); + AddOp(&netdef, "TDummyOp1", {"mid2"}, {"mid3"}); + AddOp(&netdef, "TDummyOp2", {"mid3"}, {"out"}); + + DummyTransform t; + + NetDef replaced_netdef = t.ApplyTo(netdef); + + EXPECT_EQ(replaced_netdef.op().size(), 2); + EXPECT_EQ(replaced_netdef.op(0).type(), "TDummyOp3"); + EXPECT_EQ(replaced_netdef.op(0).input(0), "in"); + EXPECT_EQ(replaced_netdef.op(1).type(), "TDummyOp3"); + EXPECT_EQ(replaced_netdef.op(1).output(0), "out"); +} + +} // namespace + +} // namespace Caffe2