Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (proxy): later dequantization #1142

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
19 changes: 4 additions & 15 deletions src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import binary_sign_ste

Expand All @@ -22,7 +21,6 @@ class BinaryQuant(brevitas.jit.ScriptModule):

Args:
scaling_impl (Module): Module that returns a scale factor.
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand All @@ -48,19 +46,17 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, signed: bool = True):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = binary_sign_ste(x) * scale
y = self.delay_wrapper(x, y)
y = binary_sign_ste(x)
return y, scale, self.zero_point(), self.bit_width()


Expand All @@ -74,7 +70,6 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Args:
scaling_impl (Module): Module that returns a scale factor.
tensor_clamp_impl (Module): Module that performs tensor-wise clamping. Default TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand Down Expand Up @@ -104,22 +99,16 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(
self,
scaling_impl: Module,
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp()):
super(ClampedBinaryQuant, self).__init__()
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.tensor_clamp_impl = tensor_clamp_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = self.tensor_clamp_impl(x, -scale, scale)
y = binary_sign_ste(y) * scale
y = self.delay_wrapper(x, y)
y = binary_sign_ste(y)
return y, scale, self.zero_point(), self.bit_width()
5 changes: 0 additions & 5 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor,
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale)
return val_fp_quant, scale

@brevitas.jit.script_method
def dequantize(self, y, scale):
return y * scale

@brevitas.jit.script_method
def forward(self, x):
if self.float_scaling_impl is not None:
Expand All @@ -95,6 +91,5 @@ def forward(self, x):
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values
8 changes: 1 addition & 7 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.nn import Module

import brevitas
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste

Expand Down Expand Up @@ -201,12 +200,10 @@ class TruncIntQuant(brevitas.jit.ScriptModule):
"""
"""

def __init__(
self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0):
def __init__(self, float_to_int_impl: Module, bit_width_impl: Module):
super(TruncIntQuant, self).__init__()
self.msb_clamp_bit_width_impl = bit_width_impl
self.float_to_int_impl = float_to_int_impl
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
Expand All @@ -219,9 +216,6 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
trunc_scale = 2.0 ** trunc_bit_width
y = y / trunc_scale
y = self.float_to_int_impl(y)
y = y - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y, scale, zero_point, output_bit_width


Expand Down
41 changes: 14 additions & 27 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import brevitas
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

Expand All @@ -24,7 +23,6 @@ class IntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tensor: Quantized output in de-quantized format.
Expand All @@ -48,19 +46,17 @@ class IntQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed', 'narrow_range']

def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp()):
super(IntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand All @@ -85,10 +81,7 @@ def max_int(self, bit_width):
@brevitas.jit.script_method
def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor:
y_int = self.to_int(scale, zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y
return y_int


class DecoupledIntQuant(brevitas.jit.ScriptModule):
Expand All @@ -102,7 +95,6 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tensor: Quantized output in de-quantized format.
Expand All @@ -124,19 +116,17 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed', 'narrow_range']

def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp()):
super(DecoupledIntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand Down Expand Up @@ -169,8 +159,5 @@ def forward(
zero_point: Tensor,
bit_width: Tensor,
x: Tensor) -> Tensor:
y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
y = self.to_int(pre_scale, pre_zero_point, bit_width, x)
return y
4 changes: 0 additions & 4 deletions src/brevitas/core/quant/ternary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import ternary_sign_ste

Expand Down Expand Up @@ -57,13 +56,10 @@ def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: in
self.threshold = threshold
self.bit_width = BitWidthConst(2)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
mask = x.abs().gt(self.threshold * scale)
y = mask.float() * ternary_sign_ste(x)
y = y * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
23 changes: 15 additions & 8 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def prepare_for_export(self, module: nn.Module):
self.module_forward = module.fused_activation_quant_proxy

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
return self.module_forward(x)
x, scale, zp, *quant_args = self.module_forward(x)
x = self.dequantize(x, scale, zp)
quant_args = tuple(quant_args)
out = (x, scale, zp) + quant_args
return out


class GroupwiseIntInferenceHandler(IntInferencetHandler):
Expand All @@ -119,14 +123,16 @@ def prepare_for_export(self, module):
self.group_dim = module.group_dim

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
x, scale, zp, *quant_args = self.module_forward(x)
x = self.dequantize(x, scale, zp)
quant_args = tuple(quant_args)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args
out = (x, scale, zp) + quant_args
return out


class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
Expand Down Expand Up @@ -263,14 +269,15 @@ def prepare_for_export(self, module: nn.Module):
self.group_dim = module.group_dim

def forward(self, x: Tensor) -> Tuple[Tensor]:
x, *other = self.module_forward(x)

x, scale, zp, *quant_args = self.module_forward(x)
x = self.dequantize(x, scale, zp)
quant_args = tuple(quant_args)
# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args
out = (x, scale, zp) + quant_args
return out


class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler):
Expand Down
23 changes: 17 additions & 6 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from brevitas import config
from brevitas import is_dynamo_compiling
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
Expand Down Expand Up @@ -96,6 +96,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_class = None # To be redefined by each class
self.quant_tensor_class = None # To be redefined by each class
self.skip_create_quant_tensor = False
quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -138,11 +140,14 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
else:
out = self.create_quant_tensor(out)
else:
out = self.tensor_quant(x)
quant_value, *quant_args = self.tensor_quant(x)
quant_args = tuple(quant_args)
quant_value = self.dequantize(*((quant_value,) + quant_args))
quant_value = self.delay_wrapper(x, quant_value)
if self.skip_create_quant_tensor:
out = out[0]
out = quant_value
else:
out = self.create_quant_tensor(out)
out = self.create_quant_tensor((quant_value,) + quant_args)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = self.cache_class(
out.detach(),
Expand Down Expand Up @@ -274,8 +279,12 @@ def forward(
input_bit_width = quant_input.bit_width
input_is_signed = quant_input.signed

impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
if self.export_mode:
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = self.export_handler(x, input_bit_width, input_is_signed)
else:
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = self.tensor_quant(x, input_bit_width, input_is_signed)
out = self.dequantize(out, scale, zero_point)

if self.skip_create_quant_tensor:
return out
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
Expand Down Expand Up @@ -360,6 +369,8 @@ def forward(
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
out, out_scale, out_zp, out_bit_width = impl(x)
if not self.export_mode:
out = self.dequantize(out, out_scale, out_zp)
if not self.skip_create_quant_tensor:
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def is_narrow_range(self):
def rounding_mode(self):
return _rounding_mode(self.quant_injector)

def dequantize(self, *quant_args):
x, scale, zero_point, *_ = quant_args
out = x - zero_point
out = out * scale
return out

def add_tracked_module(self, module: nn.Module) -> None:
if module is not None:
self.tracked_module_list.append(module)
Expand Down
Loading
Loading