From ecfb3d77dd0c7cbdec283b918f50c3ec817a98f0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 30 Dec 2024 12:33:47 +0000 Subject: [PATCH 01/14] Feat (delay): move delay to proxies --- src/brevitas/core/quant/binary.py | 15 ++------- src/brevitas/core/quant/int.py | 6 +--- src/brevitas/core/quant/int_base.py | 33 +++++++------------- src/brevitas/core/quant/ternary.py | 3 -- src/brevitas/proxy/parameter_quant.py | 12 ++++--- src/brevitas/proxy/runtime_quant.py | 32 ++++++++++++------- tests/brevitas/core/binary_quant_fixture.py | 20 ------------ tests/brevitas/core/shared_quant_fixture.py | 10 ------ tests/brevitas/core/ternary_quant_fixture.py | 13 +------- tests/brevitas/core/test_binary_quant.py | 11 ------- tests/brevitas/core/test_ternary_quant.py | 11 ------- tests/brevitas/proxy/test_proxy.py | 9 ++++++ 12 files changed, 53 insertions(+), 122 deletions(-) diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index 3a4b7346e..a00645b3f 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -10,7 +10,6 @@ import brevitas from brevitas.core.bit_width import BitWidthConst from brevitas.core.function_wrapper import TensorClamp -from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import binary_sign_ste @@ -22,7 +21,6 @@ class BinaryQuant(brevitas.jit.ScriptModule): Args: scaling_impl (Module): Module that returns a scale factor. - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width. @@ -48,19 +46,17 @@ class BinaryQuant(brevitas.jit.ScriptModule): Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ - def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0): + def __init__(self, scaling_impl: Module, signed: bool = True): super(BinaryQuant, self).__init__() assert signed, "Unsigned binary quant not supported" self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) - self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = binary_sign_ste(x) * scale - y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() @@ -74,7 +70,6 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule): Args: scaling_impl (Module): Module that returns a scale factor. tensor_clamp_impl (Module): Module that performs tensor-wise clamping. Default TensorClamp() - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width. @@ -104,16 +99,11 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule): Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ - def __init__( - self, - scaling_impl: Module, - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + def __init__(self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp()): super(ClampedBinaryQuant, self).__init__() self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.tensor_clamp_impl = tensor_clamp_impl @brevitas.jit.script_method @@ -121,5 +111,4 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = self.tensor_clamp_impl(x, -scale, scale) y = binary_sign_ste(y) * scale - y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 328ad63b3..248931d68 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -8,7 +8,6 @@ from torch.nn import Module import brevitas -from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import round_ste @@ -201,12 +200,10 @@ class TruncIntQuant(brevitas.jit.ScriptModule): """ """ - def __init__( - self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0): + def __init__(self, float_to_int_impl: Module, bit_width_impl: Module): super(TruncIntQuant, self).__init__() self.msb_clamp_bit_width_impl = bit_width_impl self.float_to_int_impl = float_to_int_impl - self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, @@ -221,7 +218,6 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, y = self.float_to_int_impl(y) y = y - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y, scale, zero_point, output_bit_width diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 338e5a433..8e94465c9 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -8,7 +8,6 @@ import brevitas from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp -from brevitas.core.quant.delay import DelayWrapper from brevitas.function.ops import max_int from brevitas.function.ops import min_int @@ -24,7 +23,6 @@ class IntQuant(brevitas.jit.ScriptModule): float_to_int_impl (Module): Module that performs the conversion from floating point to integer representation. Default: RoundSte() tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp() - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tensor: Quantized output in de-quantized format. @@ -48,19 +46,17 @@ class IntQuant(brevitas.jit.ScriptModule): __constants__ = ['signed', 'narrow_range'] def __init__( - self, - narrow_range: bool, - signed: bool, - input_view_impl: Module, - float_to_int_impl: Module = RoundSte(), - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + self, + narrow_range: bool, + signed: bool, + input_view_impl: Module, + float_to_int_impl: Module = RoundSte(), + tensor_clamp_impl: Module = TensorClamp()): super(IntQuant, self).__init__() self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.input_view_impl = input_view_impl @brevitas.jit.script_method @@ -87,7 +83,6 @@ def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tenso y_int = self.to_int(scale, zero_point, bit_width, x) y = y_int - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y @@ -102,7 +97,6 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule): float_to_int_impl (Module): Module that performs the conversion from floating point to integer representation. Default: RoundSte() tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp() - quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0 Returns: Tensor: Quantized output in de-quantized format. @@ -124,19 +118,17 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule): __constants__ = ['signed', 'narrow_range'] def __init__( - self, - narrow_range: bool, - signed: bool, - input_view_impl: Module, - float_to_int_impl: Module = RoundSte(), - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + self, + narrow_range: bool, + signed: bool, + input_view_impl: Module, + float_to_int_impl: Module = RoundSte(), + tensor_clamp_impl: Module = TensorClamp()): super(DecoupledIntQuant, self).__init__() self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.input_view_impl = input_view_impl @brevitas.jit.script_method @@ -172,5 +164,4 @@ def forward( y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x) y = y_int - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y diff --git a/src/brevitas/core/quant/ternary.py b/src/brevitas/core/quant/ternary.py index ffaa873de..9fd8f78ce 100644 --- a/src/brevitas/core/quant/ternary.py +++ b/src/brevitas/core/quant/ternary.py @@ -9,7 +9,6 @@ import brevitas from brevitas.core.bit_width import BitWidthConst -from brevitas.core.quant.delay import DelayWrapper from brevitas.core.utils import StatelessBuffer from brevitas.function.ops_ste import ternary_sign_ste @@ -57,7 +56,6 @@ def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: in self.threshold = threshold self.bit_width = BitWidthConst(2) self.zero_point = StatelessBuffer(torch.tensor(0.0)) - self.delay_wrapper = DelayWrapper(quant_delay_steps) @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: @@ -65,5 +63,4 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: mask = x.abs().gt(self.threshold * scale) y = mask.float() * ternary_sign_ste(x) y = y * scale - y = self.delay_wrapper(x, y) return y, scale, self.zero_point(), self.bit_width() diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5c4e447d4..b29962059 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -16,9 +16,9 @@ from brevitas import config from brevitas import is_dynamo_compiling from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.quant.delay import DelayWrapper from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector -from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -96,6 +96,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class self.skip_create_quant_tensor = False + quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None + self.delay_wrapper = DelayWrapper(quant_delay_steps) @property def input_view_impl(self): @@ -138,11 +140,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: else: out = self.create_quant_tensor(out) else: - out = self.tensor_quant(x) + quant_value, *quant_args = self.tensor_quant(x) + quant_args = tuple(quant_args) + quant_value = self.delay_wrapper(x, quant_value) if self.skip_create_quant_tensor: - out = out[0] + out = quant_value else: - out = self.create_quant_tensor(out) + out = self.create_quant_tensor((quant_value,) + quant_args) if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: self._cached_weight = self.cache_class( out.detach(), diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index cff192490..64a8faefe 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -14,6 +14,7 @@ import brevitas from brevitas import is_dynamo_compiling +from brevitas.core.quant.delay import DelayWrapper from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -99,6 +100,8 @@ def __init__(self, quant_layer, quant_injector): self.cache_quant_io_metadata_only = True self.cache_class = None self.skip_create_quant_tensor = False + quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None + self.delay_wrapper = DelayWrapper(quant_delay_steps) @property def input_view_impl(self): @@ -176,31 +179,33 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: y = y.value if self.export_mode: - y = self.fused_activation_quant_proxy.activation_impl(y) - y = self.export_handler(y) + out = self.fused_activation_quant_proxy.activation_impl(y) + out = self.export_handler(out) elif not self.is_quant_enabled: # A tuple helps later with control flows # The second None value is used later # If quant is not enabled, we still apply input_view in the case of groupwise + padding - y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) - y = (y, None) + out = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y)) + out = (out, None) else: - y = self.fused_activation_quant_proxy(y) + out = self.fused_activation_quant_proxy(y) # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - + quant_value, *quant_args = out + quant_args = tuple(quant_args) + quant_value = self.delay_wrapper(y, quant_value) if self.skip_create_quant_tensor: - out = y[0] + out = quant_value else: # If the second value (i.e., scale) is None, then quant is disabled - if y[1] is not None: - out = self.create_quant_tensor(y) + if out[1] is not None: + out = self.create_quant_tensor((quant_value,) + quant_args) elif self.is_passthrough_act and isinstance(x, QuantTensor): # preserve scale/zp/bit/sign even without output quant - y = y[0] - out = self.create_quant_tensor(y, x=x) + out = quant_value + out = self.create_quant_tensor(out, x=x) else: - out = y[0] + out = quant_value if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) @@ -267,6 +272,8 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.skip_create_quant_tensor = False + quant_delay_steps = self.quant_injector.quant_delay_steps if 'quant_delay_steps' in self.quant_injector else None + self.delay_wrapper = DelayWrapper(quant_delay_steps) def bit_width(self): if not self.is_quant_enabled: @@ -285,6 +292,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + out_value = self.delay_wrapper(x, out_value) if self.skip_create_quant_tensor: return out_value return IntQuantTensor( diff --git a/tests/brevitas/core/binary_quant_fixture.py b/tests/brevitas/core/binary_quant_fixture.py index 32937c2c4..f7fab6a1c 100644 --- a/tests/brevitas/core/binary_quant_fixture.py +++ b/tests/brevitas/core/binary_quant_fixture.py @@ -10,11 +10,8 @@ __all__ = [ 'binary_quant', 'clamped_binary_quant', - 'delayed_binary_quant', - 'delayed_clamped_binary_quant', 'binary_quant_impl_all', 'binary_quant_all', # noqa - 'delayed_binary_quant_all', # noqa ] @@ -43,21 +40,4 @@ def clamped_binary_quant(scaling_impl_all): return ClampedBinaryQuant(scaling_impl=scaling_impl_all) -@pytest_cases.fixture() -def delayed_binary_quant(scaling_impl_all, quant_delay_steps): - """ - Delayed BinaryQuant with all variants of scaling - """ - return BinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps) - - -@pytest_cases.fixture() -def delayed_clamped_binary_quant(scaling_impl_all, quant_delay_steps): - """ - ClampedBinaryQuant with all variants of scaling - """ - return ClampedBinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps) - - fixture_union('binary_quant_all', ['binary_quant', 'clamped_binary_quant']) -fixture_union('delayed_binary_quant_all', ['delayed_binary_quant', 'delayed_clamped_binary_quant']) diff --git a/tests/brevitas/core/shared_quant_fixture.py b/tests/brevitas/core/shared_quant_fixture.py index baf565d06..f2a440823 100644 --- a/tests/brevitas/core/shared_quant_fixture.py +++ b/tests/brevitas/core/shared_quant_fixture.py @@ -9,7 +9,6 @@ from brevitas.core.scaling import ParameterScaling __all__ = [ - 'quant_delay_steps', 'const_scaling_impl', 'parameter_scaling_impl', 'standalone_scaling_init', @@ -18,15 +17,6 @@ ] -@pytest_cases.fixture() -@pytest_cases.parametrize('steps', [1, 10]) -def quant_delay_steps(steps): - """ - Non-zero steps to delay quantization - """ - return steps - - @pytest_cases.fixture() def const_scaling_impl(standalone_scaling_init): """ diff --git a/tests/brevitas/core/ternary_quant_fixture.py b/tests/brevitas/core/ternary_quant_fixture.py index 2cb7ade78..782631817 100644 --- a/tests/brevitas/core/ternary_quant_fixture.py +++ b/tests/brevitas/core/ternary_quant_fixture.py @@ -5,7 +5,7 @@ from brevitas.core.quant import TernaryQuant -__all__ = ['threshold_init', 'ternary_quant', 'delayed_ternary_quant'] +__all__ = ['threshold_init', 'ternary_quant'] @pytest_cases.fixture() @@ -22,14 +22,3 @@ def ternary_quant(scaling_impl_all, threshold_init): Ternary quant with all variants of scaling """ return TernaryQuant(scaling_impl=scaling_impl_all, threshold=threshold_init) - - -@pytest_cases.fixture() -def delayed_ternary_quant(scaling_impl_all, quant_delay_steps, threshold_init): - """ - Delayed TernaryQuant with all variants of scaling - """ - return TernaryQuant( - scaling_impl=scaling_impl_all, - quant_delay_steps=quant_delay_steps, - threshold=threshold_init) diff --git a/tests/brevitas/core/test_binary_quant.py b/tests/brevitas/core/test_binary_quant.py index 4f82e4815..bef166053 100644 --- a/tests/brevitas/core/test_binary_quant.py +++ b/tests/brevitas/core/test_binary_quant.py @@ -57,17 +57,6 @@ def test_output_value(self, binary_quant_all, inp): output, scale, _, _ = binary_quant_all(inp) assert is_binary_output_value_correct(scale, output) - def test_delayed_output_value(self, delayed_binary_quant_all, quant_delay_steps, randn_inp): - """ - Test delayed quantization by a certain number of steps. Because delayed quantization is - stateful, we can't use Hypothesis to generate the input, so we resort to a basic fixture. - """ - for i in range(quant_delay_steps): - output, _, _, _ = delayed_binary_quant_all(randn_inp) - assert (output == randn_inp).all() - output, scale, _, _ = delayed_binary_quant_all(randn_inp) - assert is_binary_output_value_correct(scale, output) - @given(inp=float_tensor_random_shape_st()) def test_output_bit_width(self, binary_quant_all, inp): _, _, _, bit_width = binary_quant_all(inp) diff --git a/tests/brevitas/core/test_ternary_quant.py b/tests/brevitas/core/test_ternary_quant.py index d2b1817fc..2d6d66c1f 100644 --- a/tests/brevitas/core/test_ternary_quant.py +++ b/tests/brevitas/core/test_ternary_quant.py @@ -67,17 +67,6 @@ def test_output_value(self, ternary_quant, inp): output, scale, _, _ = ternary_quant(inp) assert is_ternary_output_value_correct(scale, output) - def test_delayed_output_value(self, delayed_ternary_quant, quant_delay_steps, randn_inp): - """ - Test delayed quantization by a certain number of steps. Because delayed quantization is - stateful, we can't use Hypothesis to generate the input, so we resort to a basic fixture. - """ - for i in range(quant_delay_steps): - output, _, _, _ = delayed_ternary_quant(randn_inp) - assert (output == randn_inp).all() - output, scale, _, _ = delayed_ternary_quant(randn_inp) - assert is_ternary_output_value_correct(scale, output) - @given(inp=float_tensor_random_shape_st()) def test_output_bit_width(self, ternary_quant, inp): _, _, _, bit_width = ternary_quant(inp) diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py index 28c3eed9e..cb4bedbf9 100644 --- a/tests/brevitas/proxy/test_proxy.py +++ b/tests/brevitas/proxy/test_proxy.py @@ -1,4 +1,5 @@ import pytest +import torch from brevitas.nn import QuantLinear from brevitas.nn.quant_activation import QuantReLU @@ -80,3 +81,11 @@ def test_dynamic_act_proxy(self): model.act_quant.disable_quant = True assert model.act_quant.bit_width() is None + + def test_delay_act_proxy(self): + model = QuantReLU(quant_delay_steps=1) + inp = torch.randn(1, 5) + o = model(inp) + assert torch.allclose(inp, o) + o = model(inp) + assert not torch.allclose(inp, o) From 1852621466ccbff880020e9056519c033695e0fc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 30 Dec 2024 17:40:15 +0000 Subject: [PATCH 02/14] Feat (proxy): later dequantization --- src/brevitas/core/quant/binary.py | 2 +- src/brevitas/core/quant/float.py | 5 ----- src/brevitas/core/quant/int_base.py | 8 ++------ src/brevitas/core/quant/ternary.py | 1 - src/brevitas/proxy/parameter_quant.py | 11 +++++++++-- src/brevitas/proxy/quant_proxy.py | 6 ++++++ src/brevitas/proxy/runtime_quant.py | 23 ++++++++++++++--------- 7 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index a00645b3f..8ba8016ee 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -110,5 +110,5 @@ def __init__(self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = self.tensor_clamp_impl(x, -scale, scale) - y = binary_sign_ste(y) * scale + y = binary_sign_ste(y) return y, scale, self.zero_point(), self.bit_width() diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 09dcc248a..4f3b42346 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -75,10 +75,6 @@ def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale) return val_fp_quant, scale - @brevitas.jit.script_method - def dequantize(self, y, scale): - return y * scale - @brevitas.jit.script_method def forward(self, x): if self.float_scaling_impl is not None: @@ -95,6 +91,5 @@ def forward(self, x): # after quantizing, clamp to special cases like NaN/inf if they are set y, saturating, inf_values, nan_values = self.float_clamp_impl( y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - y = self.dequantize(y, scale) # This is to respect the current interface of proxies return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 8e94465c9..745fb7ef9 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -81,9 +81,7 @@ def max_int(self, bit_width): @brevitas.jit.script_method def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: y_int = self.to_int(scale, zero_point, bit_width, x) - y = y_int - zero_point - y = y * scale - return y + return y_int class DecoupledIntQuant(brevitas.jit.ScriptModule): @@ -161,7 +159,5 @@ def forward( zero_point: Tensor, bit_width: Tensor, x: Tensor) -> Tensor: - y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x) - y = y_int - zero_point - y = y * scale + y = self.to_int(pre_scale, pre_zero_point, bit_width, x) return y diff --git a/src/brevitas/core/quant/ternary.py b/src/brevitas/core/quant/ternary.py index 9fd8f78ce..59424a10a 100644 --- a/src/brevitas/core/quant/ternary.py +++ b/src/brevitas/core/quant/ternary.py @@ -62,5 +62,4 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) mask = x.abs().gt(self.threshold * scale) y = mask.float() * ternary_sign_ste(x) - y = y * scale return y, scale, self.zero_point(), self.bit_width() diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index b29962059..27fd360a6 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -142,6 +142,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: else: quant_value, *quant_args = self.tensor_quant(x) quant_args = tuple(quant_args) + quant_value = self.dequantize((quant_value,) + quant_args) quant_value = self.delay_wrapper(x, quant_value) if self.skip_create_quant_tensor: out = quant_value @@ -278,8 +279,12 @@ def forward( input_bit_width = quant_input.bit_width input_is_signed = quant_input.signed - impl = self.export_handler if self.export_mode else self.tensor_quant - out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) + if self.export_mode: + out, scale, zero_point, bit_width, pre_scale, pre_zero_point = self.export_handler(x, input_bit_width, input_is_signed) + else: + out, scale, zero_point, bit_width, pre_scale, pre_zero_point = self.tensor_quant(x, input_bit_width, input_is_signed) + out = self.dequantize(out, scale, zero_point) + if self.skip_create_quant_tensor: return out return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) @@ -364,6 +369,8 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) + if not self.export_mode: + out = self.dequantize(out, out_scale, out_zp) if not self.skip_create_quant_tensor: out = IntQuantTensor( out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 845bfd515..49978eb72 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -113,6 +113,12 @@ def is_narrow_range(self): def rounding_mode(self): return _rounding_mode(self.quant_injector) + def dequantize(self, quant_args): + x, scale, zero_point, *_ = quant_args + out = x - zero_point + out = out * scale + return out + def add_tracked_module(self, module: nn.Module) -> None: if module is not None: self.tracked_module_list.append(module) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 64a8faefe..b48767422 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -189,23 +189,25 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: out = (out, None) else: out = self.fused_activation_quant_proxy(y) + quant_value, *quant_args = out + quant_args = tuple(quant_args) + quant_value = self.dequantize((quant_value,) + quant_args) + quant_value = self.delay_wrapper(y, quant_value) + out = (quant_value,) + quant_args # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - quant_value, *quant_args = out - quant_args = tuple(quant_args) - quant_value = self.delay_wrapper(y, quant_value) if self.skip_create_quant_tensor: - out = quant_value + out = out[0] else: # If the second value (i.e., scale) is None, then quant is disabled if out[1] is not None: - out = self.create_quant_tensor((quant_value,) + quant_args) + out = self.create_quant_tensor(out) elif self.is_passthrough_act and isinstance(x, QuantTensor): # preserve scale/zp/bit/sign even without output quant - out = quant_value + out = out[0] out = self.create_quant_tensor(out, x=x) else: - out = quant_value + out = out[0] if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) @@ -260,6 +262,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + out_value = self.dequantize(out_value, out_scale, out_zp) if self.skip_create_quant_tensor: return out_value return IntQuantTensor( @@ -291,8 +294,10 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: x.value, x.scale, x.zero_point, x.bit_width, x.signed) else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) - out_value, out_scale, out_zp, out_bit_width = out_tuple - out_value = self.delay_wrapper(x, out_value) + out_value, out_scale, out_zp, out_bit_width = out_tuple + out_value = self.dequantize(out_value, out_scale, out_zp) + out_value = self.delay_wrapper(x, out_value) + if self.skip_create_quant_tensor: return out_value return IntQuantTensor( From d115f5ff1c14c8c97e5dda68af55790c031e173e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 30 Dec 2024 17:44:29 +0000 Subject: [PATCH 03/14] Variable args --- src/brevitas/proxy/quant_proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 49978eb72..3ed444bd4 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -113,7 +113,7 @@ def is_narrow_range(self): def rounding_mode(self): return _rounding_mode(self.quant_injector) - def dequantize(self, quant_args): + def dequantize(self, *quant_args): x, scale, zero_point, *_ = quant_args out = x - zero_point out = out * scale From fb7672b006432d68309a309768496fd25ce583ac Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 30 Dec 2024 17:55:18 +0000 Subject: [PATCH 04/14] Fix --- src/brevitas/proxy/parameter_quant.py | 2 +- src/brevitas/proxy/runtime_quant.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 27fd360a6..ace2414d6 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -142,7 +142,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: else: quant_value, *quant_args = self.tensor_quant(x) quant_args = tuple(quant_args) - quant_value = self.dequantize((quant_value,) + quant_args) + quant_value = self.dequantize(*((quant_value,) + quant_args)) quant_value = self.delay_wrapper(x, quant_value) if self.skip_create_quant_tensor: out = quant_value diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index b48767422..31c5224be 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -191,7 +191,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: out = self.fused_activation_quant_proxy(y) quant_value, *quant_args = out quant_args = tuple(quant_args) - quant_value = self.dequantize((quant_value,) + quant_args) + quant_value = self.dequantize(*((quant_value,) + quant_args)) quant_value = self.delay_wrapper(y, quant_value) out = (quant_value,) + quant_args # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, From 4013d38010e5034271a30948d0c92f2f77b7a39d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 30 Dec 2024 18:05:30 +0000 Subject: [PATCH 05/14] fix --- src/brevitas/proxy/quant_proxy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 3ed444bd4..b0056e6d5 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -116,6 +116,7 @@ def rounding_mode(self): def dequantize(self, *quant_args): x, scale, zero_point, *_ = quant_args out = x - zero_point + print(out, scale) out = out * scale return out From dcb6af6edf208699fe1a46a0460b3beda47fc79c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 30 Dec 2024 18:16:53 +0000 Subject: [PATCH 06/14] Binary fix --- src/brevitas/core/quant/binary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index 8ba8016ee..7af2d1618 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -56,7 +56,7 @@ def __init__(self, scaling_impl: Module, signed: bool = True): @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) - y = binary_sign_ste(x) * scale + y = binary_sign_ste(x) return y, scale, self.zero_point(), self.bit_width() From bfaad2c232bbe71d71c5685421641e1875f39d4d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 31 Dec 2024 11:53:54 +0000 Subject: [PATCH 07/14] remove print --- src/brevitas/proxy/quant_proxy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index b0056e6d5..3ed444bd4 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -116,7 +116,6 @@ def rounding_mode(self): def dequantize(self, *quant_args): x, scale, zero_point, *_ = quant_args out = x - zero_point - print(out, scale) out = out * scale return out From 2ec8416630df2efa5bfc04feaf7aff957cb62765 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 31 Dec 2024 11:55:44 +0000 Subject: [PATCH 08/14] Trunc fix --- src/brevitas/core/quant/int.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 248931d68..3d60fee80 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -216,8 +216,6 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, trunc_scale = 2.0 ** trunc_bit_width y = y / trunc_scale y = self.float_to_int_impl(y) - y = y - zero_point - y = y * scale return y, scale, zero_point, output_bit_width From 18fee55eff3fa63831258cb55a3fbfa44f38cd35 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 31 Dec 2024 12:01:34 +0000 Subject: [PATCH 09/14] fix avgpool export --- src/brevitas/proxy/runtime_quant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 31c5224be..5de901de3 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -292,6 +292,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.export_mode: out_tuple = self.export_handler( x.value, x.scale, x.zero_point, x.bit_width, x.signed) + out_value, out_scale, out_zp, out_bit_width = out_tuple else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple From 6169c832e9c8fe9eca4bdaa0b58c6a311a562622 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 31 Dec 2024 16:57:06 +0000 Subject: [PATCH 10/14] Fix more tests --- tests/brevitas/core/test_binary_quant.py | 4 +++- tests/brevitas/core/test_int_quant.py | 1 + tests/brevitas/core/test_ternary_quant.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/core/test_binary_quant.py b/tests/brevitas/core/test_binary_quant.py index bef166053..27b43dbde 100644 --- a/tests/brevitas/core/test_binary_quant.py +++ b/tests/brevitas/core/test_binary_quant.py @@ -31,6 +31,7 @@ def test_binary_quant(self, binary_quant_impl_all, inp, scale_init): scaling_impl = mock.Mock(return_value=scale_init) binary_quant = binary_quant_impl_all(scaling_impl) output, scale, zp, bit_width = binary_quant(inp) + output = (output - zp) * scale scaling_impl.assert_called_once_with(inp) assert is_binary_output_value_correct(scale, output) assert is_binary_output_sign_correct(inp, output) @@ -54,7 +55,8 @@ def test_output_sign(self, binary_quant_all, inp): @given(inp=float_tensor_random_shape_st()) def test_output_value(self, binary_quant_all, inp): - output, scale, _, _ = binary_quant_all(inp) + output, scale, zp, _ = binary_quant_all(inp) + output = (output - zp) * scale assert is_binary_output_value_correct(scale, output) @given(inp=float_tensor_random_shape_st()) diff --git a/tests/brevitas/core/test_int_quant.py b/tests/brevitas/core/test_int_quant.py index 5e106dc4c..754cecb75 100644 --- a/tests/brevitas/core/test_int_quant.py +++ b/tests/brevitas/core/test_int_quant.py @@ -60,4 +60,5 @@ def test_int_quant_arange( # apply scale and zero point to the input distribution inp = scale * (arange_int_tensor - zero_point).float() output = int_quant(scale, zero_point, bit_width, inp) + output = (output - zero_point) * scale assert torch.isclose(inp, output).all() diff --git a/tests/brevitas/core/test_ternary_quant.py b/tests/brevitas/core/test_ternary_quant.py index 2d6d66c1f..fd8f4f075 100644 --- a/tests/brevitas/core/test_ternary_quant.py +++ b/tests/brevitas/core/test_ternary_quant.py @@ -37,6 +37,7 @@ def test_ternary_quant(self, inp, scale_init, threshold): scaling_impl = mock.Mock(return_value=scale_init) ternary_quant = TernaryQuant(scaling_impl, threshold) output, scale, zp, bit_width = ternary_quant(inp) + output = (output - zp) * scale scaling_impl.assert_called_once_with(inp) assert is_ternary_output_value_correct(scale, output) assert is_ternary_output_sign_correct(inp, scale * threshold, output) @@ -64,7 +65,8 @@ def test_output_sign(self, ternary_quant, inp): @given(inp=float_tensor_random_shape_st()) def test_output_value(self, ternary_quant, inp): - output, scale, _, _ = ternary_quant(inp) + output, scale, zp, _ = ternary_quant(inp) + output = (output - zp) * scale assert is_ternary_output_value_correct(scale, output) @given(inp=float_tensor_random_shape_st()) From 6e984c765a8c6e9fff26c4b917cc6f84add03b7e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 31 Dec 2024 17:44:29 +0000 Subject: [PATCH 11/14] fix float test --- tests/brevitas/core/test_float_quant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 148507ada..b503cd3eb 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -221,6 +221,7 @@ def test_inner_scale(inp, minifloat_format, scale): out = val_fp_quant * scale expected_out, expected_scale, *_ = float_quant(inp) + expected_out = expected_out * expected_scale assert scale == expected_scale if scale == 0.0: From 5063dc693016d1ef44ed0cfa78d8ac17da8dbe0a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 31 Dec 2024 19:17:40 +0000 Subject: [PATCH 12/14] Fix for dynamic inference mode --- src/brevitas/export/inference/handler.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index c7fc21790..7cadc0ede 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -103,7 +103,11 @@ def prepare_for_export(self, module: nn.Module): self.module_forward = module.fused_activation_quant_proxy def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: - return self.module_forward(x) + x, scale, zp, *quant_args = self.module_forward(x) + x = self.dequantize(x, scale, zp) + quant_args = tuple(quant_args) + out = (x, scale, zp) + quant_args + return out class GroupwiseIntInferenceHandler(IntInferencetHandler): @@ -119,14 +123,16 @@ def prepare_for_export(self, module): self.group_dim = module.group_dim def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]: - x, *other = self.module_forward(x) + x, scale, zp, *quant_args = self.module_forward(x) + x = self.dequantize(x, scale, zp) + quant_args = tuple(quant_args) # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) - output_args = tuple([x] + list(other)) - return output_args + out = (x, scale, zp) + quant_args + return out class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): @@ -263,14 +269,15 @@ def prepare_for_export(self, module: nn.Module): self.group_dim = module.group_dim def forward(self, x: Tensor) -> Tuple[Tensor]: - x, *other = self.module_forward(x) - + x, scale, zp, *quant_args = self.module_forward(x) + x = self.dequantize(x, scale, zp) + quant_args = tuple(quant_args) # If we skip quant tensor, we return the flattened version of the groupwise tensor if self.skip_create_quant_tensor: start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) - output_args = tuple([x] + list(other)) - return output_args + out = (x, scale, zp) + quant_args + return out class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler): From ee0beba06d619ff1c08bf1b7b25665d353abca20 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 6 Jan 2025 13:00:47 +0000 Subject: [PATCH 13/14] Fix more tests --- tests/brevitas/core/test_scaling_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/core/test_scaling_quant.py b/tests/brevitas/core/test_scaling_quant.py index ba3d8ef7c..d57ffbe42 100644 --- a/tests/brevitas/core/test_scaling_quant.py +++ b/tests/brevitas/core/test_scaling_quant.py @@ -112,13 +112,13 @@ def hook_scale(module, inp): inp = inp[0] quant_scale, scale, zp, bit_width = module.float_to_int_impl(inp) assert bit_width == SCALE_BIT_WIDTH - assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + assert torch.allclose(quant_scale, torch.round(quant_scale)) def hook_zp(module, inp): inp = inp[0] quant_scale, scale, zp, bit_width = module.zp_int_quant(inp) assert bit_width == ZP_BIT_WIDTH - assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale)) + assert torch.allclose(quant_scale, torch.round(quant_scale)) linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleQuantZPInt8WeightPerTensorFloat) for module in linear.modules(): From d345ca5f90a58493f9deb390b4424f8da2660391 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 6 Jan 2025 16:48:23 +0000 Subject: [PATCH 14/14] Last fix --- src/brevitas/proxy/runtime_quant.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 5de901de3..da7db062c 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -102,6 +102,7 @@ def __init__(self, quant_layer, quant_injector): self.skip_create_quant_tensor = False quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None self.delay_wrapper = DelayWrapper(quant_delay_steps) + self.observer_only = False @property def input_view_impl(self): @@ -191,7 +192,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: out = self.fused_activation_quant_proxy(y) quant_value, *quant_args = out quant_args = tuple(quant_args) - quant_value = self.dequantize(*((quant_value,) + quant_args)) + if not self.observer_only: + quant_value = self.dequantize(*((quant_value,) + quant_args)) quant_value = self.delay_wrapper(y, quant_value) out = (quant_value,) + quant_args # If y is an empty QuantTensor, we need to check if this is a passthrough proxy,