Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (delay): move delay to proxies #1141

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import binary_sign_ste

Expand All @@ -22,7 +21,6 @@ class BinaryQuant(brevitas.jit.ScriptModule):

Args:
scaling_impl (Module): Module that returns a scale factor.
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand All @@ -48,19 +46,17 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, signed: bool = True):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = binary_sign_ste(x) * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()


Expand All @@ -74,7 +70,6 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Args:
scaling_impl (Module): Module that returns a scale factor.
tensor_clamp_impl (Module): Module that performs tensor-wise clamping. Default TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Expand Down Expand Up @@ -104,22 +99,16 @@ class ClampedBinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(
self,
scaling_impl: Module,
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, tensor_clamp_impl: Module = TensorClamp()):
super(ClampedBinaryQuant, self).__init__()
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.tensor_clamp_impl = tensor_clamp_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = self.tensor_clamp_impl(x, -scale, scale)
y = binary_sign_ste(y) * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
6 changes: 1 addition & 5 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.nn import Module

import brevitas
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import round_ste

Expand Down Expand Up @@ -201,12 +200,10 @@ class TruncIntQuant(brevitas.jit.ScriptModule):
"""
"""

def __init__(
self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0):
def __init__(self, float_to_int_impl: Module, bit_width_impl: Module):
super(TruncIntQuant, self).__init__()
self.msb_clamp_bit_width_impl = bit_width_impl
self.float_to_int_impl = float_to_int_impl
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
Expand All @@ -221,7 +218,6 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
y = self.float_to_int_impl(y)
y = y - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y, scale, zero_point, output_bit_width


Expand Down
33 changes: 12 additions & 21 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import brevitas
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

Expand All @@ -24,7 +23,6 @@ class IntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tensor: Quantized output in de-quantized format.
Expand All @@ -48,19 +46,17 @@ class IntQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed', 'narrow_range']

def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp()):
super(IntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand All @@ -87,7 +83,6 @@ def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tenso
y_int = self.to_int(scale, zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y


Expand All @@ -102,7 +97,6 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule):
float_to_int_impl (Module): Module that performs the conversion from floating point to
integer representation. Default: RoundSte()
tensor_clamp_impl (Module): Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

Returns:
Tensor: Quantized output in de-quantized format.
Expand All @@ -124,19 +118,17 @@ class DecoupledIntQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed', 'narrow_range']

def __init__(
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
self,
narrow_range: bool,
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp()):
super(DecoupledIntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand Down Expand Up @@ -172,5 +164,4 @@ def forward(
y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y
3 changes: 0 additions & 3 deletions src/brevitas/core/quant/ternary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import brevitas
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import ternary_sign_ste

Expand Down Expand Up @@ -57,13 +56,11 @@ def __init__(self, scaling_impl: Module, threshold: float, quant_delay_steps: in
self.threshold = threshold
self.bit_width = BitWidthConst(2)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
mask = x.abs().gt(self.threshold * scale)
y = mask.float() * ternary_sign_ste(x)
y = y * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
12 changes: 8 additions & 4 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from brevitas import config
from brevitas import is_dynamo_compiling
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
Expand Down Expand Up @@ -96,6 +96,8 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_class = None # To be redefined by each class
self.quant_tensor_class = None # To be redefined by each class
self.skip_create_quant_tensor = False
quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -138,11 +140,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
else:
out = self.create_quant_tensor(out)
else:
out = self.tensor_quant(x)
quant_value, *quant_args = self.tensor_quant(x)
quant_args = tuple(quant_args)
quant_value = self.delay_wrapper(x, quant_value)
if self.skip_create_quant_tensor:
out = out[0]
out = quant_value
else:
out = self.create_quant_tensor(out)
out = self.create_quant_tensor((quant_value,) + quant_args)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = self.cache_class(
out.detach(),
Expand Down
35 changes: 23 additions & 12 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import brevitas
from brevitas import is_dynamo_compiling
from brevitas.core.quant.delay import DelayWrapper
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
Expand Down Expand Up @@ -99,6 +100,8 @@ def __init__(self, quant_layer, quant_injector):
self.cache_quant_io_metadata_only = True
self.cache_class = None
self.skip_create_quant_tensor = False
quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -176,31 +179,33 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
out = self.fused_activation_quant_proxy.activation_impl(y)
out = self.export_handler(out)
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = (y, None)
out = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
out = (out, None)
else:
y = self.fused_activation_quant_proxy(y)
out = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

quant_value, *quant_args = out
quant_args = tuple(quant_args)
quant_value = self.delay_wrapper(y, quant_value)
if self.skip_create_quant_tensor:
out = y[0]
out = quant_value
else:
# If the second value (i.e., scale) is None, then quant is disabled
if y[1] is not None:
out = self.create_quant_tensor(y)
if out[1] is not None:
out = self.create_quant_tensor((quant_value,) + quant_args)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve scale/zp/bit/sign even without output quant
y = y[0]
out = self.create_quant_tensor(y, x=x)
out = quant_value
out = self.create_quant_tensor(out, x=x)
else:
out = y[0]
out = quant_value

if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor):
cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only)
Expand Down Expand Up @@ -250,11 +255,14 @@ class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol)
def __init__(self):
super().__init__()
self.skip_create_quant_tensor = False
quant_delay_steps = self.quant_injector.quant_delay_steps if 'quant_delay_steps' in self.quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
out_value = self.delay_wrapper(x, out_value)
if self.skip_create_quant_tensor:
return out_value
return IntQuantTensor(
Expand All @@ -267,6 +275,8 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_create_quant_tensor = False
quant_delay_steps = self.quant_injector.quant_delay_steps if 'quant_delay_steps' in self.quant_injector else None
self.delay_wrapper = DelayWrapper(quant_delay_steps)

def bit_width(self):
if not self.is_quant_enabled:
Expand All @@ -285,6 +295,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
else:
out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
out_value = self.delay_wrapper(x, out_value)
if self.skip_create_quant_tensor:
return out_value
return IntQuantTensor(
Expand Down
20 changes: 0 additions & 20 deletions tests/brevitas/core/binary_quant_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
__all__ = [
'binary_quant',
'clamped_binary_quant',
'delayed_binary_quant',
'delayed_clamped_binary_quant',
'binary_quant_impl_all',
'binary_quant_all', # noqa
'delayed_binary_quant_all', # noqa
]


Expand Down Expand Up @@ -43,21 +40,4 @@ def clamped_binary_quant(scaling_impl_all):
return ClampedBinaryQuant(scaling_impl=scaling_impl_all)


@pytest_cases.fixture()
def delayed_binary_quant(scaling_impl_all, quant_delay_steps):
"""
Delayed BinaryQuant with all variants of scaling
"""
return BinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)


@pytest_cases.fixture()
def delayed_clamped_binary_quant(scaling_impl_all, quant_delay_steps):
"""
ClampedBinaryQuant with all variants of scaling
"""
return ClampedBinaryQuant(scaling_impl=scaling_impl_all, quant_delay_steps=quant_delay_steps)


fixture_union('binary_quant_all', ['binary_quant', 'clamped_binary_quant'])
fixture_union('delayed_binary_quant_all', ['delayed_binary_quant', 'delayed_clamped_binary_quant'])
Loading
Loading