Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Builder groupwise quant #1078

Draft
wants to merge 14 commits into
base: dev
Choose a base branch
from
Prev Previous commit
Integration with llm entrypoing
Giuseppe5 committed Oct 28, 2024
commit dca94faa61afd0cc5b58622ff69258293116c5f6
1 change: 1 addition & 0 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
@@ -206,6 +206,7 @@ def build_options(
raise RuntimeError("Not supported")

if is_po2_scale:
assert scale_rounding_func_type is not None
scale_rounding_func = scale_rounding_func_dict[scale_rounding_func_type]
options['restrict_scaling_type'] = RestrictValueType.POWER_OF_TWO
options['restrict_value_float_to_int_impl'] = scale_rounding_func
33 changes: 30 additions & 3 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,8 @@
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import Fp8e4m3WeightSymmetricGroupQuant
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseFloatWeightQuantizerBuilder
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseIntWeightQuantizerBuilder
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE
@@ -222,7 +224,8 @@ def generate_quantizers(
quantize_input_zero_point=False,
device=None,
weight_kwargs=None,
input_kwargs=None):
input_kwargs=None,
weight_scale_rounding_func_type=None):
"""
Replace float layers with quant layers in the target model
"""
@@ -243,8 +246,32 @@ def generate_quantizers(
else:
input_float_format = {}

weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]
if weight_quant_granularity == 'per_group':
if weight_quant_format == 'int':
weight_quant = GroupwiseIntWeightQuantizerBuilder(
bit_width=weight_bit_width,
scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method,
is_po2_scale=weight_scale_precision == 'po2_scale',
scale_computation_type='parameter_from_stats',
scale_rounding_func_type=weight_scale_rounding_func_type,
group_dim=weight_group_dim,
group_size=weight_group_size,
scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8)
else:
weight_quant = GroupwiseFloatWeightQuantizerBuilder(
exponent_bit_width=weight_float_format['exponent_bit_width'],
mantissa_bit_width=weight_float_format['mantissa_bit_width'],
bit_width=weight_bit_width,
scale_stats_op='max' if weight_param_method != 'mse' else weight_param_method,
is_po2_scale=weight_scale_precision == 'po2_scale',
scale_computation_type='parameter_from_stats',
scale_rounding_func_type=weight_scale_rounding_func_type,
group_dim=weight_group_dim,
group_size=weight_group_size,
scaling_min_val=1e-4 if dtype == torch.float16 else 1e-8)
else:
weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]

if input_bit_width is not None and input_scale_type == 'no_scale':
input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][
10 changes: 9 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
@@ -253,7 +253,9 @@ def main(args):
input_quant_granularity=args.input_quant_granularity,
input_group_size=args.input_group_size,
quantize_input_zero_point=args.quantize_input_zero_point,
device=device)
device=device,
weight_scale_rounding_func_type=args.weight_scale_rounding_func_type
)
layer_map = generate_quant_maps(
linear_input_quant=linear_input_quant,
weight_quant=weight_quant,
@@ -400,6 +402,12 @@ def parse_args(args):
default='per_group',
choices=['per_channel', 'per_tensor', 'per_group'],
help='Granularity for scales/zero-point of weights. Default: per_group.')
parser.add_argument(
'--weight-scale-rounding-func-type',
type=str,
default=None,
choices=['round', 'ceil', 'floor'],
help='Rounding function to use with Po2 scale. Default: None.')
parser.add_argument(
'--weight-group-dim',
type=int,