From c9a42eaabee0b5583fbfeeed60d699ec70f28dc3 Mon Sep 17 00:00:00 2001
From: Giuseppe Franco <giuseppefranco4@gmail.com>
Date: Thu, 9 Jan 2025 11:41:18 +0000
Subject: [PATCH] Feat (quantization): torch_function based quantization

---
 src/brevitas/graph/quantize.py      |  36 ++---
 src/brevitas_examples/llm/README.md |   7 +-
 src/brevitas_examples/llm/main.py   | 222 +++++++++++++++-------------
 3 files changed, 135 insertions(+), 130 deletions(-)

diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py
index c2f2227cc..da7eaa6ba 100644
--- a/src/brevitas/graph/quantize.py
+++ b/src/brevitas/graph/quantize.py
@@ -2,6 +2,7 @@
 # SPDX-License-Identifier: BSD-3-Clause
 
 from typing import Dict
+
 from torch import nn
 
 from brevitas import config
@@ -370,42 +371,31 @@ def layerwise_quantize(
     config.IGNORE_MISSING_KEYS = ignore_missing_keys_state
     return model
 
-from torch.overrides import TorchFunctionMode
+
 import torch
+from torch.overrides import TorchFunctionMode
 
 
 class functional_quantization_mode(TorchFunctionMode):
-    def __init__(self, model: torch.nn.Module, quant_map: Dict):
+
+    def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True):
         super().__init__()
         self.quant_map = quant_map
         self.model = model
-        self.add_module = True
-        self.map_dict = dict()
-        self.count = 0
-        self.model.register_forward_pre_hook(self.pre_hook)
-        for f in quant_map:
-            if not hasattr(model, str(f)):
+        self.enabled = enabled
+        for stateless_function, stateless_module in quant_map.items():
+            if not hasattr(model, str(stateless_function)):
                 print("Setting")
-                setattr(model, str(f), torch.nn.ModuleList())
-        
-
-    def pre_hook(self, *args, **kwargs):
-        self.count = 0
+                setattr(model, str(stateless_function), stateless_module())
 
     def __torch_function__(self, func, types, args=(), kwargs=None):
         if kwargs is None:
             kwargs = dict()
 
-        if hasattr(self.model, str(func)):
-            func_dict = getattr(self.model, str(func))
-            if len(func_dict) <= self.count:
-                module = self.quant_map[func]() # initialize new module
-                func_dict.append(module)
-            else:
-                module = func_dict[self.count]
-            self.count += 1
+        if hasattr(self.model, str(func)) and self.enabled:
+            module = getattr(self.model, str(func))
             out = module(*args, **kwargs)
         else:
             out = func(*args, **kwargs)
