Skip to content

Commit

Permalink
Python function to extract information on mobile::Module from flatbuf…
Browse files Browse the repository at this point in the history
…fer (pytorch#77624)

Summary:
Includes following refactor:
1. common loading on operator validation that is dup'd in pickle and
   flatbuffer loader moved to function.h/cpp
2. Allow loading of a function without wiring operator.

This function will be used to implement get_bundled_input and friends
for flatbuffer.

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/69fa49f1230f80d1a0667e0a6ac8aca2746431b6

Reviewed By: cccclai

Differential Revision: D36348549

Pull Request resolved: pytorch#77624
Approved by: https://github.com/cccclai
  • Loading branch information
qihqi authored and pytorchmergebot committed May 18, 2022
1 parent fc4c3c9 commit 3822a47
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 77 deletions.
1 change: 1 addition & 0 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,7 @@ TEST(LiteInterpreterUpgraderTest, Upgrader) {
std::vector<mobile::Function> upgrader_functions;

for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
byteCodeFunctionWithOperator.function.initialize_operators(true);
ASSERT_EQ(
byteCodeFunctionWithOperator.function.get_code().operators_.size(),
byteCodeFunctionWithOperator.function.get_code().op_names_.size());
Expand Down
28 changes: 28 additions & 0 deletions test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,34 @@ def forward(self) -> Optional[FooTuple]:
output = m_loaded()
self.assertEqual(output, None)

def test_module_info_flatbuffer(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x

first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(
first_script_module, first_saved_module)
first_saved_module.seek(0)
expected = {
'bytecode_version': 4,
'operator_version': 4,
'function_names': {'__torch__.___torch_mangle_0.Foo.forward'},
'type_names': set(),
'opname_to_num_args': {'aten::linear': 3}}
self.assertEqual(
torch.jit._serialization.get_flatbuffer_module_info(first_saved_module),
expected)


def test_save_load_params_buffers_submodules(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/init_flatbuffer_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,18 @@ extern "C"
reinterpret_cast<char*>(detached_buffer.data()),
detached_buffer.size());
});
pym.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer(
&flatbuffer_content[0]);
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});

return module;
}
2 changes: 2 additions & 0 deletions torch/csrc/jit/mobile/code.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct Code {
// be done in parseMethods().
std::vector<mobile::Function*> functions_;
size_t register_size_ = 0; // Aggregated output size.
// initialized means operators_ array is filled with operators
bool initialized = false;
};

} // namespace mobile
Expand Down
24 changes: 12 additions & 12 deletions torch/csrc/jit/mobile/flatbuffer_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,6 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
function->append_constant(getIValue(i));
}

std::unordered_set<std::string> unsupported_op_names;

appendUpgraderFunctions(function.get());
// 2. Decides if upgrader is needed
const uint32_t operator_version = module_->operator_version();
Expand All @@ -254,19 +252,13 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
num_args = op->num_args_serialized();
}

auto op_found = function->append_operator(
function->append_operator(
op->name()->str(), op->overload_name()->str(), num_args);

if (!op_found) {
unsupported_op_names.emplace(
op->name()->str() + "/" + op->overload_name()->str());
}
}

TORCH_CHECK(
unsupported_op_names.empty(),
"Unsupported ops: ",
c10::Join(", ", unsupported_op_names));
if (should_load_operators_) {
function->initialize_operators(true);
}

for (const auto i : *method->type_annotations()) {
function->append_type(getOrCreateTypeAnnotations(i));
Expand Down Expand Up @@ -725,5 +717,13 @@ uint64_t get_bytecode_version(const std::string& filename) {
return flatbuffer_module->bytecode_version();
}

mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
FlatbufferLoader loader;
loader.setShouldLoadOperators(false);
mobile::Module m = loader.parseModule(ff_module);
return mobile::get_module_info(m);
}

} // namespace jit
} // namespace torch
17 changes: 17 additions & 0 deletions torch/csrc/jit/mobile/flatbuffer_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
// The caller is assumed to manage the lifetimes of Module.
// This function does step 3 described above.
// If should_copy_tensor_memory is true, then the returned module will NOT
// have refences to flatbuffer_module, so it can be discarded.
// If should_copy_tensor_memory is false, then returned module will have
// tensors that points inside of flatbuffer_module; the caller need to make
// sure that flatbuffer_module outlives returned Module.
TORCH_API mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device> device = c10::nullopt,
Expand Down Expand Up @@ -66,6 +71,9 @@ TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
TORCH_API uint64_t get_bytecode_version(std::istream& in);
TORCH_API uint64_t get_bytecode_version(const std::string& filename);

TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer(
char* flatbuffer_content);

class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();
Expand Down Expand Up @@ -118,6 +126,14 @@ class TORCH_API FlatbufferLoader {
should_copy_tensor_memory_ = should_copy_tensor_memory;
}

// Whether or not should load operators in functions.
// Not loading operators is useful because if an operator is not found
// then we throw exceptions, and sometimes we want to print out
// what operators are included before that to debug.
void setShouldLoadOperators(bool should_load_operators) {
should_load_operators_ = should_load_operators;
}

std::shared_ptr<mobile::CompilationUnit> mcu_;
std::shared_ptr<CompilationUnit> cu_;

Expand All @@ -141,6 +157,7 @@ class TORCH_API FlatbufferLoader {
mobile::serialization::Module* module_ = nullptr;
bool module_parsed_ = false;
bool should_copy_tensor_memory_ = false;
bool should_load_operators_ = true;
};

} // namespace jit
Expand Down
49 changes: 42 additions & 7 deletions torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/csrc/jit/mobile/prim_ops_registery.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>

