Skip to content

Commit

Permalink
Feat (quantization): torch_function based quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 9, 2025
1 parent 5546035 commit c9a42ea
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 130 deletions.
36 changes: 13 additions & 23 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from typing import Dict

from torch import nn

from brevitas import config
Expand Down Expand Up @@ -370,42 +371,31 @@ def layerwise_quantize(
config.IGNORE_MISSING_KEYS = ignore_missing_keys_state
return model

from torch.overrides import TorchFunctionMode

import torch
from torch.overrides import TorchFunctionMode


class functional_quantization_mode(TorchFunctionMode):
def __init__(self, model: torch.nn.Module, quant_map: Dict):

def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True):
super().__init__()
self.quant_map = quant_map
self.model = model
self.add_module = True
self.map_dict = dict()
self.count = 0
self.model.register_forward_pre_hook(self.pre_hook)
for f in quant_map:
if not hasattr(model, str(f)):
self.enabled = enabled
for stateless_function, stateless_module in quant_map.items():
if not hasattr(model, str(stateless_function)):
print("Setting")
setattr(model, str(f), torch.nn.ModuleList())


def pre_hook(self, *args, **kwargs):
self.count = 0
setattr(model, str(stateless_function), stateless_module())

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = dict()

if hasattr(self.model, str(func)):
func_dict = getattr(self.model, str(func))
if len(func_dict) <= self.count:
module = self.quant_map[func]() # initialize new module
func_dict.append(module)
else:
module = func_dict[self.count]
self.count += 1
if hasattr(self.model, str(func)) and self.enabled:
module = getattr(self.model, str(func))
out = module(*args, **kwargs)
else:
out = func(*args, **kwargs)
return out

return out
7 changes: 5 additions & 2 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
[--replace-mha] [--weight-equalization]
[--rotation {fx,layerwise,fused_no_fx}]
[--functional-sdpa-quant] [--replace-mha]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
Expand Down Expand Up @@ -170,6 +170,9 @@ options:
fp16 quantization.
--quant-sdpa Quantize `F.scaled_dot_product_attention` (default:
False)
--functional-sdpa-quant
Quantize `F.scaled_dot_product_attention` with
stateless module and torch_function (default: False)
--replace-mha Replace HuggingFace Attention with a quantizable
version
--weight-equalization
Expand Down
222 changes: 117 additions & 105 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.nn.quant_sdpa import ScaledDotProductAttention
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention

import argparse
from copy import deepcopy
import functools
import sys
from warnings import warn

from brevitas.graph.quantize import functional_quantization_mode
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import numpy as np
Expand All @@ -26,8 +22,10 @@
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import functional_quantization_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.utils import get_module
from brevitas.nn.quant_sdpa import ScaledDotProductAttention
from brevitas.utils.python_utils import hooked_on_a_function
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
Expand Down Expand Up @@ -56,6 +54,7 @@
from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32
from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter
from brevitas_examples.llm.llm_quant.run_utils import get_fx
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention


def filter_results(results, tasks):
Expand Down Expand Up @@ -319,10 +318,12 @@ def quantize_llm(args):
print("Replace `F.scaled_dot_product_attention` with QuantSDPA...")
model = replace_sdpa_with_quantizable_layers(model)
print("Replacing done.")
model = offload_model(model)
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)
elif args.functional_sdpa_quant:
print("Inserting SDPA quantizable module")
model = offload_model(model)
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)
if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
Expand Down Expand Up @@ -423,107 +424,112 @@ def quantize_llm(args):
new_funct = functools.partial(update_internal_dict, m)
m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)

with torch.no_grad():
model(**calibration_loader[0])

# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.act_calibration:
with functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: None}):
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")
with functional_quantization_mode(
model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention},
enabled=args.functional_sdpa_quant):
with torch.no_grad():
model(**calibration_loader[0])

if args.learned_round:
print("Applying learned round...")
# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.act_calibration:
with functional_quantization_mode(
model, {torch.nn.functional.scaled_dot_product_attention: None}):
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")

if args.learned_round:
print("Applying learned round...")
remove_hooks(model)
apply_learned_round(
model,
calibration_loader,
iters=args.learned_round_iters,
block_name_attribute=args.gpxq_block_name,
learn_scale=args.learned_round_scale,
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr,
'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")

model = offload_model(model)

if args.gptq:
print("Applying GPTQ...")
apply_gptq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPTQ applied.")

if args.gpfq:
print("Applying GPFQ...")
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPFQ applied.")

if args.bias_corr:
print("Applying bias correction...")
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)
apply_learned_round(
model,
calibration_loader,
iters=args.learned_round_iters,
block_name_attribute=args.gpxq_block_name,
learn_scale=args.learned_round_scale,
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")

model = offload_model(model)

if args.gptq:
print("Applying GPTQ...")
apply_gptq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPTQ applied.")

if args.gpfq:
print("Applying GPFQ...")
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPFQ applied.")

if args.bias_corr:
print("Applying bias correction...")
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: None}):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
if args.checkpoint_name is not None:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)

if args.checkpoint_name is not None:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.export_target:
print(f"Export to {args.export_target}")
# Currently we always export on CPU with a float32 container to avoid float16 CPU errors
model = model.to(dtype=torch.float32)
model_export(model, calibration_loader[0], args)
if args.export_target:
print(f"Export to {args.export_target}")
# Currently we always export on CPU with a float32 container to avoid float16 CPU errors
model = model.to(dtype=torch.float32)
model_export(model, calibration_loader[0], args)

return float_ppl, quant_ppl, model

Expand Down Expand Up @@ -753,6 +759,12 @@ def parse_args(args, override_defaults={}):
'--quant-sdpa',
action='store_true',
help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)')
parser.add_argument(
'--functional-sdpa-quant',
action='store_true',
help=
'Quantize `F.scaled_dot_product_attention` with stateless module and torch_function (default: %(default)s)'
)
parser.add_argument(
'--replace-mha',
action='store_true',
Expand Down

0 comments on commit c9a42ea

Please sign in to comment.