-                
-        return out
\ No newline at end of file
+
+        return out
diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md
index 0d6fb5f42..6d0ea5dd6 100644
--- a/src/brevitas_examples/llm/README.md
+++ b/src/brevitas_examples/llm/README.md
@@ -47,8 +47,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
                [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
                [--no-quantize] [--no-float16]
                [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
-               [--replace-mha] [--weight-equalization]
-               [--rotation {fx,layerwise,fused_no_fx}]
+               [--functional-sdpa-quant] [--replace-mha]
+               [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
                [--rotation-mode {had,ort}] [--rotation-orphan-sink]
                [--act-equalization {None,layerwise,fx}]
                [--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
@@ -170,6 +170,9 @@ options:
                         fp16 quantization.
   --quant-sdpa          Quantize `F.scaled_dot_product_attention` (default:
                         False)
+  --functional-sdpa-quant
+                        Quantize `F.scaled_dot_product_attention` with
+                        stateless module and torch_function (default: False)
   --replace-mha         Replace HuggingFace Attention with a quantizable
                         version
   --weight-equalization
diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py
index 5c80bc2f1..3926802fa 100644
--- a/src/brevitas_examples/llm/main.py
+++ b/src/brevitas_examples/llm/main.py
@@ -1,16 +1,12 @@
 # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
 # SPDX-License-Identifier: BSD-3-Clause
 
-from brevitas.nn.quant_sdpa import ScaledDotProductAttention
-from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention
-
 import argparse
 from copy import deepcopy
 import functools
 import sys
 from warnings import warn
 
-from brevitas.graph.quantize import functional_quantization_mode
 from lm_eval import evaluator
 from lm_eval.models.huggingface import HFLM
 import numpy as np
@@ -26,8 +22,10 @@
 from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
 from brevitas.graph.equalize import GraphRotationEqualization
 from brevitas.graph.equalize import LayerwiseActivationRotation
+from brevitas.graph.quantize import functional_quantization_mode
 from brevitas.graph.quantize import layerwise_quantize
 from brevitas.graph.utils import get_module
+from brevitas.nn.quant_sdpa import ScaledDotProductAttention
 from brevitas.utils.python_utils import hooked_on_a_function
 from brevitas_examples.common.accelerate_utils.accelerate import offload_model
 from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
@@ -56,6 +54,7 @@
 from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32
 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter
 from brevitas_examples.llm.llm_quant.run_utils import get_fx
+from brevitas_examples.stable_diffusion.sd_quant.nn import QuantizableAttention
 
 
 def filter_results(results, tasks):
@@ -319,10 +318,12 @@ def quantize_llm(args):
         print("Replace `F.scaled_dot_product_attention` with QuantSDPA...")
         model = replace_sdpa_with_quantizable_layers(model)
         print("Replacing done.")
-    model = offload_model(model)
-    with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
-        model(**calibration_loader[0])
-    remove_hooks(model)
+    elif args.functional_sdpa_quant:
+        print("Inserting SDPA quantizable module")
+        model = offload_model(model)
+        with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
+            model(**calibration_loader[0])
+        remove_hooks(model)
     if args.weight_equalization:
         print("Apply weight equalization...")
         # In case of float16 model, we need to offload to account for missing ops
@@ -423,107 +424,112 @@ def quantize_llm(args):
                 new_funct = functools.partial(update_internal_dict, m)
                 m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)
 
-    with torch.no_grad():
-        model(**calibration_loader[0])
-
-    # We restore the original behaviour of the post-forward.
-    for k, v in dict_hooks.items():
-        k._hf_hook.post_forward = v
-
-    if args.act_calibration:
-        with functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: None}):
-            print("Apply act calibration...")
-            apply_calibration(model, calibration_loader)
-            print("Act calibration applied.")
+    with functional_quantization_mode(
+            model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention},
+            enabled=args.functional_sdpa_quant):
+        with torch.no_grad():
+            model(**calibration_loader[0])
 
