Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/enhance harness report to include detailed score counts and grouped results #1132

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 124 additions & 28 deletions langtest/metrics/llm_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import string
from typing import List, Optional, Tuple
from textwrap import dedent
from typing import List, Mapping, Optional, Tuple
from ..utils.custom_types.helpers import HashableDict

template = """You are a teacher grading a quiz.
Expand All @@ -22,10 +23,82 @@
input_variables = ["query", "result", "answer"]


class EvalTemplate:
"""
The EvalTemplate class provides a method to build a prompt for evaluating student answers
based on a given rubric. The prompt is designed for a teacher to grade a quiz by comparing
the student's answer with the true answer and scoring it according to specified criteria.

Methods
-------
build_prompt(rubic_score: Mapping[str, str] = {"CORRECT": None, "INCORRECT": None}) -> str
Constructs and returns a grading prompt based on the provided rubric scores.

"""

@staticmethod
def build_prompt(
rubic_score: Mapping[str, str] = {
"CORRECT": None,
"INCORRECT": None,
}
):
""" """
grade_list = list(rubic_score.keys())
grade_list = ", ".join(grade_list[:-1]) + f" or {grade_list[-1]}"

eval_criteria = [
f"{grade_name}: {criteria}\n"
for grade_name, criteria in rubic_score.items()
if criteria
]
prompt = (
"You are a teacher grading a quiz. You are given a question, the student's "
"answer, and the true answer, and are asked to score the student answer as either "
f"{grade_list}."
)

if eval_criteria:
eval_criteria = "".join(eval_criteria)
prompt += dedent(
f"""\n\nScore the student answer based on the following criteria:\n{eval_criteria}"""
)

prompt += dedent(
f"""
Example Format:
QUESTION: question here
STUDENT ANSWER: student's answer here
TRUE ANSWER: true answer here
GRADE: {grade_list} here

{
("Grade the student answers based ONLY on their factual accuracy. Ignore differences"
" in punctuation and phrasing between the student answer and true answer. It is OK "
"if the student answer contains more or relevant information than the true answer, as"
" long as it does not contain any conflicting statements. Begin!")
}

QUESTION: {{query}}
STUDENT ANSWER: {{result}}
TRUE ANSWER: {{answer}}
GRADE:"""
)
return prompt


class LlmEval:
"""llm_eval for evaluating question answering."""

def __init__(self, llm, template=template, input_variables=input_variables):
grade_list = None

def __init__(
self,
llm,
template=template,
input_variables=input_variables,
grade_list=None,
):
"""
Initializes the LlmEval object.

Expand All @@ -42,6 +115,7 @@ def __init__(self, llm, template=template, input_variables=input_variables):
self.template = template
self.input_variables = input_variables
self.server_prompt = server_prompt
LlmEval.grade_list = grade_list

expected_input_vars = {"query", "answer", "result"}
if expected_input_vars != set(self.input_variables):
Expand All @@ -52,33 +126,55 @@ def __init__(self, llm, template=template, input_variables=input_variables):

@staticmethod
def _get_score(text: str) -> Optional[Tuple[str, int]]:
match = re.search(r"grade:\s*(correct|incorrect)", text.strip(), re.IGNORECASE)
if LlmEval.grade_list is None:
default_grades = ["CORRECT", "INCORRECT"]
grade_list_pattern = f"grade:\\s*({'|'.join(default_grades).lower()})"
else:
grade_list_pattern = f"(?:grade\\s*)?({'|'.join(LlmEval.grade_list).lower()})"

match = re.search(grade_list_pattern, text.strip(), re.IGNORECASE)
if match:
if match.group(1).upper() == "CORRECT":
return "CORRECT", 1
elif match.group(1).upper() == "INCORRECT":
return "INCORRECT", 0
try:
first_word = (
text.strip()
.split()[0]
.translate(str.maketrans("", "", string.punctuation))
)
if first_word.upper() == "CORRECT":
return "CORRECT", 1
elif first_word.upper() == "INCORRECT":
return "INCORRECT", 0
last_word = (
text.strip()
.split()[-1]
.translate(str.maketrans("", "", string.punctuation))
)
if last_word.upper() == "CORRECT":
return "CORRECT", 1
elif last_word.upper() == "INCORRECT":
return "INCORRECT", 0
except IndexError:
pass
grade = match.group(1).upper()
if LlmEval.grade_list is None:
if grade == "CORRECT":
return "CORRECT", 1
elif grade == "INCORRECT":
return "INCORRECT", 0
elif grade in LlmEval.grade_list:
return grade, LlmEval.grade_list.index(grade)
else:
try:
# Check for first word
first_word = (
text.strip()
.split()[0]
.translate(str.maketrans("", "", string.punctuation))
)
if LlmEval.grade_list is None:
if first_word.upper() == "CORRECT":
return "CORRECT", 1
elif first_word.upper() == "INCORRECT":
return "INCORRECT", 0
elif first_word.upper() in LlmEval.grade_list:
return first_word.upper(), LlmEval.grade_list.index(
first_word.upper()
)

# Check for last word
last_word = (
text.strip()
.split()[-1]
.translate(str.maketrans("", "", string.punctuation))
)
if LlmEval.grade_list is None:
if last_word.upper() == "CORRECT":
return "CORRECT", 1
elif last_word.upper() == "INCORRECT":
return "INCORRECT", 0
elif last_word.upper() in LlmEval.grade_list:
return last_word.upper(), LlmEval.grade_list.index(last_word.upper())
except IndexError:
pass
return None

