From c9a42eaabee0b5583fbfeeed60d699ec70f28dc3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco <giuseppefranco4@gmail.com> Date: Thu, 9 Jan 2025 11:41:18 +0000 Subject: [PATCH] Feat (quantization): torch_function based quantization --- src/brevitas/graph/quantize.py | 36 ++--- src/brevitas_examples/llm/README.md | 7 +- src/brevitas_examples/llm/main.py | 222 +++++++++++++++------------- 3 files changed, 135 insertions(+), 130 deletions(-) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index c2f2227cc..da7eaa6ba 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from typing import Dict + from torch import nn from brevitas import config @@ -370,42 +371,31 @@ def layerwise_quantize( config.IGNORE_MISSING_KEYS = ignore_missing_keys_state return model -from torch.overrides import TorchFunctionMode + import torch +from torch.overrides import TorchFunctionMode class functional_quantization_mode(TorchFunctionMode): - def __init__(self, model: torch.nn.Module, quant_map: Dict): + + def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True): super().__init__() self.quant_map = quant_map self.model = model - self.add_module = True - self.map_dict = dict() - self.count = 0 - self.model.register_forward_pre_hook(self.pre_hook) - for f in quant_map: - if not hasattr(model, str(f)): + self.enabled = enabled + for stateless_function, stateless_module in quant_map.items(): + if not hasattr(model, str(stateless_function)): print("Setting") - setattr(model, str(f), torch.nn.ModuleList()) - - - def pre_hook(self, *args, **kwargs): - self.count = 0 + setattr(model, str(stateless_function), stateless_module()) def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = dict() - if hasattr(self.model, str(func)): - func_dict = getattr(self.model, str(func)) - if len(func_dict) <= self.count: - module = self.quant_map[func]() # initialize new module - func_dict.append(module) - else: - module = func_dict[self.count] - self.count += 1 + if hasattr(self.model, str(func)) and self.enabled: + module = getattr(self.model, str(func)) out = module(*args, **kwargs) else: out = func(*args, **kwargs) - - return out \ No newline at end of file + + return out diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 0d6fb5f42..6d0ea5dd6 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -47,8 +47,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] [--no-quantize] [--no-float16] [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa] - [--replace-mha] [--weight-equalization] - [--rotation {fx,layerwise,fused_no_fx}] + [--functional-sdpa-quant] [--replace-mha] + [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] [--act-equalization {None,layerwise,fx}] [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] @@ -170,6 +170,9 @@ options: fp16 quantization. --quant-sdpa Quantize `F.scaled_dot_product_attention` (default: False) + --functional-sdpa-quant + Quantize `F.scaled_dot_product_attention` with + stateless module and torch_function (default: False) --replace-mha Replace HuggingFace Attention with a quantizable version --weight-equalization diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 5c80bc2f1..3926802fa 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -1,16 +1,12 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from brevitas.nn.quant_sdpa import ScaledDotProductAttention -from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention - import argparse from copy import deepcopy import functools import sys from warnings import warn -from brevitas.graph.quantize import functional_quantization_mode from lm_eval import evaluator from lm_eval.models.huggingface import HFLM import numpy as np @@ -26,8 +22,10 @@ from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation +from brevitas.graph.quantize import functional_quantization_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module +from brevitas.nn.quant_sdpa import ScaledDotProductAttention from brevitas.utils.python_utils import hooked_on_a_function from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks @@ -56,6 +54,7 @@ from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx +from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention def filter_results(results, tasks): @@ -319,10 +318,12 @@ def quantize_llm(args): print("Replace `F.scaled_dot_product_attention` with QuantSDPA...") model = replace_sdpa_with_quantizable_layers(model) print("Replacing done.") - model = offload_model(model) - with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}): - model(**calibration_loader[0]) - remove_hooks(model) + elif args.functional_sdpa_quant: + print("Inserting SDPA quantizable module") + model = offload_model(model) + with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}): + model(**calibration_loader[0]) + remove_hooks(model) if args.weight_equalization: print("Apply weight equalization...") # In case of float16 model, we need to offload to account for missing ops @@ -423,107 +424,112 @@ def quantize_llm(args): new_funct = functools.partial(update_internal_dict, m) m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) - with torch.no_grad(): - model(**calibration_loader[0]) - - # We restore the original behaviour of the post-forward. - for k, v in dict_hooks.items(): - k._hf_hook.post_forward = v - - if args.act_calibration: - with functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: None}): - print("Apply act calibration...") - apply_calibration(model, calibration_loader) - print("Act calibration applied.") + with functional_quantization_mode( + model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}, + enabled=args.functional_sdpa_quant): + with torch.no_grad(): + model(**calibration_loader[0]) - if args.learned_round: - print("Applying learned round...") + # We restore the original behaviour of the post-forward. + for k, v in dict_hooks.items(): + k._hf_hook.post_forward = v + + if args.act_calibration: + with functional_quantization_mode( + model, {torch.nn.functional.scaled_dot_product_attention: None}): + print("Apply act calibration...") + apply_calibration(model, calibration_loader) + print("Act calibration applied.") + + if args.learned_round: + print("Applying learned round...") + remove_hooks(model) + apply_learned_round( + model, + calibration_loader, + iters=args.learned_round_iters, + block_name_attribute=args.gpxq_block_name, + learn_scale=args.learned_round_scale, + scale_optimizer_class='sgd', + optimizer_kwargs={'lr': args.learned_round_lr}, + scale_optimizer_kwargs={ + 'lr': args.learned_round_scale_lr, + 'momentum': args.learned_round_scale_momentum}, + fast_update=args.learned_round_fast_update) + print("Learned round applied.") + + model = offload_model(model) + + if args.gptq: + print("Applying GPTQ...") + apply_gptq( + model, + calibration_loader, + act_order=args.gpxq_act_order, + use_quant_activations=args.gpxq_use_quant_activations, + create_weight_orig=args.gpxq_create_weight_orig, + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) + print("GPTQ applied.") + + if args.gpfq: + print("Applying GPFQ...") + apply_gpfq( + model, + calibration_loader, + act_order=args.gpxq_act_order, + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) + print("GPFQ applied.") + + if args.bias_corr: + print("Applying bias correction...") + apply_bias_correction(model, calibration_loader) + print("Bias correction applied.") + + if args.eval and not args.no_quantize: + print("Model eval...") + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") + + if args.few_shot_eval: + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + if args.few_shot_compile: + remove_hooks(model) + model.cuda() + model = torch.compile(model) + + wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval + results = evaluator.simple_evaluate( + model=wrapped_model, + model_args=None, + tasks=list(args.few_shot_tasks), + device='cuda:0', + limit=args.few_shot_limit, + num_fewshot=0 if args.few_shot_zeroshot else None, + log_samples=False, + batch_size=None, + verbosity="ERROR") + results = filter_results(results, args.few_shot_tasks) + print("Few shot eval results") + print(results) remove_hooks(model) - apply_learned_round( - model, - calibration_loader, - iters=args.learned_round_iters, - block_name_attribute=args.gpxq_block_name, - learn_scale=args.learned_round_scale, - scale_optimizer_class='sgd', - optimizer_kwargs={'lr': args.learned_round_lr}, - scale_optimizer_kwargs={ - 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}, - fast_update=args.learned_round_fast_update) - print("Learned round applied.") - - model = offload_model(model) - - if args.gptq: - print("Applying GPTQ...") - apply_gptq( - model, - calibration_loader, - act_order=args.gpxq_act_order, - use_quant_activations=args.gpxq_use_quant_activations, - create_weight_orig=args.gpxq_create_weight_orig, - block_name=args.gpxq_block_name, - max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, - max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) - print("GPTQ applied.") - - if args.gpfq: - print("Applying GPFQ...") - apply_gpfq( - model, - calibration_loader, - act_order=args.gpxq_act_order, - block_name=args.gpxq_block_name, - max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, - max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) - print("GPFQ applied.") - - if args.bias_corr: - print("Applying bias correction...") - apply_bias_correction(model, calibration_loader) - print("Bias correction applied.") - if args.eval and not args.no_quantize: - print("Model eval...") - with torch.no_grad(), quant_inference_mode(model), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: None}): - model(**calibration_loader[0]) - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") + if args.checkpoint_name is not None: + print(f"Saving checkpoint to {args.checkpoint_name}") + torch.save(model.state_dict(), args.checkpoint_name) - if args.few_shot_eval: - with torch.no_grad(), quant_inference_mode(model): - model(**calibration_loader[0]) - if args.few_shot_compile: - remove_hooks(model) - model.cuda() - model = torch.compile(model) - - wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval - results = evaluator.simple_evaluate( - model=wrapped_model, - model_args=None, - tasks=list(args.few_shot_tasks), - device='cuda:0', - limit=args.few_shot_limit, - num_fewshot=0 if args.few_shot_zeroshot else None, - log_samples=False, - batch_size=None, - verbosity="ERROR") - results = filter_results(results, args.few_shot_tasks) - print("Few shot eval results") - print(results) - remove_hooks(model) - - if args.checkpoint_name is not None: - print(f"Saving checkpoint to {args.checkpoint_name}") - torch.save(model.state_dict(), args.checkpoint_name) - - if args.export_target: - print(f"Export to {args.export_target}") - # Currently we always export on CPU with a float32 container to avoid float16 CPU errors - model = model.to(dtype=torch.float32) - model_export(model, calibration_loader[0], args) + if args.export_target: + print(f"Export to {args.export_target}") + # Currently we always export on CPU with a float32 container to avoid float16 CPU errors + model = model.to(dtype=torch.float32) + model_export(model, calibration_loader[0], args) return float_ppl, quant_ppl, model @@ -753,6 +759,12 @@ def parse_args(args, override_defaults={}): '--quant-sdpa', action='store_true', help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)') + parser.add_argument( + '--functional-sdpa-quant', + action='store_true', + help= + 'Quantize `F.scaled_dot_product_attention` with stateless module and torch_function (default: %(default)s)' + ) parser.add_argument( '--replace-mha', action='store_true',