Skip to content

Commit

Permalink
Add a new task GPQA (the part CoT and generative) (EleutherAI#1482)
Browse files Browse the repository at this point in the history
* Add new tasks of GPQA

* Add README

* Remove unused functions

* Remove unused functions

* Linters

* Add flexible match

* update

* Remove deplicate function

* Linter

* update

* Update lm_eval/filters/extraction.py

Co-authored-by: Hailey Schoelkopf <[email protected]>

* register multi_choice_regex

* Update

* run precommit

---------

Co-authored-by: Hailey Schoelkopf <[email protected]>
Co-authored-by: haileyschoelkopf <[email protected]>
  • Loading branch information
3 people authored Mar 5, 2024
1 parent 8a875e9 commit 01108ac
Show file tree
Hide file tree
Showing 23 changed files with 466 additions and 2 deletions.
1 change: 1 addition & 0 deletions lm_eval/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
Expand Down
114 changes: 114 additions & 0 deletions lm_eval/filters/extraction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import sys
import unicodedata

from lm_eval.api.filter import Filter

Expand Down Expand Up @@ -67,3 +69,115 @@ def filter_set(inst):
filtered_resps = [filter_set(resp) for resp in resps]

return filtered_resps


class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""

def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore

def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)

def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match

punct_tbl = dict.fromkeys(
i
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)

def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)

if self.ignore_case:
st = st.lower()

if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st

filtered_resps = []

for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"

without_paren_fallback_regexes = []
without_paren_to_target = {}

choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"

without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"

next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})"
)

filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(
fallback_regex, filter_ignores(resp), choice_to_alpha
)
if not match:
match = find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)

return filtered_resps
3 changes: 3 additions & 0 deletions lm_eval/tasks/gpqa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ This dataset is gated, so you will have to accept the terms of use at https://hu

* `gpqa_{main, diamond, extended}_zeroshot`
* `gpqa_{main, diamond, extended}_n_shot`
* `gpqa_{main, diamond, extended}_generative_n_shot`
* `gpqa_{main, diamond, extended}_cot_zeroshot`
* `gpqa_{main, diamond, extended}_cot_n_shot`

### Checklist

Expand Down
26 changes: 26 additions & 0 deletions lm_eval/tasks/gpqa/cot_n_shot/_generate_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import yaml
from tqdm import tqdm


def main() -> None:
subset = ["extended", "diamond", "main"]
setting = "cot_n_shot"
for task in tqdm(subset):
file_name = f"gpqa_{task}_{setting}.yaml"
try:
with open(f"{file_name}", "w") as f:
f.write("# Generated by _generate_configs.py\n")
yaml.dump(
{
"include": f"_gpqa_{setting}_yaml",
"task": f"gpqa_{task}_{setting}",
"dataset_name": f"gpqa_{task}",
},
f,
)
except FileExistsError:
pass


if __name__ == "__main__":
main()
38 changes: 38 additions & 0 deletions lm_eval/tasks/gpqa/cot_n_shot/_gpqa_cot_n_shot_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
dataset_path: Idavidrein/gpqa
group: gpqa
output_type: generate_until
process_docs: !function utils.process_docs
training_split: train
# Because huggingface dataset only has train split
validation_split: train
test_split: null
description: "Here are some example questions from experts. Answer the final question yourself, following the format of the previous questions exactly.\n"
doc_to_text: "Question: {{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nLet's think step by step: "
doc_to_target: answer
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "(?<=The answer is )(.*)(?=.)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "multi_choice_regex"
group_select: -1
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
- function: "take_first"
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0
4 changes: 4 additions & 0 deletions lm_eval/tasks/gpqa/cot_n_shot/gpqa_diamond_cot_n_shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Generated by _generate_configs.py
dataset_name: gpqa_diamond
include: _gpqa_cot_n_shot_yaml
task: gpqa_diamond_cot_n_shot
4 changes: 4 additions & 0 deletions lm_eval/tasks/gpqa/cot_n_shot/gpqa_extended_cot_n_shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Generated by _generate_configs.py
dataset_name: gpqa_extended
include: _gpqa_cot_n_shot_yaml
task: gpqa_extended_cot_n_shot
4 changes: 4 additions & 0 deletions lm_eval/tasks/gpqa/cot_n_shot/gpqa_main_cot_n_shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Generated by _generate_configs.py
dataset_name: gpqa_main
include: _gpqa_cot_n_shot_yaml
task: gpqa_main_cot_n_shot
39 changes: 39 additions & 0 deletions lm_eval/tasks/gpqa/cot_n_shot/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import random
import re

