Skip to content

Commit

Permalink
Implement NetDef <--> JIT IR converters. (pytorch#16967)
Browse files Browse the repository at this point in the history
Summary:
Currently the converters are very straightforward, i.e. there is no code for trying to
preserve semantics, we're purely perform conversion from one format to another.

Two things that we might want to add/change:
1. Add semantic conversion as well (but probably it would be a good idea to keep
it separate as a temporary thing).
2. Make sure we don't mess with value names, as they are crucial for current
uses of NetDefs.
Pull Request resolved: pytorch#16967

Differential Revision: D14062537

Pulled By: ZolotukhinM

fbshipit-source-id: 88b184ee7276779e5e9152b149d69857515ad98a
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Feb 14, 2019
1 parent decc089 commit d25fee3
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 0 deletions.
3 changes: 3 additions & 0 deletions test/cpp/jit/gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <test/cpp/jit/test_alias_analysis.h>
#include <test/cpp/jit/test_misc.h>
#include <test/cpp/jit/test_netdef_converter.h>

using namespace torch;
using namespace torch::jit;
Expand Down Expand Up @@ -34,6 +35,8 @@ JIT_TEST(SubgraphUtils)
JIT_TEST(AliasAnalysis)
JIT_TEST(AliasTracker)

JIT_TEST(NetDefConverter)

JIT_TEST(THNNConv)
JIT_TEST(ATenNativeBatchNorm)

Expand Down
2 changes: 2 additions & 0 deletions test/cpp/jit/no-gtest.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <test/cpp/jit/test_alias_analysis.h>
#include <test/cpp/jit/test_misc.h>
#include <test/cpp/jit/test_netdef_converter.h>

#include <sstream>
#include <string>
Expand Down Expand Up @@ -37,6 +38,7 @@ std::string runJITCPPTests() {
testRegisterFusionCachesKernel();
testAliasAnalysis();
testAliasTracker();
testNetDefConverter(out);
return out.str();
}

Expand Down
146 changes: 146 additions & 0 deletions test/cpp/jit/test_netdef_converter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#pragma once

#include <torch/csrc/jit/netdef_converter.h>
#include "test/cpp/jit/test_base.h"

#include <sstream>
#include <string>

namespace torch {
namespace jit {

void testNetDefConverter(std::ostream& out = std::cout) {
{
// Check a simple net conversion back and forth.

// Create a simple graph:
// graph(%0 : Tensor
// %1 : Tensor) {
// %2 : Tensor = aten::mul(%0, %1)
// %3 : int = prim::Constant[value=1]()
// %4 : Tensor = aten::add(%0, %2, %3)
// return (%2, %4);
// }
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->addInput();
auto c = graph->insert(aten::mul, {a, b});
auto d = graph->insert(aten::add, {a, c});
graph->registerOutput(c);
graph->registerOutput(d);

// Convert it to netdef and check the result
caffe2::NetDef net;
convertIRToNetDef(&net, *graph);
AT_ASSERT(net.op().size() == 3);
AT_ASSERT(net.external_input().size() == 2);
AT_ASSERT(net.external_output().size() == 2);

const caffe2::OperatorDef& MulOp = net.op().Get(0);
AT_ASSERT(MulOp.input().size() == 2);
AT_ASSERT(MulOp.input().Get(0) == net.external_input().Get(0));
AT_ASSERT(MulOp.input().Get(1) == net.external_input().Get(1));
AT_ASSERT(MulOp.output().size() == 1);

const caffe2::OperatorDef& ConstNode = net.op().Get(1);
AT_ASSERT(ConstNode.input().size() == 0);
AT_ASSERT(ConstNode.output().size() == 1);
AT_ASSERT(ConstNode.arg().size() == 1);
AT_ASSERT(ConstNode.arg().Get(0).name() == "value");
AT_ASSERT(ConstNode.arg().Get(0).i() == 1);

const caffe2::OperatorDef& AddOp = net.op().Get(2);
AT_ASSERT(AddOp.input().size() == 3);
AT_ASSERT(AddOp.input().Get(0) == net.external_input().Get(0));
AT_ASSERT(AddOp.input().Get(1) == MulOp.output().Get(0));
AT_ASSERT(AddOp.input().Get(2) == ConstNode.output().Get(0));

AT_ASSERT(net.external_output().Get(0) == MulOp.output().Get(0));
AT_ASSERT(net.external_output().Get(1) == AddOp.output().Get(0));

// Convert NetDef back to IR and check if we get the original.
Graph graph2;
std::unordered_map<std::string, Value*> vmap;
convertNetDefToIR(net, &graph2, &vmap);

Node* mul = graph2.outputs()[0]->node();
Node* add = graph2.outputs()[1]->node();
AT_ASSERT(mul->kind() == c->node()->kind());
AT_ASSERT(add->kind() == d->node()->kind());
AT_ASSERT(mul->inputs()[0] == graph2.inputs()[0]);
AT_ASSERT(mul->inputs()[1] == graph2.inputs()[1]);
AT_ASSERT(add->inputs()[0] == graph2.inputs()[0]);
AT_ASSERT(add->inputs()[1] == graph2.outputs()[0]);
}
{
// Check attributes conversion
auto graph = std::make_shared<Graph>();
auto a = graph->addInput();
auto b = graph->addInput();
Node* node =
graph->create(Symbol::fromQualString("test::some_op"), {a, b}, 2);
graph->insertNode(node);

node->i_(Symbol::fromQualString("attr::i_attr"), 42);
node->f_(Symbol::fromQualString("attr::f_attr"), 3.0);
node->s_(Symbol::fromQualString("attr::s_attr"), "Hello!");

node->is_(Symbol::fromQualString("attr::is_attr"), {14, 18, 7});
node->fs_(Symbol::fromQualString("attr::fs_attr"), {2.72, 3.14});
node->ss_(Symbol::fromQualString("attr::ss_attr"), {"Winter", "Summer"});

graph->registerOutput(node->outputs()[0]);
graph->registerOutput(node->outputs()[1]);

// Convert it to netdef and check the result
caffe2::NetDef net;
convertIRToNetDef(&net, *graph);
const caffe2::OperatorDef& Op = net.op().Get(0);
AT_ASSERT(Op.arg().Get(0).name() == "i_attr");
AT_ASSERT(Op.arg().Get(0).i() == 42);
AT_ASSERT(Op.arg().Get(1).name() == "f_attr");
AT_ASSERT(Op.arg().Get(1).f() == 3.0);
AT_ASSERT(Op.arg().Get(2).name() == "s_attr");
AT_ASSERT(Op.arg().Get(2).s() == "Hello!");

AT_ASSERT(Op.arg().Get(3).name() == "is_attr");
AT_ASSERT(Op.arg().Get(3).ints().size() == 3);
AT_ASSERT(Op.arg().Get(3).ints().Get(0) == 14);
AT_ASSERT(Op.arg().Get(3).ints().Get(1) == 18);
AT_ASSERT(Op.arg().Get(3).ints().Get(2) == 7);

AT_ASSERT(Op.arg().Get(4).name() == "fs_attr");
AT_ASSERT(Op.arg().Get(4).floats().size() == 2);
AT_ASSERT(fabs(Op.arg().Get(4).floats().Get(0) - 2.72) < 0.001);

AT_ASSERT(Op.arg().Get(5).name() == "ss_attr");
AT_ASSERT(Op.arg().Get(5).strings().size() == 2);
AT_ASSERT(Op.arg().Get(5).strings().Get(1) == "Summer");

AT_ASSERT(net.external_output().Get(0) == Op.output().Get(0));
AT_ASSERT(net.external_output().Get(1) == Op.output().Get(1));

// Convert NetDef back to IR and check if we get the original.
Graph graph2;
std::unordered_map<std::string, Value*> vmap;
convertNetDefToIR(net, &graph2, &vmap);

AT_ASSERT(graph2.outputs()[0]->node() == graph2.outputs()[0]->node());
Node* n = graph2.outputs()[0]->node();
AT_ASSERT(n->i(Symbol::fromQualString("attr::i_attr")) == 42);
AT_ASSERT(n->f(Symbol::fromQualString("attr::f_attr")) == 3.0);
AT_ASSERT(n->s(Symbol::fromQualString("attr::s_attr")) == "Hello!");
AT_ASSERT(
n->is(Symbol::fromQualString("attr::is_attr")) ==
std::vector<long>({14, 18, 7}));
AT_ASSERT(
fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[0] - 2.72) < 0.001);
AT_ASSERT(
fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[1] - 3.14) < 0.001);
AT_ASSERT(
n->ss(Symbol::fromQualString("attr::ss_attr")) ==
std::vector<std::string>({"Winter", "Summer"}));
}
}
} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/constants.cpp
${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
${TORCH_SRC_DIR}/csrc/jit/ir.cpp
${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
${TORCH_SRC_DIR}/csrc/jit/caffe2_operator.cpp
${TORCH_SRC_DIR}/csrc/jit/register_c10_ops.cpp
Expand Down
194 changes: 194 additions & 0 deletions torch/csrc/jit/netdef_converter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#include <torch/csrc/jit/netdef_converter.h>

namespace torch {
namespace jit {

static AttributeKind getArgKind(const caffe2::Argument& arg) {
if (arg.has_i()) {
return AttributeKind::i;
} else if (arg.has_f()) {
return AttributeKind::f;
} else if (arg.has_s()) {
return AttributeKind::s;
} else if (arg.has_t()) {
return AttributeKind::t;
} else if (arg.has_n()) {
return AttributeKind::g;
} else if (arg.ints().size()) {
return AttributeKind::is;
} else if (arg.floats().size()) {
return AttributeKind::fs;
} else if (arg.strings().size()) {
return AttributeKind::ss;
} else if (arg.tensors().size()) {
return AttributeKind::ts;
} else if (arg.nets().size()) {
return AttributeKind::gs;
}
// Unknown type.
abort();
}

static void convertArg(const caffe2::Argument& arg, Node* node) {
std::string attrName = "attr::" + arg.name();
auto attrSymbol = Symbol::fromQualString(attrName);
AttributeKind kind = getArgKind(arg);
switch (kind) {
case AttributeKind::i: {
node->i_(attrSymbol, (long)arg.i());
break;
}
case AttributeKind::f: {
node->f_(attrSymbol, arg.f());
break;
}
case AttributeKind::s: {
node->s_(attrSymbol, arg.s());
break;
}
case AttributeKind::is: {
std::vector<long> is(arg.ints().begin(), arg.ints().end());
node->is_(attrSymbol, is);
break;
}
case AttributeKind::fs: {
std::vector<double> fs(arg.floats().begin(), arg.floats().end());
node->fs_(attrSymbol, fs);
break;
}
case AttributeKind::ss: {
std::vector<std::string> ss(arg.strings().begin(), arg.strings().end());
node->ss_(attrSymbol, ss);
break;
}
default: {
std::cout << "Unsupported type '" << toString(kind) << "' of attribute '"
<< attrName << "'"
<< " in node:" << std::endl;
node->dump();
abort();
}
}
}

void convertNetDefToIR(
const caffe2::NetDef& net,
Graph* g,
std::unordered_map<std::string, Value*>* valueMapPtr,
const std::string& prefix) {
std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
valueMap.clear();

for (const auto& inputName : net.external_input()) {
AT_ASSERT(!valueMap.count(inputName));
valueMap[inputName] = g->addInput();
}

for (const auto& op : net.op()) {
std::string name = prefix + op.type();
Node* node =
g->create(Symbol::fromQualString(name), {}, op.output().size());
g->insertNode(node);

for (const auto& input : op.input()) {
AT_ASSERT(valueMap.count(input));
node->addInput(valueMap[input]);
}
int idx = 0;
for (const auto& output : op.output()) {
// If output already exists in valueMap, overwrite it. This way we will
// have the last definition of a value named 'output' in valueMap.
valueMap[output] = node->outputs()[idx++];
}
for (const auto& arg : op.arg()) {
convertArg(arg, node);
}
}

for (const auto& outputName : net.external_output()) {
AT_ASSERT(valueMap.count(outputName));
g->registerOutput(valueMap.at(outputName));
}
}

static void convertAttrToCaffe2Arg(
const Node* node,
const Symbol& name,
caffe2::Argument* arg) {
arg->set_name(name.toUnqualString());
switch (node->kindOf(name)) {
case AttributeKind::i: {
arg->set_i(node->i(name));
break;
}
case AttributeKind::f: {
arg->set_f(node->f(name));
break;
}
case AttributeKind::s: {
arg->set_s(node->s(name));
break;
}
case AttributeKind::is: {
for (long i : node->is(name)) {
arg->add_ints(i);
}
break;
}
case AttributeKind::fs: {
for (double f : node->fs(name)) {
arg->add_floats(f);
}
break;
}
case AttributeKind::ss: {
for (const std::string& s : node->ss(name)) {
arg->add_strings(s);
}
break;
}
default: {
std::cout << "Unsupported type '" << toString(node->kindOf(name))
<< "' of attribute '" << name.toUnqualString() << "'"
<< " in node:" << std::endl;
node->dump();
abort();
}
}
}

static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) {
caffe2::OperatorDef op;
op.set_type(node->kind().toQualString());
for (const Value* input : node->inputs()) {
op.add_input(input->uniqueName());
}
for (const Value* output : node->outputs()) {
op.add_output(output->uniqueName());
}
std::vector<Symbol> names = node->attributeNames();
for (const Symbol& name : names) {
caffe2::Argument* arg = op.add_arg();
convertAttrToCaffe2Arg(node, name, arg);
}
*net->add_op() = op;
}

void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) {
net->mutable_op()->Clear();

for (const Value* value : g.inputs()) {
net->add_external_input(value->uniqueName());
}

for (const Node* node : g.nodes()) {
convertNodeToCaffe2Op(node, net);
}

for (const Value* value : g.outputs()) {
net->add_external_output(value->uniqueName());
}
}

} // namespace jit
} // namespace torch
Loading

0 comments on commit d25fee3

Please sign in to comment.