Skip to content

Commit

Permalink
Remove flatbuffer types/headers from flatbuffer_loader.h (pytorch#82893)
Browse files Browse the repository at this point in the history
This completely hides the flatbuffer types and headers from users of flatbuffer_loader/serializer, turning them into an internal implementation detail.

A followup diff will fix up the buck files to hide the dependencies more thoroughly.

While doing this I found another use of a flatbuffer-defined name (`FLATBUFFERS_MAX_ALIGNMENT`), which highlighted the issues described in T128189662.

Differential Revision: [D38292794](https://our.internmc.facebook.com/intern/diff/D38292794/)
Pull Request resolved: pytorch#82893
Approved by: https://github.com/qihqi
  • Loading branch information
dbort authored and pytorchmergebot committed Aug 7, 2022
1 parent 0c7ca2d commit 1d56ea5
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 158 deletions.
13 changes: 8 additions & 5 deletions torch/csrc/init_flatbuffer_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,24 @@

namespace py = pybind11;

using torch::jit::kFlatbufferDataAlignmentBytes;

static std::shared_ptr<char> copyStr(const std::string& bytes) {
size_t size = (bytes.size() / FLATBUFFERS_MAX_ALIGNMENT + 1) *
FLATBUFFERS_MAX_ALIGNMENT;
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
static_cast<char*>(_aligned_malloc(size, FLATBUFFERS_MAX_ALIGNMENT)),
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
::posix_memalign(&p, FLATBUFFERS_MAX_ALIGNMENT, size);
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, size)), free);
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;
Expand Down
153 changes: 122 additions & 31 deletions torch/csrc/jit/mobile/flatbuffer_loader.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>

#ifdef FLATBUFFERS_VERSION_MAJOR
#error "flatbuffer_loader.h must not include any flatbuffers headers"
#endif // FLATBUFFERS_VERSION_MAJOR

#include <array>
#include <istream>
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include <ATen/ATen.h>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/ivalue.h>
Expand All @@ -12,8 +26,10 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
Expand All @@ -28,35 +44,110 @@
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#endif

#if defined(HAVE_MMAP)
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#endif

#ifdef _WIN32
#include <malloc.h>
#else
#include <cstdlib>
#endif

#include <string>
#include <vector>

