Skip to content

Commit

Permalink
Feat (brevitas_examples/llm): support for lighteval
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 16, 2025
1 parent 41ace8a commit 62b4284
Showing 1 changed file with 54 additions and 6 deletions.
60 changes: 54 additions & 6 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import argparse
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
import functools
import sys
from warnings import warn

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import numpy as np
from optimum.exporters.onnx import onnx_export_from_model
import torch
Expand Down Expand Up @@ -530,7 +529,9 @@ def quantize_llm(args):
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
if args.few_shot_eval == 'lm_eval':
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
Expand All @@ -552,6 +553,51 @@ def quantize_llm(args):
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
elif args.few_shot_eval == 'lighteval':
from accelerate import Accelerator
from accelerate import InitProcessGroupKwargs
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.transformers.transformers_model import TransformersModelConfig
from lighteval.pipeline import ParallelismManager
from lighteval.pipeline import Pipeline
from lighteval.pipeline import PipelineParameters
from lighteval.utils.utils import EnvConfig

accelerator = Accelerator(
kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
evaluation_tracker = EvaluationTracker(
output_dir="./results",
save_details=True,
)
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.ACCELERATE,
env_config=EnvConfig(cache_dir="/scratch/hf_models/"),
# Remove the 2 parameters below once your configuration is tested
override_batch_size=0, # max_samples=10
)
model_config = TransformersModelConfig(
pretrained=args.model,
dtype="float16",
use_chat_template=True,
model_parallel=True,
accelerator=accelerator,
compile=False)

with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model.forward = torch.compile(model.forward, fullgraph=True)
pipeline = Pipeline(
tasks=args.few_shot_tasks,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model=model,
config=model_config)

pipeline.evaluate()
pipeline.show_results()
remove_hooks(model)

if args.checkpoint_name is not None and not args.load_checkpoint:
Expand Down Expand Up @@ -888,12 +934,14 @@ def parse_args(args, override_defaults={}):
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
parser.add_argument(
'--few-shot-eval',
action="store_true",
help='Perform zero_shot evaluation with lm_eval. Default %(default)s)')
type=str,
default=None,
choices=['lm_eval', 'lighteval'],
help='Perform zero_shot evaluation with lm_eval or lighteval. Default %(default)s)')
parser.add_argument(
'--few-shot-compile',
action="store_true",
help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)')
help='Compile during zero_shot evaluation. Default %(default)s)')
parser.add_argument(
'--few-shot-zeroshot',
action="store_true",
Expand Down

0 comments on commit 62b4284

Please sign in to comment.