Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LongBench validation #1220

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions tests/python_tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ sacremoses
librosa
soundfile
datasets
rouge
104 changes: 104 additions & 0 deletions tests/python_tests/test_kv_cache_eviction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import sys
from typing import Dict, List, Optional

import datasets
import pytest
from tqdm import tqdm

from optimum.intel.openvino import OVModelForCausalLM

Expand All @@ -16,6 +18,7 @@
from transformers import AutoTokenizer

from common import TESTS_ROOT, run_cb_pipeline_with_ref, get_default_properties
from utils_longbench import dataset2maxlen, evaluate, preprocess_prompt, post_process_pred


def load_prompts_dataset(file_name : str) -> Dict[str, List[str]]:
Expand Down Expand Up @@ -69,6 +72,7 @@ class CacheOptTestStruct:


SHORT_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=32, max_cache_size=96, aggregation_mode=AggregationMode.NORM_SUM)
LONGBENCH_CACHE_EVICTION_CONFIG = CacheEvictionConfig(start_size=32, recent_size=128, max_cache_size=672, aggregation_mode=AggregationMode.NORM_SUM)


@pytest.mark.precommit
Expand Down Expand Up @@ -190,3 +194,103 @@ def get_beam_search_seq_len_300() -> GenerationConfig:
def test_dynamic_memory_allocation(tmp_path, params):
run_cb_pipeline_with_ref(tmp_path, "facebook/opt-125m", scheduler_params=params[0], generation_config=params[1])


@pytest.fixture(scope='module')
def qwen2_converted_model(tmp_path_factory):
model_id = "Qwen/Qwen2-0.5B-Instruct"
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
models_path = tmp_path_factory.mktemp("cacheopt_test_models") / model_id
model.save_pretrained(models_path)
ov_tokenizer, ov_detokenizer = convert_tokenizer(tokenizer, with_detokenizer=True, skip_special_tokens=True)
serialize(ov_tokenizer, models_path / "openvino_tokenizer.xml")
serialize(ov_detokenizer, models_path / "openvino_detokenizer.xml")
qwen2_converted_model = ConvertedModel(model, tokenizer, models_path)
yield qwen2_converted_model
del qwen2_converted_model
del model


@dataclass
class LongBenchTestData:
subset: str
threshold: float
max_cache_usage_optimization_ratio: float
avg_cache_usage_optimization_ratio: float


@pytest.mark.nightly
@pytest.mark.parametrize("device", ["CPU", "GPU"])
@pytest.mark.parametrize("test_struct", [
LongBenchTestData("samsum", 4, 1.6, 3.3),
LongBenchTestData("trec", 3.2, 2.0, 3.3),
LongBenchTestData("qasper", 5.8, 1.7, 3.6),
])
def test_optimized_generation_longbench(qwen2_converted_model, device, test_struct):
seqs_per_request = 32
num_kv_blocks = 1000 if device == "CPU" else 500
models_path = qwen2_converted_model.models_path
scheduler_config = get_scheduler_config(num_kv_blocks)

scheduler_config_opt = get_scheduler_config(num_kv_blocks)
scheduler_config_opt.use_cache_eviction = True
if scheduler_config_opt.use_cache_eviction:
scheduler_config_opt.cache_eviction_config = LONGBENCH_CACHE_EVICTION_CONFIG

model_cb_noopt = ContinuousBatchingPipeline(models_path, scheduler_config, device, {}, get_default_properties())
model_cb_opt = ContinuousBatchingPipeline(models_path, scheduler_config_opt, device, {}, get_default_properties())

model_name = "/".join(models_path.parts[-2:])
subset = test_struct.subset
max_new_tokens = dataset2maxlen[subset]

generation_config = GenerationConfig() # expecting default greedy sampling
generation_config.num_return_sequences = 1
generation_config.max_new_tokens = max_new_tokens