import datasets


def preprocess(text):
if text is None:
return " "
text = text.strip()
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text


def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
choices = [
preprocess(doc["Incorrect Answer 1"]),
preprocess(doc["Incorrect Answer 2"]),
preprocess(doc["Incorrect Answer 3"]),
preprocess(doc["Correct Answer"]),
]

random.shuffle(choices)
correct_answer_index = choices.index(preprocess(doc["Correct Answer"]))

out_doc = {
"choice1": choices[0],
"choice2": choices[1],
"choice3": choices[2],
"choice4": choices[3],
"choices": [choices[0], choices[1], choices[2], choices[3]],
"answer": f"({chr(65 + correct_answer_index)})",
}
return out_doc

return dataset.map(_process_doc)
26 changes: 26 additions & 0 deletions lm_eval/tasks/gpqa/cot_zeroshot/_generate_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import yaml
from tqdm import tqdm


def main() -> None:
subset = ["extended", "diamond", "main"]
setting = "cot_zeroshot"
for task in tqdm(subset):
file_name = f"gpqa_{task}_{setting}.yaml"
try:
with open(f"{file_name}", "w") as f:
f.write("# Generated by _generate_configs.py\n")
yaml.dump(
{
"include": f"_gpqa_{setting}_yaml",
"task": f"gpqa_{task}_{setting}",
"dataset_name": f"gpqa_{task}",
},
f,
)
except FileExistsError:
pass


if __name__ == "__main__":
main()
38 changes: 38 additions & 0 deletions lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
dataset_path: Idavidrein/gpqa
group: gpqa
output_type: generate_until
process_docs: !function utils.process_docs
training_split: train
# Because huggingface dataset only has train split
validation_split: train
test_split: null
doc_to_text: "What is the correct answer to this question:{{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nLet's think step by step: "
doc_to_target: answer
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "(?<=The answer is )(.*)(?=.)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "multi_choice_regex"
group_select: -1
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
- function: "take_first"
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.0
num_fewshot: 0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Generated by _generate_configs.py
dataset_name: gpqa_diamond
include: _gpqa_cot_zeroshot_yaml
task: gpqa_diamond_cot_zeroshot
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Generated by _generate_configs.py
dataset_name: gpqa_extended
include: _gpqa_cot_zeroshot_yaml
task: gpqa_extended_cot_zeroshot
4 changes: 4 additions & 0 deletions lm_eval/tasks/gpqa/cot_zeroshot/gpqa_main_cot_zeroshot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Generated by _generate_configs.py
dataset_name: gpqa_main
include: _gpqa_cot_zeroshot_yaml
task: gpqa_main_cot_zeroshot
39 changes: 39 additions & 0 deletions lm_eval/tasks/gpqa/cot_zeroshot/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import random
import re

import datasets


def preprocess(text):
if text is None:
return " "
text = text.strip()
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text


def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
choices = [
preprocess(doc["Incorrect Answer 1"]),
preprocess(doc["Incorrect Answer 2"]),
preprocess(doc["Incorrect Answer 3"]),
preprocess(doc["Correct Answer"]),
]

random.shuffle(choices)
correct_answer_index = choices.index(preprocess(doc["Correct Answer"]))

out_doc = {
"choice1": choices[0],
"choice2": choices[1],
"choice3": choices[2],
"choice4": choices[3],
"choices": [choices[0], choices[1], choices[2], choices[3]],
"answer": f"({chr(65 + correct_answer_index)})",
}
return out_doc

return dataset.map(_process_doc)
Loading

0 comments on commit 01108ac

Please sign in to comment.