Skip to content

Commit

Permalink
Feat (export): extend quantized ONNX, remove PyXIR DPUv1, rename StdONNX
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius committed Oct 20, 2021
1 parent 7c3a392 commit c1ec48e
Show file tree
Hide file tree
Showing 58 changed files with 1,013 additions and 1,078 deletions.
32 changes: 16 additions & 16 deletions docs/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -1119,10 +1119,10 @@ <h2 id="D">D</h2>
</li>
<li><a href="brevitas.html#brevitas.quant_tensor.QuantTensor.detach_">detach_() (brevitas.quant_tensor.QuantTensor method)</a>
</li>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.dilation">dilation() (brevitas.export.onnx.handler.Kernel1dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.dilation">dilation() (brevitas.export.onnx.handler.Kernel1dApplHandlerMixin static method)</a>

<ul>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.dilation">(brevitas.export.onnx.handler.Kernel2dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.dilation">(brevitas.export.onnx.handler.Kernel2dApplHandlerMixin static method)</a>
</li>
</ul></li>
<li><a href="brevitas.graph.html#brevitas.graph.rewriter.DisableBreakingReturnQuantTensor">DisableBreakingReturnQuantTensor (class in brevitas.graph.rewriter)</a>
Expand Down Expand Up @@ -1729,16 +1729,16 @@ <h2 id="J">J</h2>
<h2 id="K">K</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler">Kernel1dApplHandler (class in brevitas.export.onnx.handler)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler">Kernel1dApplHandlerMixin (class in brevitas.export.onnx.handler)</a>
</li>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler">Kernel2dApplHandler (class in brevitas.export.onnx.handler)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler">Kernel2dApplHandlerMixin (class in brevitas.export.onnx.handler)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.kernel_shape">kernel_shape() (brevitas.export.onnx.handler.Kernel1dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.kernel_shape">kernel_shape() (brevitas.export.onnx.handler.Kernel1dApplHandlerMixin static method)</a>

<ul>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.kernel_shape">(brevitas.export.onnx.handler.Kernel2dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.kernel_shape">(brevitas.export.onnx.handler.Kernel2dApplHandlerMixin static method)</a>
</li>
</ul></li>
</ul></td>
Expand Down Expand Up @@ -2246,10 +2246,10 @@ <h2 id="O">O</h2>
<h2 id="P">P</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.padding">padding() (brevitas.export.onnx.handler.Kernel1dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.padding">padding() (brevitas.export.onnx.handler.Kernel1dApplHandlerMixin static method)</a>

<ul>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.padding">(brevitas.export.onnx.handler.Kernel2dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.padding">(brevitas.export.onnx.handler.Kernel2dApplHandlerMixin static method)</a>
</li>
</ul></li>
<li><a href="brevitas.core.html#brevitas.core.zero_point.ParameterFromRuntimeMinZeroPoint">ParameterFromRuntimeMinZeroPoint (class in brevitas.core.zero_point)</a>
Expand Down Expand Up @@ -2436,7 +2436,7 @@ <h2 id="Q">Q</h2>
</ul></li>
<li><a href="brevitas.export.onnx.finn.handler.html#brevitas.export.onnx.finn.handler.base.FINNQuantInputHandler.quant_input_type">quant_input_type() (brevitas.export.onnx.finn.handler.base.FINNQuantInputHandler static method)</a>
</li>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.quant_input_zero_point">quant_input_zero_point() (brevitas.export.common.handler.TypedZeroPointHandler class method)</a>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.quant_input_zero_point">quant_input_zero_point() (brevitas.export.common.handler.ZeroPointHandlerMixin class method)</a>
</li>
<li><a href="brevitas.export.onnx.finn.handler.html#brevitas.export.onnx.finn.handler.acc.FINNQuantAvgPool2dHandler.quant_output_bit_width">quant_output_bit_width() (brevitas.export.onnx.finn.handler.acc.FINNQuantAvgPool2dHandler static method)</a>