namespace torch {
namespace jit {

// Our own alignment requirement does not need to be exactly the same as what
// flatbuffers supports, but what flatbuffers supports needs to satisfy our
// requirement.
static_assert(
kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
"Sizes must be compatible");
static_assert(
(kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
kFlatbufferDataAlignmentBytes,
"Must be a power of 2");

namespace {

static constexpr c10::string_view kCustomClassPrefix =
"__torch__.torch.classes";
static constexpr c10::string_view kTorchPrefix = "__torch__";
static constexpr c10::string_view kJitPrefix = "torch.jit";

template <typename T, typename U>
std::vector<T> parseListNative(const U* list) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
return {list->items()->begin(), list->items()->end()};
}
class FlatbufferLoader final {
public:
FlatbufferLoader();

typedef IValue (
*IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
void registerIValueParser(
mobile::serialization::IValueUnion ivalue_type,
IValueParser parser);
mobile::Module parseModule(mobile::serialization::Module* module);

void extractJitSourceAndConstants(
ExtraFilesMap* jit_sources,
std::vector<IValue>* constants);

typedef TypePtr (*TypeResolver)(
const std::string& type_str,
std::shared_ptr<CompilationUnit> cu);

void internal_registerTypeResolver(TypeResolver type_resolver);

IValue& getIValue(uint32_t pos) {
TORCH_CHECK(pos < all_ivalues_.size());
return all_ivalues_[pos];
}

mobile::Function* getFunction(uint32_t pos) {
return all_functions_[pos];
}

ClassTypePtr getType(uint32_t pos) {
TORCH_CHECK(pos < all_types_.size());
return all_types_[pos];
}

c10::Storage getStorage(uint32_t index);
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
ClassTypePtr getOrCreateClassTypeForObject(
const mobile::serialization::Object* object);

const mobile::serialization::Module* getCurrentFlatbufferInput() {
return module_;
}

void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
should_copy_tensor_memory_ = should_copy_tensor_memory;
}

std::shared_ptr<mobile::CompilationUnit> mcu_;
std::shared_ptr<CompilationUnit> cu_;

private:
IValue parseIValue(const mobile::serialization::IValue* ivalue);
std::unique_ptr<mobile::Function> parseFunction(
const mobile::serialization::Function* method);
void parseAndPopulate(
uint32_t i,
const mobile::serialization::IValue* ivalue);

std::unordered_map<uint32_t, mobile::Function*> all_functions_;
std::vector<ClassTypePtr> all_types_;
std::unordered_set<uint32_t> initialized_types_;
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
std::vector<bool> storage_loaded_;
std::vector<c10::Storage> storages_;
std::vector<IValue> all_ivalues_;
std::array<
IValueParser,
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
ivalue_parsers_;
TypeResolver type_resolver_ = nullptr;
mobile::serialization::Module* module_ = nullptr;
bool module_parsed_ = false;
bool should_copy_tensor_memory_ = false;
// 0 -> mobile_ivalue_size_ elements are from the mobile module.
uint32_t mobile_ivalue_size_ = 0;
};

IValue parseList(
FlatbufferLoader&,
Expand Down Expand Up @@ -225,15 +316,13 @@ mobile::Module FlatbufferLoader::parseModule(
return m;
}

namespace {
void appendUpgraderFunctions(mobile::Function* function) {
#ifndef DISABLE_UPGRADER
for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
function->append_function(byteCodeFunctionWithOperator.function);
}
#endif
}
} // namespace

std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
const mobile::serialization::Function* method) {
Expand Down Expand Up @@ -266,9 +355,7 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
op->name()->str(), op->overload_name()->str(), num_args);
}

if (should_load_operators_) {
function->initialize_operators(true);
}
function->initialize_operators(true);

for (const auto i : *method->type_annotations()) {
function->append_type(getOrCreateTypeAnnotations(i));
Expand Down Expand Up @@ -434,6 +521,12 @@ IValue parseList(
return res;
}

template <typename T, typename U>
std::vector<T> parseListNative(const U* list) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
return {list->items()->begin(), list->items()->end()};
}

IValue parseIntList(
FlatbufferLoader&,
const mobile::serialization::IValue& ivalue) {
Expand Down Expand Up @@ -641,6 +734,8 @@ void FlatbufferLoader::extractJitSourceAndConstants(
parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
}

} // namespace

mobile::Module parse_and_initialize_mobile_module(
void* data,
size_t,
Expand All @@ -649,6 +744,8 @@ mobile::Module parse_and_initialize_mobile_module(
bool should_copy_tensor_memory) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
// TODO(T128189662): If not copying, enforce that data is aligned to
// kFlatbufferDataAlignmentBytes, and add unit tests.

FlatbufferLoader loader;
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
Expand Down Expand Up @@ -687,6 +784,8 @@ mobile::Module parse_and_initialize_mobile_module_for_jit(
ExtraFilesMap* extra_files) {
TORCH_CHECK(
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
// TODO(T128189662): Enforce that data is aligned to
// kFlatbufferDataAlignmentBytes, and add unit tests.

FlatbufferLoader loader;
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
Expand All @@ -699,16 +798,6 @@ mobile::Module parse_and_initialize_mobile_module_for_jit(
return m;
}

mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device>,
bool should_copy_tensor_memory) {
auto flatbufferLoader = FlatbufferLoader();
flatbufferLoader.setShouldCopyTensorMemory(should_copy_tensor_memory);
mobile::Module m = flatbufferLoader.parseModule(flatbuffer_module);
return m;
}

mobile::Module load_mobile_module_from_file(
const std::string& filename,
c10::optional<c10::Device> device,
Expand Down Expand Up @@ -786,7 +875,8 @@ mobile::Module load_mobile_module_from_stream_with_copy(
std::move(data), size, device, extra_files);
}

static mobile::Module parse_flatbuffer_no_object(
namespace {
mobile::Module parse_flatbuffer_no_object(
std::shared_ptr<char> data,
size_t size,
c10::optional<at::Device> device) {
Expand Down Expand Up @@ -815,6 +905,7 @@ static mobile::Module parse_flatbuffer_no_object(
m.set_delete_memory(std::move(data));
return m;
}
} // namespace

bool register_flatbuffer_loader() {
load_flatbuffer_bytes = parse_and_initialize_mobile_module;
Expand Down
Loading

0 comments on commit 1d56ea5

Please sign in to comment.