forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pytorch][Ondevice quantization] Add device side API to convert model (…
…pytorch#83807) Summary: This diff adds device side API which will convert the model to its quantized equivalent. THe input model must have been prepared AOT for quantization. API is implemented by: - Running reset obervers - Running observe method - Running quantize method - And replacing method, e.g. forward, with its quantized equivalent. Test Plan: test/quantization/jit/test_ondevice_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38889818](https://our.internmc.facebook.com/intern/diff/D38889818) Pull Request resolved: pytorch#83807 Approved by: https://github.com/iseeyuan
- Loading branch information
1 parent
eebdcb5
commit cfd18e1
Showing
10 changed files
with
256 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#include <ATen/Context.h> | ||
#include <torch/csrc/jit/mobile/module.h> | ||
#include <torch/csrc/jit/mobile/quantization.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace mobile { | ||
namespace quantization { | ||
|
||
void PTQQuanizationHelper::quantize_dynamic( | ||
torch::jit::mobile::Module& m, | ||
const std::string& method_name) { | ||
at::globalContext().setReleaseWeightsWhenPrepacking(false); | ||
std::string reset_observers_method_name = "reset_observers_" + method_name; | ||
std::string observe_method_name = "observe_" + method_name; | ||
std::string quantize_method_name = "quantize_" + method_name; | ||
std::string quantized_method_name = "quantized_" + method_name; | ||
|
||
TORCH_CHECK( | ||
m.find_method(reset_observers_method_name).has_value(), | ||
"PTQ ready module must have", | ||
reset_observers_method_name, | ||
" method."); | ||
TORCH_CHECK( | ||
m.find_method(observe_method_name), | ||
"PTQ ready module must have", | ||
reset_observers_method_name, | ||
" method."); | ||
TORCH_CHECK( | ||
m.find_method(quantize_method_name), | ||
"PTQ ready module must have", | ||
quantize_method_name, | ||
" method."); | ||
TORCH_CHECK( | ||
m.find_method(quantized_method_name), | ||
"PTQ ready module must have", | ||
quantized_method_name, | ||
" method."); | ||
TORCH_CHECK( | ||
m.find_method("get_all_bundled_inputs"), | ||
"PTQ ready module must have get_all_bundled_inputs method."); | ||
|
||
auto inputs = m.run_method("get_all_bundled_inputs") | ||
.toList() | ||
.get(0) | ||
.toTupleRef() | ||
.elements() | ||
.vec(); | ||
m.get_method(reset_observers_method_name)({}); | ||
m.get_method(observe_method_name)(inputs); | ||
m.get_method(quantize_method_name)(inputs); | ||
|
||
m.compareMethodSchemas(method_name, quantized_method_name); | ||
m.unsafeRemoveMethod(method_name); | ||
const Function& to_be_copied = | ||
m.find_method(quantized_method_name).value().function(); | ||
m.unsafeCopyMethod(method_name, to_be_copied); | ||
m.unsafeRemoveMethod(quantized_method_name); | ||
m.unsafeRemoveMethod(quantize_method_name); | ||
m.unsafeRemoveMethod(observe_method_name); | ||
m.unsafeRemoveMethod(reset_observers_method_name); | ||
} | ||
} // namespace quantization | ||
} // namespace mobile | ||
} // namespace jit | ||
} // namespace torch |
Oops, something went wrong.