Expand Down Expand Up @@ -2470,7 +2470,7 @@ <h2 id="Q">Q</h2>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="brevitas.export.onnx.finn.handler.html#brevitas.export.onnx.finn.handler.base.FINNQuantIOHandler.quant_output_signed">quant_output_signed() (brevitas.export.onnx.finn.handler.base.FINNQuantIOHandler static method)</a>
</li>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.quant_output_zero_point">quant_output_zero_point() (brevitas.export.common.handler.TypedZeroPointHandler class method)</a>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.quant_output_zero_point">quant_output_zero_point() (brevitas.export.common.handler.ZeroPointHandlerMixin class method)</a>
</li>
<li><a href="brevitas.export.onnx.finn.handler.html#brevitas.export.onnx.finn.handler.act.FINNQuantHardTanhHandler.quant_type">quant_type() (brevitas.export.onnx.finn.handler.act.FINNQuantHardTanhHandler static method)</a>

Expand All @@ -2492,7 +2492,7 @@ <h2 id="Q">Q</h2>
</ul></li>
<li><a href="brevitas.export.onnx.finn.handler.html#brevitas.export.onnx.finn.handler.parameter.FINNQuantWBIOLHandler.quant_weight_type">quant_weight_type() (brevitas.export.onnx.finn.handler.parameter.FINNQuantWBIOLHandler static method)</a>
</li>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.quant_weight_zero_point">quant_weight_zero_point() (brevitas.export.common.handler.TypedZeroPointHandler class method)</a>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.quant_weight_zero_point">quant_weight_zero_point() (brevitas.export.common.handler.ZeroPointHandlerMixin class method)</a>
</li>
<li><a href="brevitas.export.onnx.finn.function.html#brevitas.export.onnx.finn.function.acc.QuantAvgPool2dPlaceholderFunction">QuantAvgPool2dPlaceholderFunction (class in brevitas.export.onnx.finn.function.acc)</a>
</li>
Expand Down Expand Up @@ -2738,10 +2738,10 @@ <h2 id="S">S</h2>
</li>
<li><a href="brevitas.export.onnx.standard.handler.html#brevitas.export.onnx.standard.handler.base.StdONNXQuantWrapperHandler">StdONNXQuantWrapperHandler (class in brevitas.export.onnx.standard.handler.base)</a>
</li>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.stride">stride() (brevitas.export.onnx.handler.Kernel1dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel1dApplHandler.stride">stride() (brevitas.export.onnx.handler.Kernel1dApplHandlerMixin static method)</a>

