Skip to content

Commit

Permalink
[Pytorch][Ondevice quantization] Add device side API to convert model (
Browse files Browse the repository at this point in the history
…pytorch#83807)

Summary:
This diff adds device side API which will convert the model to its
quantized equivalent. THe input model must have been prepared AOT for
quantization.

API is implemented by:
- Running reset obervers
- Running observe method
- Running quantize method
- And replacing method, e.g. forward, with its quantized equivalent.

Test Plan:
test/quantization/jit/test_ondevice_quantization.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D38889818](https://our.internmc.facebook.com/intern/diff/D38889818)
Pull Request resolved: pytorch#83807
Approved by: https://github.com/iseeyuan
  • Loading branch information
kimishpatel authored and pytorchmergebot committed Aug 29, 2022
1 parent eebdcb5 commit cfd18e1
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 43 deletions.
1 change: 1 addition & 0 deletions buckbuild.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,7 @@ def define_buck_targets(
"torch/csrc/autograd/VariableTypeManual.cpp",
"torch/csrc/autograd/FunctionsManual.cpp",
"torch/csrc/api/src/data/datasets/mnist.cpp",
"torch/csrc/jit/mobile/quantization.cpp",
"torch/csrc/jit/mobile/train/export_data.cpp",
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
"torch/csrc/jit/mobile/train/random.cpp",
Expand Down
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ torch_mobile_core = [
"torch/csrc/jit/mobile/observer.cpp",
"torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/mobile/parse_operators.cpp",
"torch/csrc/jit/mobile/quantization.cpp",
"torch/csrc/jit/mobile/upgrader_mobile.cpp",
"torch/csrc/jit/runtime/register_prim_ops.cpp",
"torch/csrc/jit/runtime/register_special_ops.cpp",
Expand Down Expand Up @@ -612,6 +613,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
"torch/csrc/jit/mobile/observer.cpp",
"torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/mobile/parse_operators.cpp",
"torch/csrc/jit/mobile/quantization.cpp",
"torch/csrc/jit/mobile/train/export_data.cpp",
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
"torch/csrc/jit/mobile/train/random.cpp",
Expand Down
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/quantization.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp
Expand Down
121 changes: 78 additions & 43 deletions test/quantization/jit/test_ondevice_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Owner(s): ["oncall: quantization"]

import torch
import torch._C_flatbuffer

from torch.ao.quantization import (
default_dynamic_qconfig,
Expand All @@ -22,11 +23,13 @@
LinearAddModel,
)

from torch.jit.mobile import _load_for_lite_interpreter
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule

from torch.testing import FileCheck
from torch.utils import bundled_inputs as bundled_inputs

import io
from typing import Dict

class myMod(torch.nn.Module):
def __init__(self, weight):
Expand Down Expand Up @@ -396,7 +399,7 @@ def _check_against_ref_dynamic_ptq(self, model):
self.assertTrue(thrown)


def _check_serialization_deserialization(self, model):
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
Expand All @@ -410,27 +413,40 @@ def _check_serialization_deserialization(self, model):
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)

m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))

# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
m.run_method("reset_observers_forward")
m.run_method("observe_forward", *inputs)
m.run_method("quantize_forward", *inputs)
output = m.run_method("quantized_forward", *inputs)
self.assertTrue(torch.allclose(ref_output, output))
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))

# Now serialize to flabuffer and load from fb and check
dict: Dict[str, str] = {}
bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
fb_output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, fb_output))

model.eval()
inputs = model.get_example_inputs()
Expand All @@ -445,27 +461,41 @@ def _check_serialization_deserialization(self, model):
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)

m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))

# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
m.run_method("reset_observers_forward")
m.run_method("observe_forward", *inputs)
m.run_method("quantize_forward", *inputs)
output = m.run_method("quantized_forward", *inputs)
self.assertTrue(torch.allclose(ref_output, output))

def _check_serialization_deserialization(self, model):
self._check_serdes_and_device_side_api_helper(model, False)


def _check_device_side_api(self, model):
self._check_serdes_and_device_side_api_helper(model, True)


def test_quantize_forward(self):
Expand All @@ -492,3 +522,8 @@ def test_against_offdevice_dynamic_ptq(self):
def test_serialization_deserialization(self):
model = MyConvLinearModule()
self._check_serialization_deserialization(model)


def test_device_side_api(self):
model = MyConvLinearModule()
self._check_device_side_api(model)
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
def _export_operator_list(module: LiteScriptModule): ...
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ...
Expand Down
44 changes: 44 additions & 0 deletions torch/csrc/jit/mobile/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,50 @@ Method Module::get_method(const std::string& name) const {
AT_ERROR("Method '", name, "' is not defined.");
}

