|
|
@@ -2246,10 +2246,10 @@ O
P
@@ -3228,7 +3228,7 @@ Z
- zero_point (brevitas.quant_tensor.QuantTensor attribute)
- - zero_point_with_dtype() (brevitas.export.common.handler.TypedZeroPointHandler class method)
+
- zero_point_with_dtype() (brevitas.export.common.handler.ZeroPointHandlerMixin class method)
- ZeroZeroPoint (class in brevitas.core.zero_point)
diff --git a/src/brevitas/core/bit_width/const.py b/src/brevitas/core/bit_width/const.py
index 3d5f0f2dc..7f2aa1787 100644
--- a/src/brevitas/core/bit_width/const.py
+++ b/src/brevitas/core/bit_width/const.py
@@ -74,7 +74,7 @@ class BitWidthConst(brevitas.jit.ScriptModule):
def __init__(self, bit_width: int) -> None:
super(BitWidthConst, self).__init__()
assert isinstance(bit_width, int)
- self.bit_width = StatelessBuffer(torch.tensor(float(int(bit_width))))
+ self.bit_width = StatelessBuffer(torch.tensor(float(bit_width)))
@brevitas.jit.script_method
def forward(self) -> Tensor:
diff --git a/src/brevitas/export/__init__.py b/src/brevitas/export/__init__.py
index 876219779..5ec601802 100644
--- a/src/brevitas/export/__init__.py
+++ b/src/brevitas/export/__init__.py
@@ -1,52 +1,39 @@
-from brevitas import config
+from functools import wraps
+
from .onnx.finn.manager import FINNManager
-from .onnx.standard.manager import StdONNXManager
-from .onnx.vitis_ai.pyxir.dpuv1.manager import DPUv1Manager
-from .onnx.vitis_ai.pyxir.dpuv2.manager import DPUv2Manager
+from .onnx.generic.manager import BrevitasONNXManager
+from .onnx.standard.qoperator.manager import StdQOpONNXManager
from .onnx.vitis_ai.pyxir.manager import PyXIRManager
from .onnx.vitis_ai.xir.manager import XIRManager
from .onnx.debug import enable_debug
from .pytorch.manager import PytorchQuantManager
+@wraps(FINNManager.export)
def export_finn_onnx(*args, **kwargs):
return FINNManager.export(*args, **kwargs)
-def export_dpuv1_onnx(*args, **kwargs):
- return DPUv1Manager.export(*args, **kwargs)
-
-
-def export_dpuv2_onnx(*args, **kwargs):
- return DPUv2Manager.export(*args, **kwargs)
-
-
-def export_standard_onnx(*args, **kwargs):
- return StdONNXManager.export(*args, **kwargs)
-
-
-def jit_trace_dpuv1(*args, **kwargs):
- return DPUv1Manager.jit_inference_trace(*args, **kwargs)
-
-
-def is_ongoing_export():
- return config._ONGOING_EXPORT is not None
+@wraps(PyXIRManager.export)
+def export_pyxir_onnx(*args, **kwargs):
+ return PyXIRManager.export(*args, **kwargs)
-def is_ongoing_finn_export():
- return config._ONGOING_EXPORT == FINNManager.target_name
+@wraps(XIRManager.export)
+def export_xir(*args, **kwargs):
+ return XIRManager.export(*args, **kwargs)
-def is_ongoing_stdonnx_export():
- return config._ONGOING_EXPORT == StdONNXManager.target_name
+@wraps(BrevitasONNXManager.export)
+def export_brevitas_onnx(*args, **kwargs):
+ return BrevitasONNXManager.export(*args, **kwargs)
-def is_ongoing_pyxir_export():
- if config._ONGOING_EXPORT is not None:
- return PyXIRManager.target_name in config._ONGOING_EXPORT
- else:
- return False
+@wraps(StdQOpONNXManager.export)
+def export_standard_qop_onnx(*args, **kwargs):
+ return StdQOpONNXManager.export(*args, **kwargs)
-def is_ongoing_pytorch_export():
- return config._ONGOING_EXPORT == PytorchQuantManager.target_name
\ No newline at end of file
+@wraps(PytorchQuantManager.export)
+def export_pytorch_quant(*args, **kwargs):
+ return PytorchQuantManager.export(*args, **kwargs)
\ No newline at end of file
diff --git a/src/brevitas/export/common/handler.py b/src/brevitas/export/common/handler.py
deleted file mode 100644
index c0e48b9a7..000000000
--- a/src/brevitas/export/common/handler.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import torch
-from torch import Tensor
-
-from ..base import BaseHandler
-
-
-class Validate8BitHandler(BaseHandler):
-
- @classmethod
- def validate_8b_bit_width(cls, bit_width: Tensor):
- if bit_width is None:
- raise RuntimeError("Bit width cannot be None")
- bit_width = int(bit_width.item())
- if bit_width != 8:
- raise RuntimeError("Only 8b bit width supported")
- return bit_width
-
-
-class TypedZeroPointHandler(BaseHandler):
-
- @classmethod
- def zero_point_with_dtype(cls, signed, zero_point):
- if not signed:
- if (zero_point < 0).any():
- raise RuntimeError("Zero points have to be positive under unsigned quantization")
- return zero_point.type(torch.uint8)
- else:
- return zero_point.type(torch.int8)
-
- @classmethod
- def quant_input_zero_point(cls, module):
- signed = module.is_quant_input_signed
- zero_point = module.quant_input_zero_point()
- return cls.zero_point_with_dtype(signed, zero_point)
-
- @classmethod
- def quant_weight_zero_point(cls, module):
- signed = module.is_quant_weight_signed
- zero_point = module.quant_weight_zero_point()
- return cls.zero_point_with_dtype(signed, zero_point)
-
- @classmethod
- def quant_output_zero_point(cls, module):
- signed = module.is_quant_output_signed
- zero_point = module.quant_output_zero_point()
- return cls.zero_point_with_dtype(signed, zero_point)
diff --git a/src/brevitas/export/handler.py b/src/brevitas/export/handler.py
new file mode 100644
index 000000000..71c080c81
--- /dev/null
+++ b/src/brevitas/export/handler.py
@@ -0,0 +1,105 @@
+from abc import ABC
+import math
+
+import torch
+from torch.nn import Module
+from torch import Tensor
+
+
+__all__ = [
+ 'BaseHandler',
+ 'BitWidthHandlerMixin',
+ 'ZeroPointHandlerMixin'
+]
+
+
+class BaseHandler(Module, ABC):
+
+ def attach_debug_info(self, module):
+ pass
+
+ def prepare_for_export(self, module):
+ pass
+
+ def reset(self):
+ pass
+
+
+class BitWidthHandlerMixin(object):
+
+ @classmethod
+ def validate_bit_width(cls, bit_width: Tensor, reference: int, le_then=False):
+ if bit_width is None:
+ raise RuntimeError("Bit width cannot be None")
+ bit_width = int(bit_width.item())
+ if bit_width > reference:
+ raise RuntimeError(f"Bit width {bit_width} is not supported.")
+ elif bit_width < reference and not le_then:
+ raise RuntimeError(f"Bit width {bit_width} is not supported, should be {reference}b.")
+ return bit_width
+
+ @classmethod
+ def validate_8b_bit_width(cls, bit_width: Tensor, le_then=False):
+ return cls.validate_bit_width(bit_width, 8, le_then)
+
+ @classmethod
+ def validate_16b_bit_width(cls, bit_width: Tensor, le_then=False):
+ return cls.validate_bit_width(bit_width, 16, le_then)
+
+ @classmethod
+ def validate_32b_bit_width(cls, bit_width: Tensor, le_then=False):
+ return cls.validate_bit_width(bit_width, 32, le_then)
+
+
+class ScaleHandlerMixin(object):
+
+ @classmethod
+ def validate_scalar_scale(cls, scale: Tensor):
+ if scale is None:
+ raise RuntimeError("Scale cannot be None.")
+ if scale.view(-1).shape[0] != 1:
+ raise RuntimeError("Only per-tensor scaling is supported.")
+ return scale.item()
+
+ @classmethod
+ def validate_scalar_int_exponent(cls, scale: Tensor):
+ cls.validate_scalar_scale(scale)
+ exponent = math.log2(scale)
+ if not exponent.is_integer():
+ raise RuntimeError("Only power-of-two scale factors are supported.")
+ exponent = int(exponent)
+ return exponent
+
+ @classmethod
+ def validate_neg_scalar_int_exponent(cls, scale: Tensor):
+ return - cls.validate_scalar_int_exponent(scale)
+
+
+class ZeroPointHandlerMixin(object):
+
+ @classmethod
+ def zero_point_with_dtype(cls, signed, zero_point):
+ if not signed:
+ if (zero_point < 0).any():
+ raise RuntimeError("Zero points have to be positive under unsigned quantization")
+ return zero_point.type(torch.uint8)
+ else:
+ return zero_point.type(torch.int8)
+
+ @classmethod
+ def quant_input_zero_point(cls, module):
+ signed = module.is_quant_input_signed
+ zero_point = module.quant_input_zero_point()
+ return cls.zero_point_with_dtype(signed, zero_point)
+
+ @classmethod
+ def quant_weight_zero_point(cls, module):
+ signed = module.is_quant_weight_signed
+ zero_point = module.quant_weight_zero_point()
+ return cls.zero_point_with_dtype(signed, zero_point)
+
+ @classmethod
+ def quant_output_zero_point(cls, module):
+ signed = module.is_quant_output_signed
+ zero_point = module.quant_output_zero_point()
+ return cls.zero_point_with_dtype(signed, zero_point)
diff --git a/src/brevitas/export/base.py b/src/brevitas/export/manager.py
similarity index 98%
rename from src/brevitas/export/base.py
rename to src/brevitas/export/manager.py
index 66b9c8791..4fa381e4e 100644
--- a/src/brevitas/export/base.py
+++ b/src/brevitas/export/manager.py
@@ -139,18 +139,6 @@ def _restore_requires_grad(m: Module, previous_state):
p.requires_grad_(previous_state[n])
-class BaseHandler(Module, ABC):
-
- def attach_debug_info(self, module):
- pass
-
- def prepare_for_export(self, module):
- pass
-
- def reset(self):
- pass
-
-
class BaseManager(ABC):
target_name = None
diff --git a/src/brevitas/export/onnx/base.py b/src/brevitas/export/onnx/base.py
deleted file mode 100644
index cfc296711..000000000
--- a/src/brevitas/export/onnx/base.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from typing import Tuple, Union, Optional
-from abc import ABC
-from packaging import version
-from contextlib import ExitStack
-from io import BytesIO
-
-try:
- import onnx
- import onnxoptimizer as opt
-except ModuleNotFoundError:
- onnx = None
- opt = None
-
-import torch
-import torch.onnx
-from torch import Tensor
-from torch.nn import Module
-
-from brevitas import torch_version
-from brevitas.quant_tensor import QuantTensor
-from ..base import BaseManager, _set_layer_export_mode, ExportContext
-from ..base import _override_inp_caching_mode, _restore_inp_caching_mode
-
-
-class ONNXBaseManager(BaseManager, ABC):
-
- model_transforms = []
- onnx_passes = []
-
- @classmethod
- def apply_model_transforms(cls, model):
- for tranform in cls.model_transforms:
- model = tranform(model)
- return model
-
- @classmethod
- def solve_keep_initializers_as_inputs(cls, export_kwargs):
- # See https://github.com/pytorch/pytorch/commit/7583519b870e33ee3182f330c1bb8663559697b6
- ka = 'keep_initializers_as_inputs'
- if torch_version >= version.parse('1.3.0') and ka not in export_kwargs:
- export_kwargs[ka] = True
-
- @classmethod
- def solve_enable_onnx_checker(cls, export_kwargs):
- ka = 'enable_onnx_checker'
- if torch_version >= version.parse('1.5.0') and ka not in export_kwargs:
- export_kwargs[ka] = False
-
- @classmethod
- def export(
- cls,
- module: Module,
- input_shape: Optional[Tuple[int, ...]] = None,
- export_path: Optional[str] = None,
- input_t: Optional[Union[Tensor, QuantTensor]] = None,
- **kwargs):
- return cls.export_onnx(module, input_shape, export_path, input_t, **kwargs)
-
- @classmethod
- def export_onnx(
- cls,
- module: Module,
- input_shape: Optional[Tuple[int, ...]] = None,
- export_path: Optional[str] = None,
- input_t: Optional[Union[Tensor, QuantTensor]] = None,
- **kwargs):
- """
- * input_shape : tuple describing the shape of network input e.g. (1, 1, 28, 28)
- * export_path : ONNX filename to export to
- * input_t : if specified, do an initial forward pass with this value. this
- may be necessary for QuantTensor caching.
- * torch_onnx_kwargs : will be passed as kwargs to torch.onnx.export
- """
-
- if onnx is None or opt is None:
- raise ModuleNotFoundError("Installation of onnx and onnxoptimizer is required.")
- if input_shape is None and input_t is None:
- raise RuntimeError("Export requires to pass in either input_shape or input_t")
- if input_shape is not None and input_t is not None:
- raise RuntimeError("Export accepts either an input shape or an input tensor, not both")
-
- cls.solve_keep_initializers_as_inputs(kwargs)
- cls.solve_enable_onnx_checker(kwargs)
-
- with torch.no_grad():
- with ExportContext(cls):
- training_state = module.training
- module = module.eval()
- module.apply(cls.set_export_handler)
- if input_t is None:
- input_t = torch.empty(input_shape, dtype=torch.float)
- # do a forward pass with the dummy input to e.g. store input/output shapes
- cls._cache_inp_out(module, input_t)
- # Dequantize QuantTensor, if any
- if isinstance(input_t, QuantTensor):
- input_t = input_t.value
- # enable export mode, this triggers collecting export values into handlers
- module.apply(lambda m: cls.set_export_mode(m, enabled=True))
- # temporarily disable input caching to avoid collectives empty debug values
- module.apply(lambda m: _override_inp_caching_mode(m, enabled=False))
- # perform export pass
- with ExitStack() as stack:
- for mgr in cls._trace_patches():
- stack.enter_context(mgr)
- if export_path is not None:
- torch.onnx.export(module, input_t, export_path, **kwargs)
- else:
- model_bytes = BytesIO()
- torch.onnx.export(module, input_t, model_bytes, **kwargs)
- # restore the model to previous properties
- module.apply(lambda m: _restore_inp_caching_mode(m))
- module.apply(lambda m: cls.set_export_mode(m, enabled=False))
- module.train(training_state)
- # do some cleanup on the exported ONNX model
- if export_path is not None:
- model = onnx.load(export_path)
- else:
- model = onnx.ModelProto.FromString(model_bytes.getvalue())
- model = opt.optimize(model, cls.onnx_passes)
- model = cls.apply_model_transforms(model)
- if export_path is not None:
- onnx.save(model, export_path)
- return model
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/finn/function/__init__.py b/src/brevitas/export/onnx/finn/function/__init__.py
index e69de29bb..937bb278a 100644
--- a/src/brevitas/export/onnx/finn/function/__init__.py
+++ b/src/brevitas/export/onnx/finn/function/__init__.py
@@ -0,0 +1 @@
+DOMAIN_STRING = "finn.custom_op.general"
diff --git a/src/brevitas/export/onnx/finn/function/acc.py b/src/brevitas/export/onnx/finn/function/acc.py
index 9b7a52ad0..d244d654f 100644
--- a/src/brevitas/export/onnx/finn/function/acc.py
+++ b/src/brevitas/export/onnx/finn/function/acc.py
@@ -1,8 +1,10 @@
import torch
from torch.autograd import Function
+from . import DOMAIN_STRING
-class QuantAvgPool2dPlaceholderFunction(Function):
+
+class QuantAvgPool2dFn(Function):
@staticmethod
def symbolic(g, x, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_type):
@@ -10,7 +12,7 @@ def symbolic(g, x, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_t
x = g.op('Div', x, scale, activation_qnt_s=qnt_type)
ret = g.op(
'QuantAvgPool2d', x,
- domain_s="finn.custom_op.general",
+ domain_s=DOMAIN_STRING,
kernel_i=kernel,
stride_i=stride,
signed_i=signed,
diff --git a/src/brevitas/export/onnx/finn/function/act.py b/src/brevitas/export/onnx/finn/function/act.py
index d33a5df2a..d3be089fa 100644
--- a/src/brevitas/export/onnx/finn/function/act.py
+++ b/src/brevitas/export/onnx/finn/function/act.py
@@ -1,19 +1,24 @@
from torch.autograd import Function
+from . import DOMAIN_STRING
-class QuantHardTanhPlaceholderFunction(Function):
+
+class QuantHardTanhFn(Function):
@staticmethod
def symbolic(g, input, qnt_type, thres, bias, scale):
if qnt_type == "BIPOLAR":
return g.op(
'MultiThreshold', input, thres,
- domain_s="finn.custom_op.general",
+ domain_s=DOMAIN_STRING,
out_dtype_s=qnt_type,
out_scale_f=2.0,
out_bias_f=-1.0)
else:
- ret = g.op('MultiThreshold', input, thres, domain_s="finn.custom_op.general", out_dtype_s=qnt_type)
+ ret = g.op(
+ 'MultiThreshold', input, thres,
+ domain_s=DOMAIN_STRING,
+ out_dtype_s=qnt_type)
if bias is not None:
ret = g.op('Add', ret, bias)
if scale is not None:
@@ -25,7 +30,7 @@ def forward(ctx, input, qnt_type, thres, bias, scale):
return input.clamp(0)
-class QuantReLUPlaceholderFunction(Function):
+class QuantReLUFn(Function):
@staticmethod
def symbolic(g, input, qnt_type, thres, bias, scale):
diff --git a/src/brevitas/export/onnx/finn/function/parameter.py b/src/brevitas/export/onnx/finn/function/parameter.py
index 98a336614..21ff8f117 100644
--- a/src/brevitas/export/onnx/finn/function/parameter.py
+++ b/src/brevitas/export/onnx/finn/function/parameter.py
@@ -2,7 +2,7 @@
from torch.autograd import Function
-class QuantizedLinearPlaceholderFunction(Function):
+class QuantizedLinearFn(Function):
@staticmethod
def symbolic(g, x, Wt, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, bias):
@@ -23,11 +23,13 @@ def symbolic(g, x, Wt, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_sha
def forward(ctx, x, Wt, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, bias):
return torch.empty(out_shape, dtype=torch.float, device=x.device)
-class QuantizedConvNdPlaceholderFunction(Function):
+
+class QuantizedConvNdFn(Function):
@staticmethod
def symbolic(
- g, x, W, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, pads, strides, bias, kernel_shape, groups, dilations):
+ g, x, W, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, pads, strides,
+ bias, kernel_shape, groups, dilations):
ret = g.op(
'Conv', x, W,
weight_qnt_s=w_qnt_type,
@@ -50,5 +52,6 @@ def symbolic(
@staticmethod
def forward(
- ctx, x, W, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, pads, strides, bias, kernel_shape, groups, dilations):
+ ctx, x, W, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, pads, strides,
+ bias, kernel_shape, groups, dilations):
return torch.empty(out_shape, dtype=torch.float, device=x.device)
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/finn/handler/acc.py b/src/brevitas/export/onnx/finn/handler/acc.py
index 95afcc4b3..1d08243b1 100644
--- a/src/brevitas/export/onnx/finn/handler/acc.py
+++ b/src/brevitas/export/onnx/finn/handler/acc.py
@@ -2,7 +2,7 @@
from brevitas.nn import QuantAvgPool2d
from .base import FINNQuantIOHandler
-from ..function.acc import QuantAvgPool2dPlaceholderFunction
+from ..function.acc import QuantAvgPool2dFn
class FINNQuantAvgPool2dHandler(FINNQuantIOHandler):
@@ -48,5 +48,5 @@ def prepare_for_export(self, module):
'qnt_type': self.quant_input_type(module)}
def symbolic_execution(self, inp: Tensor):
- ret = QuantAvgPool2dPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = QuantAvgPool2dFn.apply(inp, *self.symbolic_kwargs.values())
return ret
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/finn/handler/act.py b/src/brevitas/export/onnx/finn/handler/act.py
index b46bbaac7..b5148e175 100644
--- a/src/brevitas/export/onnx/finn/handler/act.py
+++ b/src/brevitas/export/onnx/finn/handler/act.py
@@ -4,10 +4,11 @@
from torch import Tensor
from brevitas.nn import QuantReLU, QuantHardTanh, QuantIdentity
-from .base import FINNQuantInputHandler, FINNQuantIOHandler
-from ..function.act import QuantReLUPlaceholderFunction, QuantHardTanhPlaceholderFunction
+from .base import FINNQuantInputHandler
+from ..function.act import QuantReLUFn, QuantHardTanhFn
from ..utils import finn_datatype
+
class FINNQuantReLUHandler(FINNQuantInputHandler):
handled_layer = QuantReLU
@@ -49,7 +50,7 @@ def prepare_for_export(self, module: QuantReLU):
'scale': self.quant_act_scale(module)}
def symbolic_execution(self, inp: Tensor):
- ret = QuantReLUPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = QuantReLUFn.apply(inp, *self.symbolic_kwargs.values())
return ret
@@ -126,7 +127,7 @@ def prepare_for_export(self, module: QuantHardTanh):
'scale': self.quant_act_scale(module)}
def symbolic_execution(self, inp: Tensor):
- ret = QuantHardTanhPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = QuantHardTanhFn.apply(inp, *self.symbolic_kwargs.values())
return ret
diff --git a/src/brevitas/export/onnx/finn/handler/parameter.py b/src/brevitas/export/onnx/finn/handler/parameter.py
index 386cfbf2a..ccbb4aaef 100644
--- a/src/brevitas/export/onnx/finn/handler/parameter.py
+++ b/src/brevitas/export/onnx/finn/handler/parameter.py
@@ -6,10 +6,10 @@
from brevitas.nn import QuantLinear, QuantConv2d, QuantConv1d
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
-from brevitas.export.onnx.handler import Kernel2dApplHandler, Kernel1dApplHandler
+from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin, Kernel1dApplHandlerMixin
from .base import FINNQuantIOHandler
-from ..function.parameter import QuantizedLinearPlaceholderFunction
-from ..function.parameter import QuantizedConvNdPlaceholderFunction
+from ..function.parameter import QuantizedLinearFn
+from ..function.parameter import QuantizedConvNdFn
from ..utils import finn_datatype
QuantConvNd = Union[QuantConv1d, QuantConv2d]
@@ -18,7 +18,7 @@
class FINNQuantWBIOLHandler(FINNQuantIOHandler, ABC):
@staticmethod
- def sanity_check(module: QuantWBIOL):
+ def validate(module: QuantWBIOL):
assert module.is_weight_quant_enabled
assert not module.is_input_quant_enabled
assert not module.is_output_quant_enabled
@@ -80,7 +80,7 @@ def quant_output_shape(module: QuantLinear):
return shape
def prepare_for_export(self, module):
- self.sanity_check(module)
+ self.validate(module)
self.symbolic_kwargs = {
'Wt': self.int_weight_transposed(module),
'w_qnt_scale': self.quant_weight_scale(module),
@@ -91,7 +91,7 @@ def prepare_for_export(self, module):
'bias': self.maybe_int_bias(module)}
def symbolic_execution(self, inp: Tensor):
- ret = QuantizedLinearPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = QuantizedLinearFn.apply(inp, *self.symbolic_kwargs.values())
return ret
@@ -120,7 +120,7 @@ def maybe_int_bias(module: QuantWBIOL):
return bias
def prepare_for_export(self, module: QuantConvNd):
- self.sanity_check(module)
+ self.validate(module)
maybe_int_bias = self.maybe_int_bias(module)
maybe_quant_bias_scale = self.maybe_quant_bias_scale(module)
if (maybe_quant_bias_scale is not None
@@ -142,11 +142,11 @@ def prepare_for_export(self, module: QuantConvNd):
'dilations': self.dilation(module)}
def symbolic_execution(self, inp: Tensor):
- ret = QuantizedConvNdPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = QuantizedConvNdFn.apply(inp, *self.symbolic_kwargs.values())
return ret
-class FINNQuantConv1dHandler(FINNQuantConvNdHandler, Kernel1dApplHandler):
+class FINNQuantConv1dHandler(FINNQuantConvNdHandler, Kernel1dApplHandlerMixin):
handled_layer = QuantConv1d
@staticmethod
@@ -157,7 +157,7 @@ def quant_weight_scale(module: QuantConv1d):
return quant_weight_scale
-class FINNQuantConv2dHandler(FINNQuantConvNdHandler, Kernel2dApplHandler):
+class FINNQuantConv2dHandler(FINNQuantConvNdHandler, Kernel2dApplHandlerMixin):
handled_layer = QuantConv2d
@staticmethod
diff --git a/src/brevitas/export/onnx/finn/manager.py b/src/brevitas/export/onnx/finn/manager.py
index ee28cfeb7..9d05e0704 100644
--- a/src/brevitas/export/onnx/finn/manager.py
+++ b/src/brevitas/export/onnx/finn/manager.py
@@ -5,8 +5,8 @@
from torch.nn import Module, Sequential
from torch.autograd import Function
-from brevitas.export.onnx.base import ONNXBaseManager, onnx
-from brevitas.export.base import _set_layer_export_handler, _set_layer_export_mode
+from brevitas.export.onnx.manager import ONNXBaseManager, onnx
+from brevitas.export.manager import _set_layer_export_handler, _set_layer_export_mode
from brevitas.quant_tensor import QuantTensor
from ..transform import move_domain_attributes_into_domain
diff --git a/src/brevitas/export/onnx/finn/transform.py b/src/brevitas/export/onnx/finn/transform.py
index cb2ca399a..93bde9a66 100644
--- a/src/brevitas/export/onnx/finn/transform.py
+++ b/src/brevitas/export/onnx/finn/transform.py
@@ -2,7 +2,7 @@
from torch.nn import Module
-from ..base import onnx
+from ..manager import onnx
def move_quant_attributes_into_annotations(model: Module):
diff --git a/src/brevitas/export/onnx/generic/function.py b/src/brevitas/export/onnx/generic/function.py
index 9fac62559..71641d87f 100644
--- a/src/brevitas/export/onnx/generic/function.py
+++ b/src/brevitas/export/onnx/generic/function.py
@@ -1,52 +1,70 @@
+import torch
from torch.autograd import Function
-from brevitas.core.quant import IntQuant, DecoupledIntQuant
+from brevitas.function import binary_sign
+from brevitas.core.bit_width import BitWidthConst
+from brevitas.core.quant import IntQuant, TruncIntQuant
+from brevitas.quant.solver.common import solve_float_to_int_impl_from_enum
-class QuantPlaceholderFunction(Function):
+
+DOMAIN_STRING = "onnx.brevitas"
+
+
+class BrevitasBinaryQuantFn(Function):
@staticmethod
- def symbolic(g, x, scale, zero_point, bit_width, narrow_range, signed):
+ def symbolic(g, x, scale, zero_point, bit_width, narrow_range, signed, rounding_mode):
ret = g.op(
- 'Quant',
- x, scale, zero_point, bit_width,
- signed_i=int(signed),
- narrow_i=int(narrow_range))
+ 'BipolarQuant',
+ x, scale,
+ domain_s=DOMAIN_STRING)
return ret
@staticmethod
- def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed):
- quant = IntQuant(narrow_range=narrow_range, signed=signed)
- x = quant(scale, zero_point, bit_width, x)
- return x
+ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding_mode):
+ y = binary_sign(x) * scale
+ return y
+
-class DecoupledQuantPlaceholderFunction(Function):
+class BrevitasQuantFn(Function):
@staticmethod
- def symbolic(g, x, pre_scale, pre_zero_point, scale, zero_point, bit_width, narrow_range, signed):
+ def symbolic(g, x, scale, zero_point, bit_width, narrow_range, signed, rounding_mode):
ret = g.op(
- 'DecoupledQuant',
- x, pre_scale, pre_zero_point, scale, zero_point, bit_width,
+ 'Quant',
+ x, scale, zero_point, bit_width,
+ domain_s=DOMAIN_STRING,
+ rounding_mode_s=rounding_mode,
signed_i=int(signed),
narrow_i=int(narrow_range))
return ret
@staticmethod
- def forward(ctx, x, pre_scale, pre_zero_point, scale, zero_point, bit_width, narrow_range, signed):
- quant = DecoupledIntQuant(narrow_range=narrow_range, signed=signed)
- x = quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
- return x
+ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding_mode):
+ float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode)
+ quant = IntQuant(
+ float_to_int_impl=float_to_int_impl(), narrow_range=narrow_range, signed=signed)
+ y = quant(scale, zero_point, bit_width, x)
+ return y
-class TruncPlaceholderFunction(Function):
+class BrevitasTruncFn(Function):
@staticmethod
- def symbolic(g, x, scale, zero_point, bit_width):
+ def symbolic(g, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode):
ret = g.op(
'Trunc',
- x, scale, zero_point, bit_width)
+ x, scale, zero_point, input_bit_width, output_bit_width,
+ rounding_mode_s=rounding_mode,
+ domain_s=DOMAIN_STRING)
return ret
@staticmethod
- def forward(ctx, x, scale, zero_point, bit_width):
- return x
\ No newline at end of file
+ def forward(ctx, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode):
+ float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode)
+ trunc = TruncIntQuant(
+ float_to_int_impl=float_to_int_impl(),
+ bit_width_impl=BitWidthConst(int(output_bit_width)))
+ y_tuple = trunc(x, scale, zero_point, input_bit_width)
+ return y_tuple[0]
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/generic/handler.py b/src/brevitas/export/onnx/generic/handler.py
index f2aff49b7..9818197cc 100644
--- a/src/brevitas/export/onnx/generic/handler.py
+++ b/src/brevitas/export/onnx/generic/handler.py
@@ -1,97 +1,126 @@
from abc import ABC
+from copy import copy
from torch import Tensor
from brevitas.export.onnx.handler import ONNXBaseHandler
-from brevitas.proxy import WeightQuantProxyFromInjector, DecoupledWeightQuantProxyFromInjector
+from brevitas.proxy import WeightQuantProxyFromInjector
+from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import BiasQuantProxyFromInjector
from brevitas.proxy import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector
-from .function import QuantPlaceholderFunction, TruncPlaceholderFunction
-from .function import DecoupledQuantPlaceholderFunction
+from .function import BrevitasQuantFn, BrevitasBinaryQuantFn, BrevitasTruncFn
-class StaticQuantProxyHandler(ONNXBaseHandler, ABC):
+class BrevitasQuantProxyHandler(ONNXBaseHandler, ABC):
+
+ def validate(self, module):
+ if module.bit_width() == 1:
+ assert module.zero_point() == 0, "Zero-point not supported for binary quant."
def prepare_for_export(self, module):
if module.is_quant_enabled:
+ self.validate(module)
self.symbolic_kwargs = {
'scale': module.scale(),
'zero_point': module.zero_point(),
'bit_width': module.bit_width(),
'narrow_range': module.is_narrow_range,
- 'signed': module.is_signed}
+ 'signed': module.is_signed,
+ 'rounding_mode': module.rounding_mode}
def symbolic_execution(self, x: Tensor):
- x = QuantPlaceholderFunction.apply(x, *self.symbolic_kwargs.values())
scale = self.symbolic_kwargs['scale']
zero_point = self.symbolic_kwargs['zero_point']
bit_width = self.symbolic_kwargs['bit_width']
+ if bit_width == 1:
+ x = BrevitasBinaryQuantFn.apply(x, *self.symbolic_kwargs.values())
+ else:
+ x = BrevitasQuantFn.apply(x, *self.symbolic_kwargs.values())
return x, scale, zero_point, bit_width
-class DecoupledStaticQuantProxyHandler(ONNXBaseHandler, ABC):
+class BrevitasWeightQuantProxyHandler(BrevitasQuantProxyHandler):
+ handled_layer = WeightQuantProxyFromInjector
- def prepare_for_export(self, module):
- if module.is_quant_enabled:
- self.symbolic_kwargs = {
- 'pre_scale': module.pre_scale(),
- 'pre_zero_point': module.pre_zero_point(),
- 'scale': module.scale(),
- 'zero_point': module.zero_point(),
- 'bit_width': module.bit_width(),
- 'narrow_range': module.is_narrow_range,
- 'signed': module.is_signed}
+ def __init__(self):
+ super().__init__()
+ self.quant_weights = None
+
+ def reset(self):
+ super().reset()
+ self.quant_weights = None
+
+ def prepare_for_export(self, module: WeightQuantProxyFromInjector):
+ super().prepare_for_export(module)
+ quant_weights = {tm.weight: tm.quant_weight().value for tm in module.tracked_module_list}
+ self.quant_weights = quant_weights
+ # override rounding mode since quantization has been pre-applied
+ self.symbolic_kwargs['rounding_mode'] = 'ROUND'
def symbolic_execution(self, x: Tensor):
- x = DecoupledQuantPlaceholderFunction.apply(x, *self.symbolic_kwargs.values())
- scale = self.symbolic_kwargs['scale']
- zero_point = self.symbolic_kwargs['zero_point']
- pre_scale = self.symbolic_kwargs['pre_scale']
- pre_zero_point = self.symbolic_kwargs['pre_zero_point']
- bit_width = self.symbolic_kwargs['bit_width']
- return x, pre_scale, pre_zero_point, scale, zero_point, bit_width
+ quant_weight = self.quant_weights[x]
+ return super().symbolic_execution(quant_weight)
-class DecoupledWeightQuantProxyHandler(DecoupledStaticQuantProxyHandler):
+class BrevitasDecoupledWeightQuantProxyHandler(BrevitasWeightQuantProxyHandler):
handled_layer = DecoupledWeightQuantProxyFromInjector
+ def __init__(self):
+ super().__init__()
+ self.extra_kwargs = {}
-class WeightQuantProxyHandler(StaticQuantProxyHandler):
- handled_layer = WeightQuantProxyFromInjector
+ def reset(self):
+ super().reset()
+ self.extra_kwargs = {}
+ def prepare_for_export(self, module: DecoupledWeightQuantProxyFromInjector):
+ super().prepare_for_export(module)
+ self.extra_kwargs['pre_scale'] = module.pre_scale()
+ self.extra_kwargs['pre_zero_point'] = module.pre_zero_point()
-class ActQuantProxyHandler(StaticQuantProxyHandler):
+ def symbolic_execution(self, x: Tensor):
+ out, scale, zero_point, bit_width = super().symbolic_execution(x)
+ pre_scale = self.extra_kwargs['pre_scale']
+ pre_zero_point = self.extra_kwargs['pre_zero_point']
+ return out, pre_scale, pre_zero_point, scale, zero_point, bit_width
+
+
+class BrevitasActQuantProxyHandler(BrevitasQuantProxyHandler):
handled_layer = ActQuantProxyFromInjector
-class BiasQuantProxyHandler(StaticQuantProxyHandler):
+class BrevitasBiasQuantProxyHandler(BrevitasQuantProxyHandler):
handled_layer = BiasQuantProxyFromInjector
def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
- scale = self.symbolic_kwargs.pop('scale')
- bit_width = self.symbolic_kwargs.pop('bit_width')
- zero_point = self.symbolic_kwargs.pop('zero_point')
+ # avoid in place pop in case the proxy is shared
+ symbolic_kwargs = copy(self.symbolic_kwargs)
+ scale = symbolic_kwargs.pop('scale')
+ bit_width = symbolic_kwargs.pop('bit_width')
+ zero_point = symbolic_kwargs.pop('zero_point')
if scale is None:
assert input_scale is not None, 'Input scale required for bias export'
scale = input_scale
if bit_width is None:
assert input_bit_width is not None, 'Input bit_width required for bias export'
bit_width = input_bit_width
- x = QuantPlaceholderFunction.apply(
- x, scale, zero_point, bit_width, *self.symbolic_kwargs.values())
- return x, scale, zero_point, bit_width
+ y = BrevitasQuantFn.apply(
+ x, scale, zero_point, bit_width, *symbolic_kwargs.values())
+ return y, scale, zero_point, bit_width
-class TruncQuantProxyHandler(ONNXBaseHandler):
+class BrevitasTruncQuantProxyHandler(ONNXBaseHandler):
handled_layer = TruncQuantProxyFromInjector
def prepare_for_export(self, module: TruncQuantProxyFromInjector):
self.symbolic_kwargs = {
- 'bit_width': module.bit_width()}
-
- def symbolic_execution(self, x: Tensor, scale: Tensor, zero_point: Tensor, bit_width: Tensor):
- x = TruncPlaceholderFunction.apply(
- x, scale, zero_point, *self.symbolic_kwargs.values())
- return x, scale, zero_point, self.symbolic_kwargs['bit_width']
\ No newline at end of file
+ 'output_bit_width': module.bit_width(),
+ 'rounding_mode': module.rounding_mode}
+
+ def symbolic_execution(
+ self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor):
+ y = BrevitasTruncFn.apply(
+ x, scale, zero_point, input_bit_width, *self.symbolic_kwargs.values())
+ return y, scale, zero_point, self.symbolic_kwargs['output_bit_width']
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/generic/manager.py b/src/brevitas/export/onnx/generic/manager.py
index a30c14bf0..9191a5551 100644
--- a/src/brevitas/export/onnx/generic/manager.py
+++ b/src/brevitas/export/onnx/generic/manager.py
@@ -1,15 +1,21 @@
+from typing import Tuple, Union, Optional
from torch.nn import Module
+from torch import Tensor
-from brevitas.export.onnx.base import ONNXBaseManager
+from brevitas.export.onnx.manager import ONNXBaseManager
from brevitas.export.onnx.transform import move_domain_attributes_into_domain
-from brevitas.export.base import _set_proxy_export_handler, _set_proxy_export_mode
+from brevitas.export.manager import _set_proxy_export_handler, _set_proxy_export_mode
-from .handler import ActQuantProxyHandler, BiasQuantProxyHandler, WeightQuantProxyHandler
-from .handler import TruncQuantProxyHandler, DecoupledWeightQuantProxyHandler
+from .handler import BrevitasActQuantProxyHandler
+from .handler import BrevitasBiasQuantProxyHandler
+from .handler import BrevitasWeightQuantProxyHandler
+from .handler import BrevitasTruncQuantProxyHandler
+from .handler import BrevitasDecoupledWeightQuantProxyHandler
class BrevitasONNXManager(ONNXBaseManager):
target_name = 'brevitas'
+ dequantize_tracing_input = False
model_transforms = [
move_domain_attributes_into_domain]
@@ -21,11 +27,11 @@ class BrevitasONNXManager(ONNXBaseManager):
"eliminate_unused_initializer"]
handlers = [
- ActQuantProxyHandler,
- BiasQuantProxyHandler,
- WeightQuantProxyHandler,
- DecoupledWeightQuantProxyHandler,
- TruncQuantProxyHandler
+ BrevitasActQuantProxyHandler,
+ BrevitasBiasQuantProxyHandler,
+ BrevitasWeightQuantProxyHandler,
+ BrevitasDecoupledWeightQuantProxyHandler,
+ BrevitasTruncQuantProxyHandler
]
@classmethod
@@ -36,4 +42,4 @@ def set_export_mode(cls, module: Module, enabled: bool):
@classmethod
def set_export_handler(cls, module: Module):
# proxy level export
- _set_proxy_export_handler(cls, module)
+ _set_proxy_export_handler(cls, module)
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/handler.py b/src/brevitas/export/onnx/handler.py
index 5ec901713..20a52f3d8 100644
--- a/src/brevitas/export/onnx/handler.py
+++ b/src/brevitas/export/onnx/handler.py
@@ -1,13 +1,18 @@
from abc import ABC, abstractmethod
from torch import Tensor
-from torch.nn import Module
-from .debug import DebugMarkerFunction
-from ..base import BaseHandler
+from brevitas.export.onnx.debug import DebugMarkerFunction
+from brevitas.export.handler import BaseHandler
+__all__ = [
+ 'Kernel1dApplHandlerMixin',
+ 'Kernel2dApplHandlerMixin',
+ 'ONNXBaseHandler'
+]
-class Kernel1dApplHandler(ABC):
+
+class Kernel1dApplHandlerMixin(ABC):
@staticmethod
def padding(module):
@@ -42,7 +47,7 @@ def kernel_shape(module):
return list(module.kernel_size)
-class Kernel2dApplHandler(ABC):
+class Kernel2dApplHandlerMixin(ABC):
@staticmethod
def padding(module):
@@ -91,6 +96,9 @@ def prepare_for_export(self, module):
def symbolic_execution(self, *args, **kwargs):
pass
+ def reset(self):
+ self.symbolic_kwargs = None
+
def attach_debug_info(self, m):
self.export_debug_name = m.export_debug_name
self.debug_input = m.cache_inference_quant_inp and not m.cache_quant_io_metadata_only
diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py
new file mode 100644
index 000000000..be758b21f
--- /dev/null
+++ b/src/brevitas/export/onnx/manager.py
@@ -0,0 +1,128 @@
+from typing import Tuple, Union, Optional
+from abc import ABC
+from packaging import version
+from contextlib import ExitStack
+from io import BytesIO
+import warnings
+
+try:
+ import onnx
+ import onnxoptimizer as opt
+except ModuleNotFoundError:
+ onnx = None
+ opt = None
+
+import torch
+import torch.onnx
+from torch import Tensor
+from torch.nn import Module
+
+from brevitas import torch_version
+from brevitas.quant_tensor import QuantTensor
+from ..manager import BaseManager, ExportContext
+from ..manager import _override_inp_caching_mode, _restore_inp_caching_mode
+
+
+class ONNXBaseManager(BaseManager, ABC):
+
+ model_transforms = []
+ onnx_passes = []
+ dequantize_tracing_input = True
+
+ @classmethod
+ def apply_model_transforms(cls, model):
+ for tranform in cls.model_transforms:
+ model = tranform(model)
+ return model
+
+ @classmethod
+ def solve_keep_initializers_as_inputs(cls, export_kwargs):
+ # See https://github.com/pytorch/pytorch/commit/7583519b870e33ee3182f330c1bb8663559697b6
+ ka = 'keep_initializers_as_inputs'
+ if torch_version >= version.parse('1.3.0') and ka not in export_kwargs:
+ export_kwargs[ka] = True
+
+ @classmethod
+ def solve_enable_onnx_checker(cls, export_kwargs):
+ ka = 'enable_onnx_checker'
+ if torch_version >= version.parse('1.5.0') and ka not in export_kwargs:
+ export_kwargs[ka] = False
+
+ @classmethod
+ def export_onnx(
+ cls,
+ module: Module,
+ input_shape: Optional[Tuple[int, ...]] = None,
+ export_path: Optional[str] = None,
+ input_t: Optional[Union[Tensor, QuantTensor]] = None,
+ disable_warnings=True,
+ **kwargs):
+
+ if onnx is None or opt is None:
+ raise ModuleNotFoundError("Installation of onnx and onnxoptimizer is required.")
+ if input_shape is None and input_t is None:
+ raise RuntimeError("Export requires to pass in either input_shape or input_t")
+ if input_shape is not None and input_t is not None:
+ raise RuntimeError("Export accepts either an input shape or an input tensor, not both")
+
+ cls.solve_keep_initializers_as_inputs(kwargs)
+ cls.solve_enable_onnx_checker(kwargs)
+
+ with torch.no_grad():
+ with ExportContext(cls):
+ with warnings.catch_warnings():
+ if disable_warnings:
+ warnings.simplefilter("ignore")
+ training_state = module.training
+ module = module.eval()
+ module.apply(cls.set_export_handler)
+ if input_t is None:
+ input_t = torch.empty(input_shape, dtype=torch.float)
+ # do a forward pass with the dummy input to e.g. store input/output shapes
+ cls._cache_inp_out(module, input_t)
+ # Dequantize QuantTensor, if any and enabled
+ if isinstance(input_t, QuantTensor):
+ if cls.dequantize_tracing_input:
+ input_t = input_t.value
+ else:
+ input_t = (input_t,)
+ # enable export mode, this triggers collecting export values into handlers
+ module.apply(lambda m: cls.set_export_mode(m, enabled=True))
+ # temporarily disable input caching to avoid collectives empty debug values
+ module.apply(lambda m: _override_inp_caching_mode(m, enabled=False))
+ # perform export pass
+ with ExitStack() as stack:
+ for mgr in cls._trace_patches():
+ stack.enter_context(mgr)
+ if export_path is not None:
+ export_target = export_path
+ else:
+ model_bytes = BytesIO()
+ export_target = model_bytes
+ torch.onnx.export(module, input_t, export_target, **kwargs)
+
+ # restore the model to previous properties
+ module.apply(lambda m: _restore_inp_caching_mode(m))
+ module.apply(lambda m: cls.set_export_mode(m, enabled=False))
+ module.train(training_state)
+
+ # do some cleanup on the exported ONNX model
+ if export_path is not None:
+ model = onnx.load(export_path)
+ else:
+ model = onnx.ModelProto.FromString(model_bytes.getvalue())
+ model = opt.optimize(model, cls.onnx_passes)
+ model = cls.apply_model_transforms(model)
+ if export_path is not None:
+ onnx.save(model, export_path)
+ return model
+
+ @classmethod
+ def export(
+ cls,
+ module: Module,
+ input_shape: Optional[Tuple[int, ...]] = None,
+ export_path: Optional[str] = None,
+ input_t: Optional[Union[Tensor, QuantTensor]] = None,
+ **kwargs):
+ return cls.export_onnx(module, input_shape, export_path, input_t, **kwargs)
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py
index 88cf2fe6f..44b6bc7bd 100644
--- a/src/brevitas/export/onnx/standard/function.py
+++ b/src/brevitas/export/onnx/standard/function.py
@@ -1,24 +1,26 @@
import torch
from torch.autograd import Function
+
from . import OPSET
+
AXIS_OPSET = 11
-class DequantizeLinearFunction(Function):
+class DequantizeLinearFn(Function):
@staticmethod
def symbolic(
g, x,
input_scale,
input_zero_point,
- axis):
- if axis is not None and OPSET >= AXIS_OPSET:
+ input_axis):
+ if input_axis is not None and OPSET >= AXIS_OPSET:
ret = g.op(
'DequantizeLinear', x,
input_scale,
input_zero_point,
- axis_i=axis)
+ axis_i=input_axis)
else:
ret = g.op(
'DequantizeLinear', x,
@@ -31,24 +33,25 @@ def forward(
ctx, int_x,
input_scale,
input_zero_point,
- axis):
+ input_axis):
return int_x.float()
-class QuantizeLinearFunction(Function):
+class QuantizeLinearFn(Function):
@staticmethod
def symbolic(
g, x,
output_scale,
ouput_zero_point,
- axis):
- if axis is not None and OPSET >= AXIS_OPSET:
+ output_dtype,
+ output_axis):
+ if output_axis is not None and OPSET >= AXIS_OPSET:
ret = g.op(
'QuantizeLinear', x,
output_scale,
ouput_zero_point,
- axis_i=axis)
+ axis_i=output_axis)
else:
ret = g.op(
'QuantizeLinear', x,
@@ -61,115 +64,7 @@ def forward(
ctx, x,
output_scale,
ouput_zero_point,
- axis):
- return x.int()
-
-
-class QLinearConvFunction(Function):
+ output_dtype,
+ output_axis):
+ return x.type(output_dtype)
- @staticmethod
- def symbolic(
- g, int_x,
- input_scale,
- input_zero_point,
- int_weight,
- weight_scale,
- weight_zero_point,
- output_scale,
- ouput_zero_point,
- int_bias,
- out_shape,
- kernel_size,
- padding,
- stride,
- groups,
- dilation):
- if int_bias is not None:
- ret = g.op(
- 'QLinearConv', int_x,
- input_scale,
- input_zero_point,
- int_weight,
- weight_scale,
- weight_zero_point,
- output_scale,
- ouput_zero_point,
- int_bias,
- kernel_shape_i=kernel_size,
- pads_i=padding,
- strides_i=stride,
- group_i=groups,
- dilations_i=dilation)
- else:
- ret = g.op(
- 'QLinearConv', int_x,
- input_scale,
- input_zero_point,
- int_weight,
- weight_scale,
- weight_zero_point,
- output_scale,
- ouput_zero_point,
- kernel_shape_i=kernel_size,
- pads_i=padding,
- strides_i=stride,
- group_i=groups,
- dilations_i=dilation)
- return ret
-
- @staticmethod
- def forward(
- ctx, int_x,
- input_scale,
- input_zero_point,
- int_weight,
- weight_scale,
- weight_zero_point,
- output_scale,
- output_zero_point,
- bias,
- out_shape,
- kernel_size,
- padding,
- stride,
- groups,
- dilation):
- return torch.empty(out_shape, dtype=output_zero_point.dtype, device=int_x.device)
-
-
-class QLinearMatMulFunction(Function):
-
- @staticmethod
- def symbolic(
- g, int_x,
- input_scale,
- input_zero_point,
- int_weight,
- weight_scale,
- weight_zero_point,
- output_scale,
- ouput_zero_point,
- out_shape):
- ret = g.op(
- 'QLinearMatMul', int_x,
- input_scale,
- input_zero_point,
- int_weight,
- weight_scale,
- weight_zero_point,
- output_scale,
- ouput_zero_point)
- return ret
-
- @staticmethod
- def forward(
- ctx, int_x,
- input_scale,
- input_zero_point,
- int_weight,
- weight_scale,
- weight_zero_point,
- output_scale,
- output_zero_point,
- out_shape):
- return torch.empty(out_shape, dtype=output_zero_point.dtype, device=int_x.device)
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/standard/handler.py b/src/brevitas/export/onnx/standard/handler.py
new file mode 100644
index 000000000..68f513d90
--- /dev/null
+++ b/src/brevitas/export/onnx/standard/handler.py
@@ -0,0 +1,16 @@
+from abc import ABC
+
+
+from brevitas.export.onnx.handler import ONNXBaseHandler
+from brevitas.export.handler import BitWidthHandlerMixin, ZeroPointHandlerMixin
+
+
+class StdONNXQuantLayerHandler(BitWidthHandlerMixin, ZeroPointHandlerMixin, ONNXBaseHandler, ABC):
+
+ @classmethod
+ def quant_axis(cls, scale):
+ for i, s in enumerate(scale.shape):
+ if s != 1:
+ return i
+ return None
+
diff --git a/src/brevitas/export/onnx/standard/manager.py b/src/brevitas/export/onnx/standard/manager.py
deleted file mode 100644
index 3b12fe83f..000000000
--- a/src/brevitas/export/onnx/standard/manager.py
+++ /dev/null
@@ -1,99 +0,0 @@
-from typing import Tuple, Optional, Union
-from packaging import version
-
-from torch import Tensor
-from torch.nn import functional as F
-from torch.nn import Module
-
-from brevitas import torch_version
-from brevitas.quant_tensor import QuantTensor
-from brevitas.export.onnx.base import ONNXBaseManager
-from brevitas.export.base import _set_layer_export_handler, _set_layer_export_mode
-
-from .function import QuantizeLinearFunction, DequantizeLinearFunction
-from .handler.base import StdONNXQuantLayerHandler
-from .handler.parameter import StdONNXQuantConv2dHandler
-from .handler.parameter import StdONNXQuantConv1dHandler
-from .handler.parameter import StdONNXQuantLinearHandler
-from .handler.act import StdONNXQuantReLUHandler
-from .handler.act import StdONNXQuantHardTanhHandler
-from .handler.act import StdONNXQuantIdentityHandler
-from .handler.act import StdONNXQuantTanhHandler
-from .handler.act import StdONNXQuantSigmoidHandler
-from .handler.pool import StdONNXQuantMaxPool1d
-from .handler.pool import StdONNXQuantMaxPool2d
-from . import OPSET
-
-
-class StdONNXManager(ONNXBaseManager):
- target_name = 'StdONNX'
-
- _fn_to_cache = [
- F.relu,
- F.relu6,
- F.hardtanh,
- F.max_pool1d,
- F.max_pool2d,
- F.max_pool3d,
- F.adaptive_max_pool1d,
- F.adaptive_max_pool2d,
- F.adaptive_max_pool3d,
- ]
-
- handlers = [
- StdONNXQuantConv1dHandler,
- StdONNXQuantConv2dHandler,
- StdONNXQuantLinearHandler,
- StdONNXQuantReLUHandler,
- StdONNXQuantHardTanhHandler,
- StdONNXQuantIdentityHandler,
- StdONNXQuantTanhHandler,
- StdONNXQuantSigmoidHandler,
- StdONNXQuantMaxPool1d,
- StdONNXQuantMaxPool2d]
-
- onnx_passes = [
- # remove unused graph inputs & initializers
- "eliminate_unused_initializer"]
-
- @classmethod
- def solve_enable_onnx_checker(cls, export_kwargs):
- if torch_version >= version.parse('1.5.0'):
- export_kwargs['enable_onnx_checker'] = True
-
- @classmethod
- def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs):
- cached_io = cls._fn_cache.pop(0)
- if cached_io is not None:
- cached_inp, cached_out = cached_io
- if cached_inp is not None:
- deq_kwargs = StdONNXQuantLayerHandler.dequant_symbolic_kwargs_from_cached_io(
- cached_inp)
- input = DequantizeLinearFunction.apply(input, *deq_kwargs.values())
- output = fn(input, *args, **kwargs)
- if cached_out is not None:
- q_kwargs = StdONNXQuantLayerHandler.quant_symbolic_kwargs_from_cached_io(cached_out)
- output = QuantizeLinearFunction.apply(output, *q_kwargs.values())
- else:
- output = fn(input, *args, **kwargs)
- return output
-
- @classmethod
- def set_export_mode(cls, module: Module, enabled: bool):
- _set_layer_export_mode(module, enabled)
-
- @classmethod
- def set_export_handler(cls, module: Module):
- _set_layer_export_handler(cls, module)
-
- @classmethod
- def export_onnx(
- cls,
- module: Module,
- input_shape: Tuple[int, ...],
- export_path: Optional[str] = None,
- input_t: Optional[Union[Tensor, QuantTensor]] = None,
- **kwargs):
- output = super().export_onnx(
- module, input_shape, export_path, input_t, opset_version=OPSET, **kwargs)
- return output
\ No newline at end of file
diff --git a/src/brevitas/export/common/__init__.py b/src/brevitas/export/onnx/standard/qdq/__init__.py
similarity index 100%
rename from src/brevitas/export/common/__init__.py
rename to src/brevitas/export/onnx/standard/qdq/__init__.py
diff --git a/src/brevitas/export/onnx/standard/handler/__init__.py b/src/brevitas/export/onnx/standard/qoperator/__init__.py
similarity index 100%
rename from src/brevitas/export/onnx/standard/handler/__init__.py
rename to src/brevitas/export/onnx/standard/qoperator/__init__.py
diff --git a/src/brevitas/export/onnx/standard/qoperator/function.py b/src/brevitas/export/onnx/standard/qoperator/function.py
new file mode 100644
index 000000000..827957272
--- /dev/null
+++ b/src/brevitas/export/onnx/standard/qoperator/function.py
@@ -0,0 +1,116 @@
+import torch
+from torch.autograd import Function
+
+
+class QLinearConvFn(Function):
+
+ @staticmethod
+ def symbolic(
+ g, int_x,
+ input_scale,
+ input_zero_point,
+ int_weight,
+ weight_scale,
+ weight_zero_point,
+ output_scale,
+ ouput_zero_point,
+ output_dtype,
+ int_bias,
+ out_shape,
+ kernel_size,
+ padding,
+ stride,
+ groups,
+ dilation):
+ if int_bias is not None:
+ ret = g.op(
+ 'QLinearConv', int_x,
+ input_scale,
+ input_zero_point,
+ int_weight,
+ weight_scale,
+ weight_zero_point,
+ output_scale,
+ ouput_zero_point,
+ int_bias,
+ kernel_shape_i=kernel_size,
+ pads_i=padding,
+ strides_i=stride,
+ group_i=groups,
+ dilations_i=dilation)
+ else:
+ ret = g.op(
+ 'QLinearConv', int_x,
+ input_scale,
+ input_zero_point,
+ int_weight,
+ weight_scale,
+ weight_zero_point,
+ output_scale,
+ ouput_zero_point,
+ kernel_shape_i=kernel_size,
+ pads_i=padding,
+ strides_i=stride,
+ group_i=groups,
+ dilations_i=dilation)
+ return ret
+
+ @staticmethod
+ def forward(
+ ctx, int_x,
+ input_scale,
+ input_zero_point,
+ int_weight,
+ weight_scale,
+ weight_zero_point,
+ output_scale,
+ output_zero_point,
+ output_dtype,
+ int_bias,
+ out_shape,
+ kernel_size,
+ padding,
+ stride,
+ groups,
+ dilation):
+ return torch.empty(out_shape, dtype=output_dtype, device=int_x.device)
+
+
+class QLinearMatMulFn(Function):
+
+ @staticmethod
+ def symbolic(
+ g, int_x,
+ input_scale,
+ input_zero_point,
+ int_weight,
+ weight_scale,
+ weight_zero_point,
+ output_scale,
+ ouput_zero_point,
+ output_dtype,
+ out_shape):
+ ret = g.op(
+ 'QLinearMatMul', int_x,
+ input_scale,
+ input_zero_point,
+ int_weight,
+ weight_scale,
+ weight_zero_point,
+ output_scale,
+ ouput_zero_point)
+ return ret
+
+ @staticmethod
+ def forward(
+ ctx, int_x,
+ input_scale,
+ input_zero_point,
+ int_weight,
+ weight_scale,
+ weight_zero_point,
+ output_scale,
+ output_zero_point,
+ output_dtype,
+ out_shape):
+ return torch.empty(out_shape, dtype=output_dtype, device=int_x.device)
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/__init__.py b/src/brevitas/export/onnx/standard/qoperator/handler/__init__.py
similarity index 100%
rename from src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/__init__.py
rename to src/brevitas/export/onnx/standard/qoperator/handler/__init__.py
diff --git a/src/brevitas/export/onnx/standard/handler/act.py b/src/brevitas/export/onnx/standard/qoperator/handler/act.py
similarity index 77%
rename from src/brevitas/export/onnx/standard/handler/act.py
rename to src/brevitas/export/onnx/standard/qoperator/handler/act.py
index 872e9d51f..2dfeac13e 100644
--- a/src/brevitas/export/onnx/standard/handler/act.py
+++ b/src/brevitas/export/onnx/standard/qoperator/handler/act.py
@@ -1,16 +1,15 @@
-from typing import Union
-from abc import ABC, abstractmethod
+from abc import ABC
import torch
from torch import Tensor
from brevitas.nn import QuantReLU, QuantIdentity, QuantHardTanh, QuantTanh, QuantSigmoid
from brevitas.nn.quant_layer import QuantNonLinearActLayer as QuantNLAL
-from ..function import QuantizeLinearFunction, DequantizeLinearFunction
-from .base import StdONNXQuantLayerHandler
+from brevitas.export.onnx.standard.function import QuantizeLinearFn, DequantizeLinearFn
+from .base import StdQOpONNXQuantLayerHandler
-class StdONNXQuantNLALHandler(StdONNXQuantLayerHandler, ABC):
+class StdQOpONNXQuantNLALHandler(StdQOpONNXQuantLayerHandler, ABC):
@classmethod
def validate(cls, module: QuantNLAL):
@@ -34,7 +33,7 @@ def prepare_for_export(self, module: QuantNLAL):
input_redequant_symbolic_kwargs = {
'input_scale': input_quant_symbolic_kwargs['output_scale'],
'input_zero_point': input_quant_symbolic_kwargs['output_zero_point'],
- 'axis': input_quant_symbolic_kwargs['axis']}
+ 'input_axis': input_quant_symbolic_kwargs['input_axis']}
else:
input_redequant_symbolic_kwargs = None
@@ -50,43 +49,43 @@ def input_symbolic_execution(self, inp: Tensor):
input_dequant_symbolic_kwargs = self.symbolic_kwargs['input_dequant_symbolic_kwargs']
input_redequant_symbolic_kwargs = self.symbolic_kwargs['input_redequant_symbolic_kwargs']
if input_dequant_symbolic_kwargs is not None:
- inp = DequantizeLinearFunction.apply(inp, *input_dequant_symbolic_kwargs.values())
+ inp = DequantizeLinearFn.apply(inp, *input_dequant_symbolic_kwargs.values())
if input_quant_symbolic_kwargs is not None:
- inp = QuantizeLinearFunction.apply(inp, *input_quant_symbolic_kwargs.values())
- inp = DequantizeLinearFunction.apply(inp, *input_redequant_symbolic_kwargs.values())
+ inp = QuantizeLinearFn.apply(inp, *input_quant_symbolic_kwargs.values())
+ inp = DequantizeLinearFn.apply(inp, *input_redequant_symbolic_kwargs.values())
return inp
def output_symbolic_execution(self, out: Tensor):
output_quant_symbolic_kwargs = self.symbolic_kwargs['output_quant_symbolic_kwargs']
output_dequant_symbolic_kwargs = self.symbolic_kwargs['output_dequant_symbolic_kwargs']
- out = QuantizeLinearFunction.apply(out, *output_quant_symbolic_kwargs.values())
+ out = QuantizeLinearFn.apply(out, *output_quant_symbolic_kwargs.values())
if output_dequant_symbolic_kwargs is not None:
- out = DequantizeLinearFunction.apply(out, *output_dequant_symbolic_kwargs.values())
+ out = DequantizeLinearFn.apply(out, *output_dequant_symbolic_kwargs.values())
return out
-class StdONNXQuantReLUHandler(StdONNXQuantNLALHandler):
+class StdQOpONNXQuantReLUHandler(StdQOpONNXQuantNLALHandler):
handled_layer = QuantReLU
def op_symbolic_execution(self, inp: Tensor):
return torch.relu(inp)
-class StdONNXQuantTanhHandler(StdONNXQuantNLALHandler):
+class StdQOpONNXQuantTanhHandler(StdQOpONNXQuantNLALHandler):
handled_layer = QuantTanh
def op_symbolic_execution(self, inp: Tensor):
return torch.tanh(inp)
-class StdONNXQuantSigmoidHandler(StdONNXQuantNLALHandler):
+class StdQOpONNXQuantSigmoidHandler(StdQOpONNXQuantNLALHandler):
handled_layer = QuantSigmoid
def op_symbolic_execution(self, inp: Tensor):
return torch.sigmoid(inp)
-class StdONNXQuantIdentityHandler(StdONNXQuantLayerHandler):
+class StdQOpONNXQuantIdentityHandler(StdQOpONNXQuantLayerHandler):
handled_layer = QuantIdentity
@classmethod
@@ -122,12 +121,12 @@ def output_symbolic_execution(self, out: Tensor):
output_quant_symbolic_kwargs = self.symbolic_kwargs['output_quant_symbolic_kwargs']
output_dequant_symbolic_kwargs = self.symbolic_kwargs['output_dequant_symbolic_kwargs']
if input_dequant_symbolic_kwargs:
- out = DequantizeLinearFunction.apply(out, *input_dequant_symbolic_kwargs.values())
- out = QuantizeLinearFunction.apply(out, *output_quant_symbolic_kwargs.values())
+ out = DequantizeLinearFn.apply(out, *input_dequant_symbolic_kwargs.values())
+ out = QuantizeLinearFn.apply(out, *output_quant_symbolic_kwargs.values())
if output_dequant_symbolic_kwargs is not None:
- out = DequantizeLinearFunction.apply(out, *output_dequant_symbolic_kwargs.values())
+ out = DequantizeLinearFn.apply(out, *output_dequant_symbolic_kwargs.values())
return out
-class StdONNXQuantHardTanhHandler(StdONNXQuantIdentityHandler):
+class StdQOpONNXQuantHardTanhHandler(StdQOpONNXQuantIdentityHandler):
handled_layer = QuantHardTanh
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/standard/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py
similarity index 75%
rename from src/brevitas/export/onnx/standard/handler/base.py
rename to src/brevitas/export/onnx/standard/qoperator/handler/base.py
index 7b6c2c9c8..8772edd6c 100644
--- a/src/brevitas/export/onnx/standard/handler/base.py
+++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py
@@ -1,14 +1,14 @@
from abc import ABC, abstractmethod
+import torch
from torch import Tensor
-from brevitas.export.onnx.handler import ONNXBaseHandler
-from brevitas.export.common.handler import Validate8BitHandler, TypedZeroPointHandler
-from ..function import QuantizeLinearFunction, DequantizeLinearFunction
+from brevitas.export.onnx.standard.function import QuantizeLinearFn, DequantizeLinearFn
+from brevitas.export.onnx.standard.handler import StdONNXQuantLayerHandler
-class StdONNXQuantLayerHandler(Validate8BitHandler, TypedZeroPointHandler, ONNXBaseHandler, ABC):
+class StdQOpONNXQuantLayerHandler(StdONNXQuantLayerHandler, ABC):
@abstractmethod
def op_symbolic_execution(self, inp: Tensor):
@@ -26,6 +26,13 @@ def output_symbolic_execution(self, out: Tensor):
def op_symbolic_kwargs(cls, module):
raise NotImplementedError # optional method
+ @classmethod
+ def torch_8b_dtype(cls, is_signed):
+ if is_signed:
+ return torch.int8
+ else:
+ return torch.uint8
+
@classmethod
def quant_output_shape(cls, module):
cached_out = module._cached_out
@@ -33,26 +40,20 @@ def quant_output_shape(cls, module):
raise RuntimeError("Caching of outputs is required to export QuantConv2d")
return cached_out.shape
- @classmethod
- def quant_axis(cls, scale):
- for i in scale.shape:
- if i != 1:
- return i
- return None
-
@classmethod
def output_quant_symbolic_kwargs(cls, module):
return {
'output_scale': module.quant_output_scale(),
'output_zero_point': cls.quant_output_zero_point(module),
- 'axis': cls.quant_axis(module.quant_output_scale())}
+ 'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed),
+ 'output_axis': cls.quant_axis(module.quant_output_scale())}
@classmethod
def output_dequant_symbolic_kwargs(cls, module):
return {
'input_scale': module.quant_output_scale(),
'input_zero_point': cls.quant_output_zero_point(module),
- 'axis': cls.quant_axis(module.quant_output_scale())}
+ 'input_axis': cls.quant_axis(module.quant_output_scale())}
@classmethod
def input_quant_symbolic_kwargs(cls, module):
@@ -60,7 +61,8 @@ def input_quant_symbolic_kwargs(cls, module):
return {
'output_scale': module.quant_input_scale(),
'output_zero_point': cls.quant_input_zero_point(module),
- 'axis': cls.quant_axis(module.quant_input_scale())}
+ 'output_dtype': cls.torch_8b_dtype(module.is_quant_input_signed),
+ 'output_axis': cls.quant_axis(module.quant_input_scale())}
else:
return None
@@ -78,7 +80,7 @@ def dequant_symbolic_kwargs_from_cached_io(cls, cached_io):
'input_scale': cached_io.scale,
'input_zero_point': cls.zero_point_with_dtype(
cached_io.signed, cached_io.zero_point),
- 'axis': cls.quant_axis(cached_io.scale)}
+ 'input_axis': cls.quant_axis(cached_io.scale)}
@classmethod
def quant_symbolic_kwargs_from_cached_io(cls, cached_io):
@@ -87,7 +89,8 @@ def quant_symbolic_kwargs_from_cached_io(cls, cached_io):
'output_scale': cached_io.scale,
'output_zero_point': cls.zero_point_with_dtype(
cached_io.signed, cached_io.zero_point),
- 'axis': cls.quant_axis(cached_io.scale)}
+ 'output_dtype': cls.torch_8b_dtype(cached_io.signed),
+ 'output_axis': cls.quant_axis(cached_io.scale)}
def symbolic_execution(self, inp: Tensor):
inp = self.input_symbolic_execution(inp)
@@ -96,7 +99,7 @@ def symbolic_execution(self, inp: Tensor):
return ret
-class StdONNXQuantWrapperHandler(StdONNXQuantLayerHandler, ABC):
+class StdQOpONNXQuantWrapperHandler(StdQOpONNXQuantLayerHandler, ABC):
@classmethod
def validate(cls, module):
@@ -119,12 +122,12 @@ def prepare_for_export(self, module):
def input_symbolic_execution(self, inp: Tensor):
input_dequant_symbolic_kwargs = self.symbolic_kwargs['input_dequant_symbolic_kwargs']
- inp = DequantizeLinearFunction.apply(inp, *input_dequant_symbolic_kwargs.values())
+ inp = DequantizeLinearFn.apply(inp, *input_dequant_symbolic_kwargs.values())
return inp
def output_symbolic_execution(self, out: Tensor):
output_quant_symbolic_kwargs = self.symbolic_kwargs['output_quant_symbolic_kwargs']
if output_quant_symbolic_kwargs is not None:
- out = QuantizeLinearFunction.apply(out, *output_quant_symbolic_kwargs.values())
+ out = QuantizeLinearFn.apply(out, *output_quant_symbolic_kwargs.values())
return out
diff --git a/src/brevitas/export/onnx/standard/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py
similarity index 82%
rename from src/brevitas/export/onnx/standard/handler/parameter.py
rename to src/brevitas/export/onnx/standard/qoperator/handler/parameter.py
index a6c03f201..b59dca069 100644
--- a/src/brevitas/export/onnx/standard/handler/parameter.py
+++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py
@@ -6,13 +6,13 @@
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.nn import QuantConv2d, QuantConv1d, QuantLinear
-from brevitas.export.onnx.handler import Kernel2dApplHandler, Kernel1dApplHandler
-from ..function import QuantizeLinearFunction, DequantizeLinearFunction
-from ..function import QLinearConvFunction, QLinearMatMulFunction
-from .base import StdONNXQuantLayerHandler
+from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin, Kernel1dApplHandlerMixin
+from brevitas.export.onnx.standard.function import QuantizeLinearFn, DequantizeLinearFn
+from ..function import QLinearConvFn, QLinearMatMulFn
+from .base import StdQOpONNXQuantLayerHandler
-class StdONNXQuantWBIOLHandler(StdONNXQuantLayerHandler, ABC):
+class StdQOpONNXQuantWBIOLHandler(StdQOpONNXQuantLayerHandler, ABC):
@staticmethod
def int_weight(module: QuantWBIOL):
@@ -34,7 +34,7 @@ def int_bias(module: QuantWBIOL):
def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
assert module.is_weight_quant_enabled
assert module.is_output_quant_enabled
- cls.validate_8b_bit_width(module.quant_weight_bit_width())
+ cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True)
cls.validate_8b_bit_width(module.quant_input_bit_width())
cls.validate_8b_bit_width(module.quant_output_bit_width())
if module.bias is not None and requires_quant_bias:
@@ -47,13 +47,13 @@ def input_symbolic_execution(self, inp: Tensor):
input_dequant_symbolic_kwargs = self.symbolic_kwargs['input_dequant_symbolic_kwargs']
if input_dequant_symbolic_kwargs is not None:
assert input_quant_symbolic_kwargs is not None
- inp = DequantizeLinearFunction.apply(inp, *input_dequant_symbolic_kwargs.values())
+ inp = DequantizeLinearFn.apply(inp, *input_dequant_symbolic_kwargs.values())
if input_quant_symbolic_kwargs is not None:
- inp = QuantizeLinearFunction.apply(inp, *input_quant_symbolic_kwargs.values())
+ inp = QuantizeLinearFn.apply(inp, *input_quant_symbolic_kwargs.values())
return inp
-class StdONNXQuantLinearHandler(StdONNXQuantWBIOLHandler):
+class StdQOpONNXQuantLinearHandler(StdQOpONNXQuantWBIOLHandler):
handled_layer = QuantLinear
@classmethod
@@ -66,6 +66,7 @@ def op_symbolic_kwargs(cls, module: QuantLinear):
'weight_zero_point': cls.quant_weight_zero_point(module),
'output_scale': module.quant_output_scale(),
'output_zero_point': cls.quant_output_zero_point(module),
+ 'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed),
'out_shape': cls.quant_output_shape(module)}
return linear_symbolic_kwargs
@@ -100,7 +101,7 @@ def prepare_for_export(self, module: QuantLinear):
def op_symbolic_execution(self, inp):
linear_symbolic_kwargs = self.symbolic_kwargs['op_symbolic_kwargs']
- out = QLinearMatMulFunction.apply(inp, *linear_symbolic_kwargs.values())
+ out = QLinearMatMulFn.apply(inp, *linear_symbolic_kwargs.values())
return out
def output_symbolic_execution(self, out: Tensor):
@@ -108,15 +109,15 @@ def output_symbolic_execution(self, out: Tensor):
output_quant_symbolic_kwargs = self.symbolic_kwargs['output_quant_symbolic_kwargs']
bias = self.symbolic_kwargs['bias']
if output_dequant_symbolic_kwargs is not None:
- out = DequantizeLinearFunction.apply(out, *output_dequant_symbolic_kwargs.values())
+ out = DequantizeLinearFn.apply(out, *output_dequant_symbolic_kwargs.values())
if bias is not None:
out = out.add(bias)
if output_quant_symbolic_kwargs is not None:
- out = QuantizeLinearFunction.apply(out, *output_quant_symbolic_kwargs.values())
+ out = QuantizeLinearFn.apply(out, *output_quant_symbolic_kwargs.values())
return out
-class StdONNXQuantConvNdHandler(StdONNXQuantWBIOLHandler, ABC):
+class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC):
def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
conv_symbolic_kwargs = {
@@ -127,6 +128,7 @@ def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
'weight_zero_point': self.quant_weight_zero_point(module),
'output_scale': module.quant_output_scale(),
'output_zero_point': self.quant_output_zero_point(module),
+ 'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed),
'int_bias': self.int_bias(module),
'out_shape': self.quant_output_shape(module),
'kernel_size': list(module.kernel_size),
@@ -158,19 +160,19 @@ def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]):
def op_symbolic_execution(self, inp: Tensor):
conv_symbolic_kwargs = self.symbolic_kwargs['op_symbolic_kwargs']
- out = QLinearConvFunction.apply(inp, *conv_symbolic_kwargs.values())
+ out = QLinearConvFn.apply(inp, *conv_symbolic_kwargs.values())
return out
def output_symbolic_execution(self, out: Tensor):
output_dequant_symbolic_kwargs = self.symbolic_kwargs['output_dequant_symbolic_kwargs']
if output_dequant_symbolic_kwargs is not None:
- out = DequantizeLinearFunction.apply(out, *output_dequant_symbolic_kwargs.values())
+ out = DequantizeLinearFn.apply(out, *output_dequant_symbolic_kwargs.values())
return out
-class StdONNXQuantConv2dHandler(StdONNXQuantConvNdHandler, Kernel2dApplHandler):
+class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin):
handled_layer = QuantConv2d
-class StdONNXQuantConv1dHandler(StdONNXQuantConvNdHandler, Kernel1dApplHandler):
+class StdQOpONNXQuantConv1dHandler(StdQOpONNXQuantConvNdHandler, Kernel1dApplHandlerMixin):
handled_layer = QuantConv1d
diff --git a/src/brevitas/export/onnx/standard/handler/pool.py b/src/brevitas/export/onnx/standard/qoperator/handler/pool.py
similarity index 82%
rename from src/brevitas/export/onnx/standard/handler/pool.py
rename to src/brevitas/export/onnx/standard/qoperator/handler/pool.py
index b4cddbeed..257e9837b 100644
--- a/src/brevitas/export/onnx/standard/handler/pool.py
+++ b/src/brevitas/export/onnx/standard/qoperator/handler/pool.py
@@ -8,10 +8,10 @@
from torch.nn.functional import max_pool1d, max_pool2d
from brevitas.nn import QuantMaxPool1d, QuantMaxPool2d
-from .base import StdONNXQuantWrapperHandler
+from .base import StdQOpONNXQuantWrapperHandler
-class StdONNXQuantMaxPoolNd(StdONNXQuantWrapperHandler, ABC):
+class StdQOpONNXQuantMaxPoolNd(StdQOpONNXQuantWrapperHandler, ABC):
@classmethod
@@ -25,7 +25,7 @@ def op_symbolic_kwargs(cls, module: Union[QuantMaxPool1d, QuantMaxPool2d]):
'return_indices': module.return_indices}
-class StdONNXQuantMaxPool1d(StdONNXQuantMaxPoolNd):
+class StdQOpONNXQuantMaxPool1d(StdQOpONNXQuantMaxPoolNd):
handled_layer = QuantMaxPool1d
def op_symbolic_execution(self, inp: Tensor):
@@ -33,7 +33,7 @@ def op_symbolic_execution(self, inp: Tensor):
return max_pool1d(inp, *op_symbolic_kwargs.values())
-class StdONNXQuantMaxPool2d(StdONNXQuantMaxPoolNd):
+class StdQOpONNXQuantMaxPool2d(StdQOpONNXQuantMaxPoolNd):
handled_layer = QuantMaxPool2d
def op_symbolic_execution(self, inp: Tensor):
diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py
new file mode 100644
index 000000000..2a9be4529
--- /dev/null
+++ b/src/brevitas/export/onnx/standard/qoperator/manager.py
@@ -0,0 +1,102 @@
+from typing import Tuple, Optional, Union
+from packaging import version
+
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn import Module
+
+from brevitas import torch_version
+from brevitas.quant_tensor import QuantTensor
+from brevitas.export.onnx.manager import ONNXBaseManager
+from brevitas.export.manager import _set_layer_export_handler, _set_layer_export_mode
+
+from .handler.base import StdQOpONNXQuantLayerHandler
+from .handler.parameter import StdQOpONNXQuantConv2dHandler
+from .handler.parameter import StdQOpONNXQuantConv1dHandler
+from .handler.parameter import StdQOpONNXQuantLinearHandler
+from .handler.act import StdQOpONNXQuantReLUHandler
+from .handler.act import StdQOpONNXQuantHardTanhHandler
+from .handler.act import StdQOpONNXQuantIdentityHandler
+from .handler.act import StdQOpONNXQuantTanhHandler
+from .handler.act import StdQOpONNXQuantSigmoidHandler
+from .handler.pool import StdQOpONNXQuantMaxPool1d
+from .handler.pool import StdQOpONNXQuantMaxPool2d
+from .. import OPSET
+from ..function import QuantizeLinearFn, DequantizeLinearFn
+
+
+class StdQOpONNXManager(ONNXBaseManager):
+ target_name = 'StdQOpONNX'
+
+ _fn_to_cache = [
+ F.relu,
+ F.relu6,
+ F.hardtanh,
+ F.max_pool1d,
+ F.max_pool2d,
+ F.max_pool3d,
+ F.adaptive_max_pool1d,
+ F.adaptive_max_pool2d,
+ F.adaptive_max_pool3d,
+ ]
+
+ handlers = [
+ StdQOpONNXQuantConv1dHandler,
+ StdQOpONNXQuantConv2dHandler,
+ StdQOpONNXQuantLinearHandler,
+ StdQOpONNXQuantReLUHandler,
+ StdQOpONNXQuantHardTanhHandler,
+ StdQOpONNXQuantIdentityHandler,
+ StdQOpONNXQuantTanhHandler,
+ StdQOpONNXQuantSigmoidHandler,
+ StdQOpONNXQuantMaxPool1d,
+ StdQOpONNXQuantMaxPool2d]
+
+ onnx_passes = [
+ # remove unused graph inputs & initializers
+ "eliminate_unused_initializer"]
+
+ @classmethod
+ def solve_enable_onnx_checker(cls, export_kwargs):
+ if torch_version >= version.parse('1.5.0'):
+ export_kwargs['enable_onnx_checker'] = True
+
+ @classmethod
+ def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs):
+ cached_io = cls._fn_cache.pop(0)
+ if cached_io is not None:
+ cached_inp, cached_out = cached_io
+ if cached_inp is not None:
+ deq_kwargs = StdQOpONNXQuantLayerHandler.dequant_symbolic_kwargs_from_cached_io(
+ cached_inp)
+ input = DequantizeLinearFn.apply(input, *deq_kwargs.values())
+ output = fn(input, *args, **kwargs)
+ if cached_out is not None:
+ q_kwargs = StdQOpONNXQuantLayerHandler.quant_symbolic_kwargs_from_cached_io(
+ cached_out)
+ output = QuantizeLinearFn.apply(output, *q_kwargs.values())
+ else:
+ output = fn(input, *args, **kwargs)
+ return output
+
+ @classmethod
+ def set_export_mode(cls, module: Module, enabled: bool):
+ _set_layer_export_mode(module, enabled)
+
+ @classmethod
+ def set_export_handler(cls, module: Module):
+ _set_layer_export_handler(cls, module)
+
+ @classmethod
+ def export_onnx(
+ cls,
+ module: Module,
+ input_shape: Tuple[int, ...] = None,
+ export_path: Optional[str] = None,
+ input_t: Optional[Union[Tensor, QuantTensor]] = None,
+ disable_warnings=True,
+ **kwargs):
+ output = super().export_onnx(
+ module, input_shape, export_path, input_t,
+ disable_warnings=disable_warnings, opset_version=OPSET, **kwargs)
+ return output
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/transform.py b/src/brevitas/export/onnx/transform.py
index c534d9f00..5a581936c 100644
--- a/src/brevitas/export/onnx/transform.py
+++ b/src/brevitas/export/onnx/transform.py
@@ -2,7 +2,7 @@
from torch.nn import Module
-from .base import onnx
+from .manager import onnx
def move_domain_attributes_into_domain(model: Module):
diff --git a/src/brevitas/export/onnx/vitis_ai/handler.py b/src/brevitas/export/onnx/vitis_ai/handler.py
index 142fd5302..507065800 100644
--- a/src/brevitas/export/onnx/vitis_ai/handler.py
+++ b/src/brevitas/export/onnx/vitis_ai/handler.py
@@ -5,55 +5,36 @@
from brevitas.nn.quant_layer import QuantLayerMixin
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
+from brevitas.export.handler import BitWidthHandlerMixin, ScaleHandlerMixin
from brevitas.export.onnx.handler import ONNXBaseHandler
-class DPUQuantLayerHandler(ONNXBaseHandler, ABC):
-
- @staticmethod
- def neg_scalar_exponent_from_scale(scale: Tensor):
- if scale is None:
- raise RuntimeError("Scale cannot be None")
- if scale.view(-1).shape[0] != 1:
- raise RuntimeError("Only per-tensor scaling is currently supported")
- scale = scale.item()
- neg_exponent = - math.log2(scale)
- if not neg_exponent.is_integer():
- raise RuntimeError("Only power-of-two scale factors are supported")
- neg_exponent = int(neg_exponent)
- return neg_exponent
-
- @staticmethod
- def validate_8b_bit_width(bit_width: Tensor):
- if bit_width is None:
- raise RuntimeError("Bit width cannot be None")
- bit_width = int(bit_width.item())
- return bit_width
-
- @staticmethod
- def quant_input_scale(module: QuantLayerMixin):
+class DPUQuantLayerHandler(ONNXBaseHandler, BitWidthHandlerMixin, ScaleHandlerMixin, ABC):
+
+ @classmethod
+ def quant_input_scale(cls, module: QuantLayerMixin):
scale = module.quant_input_scale()
- return DPUQuantLayerHandler.neg_scalar_exponent_from_scale(scale)
+ return cls.validate_neg_scalar_int_exponent(scale)
- @staticmethod
- def quant_output_scale(module: QuantLayerMixin):
+ @classmethod
+ def quant_output_scale(cls, module: QuantLayerMixin):
scale = module.quant_output_scale()
- return DPUQuantLayerHandler.neg_scalar_exponent_from_scale(scale)
+ return cls.validate_neg_scalar_int_exponent(scale)
- @staticmethod
- def quant_input_bit_width(module: QuantLayerMixin):
+ @classmethod
+ def quant_input_bit_width(cls, module: QuantLayerMixin):
bit_width = module.quant_input_bit_width()
- return DPUQuantLayerHandler.validate_8b_bit_width(bit_width)
+ return cls.validate_8b_bit_width(bit_width)
- @staticmethod
- def quant_output_bit_width(module: QuantLayerMixin):
+ @classmethod
+ def quant_output_bit_width(cls, module: QuantLayerMixin):
bit_width = module.quant_output_bit_width()
- return DPUQuantLayerHandler.validate_8b_bit_width(bit_width)
+ return cls.validate_8b_bit_width(bit_width)
- @staticmethod
- def quant_output_shape(module: QuantLayerMixin):
- cached_out = module._cached_out # TODO add shape property to the module
+ @classmethod
+ def quant_output_shape(cls, module: QuantLayerMixin):
+ cached_out = module._cached_out
if cached_out is None:
raise RuntimeError("Caching of outputs is required")
return cached_out.shape
@@ -63,48 +44,48 @@ def prepare_from_cached_io(self, cached_io):
self.symbolic_kwargs = {
'output_shape': cached_out.shape,
'input_bit_width': self.validate_8b_bit_width(cached_inp.bit_width),
- 'input_scale': self.neg_scalar_exponent_from_scale(cached_inp.scale),
+ 'input_scale': self.validate_neg_scalar_int_exponent(cached_inp.scale),
'output_bit_width': self.validate_8b_bit_width(cached_out.bit_width),
- 'output_scale': self.neg_scalar_exponent_from_scale(cached_out.scale)
+ 'output_scale': self.validate_neg_scalar_int_exponent(cached_out.scale)
}
-class DPUQuantWeightBiasHandler(ABC):
+class DPUQuantWBIOLHandler(DPUQuantLayerHandler):
- @staticmethod
- def int_weight(module: QuantWBIOL):
+ @classmethod
+ def int_weight(cls, module: QuantWBIOL):
return module.int_weight(float_datatype=False).detach()
- @staticmethod
- def quant_weight_bit_width(module: QuantWBIOL):
+ @classmethod
+ def quant_weight_bit_width(cls, module: QuantWBIOL):
bit_width = module.quant_weight_bit_width()
- return DPUQuantLayerHandler.validate_8b_bit_width(bit_width)
+ return cls.validate_8b_bit_width(bit_width, le_then=True)
- @staticmethod
- def quant_weight_scale(module: QuantWBIOL):
+ @classmethod
+ def quant_weight_scale(cls, module: QuantWBIOL):
quant_weight_scale = module.quant_weight_scale()
- return DPUQuantLayerHandler.neg_scalar_exponent_from_scale(quant_weight_scale)
+ return cls.validate_neg_scalar_int_exponent(quant_weight_scale)
- @staticmethod
- def int_bias(module: QuantWBIOL):
+ @classmethod
+ def int_bias(cls, module: QuantWBIOL):
if module.bias is not None:
return module.int_bias(float_datatype=False).detach()
else:
return None
- @staticmethod
- def quant_bias_bit_width(module: QuantWBIOL):
+ @classmethod
+ def quant_bias_bit_width(cls, module: QuantWBIOL):
if module.bias is not None:
bit_width = module.quant_bias_bit_width()
- return DPUQuantLayerHandler.validate_8b_bit_width(bit_width)
+ return DPUQuantLayerHandler.validate_8b_bit_width(bit_width, le_then=True)
else:
return None
- @staticmethod
- def quant_bias_scale(module: QuantWBIOL):
+ @classmethod
+ def quant_bias_scale(cls, module: QuantWBIOL):
if module.bias is not None:
scale = module.quant_bias_scale()
- return DPUQuantLayerHandler.neg_scalar_exponent_from_scale(scale)
+ return cls.validate_neg_scalar_int_exponent(scale)
else:
return None
diff --git a/src/brevitas/export/onnx/vitis_ai/manager.py b/src/brevitas/export/onnx/vitis_ai/manager.py
index 207f13ad9..875f46adc 100644
--- a/src/brevitas/export/onnx/vitis_ai/manager.py
+++ b/src/brevitas/export/onnx/vitis_ai/manager.py
@@ -2,9 +2,9 @@
from torch.nn import functional as F, Module
-from brevitas.export.onnx.base import ONNXBaseManager
+from brevitas.export.onnx.manager import ONNXBaseManager
from brevitas.export.onnx.transform import move_domain_attributes_into_domain
-from brevitas.export.base import _set_layer_export_handler, _set_layer_export_mode
+from brevitas.export.manager import _set_layer_export_handler, _set_layer_export_mode
def _handler_wrapper(handler, cached_io):
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/function.py b/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/function.py
deleted file mode 100644
index aee0d0e5f..000000000
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/function.py
+++ /dev/null
@@ -1,89 +0,0 @@
-from ..function import DPUQuantMaxPoolPlaceholderFunction
-from ..function import DPUQuantConv2dPlaceholderFunction
-
-
-class DPUv1QuantConv2dPlaceholderFunction(DPUQuantConv2dPlaceholderFunction):
-
- @staticmethod
- def symbolic(
- g, x,
- int_weight,
- int_bias,
- out_shape,
- input_bit_width,
- input_scale,
- output_bit_width,
- output_scale,
- weight_bit_width,
- weight_scale,
- bias_bit_width,
- bias_scale,
- kernel_size,
- padding,
- stride,
- groups,
- dilation):
- vai_quant_s = ['vai_quant_in', 'vai_quant_out', 'vai_quant_weights']
- if int_bias is not None:
- vai_quant_s += ['vai_quant_biases']
- ret = g.op(
- 'Conv', x,
- int_weight,
- int_bias,
- domain_s="pyxir",
- vai_quant_s=vai_quant_s,
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[output_bit_width, output_scale],
- vai_quant_weights_i=[weight_bit_width, weight_scale],
- vai_quant_biases_i=[bias_bit_width, bias_scale],
- kernel_shape_i=kernel_size,
- pads_i=padding,
- strides_i=stride,
- group_i=groups,
- dilations_i=dilation)
- else:
- ret = g.op(
- 'Conv', x,
- int_weight,
- domain_s="pyxir",
- vai_quant_s=vai_quant_s,
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[output_bit_width, output_scale],
- vai_quant_weights_i=[weight_bit_width, weight_scale],
- kernel_shape_i=kernel_size,
- pads_i=padding,
- strides_i=stride,
- group_i=groups,
- dilations_i=dilation)
- return ret
-
-
-class DPUv1QuantMaxPoolPlaceholderFunction(DPUQuantMaxPoolPlaceholderFunction):
-
- @staticmethod
- def symbolic(
- g, x,
- kernel_shape,
- pads,
- strides,
- ceil_mode,
- dilations,
- out_shape,
- input_bit_width,
- input_scale,
- output_bit_width,
- output_scale):
- ret = g.op(
- 'MaxPool', x,
- domain_s="pyxir",
- kernel_shape_i=kernel_shape,
- pads_i=pads,
- strides_i=strides,
- dilations_i=dilations,
- ceil_mode_i=ceil_mode,
- vai_quant_s=['vai_quant_in', 'vai_quant_out'],
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[output_bit_width, output_scale])
- return ret
-
-
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/handler.py b/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/handler.py
deleted file mode 100644
index ac72eb7c4..000000000
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/handler.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from torch import Tensor
-
-from ..handler import DPUQuantConv2dHandler
-from ..handler import DPUQuantMaxPool2dHandler
-from .function import DPUv1QuantMaxPoolPlaceholderFunction
-from .function import DPUv1QuantConv2dPlaceholderFunction
-
-
-class DPUv1QuantMaxPool2dHandler(DPUQuantMaxPool2dHandler):
-
- def symbolic_execution(self, inp: Tensor):
- ret = DPUv1QuantMaxPoolPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
- return ret
-
- def cached_symbolic_execution(self, inp: Tensor, *args, **kwargs):
- solved_kwargs = self._solve_max_pool2d_kwargs(inp, args, kwargs)
- return DPUv1QuantMaxPoolPlaceholderFunction.apply(
- *solved_kwargs.values(), *self.symbolic_kwargs.values())
-
-
-class DPUv1QuantConv2dHandler(DPUQuantConv2dHandler):
-
- def symbolic_execution(self, inp: Tensor):
- ret = DPUv1QuantConv2dPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
- return ret
-
-
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/manager.py b/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/manager.py
deleted file mode 100644
index a74092989..000000000
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv1/manager.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from torch.nn import functional as F
-from functools import partial
-
-from .handler import DPUv1QuantConv2dHandler, DPUv1QuantMaxPool2dHandler
-from ..handler import DPUQuantReLUHandler, DPUQuantEltwiseAddHandler
-from ..handler import DPUQuantAvgPool2dHandler, DPUQuantLinearHandler
-from ..manager import PyXIRManager, _handler_wrapper
-
-
-class DPUv1Manager(PyXIRManager):
- target_name = 'PyXIR+DPUv1'
-
- handlers = [
- DPUQuantReLUHandler,
- DPUQuantEltwiseAddHandler,
- DPUQuantAvgPool2dHandler,
- DPUQuantLinearHandler,
- DPUv1QuantConv2dHandler,
- DPUv1QuantMaxPool2dHandler]
-
- _cached_io_handler_map = {
- F.relu: partial(_handler_wrapper, DPUQuantReLUHandler),
- F.max_pool2d: partial(_handler_wrapper, DPUv1QuantMaxPool2dHandler)}
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/__init__.py b/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/function.py b/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/function.py
deleted file mode 100644
index 960c3b518..000000000
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/function.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from ..function import DPUQuantMaxPoolPlaceholderFunction
-from ..function import DPUQuantConv2dPlaceholderFunction
-
-
-class DPUv2QuantConv2dPlaceholderFunction(DPUQuantConv2dPlaceholderFunction):
-
- @staticmethod
- def symbolic(
- g, x,
- int_weight,
- int_bias,
- out_shape,
- input_bit_width,
- input_scale,
- output_bit_width,
- output_scale,
- weight_bit_width,
- weight_scale,
- bias_bit_width,
- bias_scale,
- kernel_size,
- padding,
- stride,
- groups,
- dilation):
- if ((isinstance(padding, int) and padding != 0)
- or (isinstance(padding, (list, tuple)) and any([p != 0 for p in padding]))):
- x = g.op(
- 'Pad', x,
- domain_s="pyxir",
- vai_quant_s=['vai_quant_in', 'vai_quant_out'],
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[input_bit_width, input_scale],
- pads_i=padding)
- vai_quant_s = ['vai_quant_in', 'vai_quant_out', 'vai_quant_weights']
- if int_bias is not None:
- vai_quant_s += ['vai_quant_biases']
- ret = g.op(
- 'Conv', x,
- int_weight,
- int_bias,
- domain_s="pyxir",
- vai_quant_s=vai_quant_s,
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[output_bit_width, output_scale],
- vai_quant_weights_i=[weight_bit_width, weight_scale],
- vai_quant_biases_i=[bias_bit_width, bias_scale],
- kernel_shape_i=kernel_size,
- strides_i=stride,
- auto_pad_s='VALID',
- group_i=groups,
- dilations_i=dilation)
- else:
- ret = g.op(
- 'Conv', x,
- int_weight,
- domain_s="pyxir",
- vai_quant_s=vai_quant_s,
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[output_bit_width, output_scale],
- vai_quant_weights_i=[weight_bit_width, weight_scale],
- kernel_shape_i=kernel_size,
- strides_i=stride,
- auto_pad_s='VALID',
- group_i=groups,
- dilations_i=dilation)
- return ret
-
-
-class DPUv2QuantMaxPoolPlaceholderFunction(DPUQuantMaxPoolPlaceholderFunction):
-
- @staticmethod
- def symbolic(
- g, x,
- kernel_shape,
- pads,
- strides,
- ceil_mode,
- dilations,
- out_shape,
- input_bit_width,
- input_scale,
- output_bit_width,
- output_scale):
- if ((isinstance(pads, int) and pads != 0)
- or (isinstance(pads, (list, tuple)) and any([p != 0 for p in pads]))):
- x = g.op(
- 'Pad', x,
- domain_s="pyxir",
- vai_quant_s=['vai_quant_in', 'vai_quant_out'],
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[input_bit_width, input_scale],
- pads_i=pads)
- ret = g.op(
- 'MaxPool', x,
- domain_s="pyxir",
- kernel_shape_i=kernel_shape,
- strides_i=strides,
- auto_pad_s='VALID',
- dilations_i=dilations,
- ceil_mode_i=ceil_mode,
- vai_quant_s=['vai_quant_in', 'vai_quant_out'],
- vai_quant_in_i=[input_bit_width, input_scale],
- vai_quant_out_i=[output_bit_width, output_scale])
- return ret
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/handler.py b/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/handler.py
deleted file mode 100644
index 61003a5e7..000000000
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/handler.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from torch import Tensor
-
-from ..handler import DPUQuantConv2dHandler
-from ..handler import DPUQuantMaxPool2dHandler
-from .function import DPUv2QuantConv2dPlaceholderFunction
-from .function import DPUv2QuantMaxPoolPlaceholderFunction
-
-
-class DPUv2QuantMaxPool2dHandler(DPUQuantMaxPool2dHandler):
-
- def symbolic_execution(self, inp: Tensor):
- ret = DPUv2QuantMaxPoolPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
- return ret
-
- def cached_symbolic_execution(self, inp: Tensor, *args, **kwargs):
- solved_kwargs = self._solve_max_pool2d_kwargs(inp, args, kwargs)
- return DPUv2QuantMaxPoolPlaceholderFunction.apply(
- *solved_kwargs.values(), *self.symbolic_kwargs.values())
-
-
-class DPUv2QuantConv2dHandler(DPUQuantConv2dHandler):
-
- def symbolic_execution(self, inp: Tensor):
- ret = DPUv2QuantConv2dPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
- return ret
-
-
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/manager.py b/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/manager.py
deleted file mode 100644
index 50d1ce0d7..000000000
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/dpuv2/manager.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from functools import partial
-
-from torch.nn import functional as F
-
-from .handler import DPUv2QuantConv2dHandler, DPUv2QuantMaxPool2dHandler
-from ..handler import DPUQuantReLUHandler, DPUQuantEltwiseAddHandler
-from ..handler import DPUQuantAvgPool2dHandler, DPUQuantLinearHandler
-from ..manager import PyXIRManager, _handler_wrapper
-
-
-class DPUv2Manager(PyXIRManager):
- target_name = 'PyXIR+DPUv2'
-
- handlers = [
- DPUQuantReLUHandler,
- DPUQuantEltwiseAddHandler,
- DPUQuantAvgPool2dHandler,
- DPUQuantLinearHandler,
- DPUv2QuantConv2dHandler,
- DPUv2QuantMaxPool2dHandler]
-
- _cached_io_handler_map = {
- F.relu: partial(_handler_wrapper, DPUQuantReLUHandler),
- F.max_pool2d: partial(_handler_wrapper, DPUv2QuantMaxPool2dHandler)}
\ No newline at end of file
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/function.py b/src/brevitas/export/onnx/vitis_ai/pyxir/function.py
index 2f07e5bd7..cd244e8c1 100644
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/function.py
+++ b/src/brevitas/export/onnx/vitis_ai/pyxir/function.py
@@ -7,7 +7,10 @@
from brevitas import torch_version
-class DPUQuantReLUPlaceholderFunction(Function):
+DOMAIN_STRING = 'onnx.pyxir'
+
+
+class DPUQuantReLUFn(Function):
@staticmethod
def symbolic(
@@ -19,7 +22,7 @@ def symbolic(
output_scale):
ret = g.op(
'Relu', x,
- domain_s="pyxir",
+ domain_s=DOMAIN_STRING,
vai_quant_s=['vai_quant_in', 'vai_quant_out'],
vai_quant_in_i=[input_bit_width, input_scale],
vai_quant_out_i=[output_bit_width, output_scale])
@@ -36,7 +39,7 @@ def forward(
return x
-class DPUQuantAvgPoolPlaceholderFunction(Function):
+class DPUQuantAvgPoolFn(Function):
@staticmethod
def symbolic(
@@ -54,14 +57,14 @@ def symbolic(
if list(out_shape[2:]) == [1, 1]:
ret = g.op(
'GlobalAveragePool', x,
- domain_s="pyxir",
+ domain_s=DOMAIN_STRING,
vai_quant_s=['vai_quant_in', 'vai_quant_out'],
vai_quant_in_i=[input_bit_width, input_scale],
vai_quant_out_i=[output_bit_width, output_scale])
else:
ret = g.op(
'AveragePool', x,
- domain_s="pyxir",
+ domain_s=DOMAIN_STRING,
kernel_shape_i=kernel_shape,
strides_i=strides,
pads_i=pads,
@@ -84,7 +87,7 @@ def forward(
return torch.empty(out_shape, dtype=torch.float, device=x.device)
-class DPUQuantEltwiseAddPlaceholderFunction(Function):
+class DPUQuantEltwiseAddFn(Function):
@staticmethod
def symbolic(
@@ -97,7 +100,7 @@ def symbolic(
output_scale):
ret = g.op(
'Add', x, y,
- domain_s="pyxir",
+ domain_s=DOMAIN_STRING,
vai_quant_s=['vai_quant_in', 'vai_quant_out'],
vai_quant_in_i=[input_bit_width, input_scale, other_bit_width, other_scale],
vai_quant_out_i=[output_bit_width, output_scale])
@@ -115,12 +118,11 @@ def forward(
return x
-class DPUQuantMaxPoolPlaceholderFunction(Function):
+class DPUQuantMaxPoolFn(Function):
@staticmethod
- @abstractmethod
def symbolic(
- ctx, x,
+ g, x,
kernel_shape,
pads,
strides,
@@ -131,7 +133,27 @@ def symbolic(
input_scale,
output_bit_width,
output_scale):
- pass
+ if ((isinstance(pads, int) and pads != 0)
+ or (isinstance(pads, (list, tuple)) and any([p != 0 for p in pads]))):
+ x = g.op(
+ 'Pad', x,
+ domain_s=DOMAIN_STRING,
+ vai_quant_s=['vai_quant_in', 'vai_quant_out'],
+ vai_quant_in_i=[input_bit_width, input_scale],
+ vai_quant_out_i=[input_bit_width, input_scale],
+ pads_i=pads)
+ ret = g.op(
+ 'MaxPool', x,
+ domain_s=DOMAIN_STRING,
+ kernel_shape_i=kernel_shape,
+ strides_i=strides,
+ auto_pad_s='VALID',
+ dilations_i=dilations,
+ ceil_mode_i=ceil_mode,
+ vai_quant_s=['vai_quant_in', 'vai_quant_out'],
+ vai_quant_in_i=[input_bit_width, input_scale],
+ vai_quant_out_i=[output_bit_width, output_scale])
+ return ret
@staticmethod
def forward(
@@ -149,10 +171,9 @@ def forward(
return torch.empty(out_shape, dtype=torch.float, device=x.device)
-class DPUQuantConv2dPlaceholderFunction(Function):
+class DPUQuantConv2dFn(Function):
@staticmethod
- @abstractmethod
def symbolic(
g, x,
int_weight,
@@ -171,7 +192,48 @@ def symbolic(
stride,
groups,
dilation):
- pass
+ if ((isinstance(padding, int) and padding != 0)
+ or (isinstance(padding, (list, tuple)) and any([p != 0 for p in padding]))):
+ x = g.op(
+ 'Pad', x,
+ domain_s=DOMAIN_STRING,
+ vai_quant_s=['vai_quant_in', 'vai_quant_out'],
+ vai_quant_in_i=[input_bit_width, input_scale],
+ vai_quant_out_i=[input_bit_width, input_scale],
+ pads_i=padding)
+ vai_quant_s = ['vai_quant_in', 'vai_quant_out', 'vai_quant_weights']
+ if int_bias is not None:
+ vai_quant_s += ['vai_quant_biases']
+ ret = g.op(
+ 'Conv', x,
+ int_weight,
+ int_bias,
+ domain_s=DOMAIN_STRING,
+ vai_quant_s=vai_quant_s,
+ vai_quant_in_i=[input_bit_width, input_scale],
+ vai_quant_out_i=[output_bit_width, output_scale],
+ vai_quant_weights_i=[weight_bit_width, weight_scale],
+ vai_quant_biases_i=[bias_bit_width, bias_scale],
+ kernel_shape_i=kernel_size,
+ strides_i=stride,
+ auto_pad_s='VALID',
+ group_i=groups,
+ dilations_i=dilation)
+ else:
+ ret = g.op(
+ 'Conv', x,
+ int_weight,
+ domain_s=DOMAIN_STRING,
+ vai_quant_s=vai_quant_s,
+ vai_quant_in_i=[input_bit_width, input_scale],
+ vai_quant_out_i=[output_bit_width, output_scale],
+ vai_quant_weights_i=[weight_bit_width, weight_scale],
+ kernel_shape_i=kernel_size,
+ strides_i=stride,
+ auto_pad_s='VALID',
+ group_i=groups,
+ dilations_i=dilation)
+ return ret
@staticmethod
def forward(
@@ -195,7 +257,7 @@ def forward(
return torch.empty(out_shape, dtype=torch.float, device=x.device)
-class DPUQuantLinearPlaceholderFunction(Function):
+class DPUQuantLinearFn(Function):
@staticmethod
def symbolic(
@@ -218,7 +280,7 @@ def symbolic(
'Gemm', x,
int_weight,
int_bias,
- domain_s="pyxir",
+ domain_s=DOMAIN_STRING,
transB_i=1,
vai_quant_s=vai_quant_s,
vai_quant_in_i=[input_bit_width, input_scale],
@@ -230,7 +292,7 @@ def symbolic(
'Gemm', x,
int_weight,
torch.tensor(0), # workaround
- domain_s="pyxir",
+ domain_s=DOMAIN_STRING,
transB_i=1,
vai_quant_s=vai_quant_s,
vai_quant_in_i=[input_bit_width, input_scale],
@@ -240,7 +302,7 @@ def symbolic(
ret = g.op(
'Gemm', x,
int_weight,
- domain_s="pyxir",
+ domain_s=DOMAIN_STRING,
transB_i=1,
vai_quant_s=vai_quant_s,
vai_quant_in_i=[input_bit_width, input_scale],
@@ -263,4 +325,5 @@ def forward(
weight_scale,
int_bias_bit_width,
int_bias_scale):
- return torch.empty(out_shape, dtype=torch.float, device=x.device)
\ No newline at end of file
+ return torch.empty(out_shape, dtype=torch.float, device=x.device)
+
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/handler.py b/src/brevitas/export/onnx/vitis_ai/pyxir/handler.py
index b162197e1..fa0253767 100644
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/handler.py
+++ b/src/brevitas/export/onnx/vitis_ai/pyxir/handler.py
@@ -7,10 +7,11 @@
from brevitas.nn import QuantConv2d, QuantReLU, QuantEltwiseAdd, QuantMaxPool2d, QuantLinear
from brevitas.nn import QuantAdaptiveAvgPool2d, QuantAvgPool2d
-from brevitas.export.onnx.handler import Kernel2dApplHandler
-from brevitas.export.onnx.vitis_ai.handler import DPUQuantLayerHandler, DPUQuantWeightBiasHandler
-from .function import DPUQuantReLUPlaceholderFunction, DPUQuantEltwiseAddPlaceholderFunction
-from .function import DPUQuantAvgPoolPlaceholderFunction, DPUQuantLinearPlaceholderFunction
+from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin
+from brevitas.export.onnx.vitis_ai.handler import DPUQuantLayerHandler, DPUQuantWBIOLHandler
+from .function import DPUQuantReLUFn, DPUQuantEltwiseAddFn
+from .function import DPUQuantAvgPoolFn, DPUQuantLinearFn
+from .function import DPUQuantConv2dFn, DPUQuantMaxPoolFn
class DPUQuantReLUHandler(DPUQuantLayerHandler):
@@ -25,14 +26,14 @@ def prepare_for_export(self, module: QuantReLU):
'output_scale': self.quant_output_scale(module)}
def symbolic_execution(self, inp: Tensor):
- ret = DPUQuantReLUPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = DPUQuantReLUFn.apply(inp, *self.symbolic_kwargs.values())
return ret
def cached_symbolic_execution(self, inp: Tensor, *args, **kwargs):
kwargs.update(self.symbolic_kwargs)
if 'inplace' in kwargs:
del kwargs['inplace']
- return DPUQuantReLUPlaceholderFunction.apply(inp, *args, *kwargs.values())
+ return DPUQuantReLUFn.apply(inp, *args, *kwargs.values())
class DPUQuantEltwiseAddHandler(DPUQuantLayerHandler):
@@ -48,12 +49,11 @@ def prepare_for_export(self, module: QuantEltwiseAdd):
'output_scale': self.quant_output_scale(module)}
def symbolic_execution(self, inp: Tensor, other: Tensor):
- ret = DPUQuantEltwiseAddPlaceholderFunction.apply(
- inp, other, *self.symbolic_kwargs.values())
+ ret = DPUQuantEltwiseAddFn.apply(inp, other, *self.symbolic_kwargs.values())
return ret
-class DPUQuantMaxPool2dHandler(DPUQuantLayerHandler, Kernel2dApplHandler, ABC):
+class DPUQuantMaxPool2dHandler(DPUQuantLayerHandler, Kernel2dApplHandlerMixin, ABC):
handled_layer = QuantMaxPool2d
@staticmethod
@@ -79,8 +79,16 @@ def prepare_for_export(self, module: QuantMaxPool2d):
'output_bit_width': self.quant_output_bit_width(module),
'output_scale': self.quant_output_scale(module)}
+ def symbolic_execution(self, inp: Tensor):
+ ret = DPUQuantMaxPoolFn.apply(inp, *self.symbolic_kwargs.values())
+ return ret
+
+ def cached_symbolic_execution(self, inp: Tensor, *args, **kwargs):
+ solved_kwargs = self._solve_max_pool2d_kwargs(inp, args, kwargs)
+ return DPUQuantMaxPoolFn.apply(*solved_kwargs.values(), *self.symbolic_kwargs.values())
+
-class DPUQuantAvgPool2dHandler(DPUQuantLayerHandler, Kernel2dApplHandler):
+class DPUQuantAvgPool2dHandler(DPUQuantLayerHandler, Kernel2dApplHandlerMixin):
handled_layer = (QuantAvgPool2d, QuantAdaptiveAvgPool2d)
def prepare_for_export(self, module: Union[QuantAvgPool2d, QuantAdaptiveAvgPool2d]):
@@ -95,14 +103,14 @@ def prepare_for_export(self, module: Union[QuantAvgPool2d, QuantAdaptiveAvgPool2
'output_scale': self.quant_output_scale(module)}
def symbolic_execution(self, inp: Tensor):
- ret = DPUQuantAvgPoolPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = DPUQuantAvgPoolFn.apply(inp, *self.symbolic_kwargs.values())
return ret
-class DPUQuantLinearHandler(DPUQuantLayerHandler, DPUQuantWeightBiasHandler):
+class DPUQuantLinearHandler(DPUQuantWBIOLHandler):
handled_layer = QuantLinear
- def prepare_for_export(self, module: QuantAdaptiveAvgPool2d):
+ def prepare_for_export(self, module: QuantLinear):
self.symbolic_kwargs = {
'int_weight': self.int_weight(module),
'int_bias': self.int_bias(module),
@@ -117,15 +125,11 @@ def prepare_for_export(self, module: QuantAdaptiveAvgPool2d):
'bias_scale': self.quant_bias_scale(module)}
def symbolic_execution(self, inp: Tensor):
- ret = DPUQuantLinearPlaceholderFunction.apply(inp, *self.symbolic_kwargs.values())
+ ret = DPUQuantLinearFn.apply(inp, *self.symbolic_kwargs.values())
return ret
-class DPUQuantConv2dHandler(
- DPUQuantLayerHandler,
- DPUQuantWeightBiasHandler,
- Kernel2dApplHandler,
- ABC):
+class DPUQuantConv2dHandler(DPUQuantWBIOLHandler, Kernel2dApplHandlerMixin):
handled_layer = QuantConv2d
def prepare_for_export(self, module):
@@ -147,5 +151,11 @@ def prepare_for_export(self, module):
'groups': module.groups,
'dilation': self.dilation(module)}
+ def symbolic_execution(self, inp: Tensor):
+ ret = DPUQuantConv2dFn.apply(inp, *self.symbolic_kwargs.values())
+ return ret
+
+
+
diff --git a/src/brevitas/export/onnx/vitis_ai/pyxir/manager.py b/src/brevitas/export/onnx/vitis_ai/pyxir/manager.py
index ca4734a52..e3868033e 100644
--- a/src/brevitas/export/onnx/vitis_ai/pyxir/manager.py
+++ b/src/brevitas/export/onnx/vitis_ai/pyxir/manager.py
@@ -1,8 +1,13 @@
from abc import ABC
+from functools import partial
+import torch.nn.functional as F
from brevitas.export.onnx.vitis_ai import VitisAIManager
from brevitas.export.onnx.transform import move_domain_attributes_into_domain
+from .handler import DPUQuantConv2dHandler, DPUQuantMaxPool2dHandler
+from .handler import DPUQuantReLUHandler, DPUQuantEltwiseAddHandler
+from .handler import DPUQuantAvgPool2dHandler, DPUQuantLinearHandler
def _handler_wrapper(handler, cached_io):
@@ -21,4 +26,18 @@ class PyXIRManager(VitisAIManager, ABC):
# use initializers instead of Constant nodes for fixed params
"extract_constant_to_initializer",
# remove unused graph inputs & initializers
- "eliminate_unused_initializer"]
\ No newline at end of file
+ "eliminate_unused_initializer"]
+
+ handlers = [
+ DPUQuantReLUHandler,
+ DPUQuantEltwiseAddHandler,
+ DPUQuantAvgPool2dHandler,
+ DPUQuantLinearHandler,
+ DPUQuantConv2dHandler,
+ DPUQuantMaxPool2dHandler]
+
+ _cached_io_handler_map = {
+ F.relu: partial(_handler_wrapper, DPUQuantReLUHandler),
+ F.max_pool2d: partial(_handler_wrapper, DPUQuantMaxPool2dHandler)}
+
+
diff --git a/src/brevitas/export/onnx/vitis_ai/xir/function.py b/src/brevitas/export/onnx/vitis_ai/xir/function.py
index 62da8d191..c24c585b1 100644
--- a/src/brevitas/export/onnx/vitis_ai/xir/function.py
+++ b/src/brevitas/export/onnx/vitis_ai/xir/function.py
@@ -6,7 +6,7 @@
from brevitas import torch_version
-class XIRFixPlaceholderFunction(Function):
+class XIRFixFn(Function):
@staticmethod
def symbolic(g, x, bit_width, fix_point, signed):
@@ -23,7 +23,7 @@ def forward(ctx, x, bit_width, fix_point, signed):
return x
-class XIRGemmPlaceholderFunction(Function):
+class XIRGemmFn(Function):
@staticmethod
def symbolic(g, x, weight, bias):
@@ -40,7 +40,7 @@ def forward(ctx, x, weight, bias):
return torch.nn.functional.linear(x, weight, bias)
-class XIRConv2dPlaceholderFunction(Function):
+class XIRConv2dFn(Function):
@staticmethod
def symbolic(
@@ -91,7 +91,7 @@ def forward(
return torch.empty(output_shape, dtype=x.dtype, device=x.device)
-class XIRConvTranpose2dPlaceholderFunction(Function):
+class XIRConvTranpose2dFn(Function):
@staticmethod
def symbolic(
diff --git a/src/brevitas/export/onnx/vitis_ai/xir/handler.py b/src/brevitas/export/onnx/vitis_ai/xir/handler.py
index 5e9dbfe10..8a61b5774 100644
--- a/src/brevitas/export/onnx/vitis_ai/xir/handler.py
+++ b/src/brevitas/export/onnx/vitis_ai/xir/handler.py
@@ -7,10 +7,10 @@
from brevitas.nn.quant_layer import QuantNonLinearActLayer as QuantNLAL
from brevitas.nn import QuantIdentity, QuantReLU
from brevitas.nn import QuantConvTranspose2d, QuantConv2d, QuantLinear
-from brevitas.export.onnx.handler import Kernel2dApplHandler
-from ..handler import DPUQuantWeightBiasHandler, DPUQuantLayerHandler
-from .function import XIRFixPlaceholderFunction, XIRGemmPlaceholderFunction
-from .function import XIRConv2dPlaceholderFunction, XIRConvTranpose2dPlaceholderFunction
+from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin
+from ..handler import DPUQuantWBIOLHandler, DPUQuantLayerHandler
+from .function import XIRFixFn, XIRGemmFn
+from .function import XIRConv2dFn, XIRConvTranpose2dFn
class XIRQuantActHandler(DPUQuantLayerHandler, ABC):
@@ -48,7 +48,7 @@ def symbolic_execution(self, x: Tensor):
if self.act_impl is not None:
x = self.act_impl(x, *act_kwargs.values())
if act_quant_kwargs is not None:
- x = XIRFixPlaceholderFunction.apply(x, *act_quant_kwargs.values())
+ x = XIRFixFn.apply(x, *act_quant_kwargs.values())
return x
@@ -74,7 +74,7 @@ def act_symbolic_kwargs(cls, module: QuantReLU):
return {}
-class XIRQuantWBIOLHandler(DPUQuantLayerHandler, DPUQuantWeightBiasHandler, ABC):
+class XIRQuantWBIOLHandler(DPUQuantWBIOLHandler, ABC):
@property
@abstractmethod
@@ -159,23 +159,23 @@ def symbolic_execution(self, x):
output_quant_kwargs = self.symbolic_kwargs['output_quant']
op_kwargs = self.symbolic_kwargs['op']
if input_quant_kwargs is not None:
- x = XIRFixPlaceholderFunction.apply(x, *input_quant_kwargs.values())
+ x = XIRFixFn.apply(x, *input_quant_kwargs.values())
if weight_quant_kwargs is not None:
- weight = XIRFixPlaceholderFunction.apply(weight, *weight_quant_kwargs.values())
+ weight = XIRFixFn.apply(weight, *weight_quant_kwargs.values())
if bias is not None and bias_quant_kwargs is not None:
- bias = XIRFixPlaceholderFunction.apply(bias, *bias_quant_kwargs.values())
+ bias = XIRFixFn.apply(bias, *bias_quant_kwargs.values())
out = self.op_impl(x, weight, bias, *op_kwargs.values())
if output_quant_kwargs is not None:
- out = XIRFixPlaceholderFunction.apply(out, *output_quant_kwargs.values())
+ out = XIRFixFn.apply(out, *output_quant_kwargs.values())
return out
-class XIRQuantConv2dHandler(XIRQuantWBIOLHandler, Kernel2dApplHandler):
+class XIRQuantConv2dHandler(XIRQuantWBIOLHandler, Kernel2dApplHandlerMixin):
handled_layer = QuantConv2d
@property
def op_impl(self):
- return XIRConv2dPlaceholderFunction.apply
+ return XIRConv2dFn.apply
@classmethod
def op_symbolic_kwargs(cls, module: QuantConv2d):
@@ -195,7 +195,7 @@ class XIRQuantConvTranspose2dHandler(XIRQuantWBIOLHandler):
@property
def op_impl(self):
- return XIRConvTranpose2dPlaceholderFunction.apply
+ return XIRConvTranpose2dFn.apply
@classmethod
def op_symbolic_kwargs(cls, module: QuantConvTranspose2d):
@@ -214,7 +214,7 @@ class XIRQuantLinearHandler(XIRQuantWBIOLHandler):
@property
def op_impl(self):
- return XIRGemmPlaceholderFunction.apply
+ return XIRGemmFn.apply
@classmethod
def op_symbolic_kwargs(cls, module: QuantLinear):
diff --git a/src/brevitas/export/pytorch/handler/base.py b/src/brevitas/export/pytorch/handler/base.py
index f445b19a8..66cd6d138 100644
--- a/src/brevitas/export/pytorch/handler/base.py
+++ b/src/brevitas/export/pytorch/handler/base.py
@@ -2,7 +2,7 @@
import torch
from torch import Tensor
-from brevitas.export.common.handler import Validate8BitHandler, TypedZeroPointHandler
+from brevitas.export.handler import BaseHandler, BitWidthHandlerMixin, ZeroPointHandlerMixin
SCALAR_SHAPE = ()
@@ -11,7 +11,7 @@ def _is_scalar(x: Tensor):
return x.shape == SCALAR_SHAPE
-class PytorchQuantLayerHandler(Validate8BitHandler, TypedZeroPointHandler, ABC):
+class PytorchQuantLayerHandler(BaseHandler, BitWidthHandlerMixin, ZeroPointHandlerMixin, ABC):
@classmethod
@abstractmethod
diff --git a/src/brevitas/export/pytorch/handler/parameter.py b/src/brevitas/export/pytorch/handler/parameter.py
index bf1da5cae..5ae8b48d8 100644
--- a/src/brevitas/export/pytorch/handler/parameter.py
+++ b/src/brevitas/export/pytorch/handler/parameter.py
@@ -9,7 +9,6 @@
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from .base import PytorchQuantLayerHandler
-from . import qF
class PytorchQuantWBIOLHandler(PytorchQuantLayerHandler):
@@ -47,7 +46,7 @@ def prepare_bias(cls, module: QuantWBIOL):
@classmethod
def prepare_weight_quant(cls, module: QuantWBIOL):
- cls.validate_8b_bit_width(module.quant_weight_bit_width())
+ cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True)
scale = module.quant_weight_scale()
zero_point = cls.quant_weight_zero_point(module)
signed = module.is_quant_weight_signed
diff --git a/src/brevitas/export/pytorch/manager.py b/src/brevitas/export/pytorch/manager.py
index dfdc22f7c..78e68a5f5 100644
--- a/src/brevitas/export/pytorch/manager.py
+++ b/src/brevitas/export/pytorch/manager.py
@@ -5,8 +5,8 @@
from torch.nn import Module
from brevitas.quant_tensor import QuantTensor
-from brevitas.export.base import BaseManager, ExportContext
-from brevitas.export.base import _set_layer_export_handler, _set_layer_export_mode
+from brevitas.export.manager import BaseManager, ExportContext
+from brevitas.export.manager import _set_layer_export_handler, _set_layer_export_mode
from .handler.parameter import PytorchQuantConv2dHandler
from .handler.parameter import PytorchQuantConv1dHandler
from .handler.parameter import PytorchQuantLinearHandler
diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py
index f22a194b3..0bdde9b9a 100644
--- a/src/brevitas/nn/mixin/base.py
+++ b/src/brevitas/nn/mixin/base.py
@@ -37,7 +37,7 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
-
+import torch.jit
from torch import Tensor
from warnings import warn
@@ -223,6 +223,13 @@ def quant_output_bit_width(self):
def unpack_input(self, inp: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(True)
+ # Hack to recognize a QuantTensor that has decayed to a tuple
+ # when used as input to tracing (e.g. during ONNX export)
+ if (torch._C._get_tracing_state() is not None
+ and isinstance(inp, tuple)
+ and len(inp) == len(QuantTensor._fields)
+ and all([isinstance(t, Tensor) for t in inp])):
+ inp = QuantTensor(*inp)
if isinstance(inp, QuantTensor):
# don't cache values during export pass
if not self.training and not self._export_mode and self.cache_inference_quant_inp:
diff --git a/src/brevitas/onnx/__init__.py b/src/brevitas/onnx/__init__.py
index 1b0bbe55b..7ac536daf 100644
--- a/src/brevitas/onnx/__init__.py
+++ b/src/brevitas/onnx/__init__.py
@@ -1,4 +1,7 @@
# for retrocompatibility
from brevitas.export import export_finn_onnx # noqa
+from brevitas.export import export_brevitas_onnx # noqa
+from brevitas.export import export_pyxir_onnx # noqa
+from brevitas.export import export_standard_qop_onnx # noqa
from brevitas.export import enable_debug # noqa
\ No newline at end of file
diff --git a/src/brevitas/proxy/__init__.py b/src/brevitas/proxy/__init__.py
index 20cbef212..11c4412b5 100644
--- a/src/brevitas/proxy/__init__.py
+++ b/src/brevitas/proxy/__init__.py
@@ -1,4 +1,5 @@
-from .parameter_quant import WeightQuantProxyFromInjector, DecoupledWeightQuantProxyFromInjector
+from .parameter_quant import WeightQuantProxyFromInjector
+from .parameter_quant import DecoupledWeightQuantProxyFromInjector
from .parameter_quant import BiasQuantProxyFromInjector
from .runtime_quant import ActQuantProxyFromInjector
from .runtime_quant import ClampQuantProxyFromInjector
diff --git a/tests/brevitas/export/test_generic_export.py b/tests/brevitas/export/test_generic_export.py
index 7530b33d8..9aecffb09 100644
--- a/tests/brevitas/export/test_generic_export.py
+++ b/tests/brevitas/export/test_generic_export.py
@@ -37,7 +37,7 @@ def forward(self, x):
model(inp) # collect scale factors
model.eval()
BrevitasONNXManager.export(
- model, input_t=inp, export_path='./generic_quant_linear.onnx')
+ model, input_t=inp, export_path='generic_quant_linear.onnx')
def test_generic_decoupled_quant_linear_export():
@@ -66,7 +66,7 @@ def forward(self, x):
model(inp) # collect scale factors
model.eval()
BrevitasONNXManager.export(
- model, input_t=inp, export_path='./generic_decoupled_quant_linear.onnx')
+ model, input_t=inp, export_path='generic_decoupled_quant_linear.onnx')
def test_generic_quant_conv_export():
@@ -95,7 +95,7 @@ def forward(self, x):
model(inp) # collect scale factors
model.eval()
BrevitasONNXManager.export(
- model, input_t=inp, export_path='./generic_quant_conv.onnx')
+ model, input_t=inp, export_path='generic_quant_conv.onnx')
def test_generic_quant_tensor_export():
@@ -123,7 +123,7 @@ def forward(self, x):
model(inp) # collect scale factors
model.eval()
BrevitasONNXManager.export(
- model, input_t=inp, export_path='./generic_quant_tensor.onnx')
+ model, input_t=inp, export_path='generic_quant_tensor.onnx')
def test_generic_quant_avgpool_export():
@@ -134,7 +134,7 @@ class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.inp_quant = QuantIdentity(return_quant_tensor=True)
- self.pool = QuantAvgPool2d(kernel_size=2)
+ self.pool = QuantAvgPool2d(kernel_size=2, return_quant_tensor=False)
def forward(self, x):
return self.pool(self.inp_quant(x))
@@ -145,4 +145,16 @@ def forward(self, x):
model(inp) # collect scale factors
model.eval()
BrevitasONNXManager.export(
- model, input_t=inp, export_path='./generic_quant_avgpool.onnx')
+ model, input_t=inp, export_path='generic_quant_avgpool.onnx')
+
+
+def test_generic_quant_avgpool_export_quant_input():
+ IN_SIZE = (2, OUT_CH, IN_CH, IN_CH)
+ inp = torch.randn(IN_SIZE)
+ inp_quant = QuantIdentity(return_quant_tensor=True)
+ model = QuantAvgPool2d(kernel_size=2, return_quant_tensor=False)
+ inp_quant(inp) # collect scale factors
+ inp_quant.eval()
+ model.eval()
+ BrevitasONNXManager.export(
+ model, input_t=inp_quant(inp), export_path='generic_quant_avgpool_quant_input.onnx')
diff --git a/tests/brevitas_ort/test_onnx_standard.py b/tests/brevitas_ort/test_onnx_standard.py
index 576665db5..ce116401a 100644
--- a/tests/brevitas_ort/test_onnx_standard.py
+++ b/tests/brevitas_ort/test_onnx_standard.py
@@ -8,7 +8,7 @@
from brevitas.nn import QuantConv2d, QuantLinear, QuantIdentity, QuantMaxPool2d
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
-from brevitas.export import export_standard_onnx
+from brevitas.export import export_standard_qop_onnx
from tests.marker import requires_pt_ge
@@ -30,7 +30,7 @@ def compute_ort(export_name, np_input):
def is_brevitas_ort_close(model, np_input, export_name, atol=None):
- export_standard_onnx(model, input_shape=np_input.shape, export_path=export_name)
+ export_standard_qop_onnx(model, input_shape=np_input.shape, export_path=export_name)
brevitas_output = model(torch.from_numpy(np_input))
ort_output = compute_ort(export_name, np_input)
ort_output = torch.from_numpy(ort_output)
diff --git a/tests/brevitas_pyxir/dpuv1/test_quantizer_export.py b/tests/brevitas_pyxir/dpuv1/test_quantizer_export.py
index fd3cacdd3..50f6de3b3 100644
--- a/tests/brevitas_pyxir/dpuv1/test_quantizer_export.py
+++ b/tests/brevitas_pyxir/dpuv1/test_quantizer_export.py
@@ -4,7 +4,7 @@
import torch
from torchvision import models
-from brevitas.export import export_dpuv2_onnx
+from brevitas.export import export_pyxir_onnx
from brevitas.graph.quantizer import quantize, BatchNormHandling
from brevitas.quant.fixed_point import *
from brevitas import config
@@ -36,4 +36,4 @@ def test_rewriter_export(model_name: str):
bias_quant=Int8BiasPerTensorFixedPointInternalScaling,
bn_handling=BatchNormHandling.MERGE_AND_QUANTIZE)
out_file = f'{model_name}.onnx'
- export_dpuv2_onnx(gen_model, input_t=input, export_path=out_file)
+ export_pyxir_onnx(gen_model, input_t=input, export_path=out_file)
diff --git a/tests/brevitas_pyxir/test_dpu_export.py b/tests/brevitas_pyxir/test_dpu_export.py
index ff77a0eaa..5f1a73d33 100644
--- a/tests/brevitas_pyxir/test_dpu_export.py
+++ b/tests/brevitas_pyxir/test_dpu_export.py
@@ -8,7 +8,7 @@
from brevitas.nn import QuantConv2d, QuantLinear, QuantReLU, QuantMaxPool2d, QuantEltwiseAdd
from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint, Int8BiasPerTensorFixedPointInternalScaling
-from brevitas.export import DPUv1Manager, DPUv2Manager
+from brevitas.export import PyXIRManager
from tests.marker import requires_pt_ge
@@ -16,16 +16,13 @@
OUT_CH = 40
IN_CH = 50
TOLERANCE = 1.1
-DPUS = ['DPUv1', 'DPUv2']
-MANAGERS_MAP = {'DPUv1': DPUv1Manager, 'DPUv2': DPUv2Manager}
def gen_linspaced_data(num_samples, min_val=-1.0, max_val=1.0):
return np.linspace(min_val, max_val, num_samples).astype(dtype=np.float32)
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_conv(dpu):
+def test_dpu_export_onnx_quant_conv():
FEATURES = 7
IN_SIZE = (1, IN_CH, FEATURES, FEATURES)
KERNEL_SIZE = 3
@@ -52,11 +49,10 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_conv.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_linear(dpu):
+def test_dpu_export_onnx_quant_linear():
IN_SIZE = (IN_CH, IN_CH)
class Model(torch.nn.Module):
@@ -80,11 +76,10 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_linear.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_conv_bias(dpu):
+def test_dpu_export_onnx_quant_conv_bias():
FEATURES = 7
IN_SIZE = (1, IN_CH, FEATURES, FEATURES)
KERNEL_SIZE = 3
@@ -112,11 +107,10 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_conv_bias.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
-@pytest.mark.parametrize('dpu', DPUS)
-def test_standard_onnx_quant_linear_bias_export(dpu):
+def test_standard_onnx_quant_linear_bias_export():
IN_SIZE = (IN_CH, IN_CH)
class Model(torch.nn.Module):
@@ -141,11 +135,10 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_linear_bias.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_conv_max_pool(dpu):
+def test_dpu_export_onnx_quant_conv_max_pool():
FEATURES = 7
IN_SIZE = (1, IN_CH, FEATURES, FEATURES)
KERNEL_SIZE = 3
@@ -173,12 +166,11 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_conv_maxpool.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
@requires_pt_ge('1.5.0')
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_conv_f_max_pool(dpu):
+def test_dpu_export_onnx_quant_conv_f_max_pool():
FEATURES = 7
IN_SIZE = (1, IN_CH, FEATURES, FEATURES)
KERNEL_SIZE = 3
@@ -205,11 +197,10 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_conv_f_maxpool.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_conv_relu(dpu):
+def test_dpu_export_onnx_quant_conv_relu():
FEATURES = 7
IN_SIZE = (1, IN_CH, FEATURES, FEATURES)
KERNEL_SIZE = 3
@@ -237,12 +228,11 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_conv_relu.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
@requires_pt_ge('1.5.0')
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_conv_f_relu(dpu):
+def test_dpu_export_onnx_quant_conv_f_relu():
FEATURES = 7
IN_SIZE = (1, IN_CH, FEATURES, FEATURES)
KERNEL_SIZE = 3
@@ -269,11 +259,10 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_conv_f_relu.onnx')
+ PyXIRManager.export(model, input_shape=IN_SIZE)
-@pytest.mark.parametrize('dpu', DPUS)
-def test_dpu_export_onnx_quant_conv_add(dpu):
+def test_dpu_export_onnx_quant_conv_add():
FEATURES = 7
IN_SIZE = (1, IN_CH, FEATURES, FEATURES)
KERNEL_SIZE = 3
@@ -311,4 +300,4 @@ def forward(self, x):
model = Model()
model(torch.from_numpy(inp)) # accumulate scale factors
model.eval()
- MANAGERS_MAP[dpu].export(model, input_shape=IN_SIZE, export_path=f'{dpu}_conv_add.onnx')
\ No newline at end of file
+ PyXIRManager.export(model, input_shape=IN_SIZE)
|