<ul>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.stride">(brevitas.export.onnx.handler.Kernel2dApplHandler static method)</a>
<li><a href="brevitas.export.onnx.html#brevitas.export.onnx.handler.Kernel2dApplHandler.stride">(brevitas.export.onnx.handler.Kernel2dApplHandlerMixin static method)</a>
</li>
</ul></li>
<li><a href="brevitas.graph.tracer.wrapper.html#brevitas.graph.tracer.wrapper.builtin.StrWrapper">StrWrapper (class in brevitas.graph.tracer.wrapper.builtin)</a>
Expand Down Expand Up @@ -2997,7 +2997,7 @@ <h2 id="T">T</h2>
</li>
<li><a href="brevitas.export.html#brevitas.export.base.BaseHandler.training">(brevitas.export.base.BaseHandler attribute)</a>
</li>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.training">(brevitas.export.common.handler.TypedZeroPointHandler attribute)</a>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.training">(brevitas.export.common.handler.ZeroPointHandlerMixin attribute)</a>
</li>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.Validate8BitHandler.training">(brevitas.export.common.handler.Validate8BitHandler attribute)</a>
</li>
Expand Down Expand Up @@ -3140,7 +3140,7 @@ <h2 id="T">T</h2>
</li>
<li><a href="brevitas.utils.html#brevitas.utils.torch_utils.TupleSequential">TupleSequential (class in brevitas.utils.torch_utils)</a>
</li>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler">TypedZeroPointHandler (class in brevitas.export.common.handler)</a>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler">ZeroPointHandlerMixin (class in brevitas.export.common.handler)</a>
</li>
</ul></td>
</tr></table>
Expand Down Expand Up @@ -3228,7 +3228,7 @@ <h2 id="Z">Z</h2>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="brevitas.html#brevitas.quant_tensor.QuantTensor.zero_point">zero_point (brevitas.quant_tensor.QuantTensor attribute)</a>
</li>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.zero_point_with_dtype">zero_point_with_dtype() (brevitas.export.common.handler.TypedZeroPointHandler class method)</a>
<li><a href="brevitas.export.common.html#brevitas.export.common.handler.TypedZeroPointHandler.zero_point_with_dtype">zero_point_with_dtype() (brevitas.export.common.handler.ZeroPointHandlerMixin class method)</a>
</li>
<li><a href="brevitas.core.html#brevitas.core.zero_point.ZeroZeroPoint">ZeroZeroPoint (class in brevitas.core.zero_point)</a>
</li>
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/core/bit_width/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class BitWidthConst(brevitas.jit.ScriptModule):
def __init__(self, bit_width: int) -> None:
super(BitWidthConst, self).__init__()
assert isinstance(bit_width, int)
self.bit_width = StatelessBuffer(torch.tensor(float(int(bit_width))))
self.bit_width = StatelessBuffer(torch.tensor(float(bit_width)))

@brevitas.jit.script_method
def forward(self) -> Tensor:
Expand Down
53 changes: 20 additions & 33 deletions src/brevitas/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,39 @@
from brevitas import config
from functools import wraps

from .onnx.finn.manager import FINNManager
from .onnx.standard.manager import StdONNXManager
from .onnx.vitis_ai.pyxir.dpuv1.manager import DPUv1Manager
from .onnx.vitis_ai.pyxir.dpuv2.manager import DPUv2Manager
from .onnx.generic.manager import BrevitasONNXManager
from .onnx.standard.qoperator.manager import StdQOpONNXManager
from .onnx.vitis_ai.pyxir.manager import PyXIRManager
from .onnx.vitis_ai.xir.manager import XIRManager
from .onnx.debug import enable_debug
from .pytorch.manager import PytorchQuantManager


@wraps(FINNManager.export)
def export_finn_onnx(*args, **kwargs):
return FINNManager.export(*args, **kwargs)


def export_dpuv1_onnx(*args, **kwargs):
return DPUv1Manager.export(*args, **kwargs)


def export_dpuv2_onnx(*args, **kwargs):
return DPUv2Manager.export(*args, **kwargs)


def export_standard_onnx(*args, **kwargs):
return StdONNXManager.export(*args, **kwargs)


def jit_trace_dpuv1(*args, **kwargs):
return DPUv1Manager.jit_inference_trace(*args, **kwargs)


def is_ongoing_export():
return config._ONGOING_EXPORT is not None
@wraps(PyXIRManager.export)
def export_pyxir_onnx(*args, **kwargs):
return PyXIRManager.export(*args, **kwargs)


def is_ongoing_finn_export():
return config._ONGOING_EXPORT == FINNManager.target_name
@wraps(XIRManager.export)
def export_xir(*args, **kwargs):
return XIRManager.export(*args, **kwargs)


def is_ongoing_stdonnx_export():
return config._ONGOING_EXPORT == StdONNXManager.target_name
@wraps(BrevitasONNXManager.export)
def export_brevitas_onnx(*args, **kwargs):
return BrevitasONNXManager.export(*args, **kwargs)


def is_ongoing_pyxir_export():
if config._ONGOING_EXPORT is not None:
return PyXIRManager.target_name in config._ONGOING_EXPORT
else:
return False
@wraps(StdQOpONNXManager.export)
def export_standard_qop_onnx(*args, **kwargs):
return StdQOpONNXManager.export(*args, **kwargs)


