From 3822a472eff63c08419a828294ef3c9265bc7920 Mon Sep 17 00:00:00 2001 From: "Han Qi (qihqi)" Date: Wed, 18 May 2022 00:42:56 +0000 Subject: [PATCH] Python function to extract information on mobile::Module from flatbuffer (#77624) Summary: Includes following refactor: 1. common loading on operator validation that is dup'd in pickle and flatbuffer loader moved to function.h/cpp 2. Allow loading of a function without wiring operator. This function will be used to implement get_bundled_input and friends for flatbuffer. Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/69fa49f1230f80d1a0667e0a6ac8aca2746431b6 Reviewed By: cccclai Differential Revision: D36348549 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77624 Approved by: https://github.com/cccclai --- test/cpp/jit/test_lite_interpreter.cpp | 1 + test/jit/test_save_load.py | 28 +++++++++++ torch/csrc/init_flatbuffer_module.cpp | 13 +++++ torch/csrc/jit/mobile/code.h | 2 + torch/csrc/jit/mobile/flatbuffer_loader.cpp | 24 ++++----- torch/csrc/jit/mobile/flatbuffer_loader.h | 17 +++++++ torch/csrc/jit/mobile/function.cpp | 49 +++++++++++++++--- torch/csrc/jit/mobile/function.h | 10 +++- torch/csrc/jit/mobile/module.cpp | 35 +++++++++++++ torch/csrc/jit/mobile/module.h | 10 ++++ torch/csrc/jit/mobile/parse_operators.cpp | 55 ++------------------- torch/jit/__init__.py | 7 ++- torch/jit/_serialization.py | 42 +++++++++++++--- 13 files changed, 216 insertions(+), 77 deletions(-) diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index a07cc8af5aa707..d01c611bbaeca8 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -2091,6 +2091,7 @@ TEST(LiteInterpreterUpgraderTest, Upgrader) { std::vector upgrader_functions; for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) { + byteCodeFunctionWithOperator.function.initialize_operators(true); ASSERT_EQ( byteCodeFunctionWithOperator.function.get_code().operators_.size(), byteCodeFunctionWithOperator.function.get_code().op_names_.size()); diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index bbe7e0a7016f6e..47cbc0fd9b3a52 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -918,6 +918,34 @@ def forward(self) -> Optional[FooTuple]: output = m_loaded() self.assertEqual(output, None) + def test_module_info_flatbuffer(self): + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + self.foo = torch.nn.Linear(2, 2) + self.bar = torch.nn.Linear(2, 2) + + def forward(self, x): + x = self.foo(x) + x = self.bar(x) + return x + + first_script_module = torch.jit.script(Foo()) + first_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer( + first_script_module, first_saved_module) + first_saved_module.seek(0) + expected = { + 'bytecode_version': 4, + 'operator_version': 4, + 'function_names': {'__torch__.___torch_mangle_0.Foo.forward'}, + 'type_names': set(), + 'opname_to_num_args': {'aten::linear': 3}} + self.assertEqual( + torch.jit._serialization.get_flatbuffer_module_info(first_saved_module), + expected) + + def test_save_load_params_buffers_submodules(self): """ Check that parameters, buffers, and submodules are the same after loading. diff --git a/torch/csrc/init_flatbuffer_module.cpp b/torch/csrc/init_flatbuffer_module.cpp index 4bc0446a2c23e7..77bb302423febc 100644 --- a/torch/csrc/init_flatbuffer_module.cpp +++ b/torch/csrc/init_flatbuffer_module.cpp @@ -99,5 +99,18 @@ extern "C" reinterpret_cast(detached_buffer.data()), detached_buffer.size()); }); + pym.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) { + py::gil_scoped_acquire acquire; + py::dict result; + mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer( + &flatbuffer_content[0]); + result["bytecode_version"] = minfo.bytecode_version; + result["operator_version"] = minfo.operator_version; + result["function_names"] = minfo.function_names; + result["type_names"] = minfo.type_names; + result["opname_to_num_args"] = minfo.opname_to_num_args; + return result; + }); + return module; } diff --git a/torch/csrc/jit/mobile/code.h b/torch/csrc/jit/mobile/code.h index 7e01bd7199e124..128b193b63aa53 100644 --- a/torch/csrc/jit/mobile/code.h +++ b/torch/csrc/jit/mobile/code.h @@ -30,6 +30,8 @@ struct Code { // be done in parseMethods(). std::vector functions_; size_t register_size_ = 0; // Aggregated output size. + // initialized means operators_ array is filled with operators + bool initialized = false; }; } // namespace mobile diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 555e8ff034b9a6..d2edc5f31ae86c 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -240,8 +240,6 @@ std::unique_ptr FlatbufferLoader::parseFunction( function->append_constant(getIValue(i)); } - std::unordered_set unsupported_op_names; - appendUpgraderFunctions(function.get()); // 2. Decides if upgrader is needed const uint32_t operator_version = module_->operator_version(); @@ -254,19 +252,13 @@ std::unique_ptr FlatbufferLoader::parseFunction( num_args = op->num_args_serialized(); } - auto op_found = function->append_operator( + function->append_operator( op->name()->str(), op->overload_name()->str(), num_args); - - if (!op_found) { - unsupported_op_names.emplace( - op->name()->str() + "/" + op->overload_name()->str()); - } } - TORCH_CHECK( - unsupported_op_names.empty(), - "Unsupported ops: ", - c10::Join(", ", unsupported_op_names)); + if (should_load_operators_) { + function->initialize_operators(true); + } for (const auto i : *method->type_annotations()) { function->append_type(getOrCreateTypeAnnotations(i)); @@ -725,5 +717,13 @@ uint64_t get_bytecode_version(const std::string& filename) { return flatbuffer_module->bytecode_version(); } +mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) { + auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content); + FlatbufferLoader loader; + loader.setShouldLoadOperators(false); + mobile::Module m = loader.parseModule(ff_module); + return mobile::get_module_info(m); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h index e68ad7fa34aa96..7e883264da5007 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.h +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -30,6 +30,11 @@ using ExtraFilesMap = std::unordered_map; // Parse a mobile::Module from flatbuffer's in-memory Module representation. // The caller is assumed to manage the lifetimes of Module. // This function does step 3 described above. +// If should_copy_tensor_memory is true, then the returned module will NOT +// have refences to flatbuffer_module, so it can be discarded. +// If should_copy_tensor_memory is false, then returned module will have +// tensors that points inside of flatbuffer_module; the caller need to make +// sure that flatbuffer_module outlives returned Module. TORCH_API mobile::Module initialize_mobile_module( mobile::serialization::Module* flatbuffer_module, c10::optional device = c10::nullopt, @@ -66,6 +71,9 @@ TORCH_API std::tuple, size_t> get_stream_content( TORCH_API uint64_t get_bytecode_version(std::istream& in); TORCH_API uint64_t get_bytecode_version(const std::string& filename); +TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer( + char* flatbuffer_content); + class TORCH_API FlatbufferLoader { public: FlatbufferLoader(); @@ -118,6 +126,14 @@ class TORCH_API FlatbufferLoader { should_copy_tensor_memory_ = should_copy_tensor_memory; } + // Whether or not should load operators in functions. + // Not loading operators is useful because if an operator is not found + // then we throw exceptions, and sometimes we want to print out + // what operators are included before that to debug. + void setShouldLoadOperators(bool should_load_operators) { + should_load_operators_ = should_load_operators; + } + std::shared_ptr mcu_; std::shared_ptr cu_; @@ -141,6 +157,7 @@ class TORCH_API FlatbufferLoader { mobile::serialization::Module* module_ = nullptr; bool module_parsed_ = false; bool should_copy_tensor_memory_ = false; + bool should_load_operators_ = true; }; } // namespace jit diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index cc038d2536eba6..c963fcb85828cd 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -43,20 +44,52 @@ void Function::append_instruction(OpCode op, int X, int N) { code_.instructions_.emplace_back(op, X, N); } -bool Function::append_operator( +void Function::append_operator( const std::string& name, const std::string& overload_name, const c10::optional& num_specified_args) { // Keep the original opname in code_ code_.op_names_.emplace_back(name, overload_name); - const auto& opname = code_.op_names_.back(); code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1)); - auto func = makeOperatorFunction(opname, num_specified_args); - if (!func.has_value()) { - return false; +} + +std::string operator_str(const c10::OperatorName& opname) { + std::string result = opname.name; + if (!opname.overload_name.empty()) { + result += "." + opname.overload_name; } - code_.operators_.emplace_back(*func); - return true; + return result; +} + +bool Function::initialize_operators(bool should_check_operators) { + if (code_.initialized) { + return true; + } + std::unordered_set unsupported_op_names; + code_.operators_.resize(code_.op_names_.size()); + bool all_ops_supported = true; + for (int i = 0; i < code_.op_names_.size(); i++) { + const auto& opname = code_.op_names_[i]; + int num_args = code_.operator_input_sizes_[i]; + c10::optional num_specified_args = + num_args < 0 ? c10::nullopt : c10::optional(num_args); + auto func = makeOperatorFunction(opname, num_specified_args); + if (!func.has_value()) { + unsupported_op_names.insert(operator_str(opname)); + all_ops_supported = false; + break; + } else { + code_.operators_[i] = *func; + } + } + if (should_check_operators) { + TORCH_CHECK( + unsupported_op_names.empty(), + "Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/", + c10::Join(", ", unsupported_op_names)); + } + code_.initialized = all_ops_supported; + return all_ops_supported; } void Function::append_constant(const c10::IValue& constant) { @@ -96,6 +129,7 @@ const c10::FunctionSchema& Function::getSchema() const { } void Function::run(Stack& stack) { + initialize_operators(/* should_check_operators */ true); if (hasSchema()) { // if we have a schema then resolve optional args if any getSchema().checkAndNormalizeInputs( stack, std::unordered_map{} /*kwargs*/); @@ -114,6 +148,7 @@ size_t Function::num_inputs() const { } bool Function::call(Stack&, c10::function_ref f) { + initialize_operators(true); f(code_); return true; } diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h index f832a4373a603b..fb6f77fa64d766 100644 --- a/torch/csrc/jit/mobile/function.h +++ b/torch/csrc/jit/mobile/function.h @@ -34,7 +34,7 @@ class TORCH_API Function : public torch::jit::Function { // misaligned. Therefore only use ONE variant at time. void append_instruction(OpCode op, int X, int N, int64_t dbg_handle); void append_instruction(OpCode op, int X, int N); - bool append_operator( + void append_operator( const std::string& name, const std::string& overload_name, const c10::optional& num_specified_args); @@ -63,6 +63,12 @@ class TORCH_API Function : public torch::jit::Function { const std::vector& types, const size_t register_size); + // if not initialize, initialize by loading operators. + // return true of all op loaded, return false if some op is not found + // in the current runtime. Then, the ops that did not found will be filled + // in unsupported_op_names + bool initialize_operators(bool should_check_operators); + private: c10::QualifiedName name_; Code code_; @@ -73,6 +79,8 @@ c10::optional> makeOperatorFunction( c10::OperatorName opname, c10::optional num_specified_args); +TORCH_API std::string operator_str(const c10::OperatorName& opname); + } // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 27f013808a78bb..1483af299c43c7 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -263,6 +264,40 @@ c10::IValue Method::operator()(std::vector stack) const { return stack.front(); } +c10::optional print_type(const c10::Type& t) { + auto namedType = t.cast(); + if (namedType && namedType->name()) { + return namedType->name().value().qualifiedName(); + } + if (auto dyn = t.castRaw()) { + return dyn->fallback()->annotation_str(); + } + return c10::nullopt; +} + +TORCH_API ModuleInfo get_module_info(const mobile::Module& module) { + ModuleInfo minfo; + minfo.operator_version = module.min_operator_version(); + minfo.bytecode_version = module.bytecode_version(); + std::vector type_name_list; + for (const auto& func_ptr : module.compilation_unit().methods()) { + const auto& function = *func_ptr; + for (int i = 0; i < function.get_code().op_names_.size(); i++) { + const auto& op = function.get_code().op_names_[i]; + minfo.opname_to_num_args[mobile::operator_str(op)] = + function.get_code().operator_input_sizes_[i]; + } + for (const c10::TypePtr& tp : function.get_code().types_) { + type_name_list.push_back(tp->annotation_str(print_type)); + } + minfo.function_names.insert(function.qualname().qualifiedName()); + } + c10::TypeParser parser(type_name_list); + parser.parseList(); + minfo.type_names = parser.getContainedTypes(); + return minfo; +} + } // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 4a1c14cd2d3ad4..01c76e146581be 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -163,6 +163,16 @@ class TORCH_API Module { // Extra handle for the module to delete when itself is deleted std::shared_ptr mem_to_delete_; }; + +struct TORCH_API ModuleInfo { + uint64_t bytecode_version; + uint64_t operator_version; + std::unordered_map opname_to_num_args; + std::unordered_set function_names; + std::unordered_set type_names; +}; +TORCH_API ModuleInfo get_module_info(const mobile::Module& module); + } // namespace mobile } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/parse_operators.cpp b/torch/csrc/jit/mobile/parse_operators.cpp index 2c61ad03c26ee7..47acf09f106f17 100644 --- a/torch/csrc/jit/mobile/parse_operators.cpp +++ b/torch/csrc/jit/mobile/parse_operators.cpp @@ -5,27 +5,10 @@ namespace torch { namespace jit { namespace mobile { -std::string operator_str( - const std::string& name, - const std::string& overloadname) { - std::string result = name; - if (!overloadname.empty()) { - result += "." + overloadname; - } - return result; -} - -/** - * Loads operators by looking them up in the Dispatcher and returns - * the set of operator names (with overload) that are not supported - * by the current runtime. - */ -std::unordered_set load_and_find_unsupported_operator_names( +void parseOperators( c10::ivalue::TupleElements&& ops_list, + const uint64_t& module_load_options, mobile::Function* function) { - std::unordered_set unsupported_op_names; - // ops_list is the list of operator names that were read in from - // bytecode.plk for the method that is currently being processed. for (auto& op : std::move(ops_list)) { auto op_item = std::move(*std::move(op).toTuple()).elements(); TORCH_CHECK( @@ -37,41 +20,13 @@ std::unordered_set load_and_find_unsupported_operator_names( if (op_item.size() > 2) { num_args = op_item[2].toInt(); } - auto op_found = function->append_operator( + function->append_operator( op_item[0].toString()->string(), op_item[1].toString()->string(), num_args); - if (!op_found) { - unsupported_op_names.emplace(operator_str( - op_item[0].toString()->string(), op_item[1].toString()->string())); - } - } - return unsupported_op_names; -} - -void print_unsupported_ops_and_throw( - const std::unordered_set& unsupported_ops) { - std::string error_message("{"); - for (const auto& op_name : unsupported_ops) { - error_message += op_name + ", "; - } - error_message += "}"; - TORCH_CHECK( - false, - "Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/", - error_message); -} - -void parseOperators( - c10::ivalue::TupleElements&& ops_list, - const uint64_t& module_load_options, - mobile::Function* function) { - std::unordered_set unsupported_op_names = - load_and_find_unsupported_operator_names(std::move(ops_list), function); - if ((module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK) && - !unsupported_op_names.empty()) { - print_unsupported_ops_and_throw(unsupported_op_names); } + function->initialize_operators( + (module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK)); } } // namespace mobile diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 5f3ab73324e693..9d0e94542bf4a2 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -49,7 +49,12 @@ ) from torch.jit._async import fork, wait from torch.jit._decomposition_utils import _register_decomposition -from torch.jit._serialization import save, load, jit_module_from_flatbuffer, save_jit_module_to_flatbuffer +from torch.jit._serialization import ( + save, + load, + jit_module_from_flatbuffer, + save_jit_module_to_flatbuffer, +) from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph, set_fusion_strategy from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations from torch.jit._ir_utils import _InsertPoint diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index 3911cb411c5b2a..5e6517682be50a 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -184,13 +184,17 @@ def validate_map_location(map_location=None): return map_location -def jit_module_from_flatbuffer(f): +def get_ff_module(): try: import torch._C_flatbuffer as ff + return ff except ImportError: print("Please include //caffe2:_C_flatbuffer as dependency.") raise + +def jit_module_from_flatbuffer(f): + ff = get_ff_module() if isinstance(f, string_classes): if not os.path.exists(f): # type: ignore[type-var] raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe] @@ -242,14 +246,40 @@ def forward(self, x): # Save to file torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff') """ - try: - import torch._C_flatbuffer as ff - except ImportError: - print("Please include //caffe2:_C_flatbuffer as dependency.") - raise + ff = get_ff_module() if isinstance(f, str) or isinstance(f, pathlib.Path): f = str(f) ff._save_jit_module(m._c, f) else: s = ff._save_jit_module_to_bytes(m._c) f.write(s) + + +def get_flatbuffer_module_info(path_or_file): + r"""Get some information regarding a model file in flatbuffer format. + + + Args: + path_or_file: Either str, Path or file like object (BytesIO OK). + If it's str or Path, we will read the file referenced by that + path as Bytes. + + Returns: + A dict with metadata on what that file contains, currently looks like + this: + { + 'bytecode_version': 4, # int + 'operator_version': 4, # int + 'function_names': { + '__torch__.___torch_mangle_0.Foo.forward'}, # set + 'type_names': set(), # set + 'opname_to_num_args': {'aten::linear': 3} # Dict[str, int] + } + """ + ff = get_ff_module() + if isinstance(path_or_file, str) or isinstance(path_or_file, pathlib.Path): + with open(path_or_file, 'rb') as f: + all_bytes = f.read() + else: + all_bytes = path_or_file.read() + return ff._get_module_info_from_flatbuffer(all_bytes)