Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 15, 2025
1 parent 925c3a5 commit 3ea4b03
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 45 deletions.
30 changes: 0 additions & 30 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,33 +344,3 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
# pre-zero centering before rounding and clipping
z = self.get_zero_center(x) / scale # need to scale the norm by s
return z


class RuntimeDynamicGroupZeroScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
zero_point_stats_impl: Module,
int_quant,
quantize_zero_point) -> None:
super(RuntimeDynamicGroupZeroScaling, self).__init__()

self.group_size = group_size
self.group_dim = group_dim
self.zero_point_stats_impl = zero_point_stats_impl
self.input_view_impl = input_view_impl
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)

@brevitas.jit.script_method
def forward(
self,
stats_input: torch.Tensor,
scale,
bit_width) -> torch.Tensor:

stats_input_reshaped = self.input_view_impl(stats_input)
out = self.zero_point_stats_impl(stats_input_reshaped)
return self.scale_shift_zero_point(-out, scale, bit_width)
22 changes: 7 additions & 15 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
"""
import re

from brevitas.core.stats import NegativeMinOrZero
from brevitas.quant.base import ParameterFromRuntimeZeroPoint
from dependencies import this
import torch
from torch import nn
Expand All @@ -14,8 +12,11 @@
from brevitas.core.function_wrapper import CeilSte
from brevitas.core.function_wrapper import FloorSte
from brevitas.core.restrict_val import RoundSte
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint, RuntimeDynamicGroupZeroScaling
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.core.zero_point import RuntimeDynamicGroupZeroScaling
from brevitas.graph.quantize import layerwise_quantize
from brevitas.quant.base import ParameterFromRuntimeZeroPoint
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
Expand Down Expand Up @@ -60,7 +61,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat, RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE
Expand All @@ -71,6 +72,7 @@
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat

Expand Down Expand Up @@ -152,15 +154,6 @@
'per_channel': {
'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}}

class Test(Int8DynamicActPerGroupFloat):
# zero_point_impl = RuntimeDynamicStatsZeroPoint
zero_point_impl = RuntimeDynamicGroupZeroScaling
zero_point_stats_impl = NegativeMinOrZero
scaling_stats_op = 'min_max'
signed = False
# zero_point_shape = this.scaling_shape
# zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl

INPUT_QUANT_MAP = {
'int': {
'static': {
Expand Down Expand Up @@ -189,8 +182,7 @@ class Test(Int8DynamicActPerGroupFloat):
'sym': Int8DynamicActPerRowFloat,
'asym': ShiftedUint8DynamicActPerRowFloat},
'per_group': {
'sym': Int8DynamicActPerGroupFloat,
'asym': Test}}},
'sym': Int8DynamicActPerGroupFloat}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down

0 comments on commit 3ea4b03

Please sign in to comment.