From 80432dc281a183bb327267cc4ad2c18a54a31cf0 Mon Sep 17 00:00:00 2001 From: Jialiang Xu <48697394+liamjxu@users.noreply.github.com> Date: Wed, 22 Jan 2025 16:18:25 -0800 Subject: [PATCH] Include multiple annotators for WildBench (#3283) --- .../annotation/wildbench_annotator.py | 71 +++++++++++++------ .../benchmark/metrics/wildbench_metrics.py | 5 +- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/src/helm/benchmark/annotation/wildbench_annotator.py b/src/helm/benchmark/annotation/wildbench_annotator.py index f74d873eb5..1bd8361e22 100644 --- a/src/helm/benchmark/annotation/wildbench_annotator.py +++ b/src/helm/benchmark/annotation/wildbench_annotator.py @@ -1,9 +1,11 @@ import re from typing import Any from importlib.resources import files +from typing import Dict from helm.benchmark.adaptation.request_state import RequestState from helm.benchmark.annotation.annotator import Annotator +from helm.benchmark.annotation.model_as_judge import _AnnotatorModelInfo from helm.clients.auto_client import AutoClient from helm.common.request import Request @@ -38,28 +40,51 @@ def annotate(self, request_state: RequestState) -> Any: .replace("{$model_output}", model_output_text) .replace("{$checklist}", "\n".join(request_state.instance.extra_data["checklist"])) ) - annotator_request = Request( - model="openai/gpt-4o-2024-05-13", - model_deployment="openai/gpt-4o-2024-05-13", - prompt=annotator_prompt, - temperature=0.0, - max_tokens=2000, - ) - annotator_response = self._auto_client.make_request(annotator_request) - if not annotator_response.success: - raise Exception(f"Annotation request failed: {annotator_response.error}") - assert len(annotator_response.completions) == 1 - annotator_response_text = annotator_response.completions[0].text - annotator_response_parts = self._pattern.search(annotator_response_text) - if not annotator_response_parts: - raise ValueError(f"Malformed annotator response: {annotator_response_text}") - strengths = annotator_response_parts[1].strip() - weaknesses = annotator_response_parts[2].strip() - score_text = annotator_response_parts[3].strip().strip('"') - try: - score = float(score_text) - except ValueError: - raise ValueError(f"Malformed score '{score_text}' in annotator response: {annotator_response_text}") + SHORT_NAME_TO_MODEL_INFO: Dict[str, _AnnotatorModelInfo] = { + "gpt": _AnnotatorModelInfo( + model_name="openai/gpt-4o-2024-05-13", model_deployment="openai/gpt-4o-2024-05-13" + ), + "llama": _AnnotatorModelInfo( + model_name="meta/llama-3.1-405b-instruct-turbo", + model_deployment="together/llama-3.1-405b-instruct-turbo", + ), + "claude": _AnnotatorModelInfo( + model_name="anthropic/claude-3-5-sonnet-20241022", + model_deployment="anthropic/claude-3-5-sonnet-20241022", + ), + } + all_strengths = [] + all_weaknesses = [] + all_scores = [] + for annotator_model in SHORT_NAME_TO_MODEL_INFO: + annotator_model_info = SHORT_NAME_TO_MODEL_INFO[annotator_model] + annotator_request = Request( + model=annotator_model_info.model_name, + model_deployment=annotator_model_info.model_deployment, + prompt=annotator_prompt, + temperature=0.0, + max_tokens=2000, + ) + annotator_response = self._auto_client.make_request(annotator_request) + if not annotator_response.success: + continue # skip this annotator if the request failed + assert len(annotator_response.completions) == 1 + annotator_response_text = annotator_response.completions[0].text + annotator_response_parts = self._pattern.search(annotator_response_text) + if not annotator_response_parts: + continue # skip this annotator if the response is malformed + + strengths = annotator_response_parts[1].strip() + weaknesses = annotator_response_parts[2].strip() + score_text = annotator_response_parts[3].strip().strip('"') + try: + score = float(score_text) + except ValueError: + continue # skip this annotator if the score is not a number + + all_strengths.append(strengths) + all_weaknesses.append(weaknesses) + all_scores.append(score) - return {"strengths": strengths, "weaknesses": weaknesses, "score": score} + return {"strengths": all_strengths, "weaknesses": all_weaknesses, "score": all_scores} diff --git a/src/helm/benchmark/metrics/wildbench_metrics.py b/src/helm/benchmark/metrics/wildbench_metrics.py index b3deb766b1..cde95bd689 100644 --- a/src/helm/benchmark/metrics/wildbench_metrics.py +++ b/src/helm/benchmark/metrics/wildbench_metrics.py @@ -19,7 +19,10 @@ def evaluate_generation( eval_cache_path: str, ) -> List[Stat]: assert request_state.annotations - score = request_state.annotations["wildbench"]["score"] + all_scores = request_state.annotations["wildbench"]["score"] + if len(all_scores) == 0: + raise ValueError("Could not compute WB Score because all annotators failed.") + score = sum(all_scores) / len(all_scores) score_rescaled = (score - 1) / 9 return [ Stat(MetricName("wildbench_score")).add(score),