forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement NetDef <--> JIT IR converters. (pytorch#16967)
Summary: Currently the converters are very straightforward, i.e. there is no code for trying to preserve semantics, we're purely perform conversion from one format to another. Two things that we might want to add/change: 1. Add semantic conversion as well (but probably it would be a good idea to keep it separate as a temporary thing). 2. Make sure we don't mess with value names, as they are crucial for current uses of NetDefs. Pull Request resolved: pytorch#16967 Differential Revision: D14062537 Pulled By: ZolotukhinM fbshipit-source-id: 88b184ee7276779e5e9152b149d69857515ad98a
- Loading branch information
1 parent
decc089
commit d25fee3
Showing
6 changed files
with
384 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/netdef_converter.h> | ||
#include "test/cpp/jit/test_base.h" | ||
|
||
#include <sstream> | ||
#include <string> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
void testNetDefConverter(std::ostream& out = std::cout) { | ||
{ | ||
// Check a simple net conversion back and forth. | ||
|
||
// Create a simple graph: | ||
// graph(%0 : Tensor | ||
// %1 : Tensor) { | ||
// %2 : Tensor = aten::mul(%0, %1) | ||
// %3 : int = prim::Constant[value=1]() | ||
// %4 : Tensor = aten::add(%0, %2, %3) | ||
// return (%2, %4); | ||
// } | ||
auto graph = std::make_shared<Graph>(); | ||
auto a = graph->addInput(); | ||
auto b = graph->addInput(); | ||
auto c = graph->insert(aten::mul, {a, b}); | ||
auto d = graph->insert(aten::add, {a, c}); | ||
graph->registerOutput(c); | ||
graph->registerOutput(d); | ||
|
||
// Convert it to netdef and check the result | ||
caffe2::NetDef net; | ||
convertIRToNetDef(&net, *graph); | ||
AT_ASSERT(net.op().size() == 3); | ||
AT_ASSERT(net.external_input().size() == 2); | ||
AT_ASSERT(net.external_output().size() == 2); | ||
|
||
const caffe2::OperatorDef& MulOp = net.op().Get(0); | ||
AT_ASSERT(MulOp.input().size() == 2); | ||
AT_ASSERT(MulOp.input().Get(0) == net.external_input().Get(0)); | ||
AT_ASSERT(MulOp.input().Get(1) == net.external_input().Get(1)); | ||
AT_ASSERT(MulOp.output().size() == 1); | ||
|
||
const caffe2::OperatorDef& ConstNode = net.op().Get(1); | ||
AT_ASSERT(ConstNode.input().size() == 0); | ||
AT_ASSERT(ConstNode.output().size() == 1); | ||
AT_ASSERT(ConstNode.arg().size() == 1); | ||
AT_ASSERT(ConstNode.arg().Get(0).name() == "value"); | ||
AT_ASSERT(ConstNode.arg().Get(0).i() == 1); | ||
|
||
const caffe2::OperatorDef& AddOp = net.op().Get(2); | ||
AT_ASSERT(AddOp.input().size() == 3); | ||
AT_ASSERT(AddOp.input().Get(0) == net.external_input().Get(0)); | ||
AT_ASSERT(AddOp.input().Get(1) == MulOp.output().Get(0)); | ||
AT_ASSERT(AddOp.input().Get(2) == ConstNode.output().Get(0)); | ||
|
||
AT_ASSERT(net.external_output().Get(0) == MulOp.output().Get(0)); | ||
AT_ASSERT(net.external_output().Get(1) == AddOp.output().Get(0)); | ||
|
||
// Convert NetDef back to IR and check if we get the original. | ||
Graph graph2; | ||
std::unordered_map<std::string, Value*> vmap; | ||
convertNetDefToIR(net, &graph2, &vmap); | ||
|
||
Node* mul = graph2.outputs()[0]->node(); | ||
Node* add = graph2.outputs()[1]->node(); | ||
AT_ASSERT(mul->kind() == c->node()->kind()); | ||
AT_ASSERT(add->kind() == d->node()->kind()); | ||
AT_ASSERT(mul->inputs()[0] == graph2.inputs()[0]); | ||
AT_ASSERT(mul->inputs()[1] == graph2.inputs()[1]); | ||
AT_ASSERT(add->inputs()[0] == graph2.inputs()[0]); | ||
AT_ASSERT(add->inputs()[1] == graph2.outputs()[0]); | ||
} | ||
{ | ||
// Check attributes conversion | ||
auto graph = std::make_shared<Graph>(); | ||
auto a = graph->addInput(); | ||
auto b = graph->addInput(); | ||
Node* node = | ||
graph->create(Symbol::fromQualString("test::some_op"), {a, b}, 2); | ||
graph->insertNode(node); | ||
|
||
node->i_(Symbol::fromQualString("attr::i_attr"), 42); | ||
node->f_(Symbol::fromQualString("attr::f_attr"), 3.0); | ||
node->s_(Symbol::fromQualString("attr::s_attr"), "Hello!"); | ||
|
||
node->is_(Symbol::fromQualString("attr::is_attr"), {14, 18, 7}); | ||
node->fs_(Symbol::fromQualString("attr::fs_attr"), {2.72, 3.14}); | ||
node->ss_(Symbol::fromQualString("attr::ss_attr"), {"Winter", "Summer"}); | ||
|
||
graph->registerOutput(node->outputs()[0]); | ||
graph->registerOutput(node->outputs()[1]); | ||
|
||
// Convert it to netdef and check the result | ||
caffe2::NetDef net; | ||
convertIRToNetDef(&net, *graph); | ||
const caffe2::OperatorDef& Op = net.op().Get(0); | ||
AT_ASSERT(Op.arg().Get(0).name() == "i_attr"); | ||
AT_ASSERT(Op.arg().Get(0).i() == 42); | ||
AT_ASSERT(Op.arg().Get(1).name() == "f_attr"); | ||
AT_ASSERT(Op.arg().Get(1).f() == 3.0); | ||
AT_ASSERT(Op.arg().Get(2).name() == "s_attr"); | ||
AT_ASSERT(Op.arg().Get(2).s() == "Hello!"); | ||
|
||
AT_ASSERT(Op.arg().Get(3).name() == "is_attr"); | ||
AT_ASSERT(Op.arg().Get(3).ints().size() == 3); | ||
AT_ASSERT(Op.arg().Get(3).ints().Get(0) == 14); | ||
AT_ASSERT(Op.arg().Get(3).ints().Get(1) == 18); | ||
AT_ASSERT(Op.arg().Get(3).ints().Get(2) == 7); | ||
|
||
AT_ASSERT(Op.arg().Get(4).name() == "fs_attr"); | ||
AT_ASSERT(Op.arg().Get(4).floats().size() == 2); | ||
AT_ASSERT(fabs(Op.arg().Get(4).floats().Get(0) - 2.72) < 0.001); | ||
|
||
AT_ASSERT(Op.arg().Get(5).name() == "ss_attr"); | ||
AT_ASSERT(Op.arg().Get(5).strings().size() == 2); | ||
AT_ASSERT(Op.arg().Get(5).strings().Get(1) == "Summer"); | ||
|
||
AT_ASSERT(net.external_output().Get(0) == Op.output().Get(0)); | ||
AT_ASSERT(net.external_output().Get(1) == Op.output().Get(1)); | ||
|
||
// Convert NetDef back to IR and check if we get the original. | ||
Graph graph2; | ||
std::unordered_map<std::string, Value*> vmap; | ||
convertNetDefToIR(net, &graph2, &vmap); | ||
|
||
AT_ASSERT(graph2.outputs()[0]->node() == graph2.outputs()[0]->node()); | ||
Node* n = graph2.outputs()[0]->node(); | ||
AT_ASSERT(n->i(Symbol::fromQualString("attr::i_attr")) == 42); | ||
AT_ASSERT(n->f(Symbol::fromQualString("attr::f_attr")) == 3.0); | ||
AT_ASSERT(n->s(Symbol::fromQualString("attr::s_attr")) == "Hello!"); | ||
AT_ASSERT( | ||
n->is(Symbol::fromQualString("attr::is_attr")) == | ||
std::vector<long>({14, 18, 7})); | ||
AT_ASSERT( | ||
fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[0] - 2.72) < 0.001); | ||
AT_ASSERT( | ||
fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[1] - 3.14) < 0.001); | ||
AT_ASSERT( | ||
n->ss(Symbol::fromQualString("attr::ss_attr")) == | ||
std::vector<std::string>({"Winter", "Summer"})); | ||
} | ||
} | ||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
#include <torch/csrc/jit/netdef_converter.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
static AttributeKind getArgKind(const caffe2::Argument& arg) { | ||
if (arg.has_i()) { | ||
return AttributeKind::i; | ||
} else if (arg.has_f()) { | ||
return AttributeKind::f; | ||
} else if (arg.has_s()) { | ||
return AttributeKind::s; | ||
} else if (arg.has_t()) { | ||
return AttributeKind::t; | ||
} else if (arg.has_n()) { | ||
return AttributeKind::g; | ||
} else if (arg.ints().size()) { | ||
return AttributeKind::is; | ||
} else if (arg.floats().size()) { | ||
return AttributeKind::fs; | ||
} else if (arg.strings().size()) { | ||
return AttributeKind::ss; | ||
} else if (arg.tensors().size()) { | ||
return AttributeKind::ts; | ||
} else if (arg.nets().size()) { | ||
return AttributeKind::gs; | ||
} | ||
// Unknown type. | ||
abort(); | ||
} | ||
|
||
static void convertArg(const caffe2::Argument& arg, Node* node) { | ||
std::string attrName = "attr::" + arg.name(); | ||
auto attrSymbol = Symbol::fromQualString(attrName); | ||
AttributeKind kind = getArgKind(arg); | ||
switch (kind) { | ||
case AttributeKind::i: { | ||
node->i_(attrSymbol, (long)arg.i()); | ||
break; | ||
} | ||
case AttributeKind::f: { | ||
node->f_(attrSymbol, arg.f()); | ||
break; | ||
} | ||
case AttributeKind::s: { | ||
node->s_(attrSymbol, arg.s()); | ||
break; | ||
} | ||
case AttributeKind::is: { | ||
std::vector<long> is(arg.ints().begin(), arg.ints().end()); | ||
node->is_(attrSymbol, is); | ||
break; | ||
} | ||
case AttributeKind::fs: { | ||
std::vector<double> fs(arg.floats().begin(), arg.floats().end()); | ||
node->fs_(attrSymbol, fs); | ||
break; | ||
} | ||
case AttributeKind::ss: { | ||
std::vector<std::string> ss(arg.strings().begin(), arg.strings().end()); | ||
node->ss_(attrSymbol, ss); | ||
break; | ||
} | ||
default: { | ||
std::cout << "Unsupported type '" << toString(kind) << "' of attribute '" | ||
<< attrName << "'" | ||
<< " in node:" << std::endl; | ||
node->dump(); | ||
abort(); | ||
} | ||
} | ||
} | ||
|
||
void convertNetDefToIR( | ||
const caffe2::NetDef& net, | ||
Graph* g, | ||
std::unordered_map<std::string, Value*>* valueMapPtr, | ||
const std::string& prefix) { | ||
std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr; | ||
valueMap.clear(); | ||
|
||
for (const auto& inputName : net.external_input()) { | ||
AT_ASSERT(!valueMap.count(inputName)); | ||
valueMap[inputName] = g->addInput(); | ||
} | ||
|
||
for (const auto& op : net.op()) { | ||
std::string name = prefix + op.type(); | ||
Node* node = | ||
g->create(Symbol::fromQualString(name), {}, op.output().size()); | ||
g->insertNode(node); | ||
|
||
for (const auto& input : op.input()) { | ||
AT_ASSERT(valueMap.count(input)); | ||
node->addInput(valueMap[input]); | ||
} | ||
int idx = 0; | ||
for (const auto& output : op.output()) { | ||
// If output already exists in valueMap, overwrite it. This way we will | ||
// have the last definition of a value named 'output' in valueMap. | ||
valueMap[output] = node->outputs()[idx++]; | ||
} | ||
for (const auto& arg : op.arg()) { | ||
convertArg(arg, node); | ||
} | ||
} | ||
|
||
for (const auto& outputName : net.external_output()) { | ||
AT_ASSERT(valueMap.count(outputName)); | ||
g->registerOutput(valueMap.at(outputName)); | ||
} | ||
} | ||
|
||
static void convertAttrToCaffe2Arg( | ||
const Node* node, | ||
const Symbol& name, | ||
caffe2::Argument* arg) { | ||
arg->set_name(name.toUnqualString()); | ||
switch (node->kindOf(name)) { | ||
case AttributeKind::i: { | ||
arg->set_i(node->i(name)); | ||
break; | ||
} | ||
case AttributeKind::f: { | ||
arg->set_f(node->f(name)); | ||
break; | ||
} | ||
case AttributeKind::s: { | ||
arg->set_s(node->s(name)); | ||
break; | ||
} | ||
case AttributeKind::is: { | ||
for (long i : node->is(name)) { | ||
arg->add_ints(i); | ||
} | ||
break; | ||
} | ||
case AttributeKind::fs: { | ||
for (double f : node->fs(name)) { | ||
arg->add_floats(f); | ||
} | ||
break; | ||
} | ||
case AttributeKind::ss: { | ||
for (const std::string& s : node->ss(name)) { | ||
arg->add_strings(s); | ||
} | ||
break; | ||
} | ||
default: { | ||
std::cout << "Unsupported type '" << toString(node->kindOf(name)) | ||
<< "' of attribute '" << name.toUnqualString() << "'" | ||
<< " in node:" << std::endl; | ||
node->dump(); | ||
abort(); | ||
} | ||
} | ||
} | ||
|
||
static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) { | ||
caffe2::OperatorDef op; | ||
op.set_type(node->kind().toQualString()); | ||
for (const Value* input : node->inputs()) { | ||
op.add_input(input->uniqueName()); | ||
} | ||
for (const Value* output : node->outputs()) { | ||
op.add_output(output->uniqueName()); | ||
} | ||
std::vector<Symbol> names = node->attributeNames(); | ||
for (const Symbol& name : names) { | ||
caffe2::Argument* arg = op.add_arg(); | ||
convertAttrToCaffe2Arg(node, name, arg); | ||
} | ||
*net->add_op() = op; | ||
} | ||
|
||
void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) { | ||
net->mutable_op()->Clear(); | ||
|
||
for (const Value* value : g.inputs()) { | ||
net->add_external_input(value->uniqueName()); | ||
} | ||
|
||
for (const Node* node : g.nodes()) { | ||
convertNodeToCaffe2Op(node, net); | ||
} | ||
|
||
for (const Value* value : g.outputs()) { | ||
net->add_external_output(value->uniqueName()); | ||
} | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
Oops, something went wrong.