data = datasets.load_dataset('THUDM/LongBench', subset, split='test[:32]')
with tqdm(total=len(data)) as progress_bar:
batch = []
answers = []
ref_answers = []
for p_idx, data_sample in enumerate(data):
prompt = preprocess_prompt(data_sample, subset, model_name)
progress_bar.update(1)
batch.append(prompt)
answers.append({"answers": data_sample["answers"], "all_classes": data_sample["all_classes"]})
ref_answers.append({"answers": data_sample["answers"], "all_classes": data_sample["all_classes"]})

if len(batch) == seqs_per_request or p_idx == len(data) - 1:
ans_batch = model_cb_opt.generate(
batch, [generation_config] * len(batch)
)
ref_ans_batch = model_cb_noopt.generate(
batch, [generation_config] * len(batch)
)
for i, (opt_output, ref_output) in enumerate(zip(ans_batch, ref_ans_batch), start=p_idx-len(batch)+1):
answers[i]["pred"] = post_process_pred(opt_output.m_generation_ids[0], subset, model_name)
ref_answers[i]["pred"] = post_process_pred(ref_output.m_generation_ids[0], subset, model_name)
batch.clear()

score = evaluate(answers, subset)
print(f"Score: {score}")

ref_score = evaluate(ref_answers, subset)
print(f"Reference score: {ref_score}")
pipeline_opt_metrics = model_cb_opt.get_metrics()
pipeline_noopt_metrics = model_cb_noopt.get_metrics()

print(f"No-opt cache usage: max {pipeline_noopt_metrics.max_cache_usage:.3f}, avg {pipeline_noopt_metrics.avg_cache_usage:.3f}")
print(f"Opt cache usage: max {pipeline_opt_metrics.max_cache_usage:.3f}, avg {pipeline_opt_metrics.avg_cache_usage:.3f}")
max_optimization_ratio = (pipeline_noopt_metrics.max_cache_usage / pipeline_opt_metrics.max_cache_usage)
avg_optimization_ratio = (pipeline_noopt_metrics.avg_cache_usage / pipeline_opt_metrics.avg_cache_usage)
print(f"Optimization ratios: max {max_optimization_ratio:.3f}x, avg {avg_optimization_ratio:.3f}x")

del model_cb_opt
del model_cb_noopt
import gc
gc.collect()

assert ref_score - score <= test_struct.threshold
assert max_optimization_ratio >= test_struct.max_cache_usage_optimization_ratio
assert avg_optimization_ratio >= test_struct.avg_cache_usage_optimization_ratio
254 changes: 254 additions & 0 deletions tests/python_tests/utils_longbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright (C) 2023-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# This file includes utility functions copied from the LongBench repository:
# https://github.com/THUDM/LongBench
#
# Copyright (c) 2023 THU-KEG & Zhipu AI
# Licensed under the MIT License
Comment on lines +1 to +8
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@moslex please take a look at the license. Could you confirm if it is sufficient for code adapted from a 3rd party repo?

import re
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is probably taken from 3rd party repo. We need to keep the license in this case.

import string

from collections import Counter
from rouge import Rouge


def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))

def normalize_zh_answer(s):
"""Lower text and remove punctuation, extra whitespace."""

def white_space_fix(text):
return "".join(text.split())

def remove_punc(text):
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
all_punctuation = set(string.punctuation + cn_punctuation)
return "".join(ch for ch in text if ch not in all_punctuation)

def lower(text):
return text.lower()

return white_space_fix(remove_punc(lower(s)))

def count_score(prediction, ground_truth, **kwargs):
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)

def retrieval_score(prediction, ground_truth, **kwargs):
pattern = r'Paragraph (\d+)'
matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0]
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth_id):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)

def retrieval_zh_score(prediction, ground_truth, **kwargs):
pattern = r'段落(\d+)'
matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0]
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth_id):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)

def code_sim_score(prediction, ground_truth, **kwargs):
from fuzzywuzzy import fuzz
all_lines = prediction.lstrip('\n').split('\n')
prediction = ""
for line in all_lines:
if ('`' not in line) and ('#' not in line) and ('//' not in line):
prediction = line
break
return (fuzz.ratio(prediction, ground_truth) / 100)