@staticmethod
Expand Down
79 changes: 63 additions & 16 deletions langtest/utils/custom_types/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from pydantic import BaseModel
from collections.abc import Hashable
import importlib
from typing import List, Tuple
from typing import List, Tuple, Union

from ...errors import Errors

default_user_prompt = {
Expand Down Expand Up @@ -350,6 +351,7 @@ def is_pass_llm_eval(
answer: str,
perturbed_question: str,
prediction: str,
eval_template: Union[str, dict] = None,
):
"""
Determines whether the model's prediction passes the Language Model Metric (LLM) evaluation.
Expand All @@ -367,22 +369,47 @@ def is_pass_llm_eval(

"""

if prediction.lower().strip() == answer.lower().strip():
return True
if eval_template is None:
if prediction.lower().strip() == answer.lower().strip():
return True

inputs, predictions = prepare_llm_evaluation_data(
original_question, answer, perturbed_question, prediction
)

grades = None
if eval_template is None:
# from ...transform.constants import qa_prompt_template as template
from ...metrics.llm_eval import template

eval_template = template
elif isinstance(eval_template, dict):
from ...metrics.llm_eval import EvalTemplate

rubic_score_dict = eval_template.get("rubic_score", None)
grades = list(rubic_score_dict.keys())

eval_template = EvalTemplate.build_prompt(rubic_score_dict)

if "llm" in str(type(eval_model)):
result = llm_prompt_eval(eval_model, dataset_name, inputs, predictions)
result = llm_prompt_eval(
eval_model, dataset_name, inputs, predictions, eval_template, grades
)
else:
result = transformer_prompt_eval(eval_model, inputs, predictions)
result = transformer_prompt_eval(
eval_model, inputs, predictions, eval_template, grades
)

return result


def llm_prompt_eval(
eval_model, dataset_name: str, inputs: List[dict], predictions: List[dict]
eval_model,
dataset_name: str,
inputs: List[dict],
predictions: List[dict],
template: str = None,
grades: List[str] = None,
) -> bool:
"""
Evaluates model predictions using the Language Model Metric (LLM) with prompt-based evaluation.
Expand All @@ -400,9 +427,6 @@ def llm_prompt_eval(
from langchain.evaluation.qa import QAEvalChain
from langchain.prompts import PromptTemplate

# from ...transform.constants import qa_prompt_template as template
from ...metrics.llm_eval import template

PROMPT = PromptTemplate(
input_variables=["query", "answer", "result"],
template=template,
Expand Down Expand Up @@ -436,17 +460,31 @@ def llm_prompt_eval(
answer_key="answer",
prediction_key="text",
)
result = bool(
re.match(
r"CORRECT|TRUE",
if grades:
# Extract the grade from the result by matching the pattern
result = re.sub(
r"GRADE: ",
"",
list(graded_outputs[0].values())[0].replace("\n", "").strip(),
)
)
match = re.search(f"({'|'.join(grades)})", result, re.IGNORECASE).group(0)
return match
else:
result = bool(
re.match(
r"CORRECT|TRUE",
list(graded_outputs[0].values())[0].replace("\n", "").strip(),
)
)
return result


def transformer_prompt_eval(
eval_model, inputs: List[dict], predictions: List[dict]
eval_model,
inputs: List[dict],
predictions: List[dict],
template: str = None,
grades: List[str] = None,
) -> bool:
"""
Evaluates model predictions using a transformer-based language model.
Expand All @@ -461,15 +499,24 @@ def transformer_prompt_eval(
"""
from ...metrics.llm_eval import LlmEval

eval_chain = LlmEval(llm=eval_model)
eval_chain = LlmEval(llm=eval_model, template=template, grade_list=grades)
graded_outputs = eval_chain.evaluate(
inputs,
predictions,
question_key="question",
answer_key="answer",
prediction_key="result",
)
result = list(graded_outputs[0].values())[0].replace("\n", "").strip() == "CORRECT"
if grades is None:
result = (
list(graded_outputs[0].values())[0].replace("\n", "").strip() == "CORRECT"
)
else:
result = re.sub(
r"GRADE: ",
"",
list(graded_outputs[0].values())[0].replace("\n", "").strip(),
)
return result


Expand Down
11 changes: 10 additions & 1 deletion langtest/utils/custom_types/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class BaseQASample(BaseModel):
state: str = None
task: str = Field(default="question-answering", const=True)
test_case: str = None
config: str = None
config: Mapping[str, Mapping] = None
distance_result: float = None
eval_model: Union[str, tuple] = None
ran_pass: bool = None
Expand Down Expand Up @@ -553,6 +553,8 @@ def __update_params(self):
self.eval_model = load_eval_model.model(
model, hub, **harness_config.get("model_parameters", {})
)
else:
self.eval_model = EVAL_MODEL

else:
self.eval_model = EVAL_MODEL
Expand Down Expand Up @@ -656,13 +658,20 @@ def is_pass(self) -> bool:
elif self.metric_name == "llm_eval":
if isinstance(self.eval_model, dict):
self.eval_model = list(self.eval_model.values())[-1]

# get the template for evaluation

template = self.config.get("evaluation", {}).get("eval_prompt", None)

# get the metric function
result = metric_function(
eval_model=self.eval_model,
dataset_name=self.dataset_name,
original_question=self.original_question,
answer=self.expected_results,
perturbed_question=self.perturbed_question,
prediction=self.actual_results,
eval_template=template,
)

self.ran_pass = result
Expand Down
Loading
Loading