Expand Down Expand Up @@ -43,20 +44,52 @@ void Function::append_instruction(OpCode op, int X, int N) {
code_.instructions_.emplace_back(op, X, N);
}

bool Function::append_operator(
void Function::append_operator(
const std::string& name,
const std::string& overload_name,
const c10::optional<int>& num_specified_args) {
// Keep the original opname in code_
code_.op_names_.emplace_back(name, overload_name);
const auto& opname = code_.op_names_.back();
code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
auto func = makeOperatorFunction(opname, num_specified_args);
if (!func.has_value()) {
return false;
}

std::string operator_str(const c10::OperatorName& opname) {
std::string result = opname.name;
if (!opname.overload_name.empty()) {
result += "." + opname.overload_name;
}
code_.operators_.emplace_back(*func);
return true;
return result;
}

bool Function::initialize_operators(bool should_check_operators) {
if (code_.initialized) {
return true;
}
std::unordered_set<std::string> unsupported_op_names;
code_.operators_.resize(code_.op_names_.size());
bool all_ops_supported = true;
for (int i = 0; i < code_.op_names_.size(); i++) {
const auto& opname = code_.op_names_[i];
int num_args = code_.operator_input_sizes_[i];
c10::optional<int> num_specified_args =
num_args < 0 ? c10::nullopt : c10::optional<int>(num_args);
auto func = makeOperatorFunction(opname, num_specified_args);
if (!func.has_value()) {
unsupported_op_names.insert(operator_str(opname));
all_ops_supported = false;
break;
} else {
code_.operators_[i] = *func;
}
}
if (should_check_operators) {
TORCH_CHECK(
unsupported_op_names.empty(),
"Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/",
c10::Join(", ", unsupported_op_names));
}
code_.initialized = all_ops_supported;
return all_ops_supported;
}

void Function::append_constant(const c10::IValue& constant) {
Expand Down Expand Up @@ -96,6 +129,7 @@ const c10::FunctionSchema& Function::getSchema() const {
}

void Function::run(Stack& stack) {
initialize_operators(/* should_check_operators */ true);
if (hasSchema()) { // if we have a schema then resolve optional args if any
getSchema().checkAndNormalizeInputs<c10::DynamicType>(
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
Expand All @@ -114,6 +148,7 @@ size_t Function::num_inputs() const {
}

bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
initialize_operators(true);
f(code_);
return true;
}
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/mobile/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TORCH_API Function : public torch::jit::Function {
// misaligned. Therefore only use ONE variant at time.
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle);
void append_instruction(OpCode op, int X, int N);
bool append_operator(
void append_operator(
const std::string& name,
const std::string& overload_name,
const c10::optional<int>& num_specified_args);
Expand Down Expand Up @@ -63,6 +63,12 @@ class TORCH_API Function : public torch::jit::Function {
const std::vector<c10::TypePtr>& types,
const size_t register_size);

// if not initialize, initialize by loading operators.
// return true of all op loaded, return false if some op is not found
// in the current runtime. Then, the ops that did not found will be filled
// in unsupported_op_names
bool initialize_operators(bool should_check_operators);

private:
c10::QualifiedName name_;
Code code_;
Expand All @@ -73,6 +79,8 @@ c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
c10::OperatorName opname,
c10::optional<int> num_specified_args);

TORCH_API std::string operator_str(const c10::OperatorName& opname);

} // namespace mobile
} // namespace jit
} // namespace torch
35 changes: 35 additions & 0 deletions torch/csrc/jit/mobile/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/backends/backend_exception.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <exception>

Expand Down Expand Up @@ -263,6 +264,40 @@ c10::IValue Method::operator()(std::vector<c10::IValue> stack) const {
return stack.front();
}

c10::optional<std::string> print_type(const c10::Type& t) {
auto namedType = t.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return namedType->name().value().qualifiedName();
}
if (auto dyn = t.castRaw<c10::DynamicType>()) {
return dyn->fallback()->annotation_str();
}
return c10::nullopt;
}

TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
ModuleInfo minfo;
minfo.operator_version = module.min_operator_version();
minfo.bytecode_version = module.bytecode_version();
std::vector<std::string> type_name_list;
for (const auto& func_ptr : module.compilation_unit().methods()) {
const auto& function = *func_ptr;
for (int i = 0; i < function.get_code().op_names_.size(); i++) {
const auto& op = function.get_code().op_names_[i];
minfo.opname_to_num_args[mobile::operator_str(op)] =
function.get_code().operator_input_sizes_[i];
}
for (const c10::TypePtr& tp : function.get_code().types_) {
type_name_list.push_back(tp->annotation_str(print_type));
}
minfo.function_names.insert(function.qualname().qualifiedName());
}
c10::TypeParser parser(type_name_list);
parser.parseList();
minfo.type_names = parser.getContainedTypes();
return minfo;
}

} // namespace mobile
} // namespace jit
} // namespace torch
10 changes: 10 additions & 0 deletions torch/csrc/jit/mobile/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ class TORCH_API Module {
// Extra handle for the module to delete when itself is deleted
std::shared_ptr<char> mem_to_delete_;
};

struct TORCH_API ModuleInfo {
uint64_t bytecode_version;
uint64_t operator_version;
std::unordered_map<std::string, int> opname_to_num_args;
std::unordered_set<std::string> function_names;
std::unordered_set<std::string> type_names;
};
TORCH_API ModuleInfo get_module_info(const mobile::Module& module);

} // namespace mobile
} // namespace jit
} // namespace torch
Loading

0 comments on commit 3822a47

Please sign in to comment.