def classification_score(prediction, ground_truth, **kwargs):
em_match_list = []
all_classes = kwargs["all_classes"]
for class_name in all_classes:
if class_name in prediction:
em_match_list.append(class_name)
for match_term in em_match_list:
if match_term in ground_truth and match_term != ground_truth:
em_match_list.remove(match_term)
if ground_truth in em_match_list:
score = (1.0 / len(em_match_list))
else:
score = 0.0
return score

def rouge_score(prediction, ground_truth, **kwargs):
rouge = Rouge()
try:
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
except:
return 0.0
return scores["rouge-l"]["f"]

def f1_score(prediction, ground_truth, **kwargs):
common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction)
recall = 1.0 * num_same / len(ground_truth)
f1 = (2 * precision * recall) / (precision + recall)
return f1

def qa_f1_score(prediction, ground_truth, **kwargs):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)

prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
return f1_score(prediction_tokens, ground_truth_tokens)


dataset2metric = {
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score,
"hotpotqa": qa_f1_score,
"2wikimqa": qa_f1_score,
"musique": qa_f1_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"multi_news": rouge_score,
"trec": classification_score,
"triviaqa": qa_f1_score,
"samsum": rouge_score,
"lsht": classification_score,
"passage_retrieval_en": retrieval_score,
"passage_count": count_score,
"passage_retrieval_zh": retrieval_zh_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}

dataset2maxlen = {
"narrativeqa": 128,
"qasper": 128,
"multifieldqa_en": 64,
"multifieldqa_zh": 64,
"hotpotqa": 32,
"2wikimqa": 32,
"musique": 32,
"dureader": 128,
"gov_report": 512,
"qmsum": 512,
"multi_news": 512,
"vcsum": 512,
"trec": 64,
"triviaqa": 32,
"samsum": 128,
"lsht": 64,
"passage_count": 32,
"passage_retrieval_en": 32,
"passage_retrieval_zh": 32,
"lcc": 64,
"repobench-p": 64
}

dataset2prompt = {
"narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
"qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
"multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
"hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
"gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
"qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
"multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
"vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
"trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
"triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
"samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
"lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
"passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
"passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
"passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
"lcc": "Please complete the code given below. \n{context}Next line of code:\n",
"repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
}


def scorer(dataset, predictions, answers, all_classes):
total_score = 0.
for (prediction, ground_truths) in zip(predictions, answers):
score = 0.
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip('\n').split('\n')[0]
for ground_truth in ground_truths:
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
total_score += score
return round(100 * total_score / len(predictions), 2)


def evaluate(model_output, task):
predictions, answers = [], []
for data in model_output:
predictions.append(data["pred"])
answers.append(data["answers"])
all_classes = data["all_classes"]
score = scorer(task, predictions, answers, all_classes)
return score


def build_chat(prompt, model_name):
if "Llama-2" in model_name:
prompt = f"[INST]{prompt}[/INST]"
elif "Llama" in model_name:
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
elif "Phi-3" in model_name:
prompt = f"<|user|>\n{prompt} <|end|>\n<|assistant|>"
return prompt


def preprocess_prompt(data_sample, subset, model_name):
prompt_format = dataset2prompt[subset]
prompt = prompt_format.format(**data_sample)
if subset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
prompt = build_chat(prompt, model_name)
return prompt


def post_process_pred(pred, subset, model_name):
if subset in ["samsum", "qsum", "hotpotqa", "qasper"] and "Llama-3" in model_name:
pred = pred[:pred.find("assistant")]
elif subset == "samsum":
pred = pred[:pred.find("\nDialogue")]
elif "Phi-3" in model_name and subset == "hotpotqa":
pred = pred.lstrip('\n').split('\n')[0]
elif subset in ["trec", "hotpotqa", "qasper"] and "Qwen" in model_name:
pred = pred[:pred.find("\nQuestion")]
return pred
Loading