Skip to content

Commit

Permalink
Feat (groupwise): builder class for quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 28, 2024
1 parent 7494e2e commit 87358df
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 10 deletions.
125 changes: 125 additions & 0 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

from dependencies import this
from dependencies import value

from brevitas.core.function_wrapper.ops_ste import CeilSte
from brevitas.core.function_wrapper.ops_ste import FloorSte
from brevitas.core.restrict_val import PowerOfTwo
from brevitas.core.restrict_val import PowerOfTwoRestrictValue
from brevitas.core.restrict_val import RoundSte
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
from brevitas.inject import ExtendedInjector
from brevitas.inject.enum import RestrictValueType
Expand All @@ -22,10 +25,14 @@
from brevitas.quant.base import MinMaxStatsScaling
from brevitas.quant.base import MSEAsymmetricScale
from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.base import MSESymmetricScaleSubInjector
from brevitas.quant.base import ShiftedMinUintQuant
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase
from brevitas.quant.experimental.float_quant_fnuz import FpFNUZMixin
from brevitas.quant.experimental.float_quant_ocp import FpOCPAct
from brevitas.quant.experimental.float_quant_ocp import FpOCPMixin
from brevitas.quant.experimental.float_quant_ocp import FpOCPWeight
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.weight import WeightQuantSolver
Expand Down Expand Up @@ -154,3 +161,121 @@ class ShiftedMXUInt8WeightMSE(MSEAsymmetricScale, ShiftedMXUInt8Weight):
MX Int signed weight quantizer with per-channel MSE-based scaling.
"""
pass


class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat):
"""
Block / group / vector signed symmetric e4m3 weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
proxy_class = GroupwiseWeightFloatQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP


def build_options(
weight_quant,
bit_width,
scale_stats_op,
is_po2_scale,
scale_computation_type,
scale_rounding_func_type: Optional[str],
group_size: int = 32,
group_dim: Optional[int] = None,
scaling_min_val: float = 1e-8):

options = dict()
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}

options['group_size'] = group_size
options['bit_width'] = bit_width
options['scaling_min_val'] = scaling_min_val

if scale_stats_op == 'mse':
weight_quant = type('MSEWeightQuant', (MSESymmetricScale, weight_quant), {})
else:
options['scale_stats_op'] = scale_stats_op

if group_dim is not None:
options['group_dim'] = group_dim

if scale_computation_type == 'param_from_stats':
options['scaling_impl_type'] = 'parameter_from_stats'
elif scale_computation_type == 'stats':
options['scaling_impl_type'] = 'stats'
else:
raise RuntimeError("Not supported")

if is_po2_scale:
scale_rounding_func = scale_rounding_func_dict[scale_rounding_func_type]
options['restrict_scaling_type'] = RestrictValueType.POWER_OF_TWO
options['restrict_value_float_to_int_impl'] = scale_rounding_func
else:
# If not po2, threshold does need any restriction and will match float restriction of the scale
options['restrict_scaling_type'] = RestrictValueType.FP
options['restrict_threshold_impl'] = None
assert scale_rounding_func_type is None, "Rounding for scale not needed when float"
return options, weight_quant


class GroupwiseIntWeightQuantizerBuilder:

def __new__(
self,
bit_width,
scale_stats_op,
is_po2_scale,
scale_computation_type,
scale_rounding_func_type: Optional[str],
group_size: int = 32,
group_dim: Optional[int] = None,
scaling_min_val: float = 1e-8,
):

weight_quant = MXInt8Weight
options, weight_quant = build_options(weight_quant, bit_width,
scale_stats_op,
is_po2_scale,
scale_computation_type,
scale_rounding_func_type,
group_size,
group_dim,
scaling_min_val)
weight_quant = weight_quant.let(**options)
return weight_quant


class GroupwiseFloatWeightQuantizerBuilder(GroupwiseIntWeightQuantizerBuilder):

def __new__(
self,
exponent_bit_width,
mantissa_bit_width,
bit_width,
scale_stats_op,
is_po2_scale,
scale_computation_type,
scale_rounding_func_type: Optional[str],
group_size: int = 32,
group_dim: Optional[int] = None,
scaling_min_val: float = 1e-8,
format: Optional[str] = None):
weight_quant = Fp8e4m3WeightSymmetricGroupQuant

if format == 'ocp':
weight_quant = type('OCPWeightQuant', (FpOCPMixin, weight_quant), {})
if format == 'fnuz':
weight_quant = type('OCPWeightQuant', (FpFNUZMixin, weight_quant), {})

options, weight_quant = build_options(weight_quant, bit_width,
scale_stats_op,
is_po2_scale,
scale_computation_type,
scale_rounding_func_type,
group_size,
group_dim,
scaling_min_val)
options['exponent_bit_width'] = exponent_bit_width
options['mantissa_bit_width'] = mantissa_bit_width

weight_quant = weight_quant.let(**options)
return weight_quant
2 changes: 1 addition & 1 deletion src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import Fp8e4m3WeightSymmetricGroupQuant
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE
Expand Down Expand Up @@ -55,7 +56,6 @@
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
Expand Down
9 changes: 0 additions & 9 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat):
scaling_per_output_type = ScalingPerOutputType.GROUP


class Fp8e4m3WeightSymmetricGroupQuant(Fp8e4m3WeightPerChannelFloat):
"""
Block / group / vector signed symmetric e4m3 weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
proxy_class = GroupwiseWeightFloatQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP


class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
Expand Down

0 comments on commit 87358df

Please sign in to comment.