Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 12, 2025
1 parent 18137b3 commit 551a4bc
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from contextlib import nullcontext
from copy import deepcopy
import functools
import sys
Expand Down Expand Up @@ -425,9 +426,15 @@ def quantize_llm(args):
new_funct = functools.partial(update_internal_dict, m)
m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)

with functional_quantization_mode(
model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention},
enabled=args.functional_sdpa_quant):
# If we are doing functional SDPA quantization, we create the correct context manager,
# otherwise nullcontext. We would love to avoid the extra indentation level but it doesn't seem easy.
if args.functional_sdpa_quant:
quantization_cm = functional_quantization_mode(
model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention})
else:
quantization_cm = nullcontext()

with quantization_cm:
with torch.no_grad():
model(**calibration_loader[0])

Expand Down

0 comments on commit 551a4bc

Please sign in to comment.