diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 7038fe1a9..f74fffae8 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -344,33 +344,3 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: # pre-zero centering before rounding and clipping z = self.get_zero_center(x) / scale # need to scale the norm by s return z - - -class RuntimeDynamicGroupZeroScaling(brevitas.jit.ScriptModule): - - def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: Module, - zero_point_stats_impl: Module, - int_quant, - quantize_zero_point) -> None: - super(RuntimeDynamicGroupZeroScaling, self).__init__() - - self.group_size = group_size - self.group_dim = group_dim - self.zero_point_stats_impl = zero_point_stats_impl - self.input_view_impl = input_view_impl - self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) - - @brevitas.jit.script_method - def forward( - self, - stats_input: torch.Tensor, - scale, - bit_width) -> torch.Tensor: - - stats_input_reshaped = self.input_view_impl(stats_input) - out = self.zero_point_stats_impl(stats_input_reshaped) - return self.scale_shift_zero_point(-out, scale, bit_width) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 40e6063f4..0e85917b2 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,8 +4,6 @@ """ import re -from brevitas.core.stats import NegativeMinOrZero -from brevitas.quant.base import ParameterFromRuntimeZeroPoint from dependencies import this import torch from torch import nn @@ -14,8 +12,11 @@ from brevitas.core.function_wrapper import CeilSte from brevitas.core.function_wrapper import FloorSte from brevitas.core.restrict_val import RoundSte -from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint, RuntimeDynamicGroupZeroScaling +from brevitas.core.stats import NegativeMinOrZero +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.core.zero_point import RuntimeDynamicGroupZeroScaling from brevitas.graph.quantize import layerwise_quantize +from brevitas.quant.base import ParameterFromRuntimeZeroPoint from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -60,7 +61,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear -from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat, RuntimeDynamicStatsZeroPoint +from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE @@ -71,6 +72,7 @@ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant +from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat @@ -152,15 +154,6 @@ 'per_channel': { 'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}} -class Test(Int8DynamicActPerGroupFloat): - # zero_point_impl = RuntimeDynamicStatsZeroPoint - zero_point_impl = RuntimeDynamicGroupZeroScaling - zero_point_stats_impl = NegativeMinOrZero - scaling_stats_op = 'min_max' - signed = False - # zero_point_shape = this.scaling_shape - # zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl - INPUT_QUANT_MAP = { 'int': { 'static': { @@ -189,8 +182,7 @@ class Test(Int8DynamicActPerGroupFloat): 'sym': Int8DynamicActPerRowFloat, 'asym': ShiftedUint8DynamicActPerRowFloat}, 'per_group': { - 'sym': Int8DynamicActPerGroupFloat, - 'asym': Test}}}, + 'sym': Int8DynamicActPerGroupFloat}}}, 'po2_scale': { 'stats': { 'per_row': {