Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (quantization): torch_function based quantization #1147

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Dict

from packaging import version
import torch
from torch import nn

from brevitas import config
from brevitas import torch_version
from brevitas.core.scaling.standalone import ConstScaling
from brevitas.core.scaling.standalone import ParameterScaling
from brevitas.fx.brevitas_tracer import symbolic_trace
Expand Down Expand Up @@ -33,6 +38,34 @@
from brevitas.quant import Uint8ActPerTensorFloatMaxInit
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat

if torch_version >= version.parse('1.12'):
from torch.overrides import TorchFunctionMode

class functional_quantization_mode(TorchFunctionMode):

def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True):
super().__init__()
self.quant_map = quant_map
self.model = model
self.enabled = enabled
for stateless_function, stateless_module in quant_map.items():
if not hasattr(model, str(stateless_function)):
setattr(model, str(stateless_function), stateless_module())

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = dict()

if hasattr(self.model, str(func)) and self.enabled:
module = getattr(self.model, str(func))
out = module(*args, **kwargs)
else:
out = func(*args, **kwargs)

return out
else:
functional_quantization_mode = object()

COMPUTE_LAYER_MAP = {
nn.AvgPool2d:
None,
Expand Down
7 changes: 5 additions & 2 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
[--replace-mha] [--weight-equalization]
[--rotation {fx,layerwise,fused_no_fx}]
[--functional-sdpa-quant] [--replace-mha]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
Expand Down Expand Up @@ -170,6 +170,9 @@ options:
fp16 quantization.
--quant-sdpa Quantize `F.scaled_dot_product_attention` (default:
False)
--functional-sdpa-quant
Quantize `F.scaled_dot_product_attention` with
stateless module and torch_function (default: False)
--replace-mha Replace HuggingFace Attention with a quantizable
version
--weight-equalization
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ few_shot_tasks:
- winogrande
- piqa
few_shot_zeroshot: false
functional_sdpa_quant: false
fuse_sequences: false
gpfq: false
gptq: false
Expand Down
239 changes: 134 additions & 105 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from contextlib import nullcontext
from copy import deepcopy
import functools
import sys
Expand All @@ -23,8 +24,10 @@
from brevitas.graph import load_quant_model_mode
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import functional_quantization_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.utils import get_module
from brevitas.nn.quant_sdpa import ScaledDotProductAttention
from brevitas.utils.python_utils import hooked_on_a_function
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
Expand Down Expand Up @@ -132,6 +135,8 @@ def model_export(model, ref_input, args):


def validate(args):
if args.functional_sdpa_quant:
assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization"
if args.rotation == 'fx':
assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters'
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
Expand Down Expand Up @@ -316,7 +321,12 @@ def quantize_llm(args):
print("Replace `F.scaled_dot_product_attention` with QuantSDPA...")
model = replace_sdpa_with_quantizable_layers(model)
print("Replacing done.")

elif args.functional_sdpa_quant:
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
print("Inserting SDPA quantizable module")
model = offload_model(model)
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)
if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
Expand Down Expand Up @@ -421,118 +431,131 @@ def quantize_llm(args):
new_funct = functools.partial(update_internal_dict, m)
m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)

with torch.no_grad():
model(**calibration_loader[0])
# If we are doing functional SDPA quantization, we create the correct context manager,
# otherwise nullcontext. We would love to avoid the extra indentation level but it doesn't seem easy.
if args.functional_sdpa_quant:
quantization_cm = functional_quantization_mode(
model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention})
else:
quantization_cm = nullcontext()

# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v
with quantization_cm:
with torch.no_grad():
model(**calibration_loader[0])

if args.act_calibration and not args.load_checkpoint:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")
# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.learned_round:
print("Applying learned round...")
if args.load_checkpoint:
iters = 1
loader = [calibration_loader[0]]
else:
iters = args.learned_round_iters
loader = calibration_loader
remove_hooks(model)
apply_learned_round(
model,
loader,
iters=iters,
block_name_attribute=args.gpxq_block_name,
learn_scale=args.learned_round_scale,
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")
if args.act_calibration and not args.load_checkpoint:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")

model = offload_model(model)
if args.learned_round:
print("Applying learned round...")
if args.load_checkpoint:
iters = 1
loader = [calibration_loader[0]]
else:
iters = args.learned_round_iters
loader = calibration_loader
remove_hooks(model)
apply_learned_round(
model,
loader,
iters=iters,
block_name_attribute=args.gpxq_block_name,
learn_scale=args.learned_round_scale,
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr,
'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")
model = offload_model(model)

if args.load_checkpoint:
if args.load_checkpoint:
remove_hooks(model)
with load_quant_model_mode(model):
model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu'))
model = offload_model(model)

if args.gptq and not args.load_checkpoint:
print("Applying GPTQ...")
apply_gptq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPTQ applied.")

if args.gpfq and not args.load_checkpoint:
print("Applying GPFQ...")
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPFQ applied.")

if args.bias_corr and not args.load_checkpoint:
print("Applying bias correction...")
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

if args.checkpoint_name is not None:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)
with load_quant_model_mode(model):
model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu'))
model = offload_model(model)

if args.gptq and not args.load_checkpoint:
print("Applying GPTQ...")
apply_gptq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPTQ applied.")

if args.gpfq and not args.load_checkpoint:
print("Applying GPFQ...")
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPFQ applied.")

if args.bias_corr and not args.load_checkpoint:
print("Applying bias correction...")
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
if args.checkpoint_name is not None and not args.load_checkpoint:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)

if args.checkpoint_name is not None and not args.load_checkpoint:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.export_target:
print(f"Export to {args.export_target}")
# Currently we always export on CPU with a float32 container to avoid float16 CPU errors
model = model.to(dtype=torch.float32)
model_export(model, calibration_loader[0], args)
if args.export_target:
print(f"Export to {args.export_target}")
# Currently we always export on CPU with a float32 container to avoid float16 CPU errors
model = model.to(dtype=torch.float32)
model_export(model, calibration_loader[0], args)

return float_ppl, quant_ppl, model

Expand Down Expand Up @@ -762,6 +785,12 @@ def parse_args(args, override_defaults={}):
'--quant-sdpa',
action='store_true',
help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)')
parser.add_argument(
'--functional-sdpa-quant',
action='store_true',
help=
'Quantize `F.scaled_dot_product_attention` with stateless module and torch_function (default: %(default)s)'
)
parser.add_argument(
'--replace-mha',
action='store_true',
Expand Down
Loading