bool Module::compareMethodSchemas(
const std::string& name_1,
const std::string& name_2) {
c10::optional<c10::FunctionSchema> schema_1, schema_2;
for (const auto& fn : cu_->methods()) {
if (fn->name() == name_1) {
schema_1 = fn->getSchema();
}
if (fn->name() == name_2) {
schema_2 = fn->getSchema();
}
}
if (schema_1.has_value() && schema_2.has_value()) {
return (schema_1 == schema_2);
}
return false;
}

void Module::unsafeRemoveMethod(const std::string& basename) {
int64_t i = 0;
for (; i < cu_->methods().size(); ++i) {
if ((cu_->methods()[i])->name() == basename) {
break;
}
}
object_->type()->unsafeRemoveMethod(basename);
cu_->unsafeRemoveFunction(i);
}

void Module::unsafeCopyMethod(
const std::string& new_method_name,
const Function& to_be_copied) {
TORCH_CHECK(
!find_method(new_method_name).has_value(),
"Trying to replace existing method.");
const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname();
c10::QualifiedName qualified_method_name(
tobe_copied_name.prefix(), new_method_name);
std::unique_ptr<Function> new_fn = std::make_unique<Function>(
qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema());
object_->type()->addMethod(new_fn.get());
cu_->register_function(std::move(new_fn));
}

c10::optional<Method> Module::find_method(const std::string& basename) const {
for (const auto& fn : cu_->methods()) {
if (fn->name() == basename) {
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/jit/mobile/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/mobile/debug_info.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/method.h>
#include <torch/csrc/jit/mobile/quantization.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -42,6 +43,10 @@ class CompilationUnit {
Function* find_function(const c10::QualifiedName& qn);
const Function* find_function(const c10::QualifiedName& qn) const;

void unsafeRemoveFunction(const int64_t index) {
methods_.erase(methods_.begin() + index);
}

private:
std::vector<std::unique_ptr<Function>> methods_;
};
Expand Down Expand Up @@ -71,6 +76,7 @@ class TORCH_API Module {
return get_method("forward")(std::move(inputs));
}
c10::optional<Method> find_method(const std::string& basename) const;

const std::string name() const {
return object_->name();
}
Expand Down Expand Up @@ -152,6 +158,18 @@ class TORCH_API Module {
}

private:
friend class quantization::PTQQuanizationHelper;

bool compareMethodSchemas(
const std::string& name_1,
const std::string& name_2);

void unsafeRemoveMethod(const std::string& basename);

void unsafeCopyMethod(
const std::string& new_method_name,
const Function& to_be_copied);

c10::intrusive_ptr<c10::ivalue::Object> object_;
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;
Expand Down
66 changes: 66 additions & 0 deletions torch/csrc/jit/mobile/quantization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include <ATen/Context.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/quantization.h>

namespace torch {
namespace jit {
namespace mobile {
namespace quantization {

void PTQQuanizationHelper::quantize_dynamic(
torch::jit::mobile::Module& m,
const std::string& method_name) {
at::globalContext().setReleaseWeightsWhenPrepacking(false);
std::string reset_observers_method_name = "reset_observers_" + method_name;
std::string observe_method_name = "observe_" + method_name;
std::string quantize_method_name = "quantize_" + method_name;
std::string quantized_method_name = "quantized_" + method_name;

TORCH_CHECK(
m.find_method(reset_observers_method_name).has_value(),
"PTQ ready module must have",
reset_observers_method_name,
" method.");
TORCH_CHECK(
m.find_method(observe_method_name),
"PTQ ready module must have",
reset_observers_method_name,
" method.");
TORCH_CHECK(
m.find_method(quantize_method_name),
"PTQ ready module must have",
quantize_method_name,
" method.");
TORCH_CHECK(
m.find_method(quantized_method_name),
"PTQ ready module must have",
quantized_method_name,
" method.");
TORCH_CHECK(
m.find_method("get_all_bundled_inputs"),
"PTQ ready module must have get_all_bundled_inputs method.");

auto inputs = m.run_method("get_all_bundled_inputs")
.toList()
.get(0)
.toTupleRef()
.elements()
.vec();
m.get_method(reset_observers_method_name)({});
m.get_method(observe_method_name)(inputs);
m.get_method(quantize_method_name)(inputs);

m.compareMethodSchemas(method_name, quantized_method_name);
m.unsafeRemoveMethod(method_name);
const Function& to_be_copied =
m.find_method(quantized_method_name).value().function();
m.unsafeCopyMethod(method_name, to_be_copied);
m.unsafeRemoveMethod(quantized_method_name);
m.unsafeRemoveMethod(quantize_method_name);
m.unsafeRemoveMethod(observe_method_name);
m.unsafeRemoveMethod(reset_observers_method_name);
}
} // namespace quantization
} // namespace mobile
} // namespace jit
} // namespace torch
Loading

0 comments on commit cfd18e1

Please sign in to comment.