diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 8c4fe1968..170d2d735 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -4,12 +4,11 @@ import argparse from contextlib import nullcontext from copy import deepcopy +from datetime import timedelta import functools import sys from warnings import warn -from lm_eval import evaluator -from lm_eval.models.huggingface import HFLM import numpy as np from optimum.exporters.onnx import onnx_export_from_model import torch @@ -530,7 +529,9 @@ def quantize_llm(args): model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - if args.few_shot_eval: + if args.few_shot_eval == 'lm_eval': + from lm_eval import evaluator + from lm_eval.models.huggingface import HFLM with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0]) if args.few_shot_compile: @@ -552,6 +553,51 @@ def quantize_llm(args): results = filter_results(results, args.few_shot_tasks) print("Few shot eval results") print(results) + elif args.few_shot_eval == 'lighteval': + from accelerate import Accelerator + from accelerate import InitProcessGroupKwargs + from lighteval.logging.evaluation_tracker import EvaluationTracker + from lighteval.models.transformers.transformers_model import TransformersModelConfig + from lighteval.pipeline import ParallelismManager + from lighteval.pipeline import Pipeline + from lighteval.pipeline import PipelineParameters + from lighteval.utils.utils import EnvConfig + + accelerator = Accelerator( + kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) + evaluation_tracker = EvaluationTracker( + output_dir="./results", + save_details=True, + ) + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.ACCELERATE, + env_config=EnvConfig(cache_dir="/scratch/hf_models/"), + # Remove the 2 parameters below once your configuration is tested + override_batch_size=0, # max_samples=10 + ) + model_config = TransformersModelConfig( + pretrained=args.model, + dtype="float16", + use_chat_template=True, + model_parallel=True, + accelerator=accelerator, + compile=False) + + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + if args.few_shot_compile: + remove_hooks(model) + model.cuda() + model.forward = torch.compile(model.forward, fullgraph=True) + pipeline = Pipeline( + tasks=args.few_shot_tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + config=model_config) + + pipeline.evaluate() + pipeline.show_results() remove_hooks(model) if args.checkpoint_name is not None and not args.load_checkpoint: @@ -888,12 +934,14 @@ def parse_args(args, override_defaults={}): help='Whether to use fast update with learned round. Prototype (default: %(default)s)') parser.add_argument( '--few-shot-eval', - action="store_true", - help='Perform zero_shot evaluation with lm_eval. Default %(default)s)') + type=str, + default=None, + choices=['lm_eval', 'lighteval'], + help='Perform zero_shot evaluation with lm_eval or lighteval. Default %(default)s)') parser.add_argument( '--few-shot-compile', action="store_true", - help='Compile during zero_shot evaluation with lm_eval. Default %(default)s)') + help='Compile during zero_shot evaluation. Default %(default)s)') parser.add_argument( '--few-shot-zeroshot', action="store_true",