-    if args.learned_round:
-        print("Applying learned round...")
+        # We restore the original behaviour of the post-forward.
+        for k, v in dict_hooks.items():
+            k._hf_hook.post_forward = v
+
+        if args.act_calibration:
+            with functional_quantization_mode(
+                    model, {torch.nn.functional.scaled_dot_product_attention: None}):
+                print("Apply act calibration...")
+                apply_calibration(model, calibration_loader)
+                print("Act calibration applied.")
+
+        if args.learned_round:
+            print("Applying learned round...")
+            remove_hooks(model)
+            apply_learned_round(
+                model,
+                calibration_loader,
+                iters=args.learned_round_iters,
+                block_name_attribute=args.gpxq_block_name,
+                learn_scale=args.learned_round_scale,
+                scale_optimizer_class='sgd',
+                optimizer_kwargs={'lr': args.learned_round_lr},
+                scale_optimizer_kwargs={
+                    'lr': args.learned_round_scale_lr,
+                    'momentum': args.learned_round_scale_momentum},
+                fast_update=args.learned_round_fast_update)
+            print("Learned round applied.")
+
+            model = offload_model(model)
+
+        if args.gptq:
+            print("Applying GPTQ...")
+            apply_gptq(
+                model,
+                calibration_loader,
+                act_order=args.gpxq_act_order,
+                use_quant_activations=args.gpxq_use_quant_activations,
+                create_weight_orig=args.gpxq_create_weight_orig,
+                block_name=args.gpxq_block_name,
+                max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
+                max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
+            print("GPTQ applied.")
+
+        if args.gpfq:
+            print("Applying GPFQ...")
+            apply_gpfq(
+                model,
+                calibration_loader,
+                act_order=args.gpxq_act_order,
+                block_name=args.gpxq_block_name,
+                max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
+                max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
+            print("GPFQ applied.")
+
+        if args.bias_corr:
+            print("Applying bias correction...")
+            apply_bias_correction(model, calibration_loader)
+            print("Bias correction applied.")
+
+        if args.eval and not args.no_quantize:
+            print("Model eval...")
+            with torch.no_grad(), quant_inference_mode(model):
+                model(**calibration_loader[0])
+                quant_ppl = compute_perplexity(
+                    model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
+            print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
+
+        if args.few_shot_eval:
+            with torch.no_grad(), quant_inference_mode(model):
+                model(**calibration_loader[0])
+                if args.few_shot_compile:
+                    remove_hooks(model)
+                    model.cuda()
+                    model = torch.compile(model)
+
+                wrapped_model = HFLM(pretrained=model)  # need to wrap for LLM eval
+                results = evaluator.simple_evaluate(
+                    model=wrapped_model,
+                    model_args=None,
+                    tasks=list(args.few_shot_tasks),
+                    device='cuda:0',
+                    limit=args.few_shot_limit,
+                    num_fewshot=0 if args.few_shot_zeroshot else None,
+                    log_samples=False,
+                    batch_size=None,
+                    verbosity="ERROR")
+            results = filter_results(results, args.few_shot_tasks)
+            print("Few shot eval results")
+            print(results)
         remove_hooks(model)
-        apply_learned_round(
-            model,
-            calibration_loader,
-            iters=args.learned_round_iters,
-            block_name_attribute=args.gpxq_block_name,
-            learn_scale=args.learned_round_scale,
-            scale_optimizer_class='sgd',
-            optimizer_kwargs={'lr': args.learned_round_lr},
-            scale_optimizer_kwargs={
-                'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum},
-            fast_update=args.learned_round_fast_update)
-        print("Learned round applied.")
-
-        model = offload_model(model)
-
-    if args.gptq:
-        print("Applying GPTQ...")
-        apply_gptq(
-            model,
-            calibration_loader,
-            act_order=args.gpxq_act_order,
-            use_quant_activations=args.gpxq_use_quant_activations,
-            create_weight_orig=args.gpxq_create_weight_orig,
-            block_name=args.gpxq_block_name,
-            max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
-            max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
-        print("GPTQ applied.")
-
-    if args.gpfq:
-        print("Applying GPFQ...")
-        apply_gpfq(
-            model,
-            calibration_loader,
-            act_order=args.gpxq_act_order,
-            block_name=args.gpxq_block_name,
-            max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
-            max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
-        print("GPFQ applied.")
-
-    if args.bias_corr:
-        print("Applying bias correction...")
-        apply_bias_correction(model, calibration_loader)
-        print("Bias correction applied.")
 
-    if args.eval and not args.no_quantize:
-        print("Model eval...")
-        with torch.no_grad(), quant_inference_mode(model), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: None}):
-            model(**calibration_loader[0])
-            quant_ppl = compute_perplexity(
-                model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
-        print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
+        if args.checkpoint_name is not None:
+            print(f"Saving checkpoint to {args.checkpoint_name}")
+            torch.save(model.state_dict(), args.checkpoint_name)
 
-    if args.few_shot_eval:
-        with torch.no_grad(), quant_inference_mode(model):
-            model(**calibration_loader[0])
-            if args.few_shot_compile:
-                remove_hooks(model)
-                model.cuda()
-                model = torch.compile(model)
-
-            wrapped_model = HFLM(pretrained=model)  # need to wrap for LLM eval
-            results = evaluator.simple_evaluate(
-                model=wrapped_model,
-                model_args=None,
-                tasks=list(args.few_shot_tasks),
-                device='cuda:0',
-                limit=args.few_shot_limit,
-                num_fewshot=0 if args.few_shot_zeroshot else None,
-                log_samples=False,
-                batch_size=None,
-                verbosity="ERROR")
-        results = filter_results(results, args.few_shot_tasks)
-        print("Few shot eval results")
-        print(results)
-    remove_hooks(model)
-
-    if args.checkpoint_name is not None:
-        print(f"Saving checkpoint to {args.checkpoint_name}")
-        torch.save(model.state_dict(), args.checkpoint_name)
-
-    if args.export_target:
-        print(f"Export to {args.export_target}")
-        # Currently we always export on CPU with a float32 container to avoid float16 CPU errors
-        model = model.to(dtype=torch.float32)
-        model_export(model, calibration_loader[0], args)
+        if args.export_target:
+            print(f"Export to {args.export_target}")
+            # Currently we always export on CPU with a float32 container to avoid float16 CPU errors
+            model = model.to(dtype=torch.float32)
+            model_export(model, calibration_loader[0], args)
 
     return float_ppl, quant_ppl, model
 
@@ -753,6 +759,12 @@ def parse_args(args, override_defaults={}):
         '--quant-sdpa',
         action='store_true',
         help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)')
+    parser.add_argument(
+        '--functional-sdpa-quant',
+        action='store_true',
+        help=
+        'Quantize `F.scaled_dot_product_attention` with stateless module and torch_function (default: %(default)s)'
+    )
     parser.add_argument(
         '--replace-mha',
         action='store_true',