def is_ongoing_pytorch_export():
return config._ONGOING_EXPORT == PytorchQuantManager.target_name
@wraps(PytorchQuantManager.export)
def export_pytorch_quant(*args, **kwargs):
return PytorchQuantManager.export(*args, **kwargs)
46 changes: 0 additions & 46 deletions src/brevitas/export/common/handler.py

This file was deleted.

105 changes: 105 additions & 0 deletions src/brevitas/export/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from abc import ABC
import math

import torch
from torch.nn import Module
from torch import Tensor


__all__ = [
'BaseHandler',
'BitWidthHandlerMixin',
'ZeroPointHandlerMixin'
]


class BaseHandler(Module, ABC):

def attach_debug_info(self, module):
pass

def prepare_for_export(self, module):
pass

def reset(self):
pass


class BitWidthHandlerMixin(object):

@classmethod
def validate_bit_width(cls, bit_width: Tensor, reference: int, le_then=False):
if bit_width is None:
raise RuntimeError("Bit width cannot be None")
bit_width = int(bit_width.item())
if bit_width > reference:
raise RuntimeError(f"Bit width {bit_width} is not supported.")
elif bit_width < reference and not le_then:
raise RuntimeError(f"Bit width {bit_width} is not supported, should be {reference}b.")
return bit_width

@classmethod
def validate_8b_bit_width(cls, bit_width: Tensor, le_then=False):
return cls.validate_bit_width(bit_width, 8, le_then)

@classmethod
def validate_16b_bit_width(cls, bit_width: Tensor, le_then=False):
return cls.validate_bit_width(bit_width, 16, le_then)

@classmethod
def validate_32b_bit_width(cls, bit_width: Tensor, le_then=False):
return cls.validate_bit_width(bit_width, 32, le_then)


class ScaleHandlerMixin(object):

@classmethod
def validate_scalar_scale(cls, scale: Tensor):
if scale is None:
raise RuntimeError("Scale cannot be None.")
if scale.view(-1).shape[0] != 1:
raise RuntimeError("Only per-tensor scaling is supported.")
return scale.item()

@classmethod
def validate_scalar_int_exponent(cls, scale: Tensor):
cls.validate_scalar_scale(scale)
exponent = math.log2(scale)
if not exponent.is_integer():
raise RuntimeError("Only power-of-two scale factors are supported.")
exponent = int(exponent)
return exponent

@classmethod
def validate_neg_scalar_int_exponent(cls, scale: Tensor):
return - cls.validate_scalar_int_exponent(scale)


class ZeroPointHandlerMixin(object):

@classmethod
def zero_point_with_dtype(cls, signed, zero_point):
if not signed:
if (zero_point < 0).any():
raise RuntimeError("Zero points have to be positive under unsigned quantization")
return zero_point.type(torch.uint8)
else:
return zero_point.type(torch.int8)

@classmethod
def quant_input_zero_point(cls, module):
signed = module.is_quant_input_signed
zero_point = module.quant_input_zero_point()
return cls.zero_point_with_dtype(signed, zero_point)

@classmethod
def quant_weight_zero_point(cls, module):
signed = module.is_quant_weight_signed
zero_point = module.quant_weight_zero_point()
return cls.zero_point_with_dtype(signed, zero_point)

@classmethod
def quant_output_zero_point(cls, module):
signed = module.is_quant_output_signed
zero_point = module.quant_output_zero_point()
return cls.zero_point_with_dtype(signed, zero_point)
12 changes: 0 additions & 12 deletions src/brevitas/export/base.py → src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,6 @@ def _restore_requires_grad(m: Module, previous_state):
p.requires_grad_(previous_state[n])


class BaseHandler(Module, ABC):

def attach_debug_info(self, module):
pass

def prepare_for_export(self, module):
pass

def reset(self):
pass


class BaseManager(ABC):

target_name = None
Expand Down
Loading

0 comments on commit c1ec48e

Please sign in to comment.