forked from Xilinx/brevitas
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (export): extend quantized ONNX, remove PyXIR DPUv1, rename StdONNX
Signed-off-by: Alessandro Pappalardo <[email protected]>
- Loading branch information
Showing
58 changed files
with
1,013 additions
and
1,078 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,39 @@ | ||
from brevitas import config | ||
from functools import wraps | ||
|
||
from .onnx.finn.manager import FINNManager | ||
from .onnx.standard.manager import StdONNXManager | ||
from .onnx.vitis_ai.pyxir.dpuv1.manager import DPUv1Manager | ||
from .onnx.vitis_ai.pyxir.dpuv2.manager import DPUv2Manager | ||
from .onnx.generic.manager import BrevitasONNXManager | ||
from .onnx.standard.qoperator.manager import StdQOpONNXManager | ||
from .onnx.vitis_ai.pyxir.manager import PyXIRManager | ||
from .onnx.vitis_ai.xir.manager import XIRManager | ||
from .onnx.debug import enable_debug | ||
from .pytorch.manager import PytorchQuantManager | ||
|
||
|
||
@wraps(FINNManager.export) | ||
def export_finn_onnx(*args, **kwargs): | ||
return FINNManager.export(*args, **kwargs) | ||
|
||
|
||
def export_dpuv1_onnx(*args, **kwargs): | ||
return DPUv1Manager.export(*args, **kwargs) | ||
|
||
|
||
def export_dpuv2_onnx(*args, **kwargs): | ||
return DPUv2Manager.export(*args, **kwargs) | ||
|
||
|
||
def export_standard_onnx(*args, **kwargs): | ||
return StdONNXManager.export(*args, **kwargs) | ||
|
||
|
||
def jit_trace_dpuv1(*args, **kwargs): | ||
return DPUv1Manager.jit_inference_trace(*args, **kwargs) | ||
|
||
|
||
def is_ongoing_export(): | ||
return config._ONGOING_EXPORT is not None | ||
@wraps(PyXIRManager.export) | ||
def export_pyxir_onnx(*args, **kwargs): | ||
return PyXIRManager.export(*args, **kwargs) | ||
|
||
|
||
def is_ongoing_finn_export(): | ||
return config._ONGOING_EXPORT == FINNManager.target_name | ||
@wraps(XIRManager.export) | ||
def export_xir(*args, **kwargs): | ||
return XIRManager.export(*args, **kwargs) | ||
|
||
|
||
def is_ongoing_stdonnx_export(): | ||
return config._ONGOING_EXPORT == StdONNXManager.target_name | ||
@wraps(BrevitasONNXManager.export) | ||
def export_brevitas_onnx(*args, **kwargs): | ||
return BrevitasONNXManager.export(*args, **kwargs) | ||
|
||
|
||
def is_ongoing_pyxir_export(): | ||
if config._ONGOING_EXPORT is not None: | ||
return PyXIRManager.target_name in config._ONGOING_EXPORT | ||
else: | ||
return False | ||
@wraps(StdQOpONNXManager.export) | ||
def export_standard_qop_onnx(*args, **kwargs): | ||
return StdQOpONNXManager.export(*args, **kwargs) | ||
|
||
|
||
def is_ongoing_pytorch_export(): | ||
return config._ONGOING_EXPORT == PytorchQuantManager.target_name | ||
@wraps(PytorchQuantManager.export) | ||
def export_pytorch_quant(*args, **kwargs): | ||
return PytorchQuantManager.export(*args, **kwargs) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from abc import ABC | ||
import math | ||
|
||
import torch | ||
from torch.nn import Module | ||
from torch import Tensor | ||
|
||
|
||
__all__ = [ | ||
'BaseHandler', | ||
'BitWidthHandlerMixin', | ||
'ZeroPointHandlerMixin' | ||
] | ||
|
||
|
||
class BaseHandler(Module, ABC): | ||
|
||
def attach_debug_info(self, module): | ||
pass | ||
|
||
def prepare_for_export(self, module): | ||
pass | ||
|
||
def reset(self): | ||
pass | ||
|
||
|
||
class BitWidthHandlerMixin(object): | ||
|
||
@classmethod | ||
def validate_bit_width(cls, bit_width: Tensor, reference: int, le_then=False): | ||
if bit_width is None: | ||
raise RuntimeError("Bit width cannot be None") | ||
bit_width = int(bit_width.item()) | ||
if bit_width > reference: | ||
raise RuntimeError(f"Bit width {bit_width} is not supported.") | ||
elif bit_width < reference and not le_then: | ||
raise RuntimeError(f"Bit width {bit_width} is not supported, should be {reference}b.") | ||
return bit_width | ||
|
||
@classmethod | ||
def validate_8b_bit_width(cls, bit_width: Tensor, le_then=False): | ||
return cls.validate_bit_width(bit_width, 8, le_then) | ||
|
||
@classmethod | ||
def validate_16b_bit_width(cls, bit_width: Tensor, le_then=False): | ||
return cls.validate_bit_width(bit_width, 16, le_then) | ||
|
||
@classmethod | ||
def validate_32b_bit_width(cls, bit_width: Tensor, le_then=False): | ||
return cls.validate_bit_width(bit_width, 32, le_then) | ||
|
||
|
||
class ScaleHandlerMixin(object): | ||
|
||
@classmethod | ||
def validate_scalar_scale(cls, scale: Tensor): | ||
if scale is None: | ||
raise RuntimeError("Scale cannot be None.") | ||
if scale.view(-1).shape[0] != 1: | ||
raise RuntimeError("Only per-tensor scaling is supported.") | ||
return scale.item() | ||
|
||
@classmethod | ||
def validate_scalar_int_exponent(cls, scale: Tensor): | ||
cls.validate_scalar_scale(scale) | ||
exponent = math.log2(scale) | ||
if not exponent.is_integer(): | ||
raise RuntimeError("Only power-of-two scale factors are supported.") | ||
exponent = int(exponent) | ||
return exponent | ||
|
||
@classmethod | ||
def validate_neg_scalar_int_exponent(cls, scale: Tensor): | ||
return - cls.validate_scalar_int_exponent(scale) | ||
|
||
|
||
class ZeroPointHandlerMixin(object): | ||
|
||
@classmethod | ||
def zero_point_with_dtype(cls, signed, zero_point): | ||
if not signed: | ||
if (zero_point < 0).any(): | ||
raise RuntimeError("Zero points have to be positive under unsigned quantization") | ||
return zero_point.type(torch.uint8) | ||
else: | ||
return zero_point.type(torch.int8) | ||
|
||
@classmethod | ||
def quant_input_zero_point(cls, module): | ||
signed = module.is_quant_input_signed | ||
zero_point = module.quant_input_zero_point() | ||
return cls.zero_point_with_dtype(signed, zero_point) | ||
|
||
@classmethod | ||
def quant_weight_zero_point(cls, module): | ||
signed = module.is_quant_weight_signed | ||
zero_point = module.quant_weight_zero_point() | ||
return cls.zero_point_with_dtype(signed, zero_point) | ||
|
||
@classmethod | ||
def quant_output_zero_point(cls, module): | ||
signed = module.is_quant_output_signed | ||
zero_point = module.quant_output_zero_point() | ||
return cls.zero_point_with_dtype(signed, zero_point) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.