diff --git a/lm_eval/filters/__init__.py b/lm_eval/filters/__init__.py index 271f8c1ee8..42d8e9b040 100644 --- a/lm_eval/filters/__init__.py +++ b/lm_eval/filters/__init__.py @@ -1,10 +1,9 @@ -from typing import List, Union from functools import partial +from typing import List, Union from lm_eval.api.filter import FilterEnsemble -from . import selection -from . import extraction -from . import transformation + +from . import extraction, selection, transformation FILTER_REGISTRY = { diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index 23dace2f44..a2f9715d9d 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -1,13 +1,17 @@ -from . import huggingface -from . import openai_completions -from . import textsynth -from . import dummy -from . import anthropic_llms -from . import gguf -from . import vllm_causallms -from . import mamba_lm -from . import optimum_lm -from . import neuron_optimum +from . import ( + anthropic_llms, + dummy, + gguf, + huggingface, + mamba_lm, + neuron_optimum, + openai_completions, + optimum_lm, + textsynth, + vllm_causallms, +) + + # TODO: implement __all__ diff --git a/lm_eval/prompts/__init__.py b/lm_eval/prompts/__init__.py index c505113a3d..1f814214de 100644 --- a/lm_eval/prompts/__init__.py +++ b/lm_eval/prompts/__init__.py @@ -1,10 +1,11 @@ -import os import ast - +import os from typing import Dict + from lm_eval import utils from lm_eval.utils import eval_logger + # Prompt library. # Stores prompts in a dictionary indexed by 2 levels: # prompt category name, and prompt name. diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index 20d87c082e..a336b7f91c 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -1,14 +1,12 @@ -import os import abc import collections - +import logging +import os from functools import partial -from typing import List, Union, Dict +from typing import Dict, List, Union from lm_eval import utils -from lm_eval.api.task import Task, ConfigurableTask - -import logging +from lm_eval.api.task import ConfigurableTask, Task class TaskManager: @@ -16,20 +14,14 @@ class TaskManager: and an optional directory if provided. """ - def __init__( - self, - verbosity="INFO", - include_path=None - ) -> None: + def __init__(self, verbosity="INFO", include_path=None) -> None: self.verbosity = verbosity self.include_path = include_path self.logger = utils.eval_logger self.logger.setLevel(getattr(logging, f"{verbosity}")) - self._task_index = self.initialize_tasks( - include_path=include_path - ) + self._task_index = self.initialize_tasks(include_path=include_path) self._all_tasks = sorted(list(self._task_index.keys())) self.task_group_map = collections.defaultdict(list) @@ -65,27 +57,29 @@ def task_index(self): return self._task_index def match_tasks(self, task_list): - return utils.pattern_match( - task_list, self.all_tasks - ) + return utils.pattern_match(task_list, self.all_tasks) def _name_is_registered(self, name): if name in self.all_tasks: return True return False - def _name_is_task(self, name): + def _name_is_task(self, name) -> bool: if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]): return True return False def _name_is_group(self, name): - if self._name_is_registered(name) and (self.task_index[name]["type"] == "group"): + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "group" + ): return True return False def _name_is_python_task(self, name): - if self._name_is_registered(name) and (self.task_index[name]["type"] == "python_task"): + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "python_task" + ): return True return False @@ -117,7 +111,7 @@ def _get_config(self, name): return utils.load_yaml_config(yaml_path, mode="full") def _get_tasklist(self, name): - assert self._name_is_task(name) == False + assert self._name_is_task(name) is False return self.task_index[name]["task"] def _process_alias(self, config, group=None): @@ -130,12 +124,12 @@ def _process_alias(self, config, group=None): return config def _load_individual_task_or_group( - self, - name_or_config: Union[str, dict] = None, - parent_name: str = None, - update_config: dict = None, - yaml_path: str = None, - ) -> ConfigurableTask: + self, + name_or_config: Union[str, dict] = None, + parent_name: str = None, + update_config: dict = None, + yaml_path: str = None, + ) -> ConfigurableTask: def load_task(config, task, group=None, yaml_path=None): if "include" in config: assert yaml_path is not None @@ -174,7 +168,9 @@ def load_task(config, task, group=None, yaml_path=None): group_config = self._get_config(name_or_config) if set(group_config.keys()) > set(["task", "group"]): update_config = { - k:v for k,v in group_config.items() if k not in ["task", "group"] + k: v + for k, v in group_config.items() + if k not in ["task", "group"] } yaml_path = self._get_yaml_path(group_name) @@ -183,9 +179,8 @@ def load_task(config, task, group=None, yaml_path=None): update_config.pop("group_alias") if isinstance(name_or_config, dict): - if update_config is not None: - name_or_config={ + name_or_config = { **name_or_config, **update_config, } @@ -196,7 +191,9 @@ def load_task(config, task, group=None, yaml_path=None): # if self._name_is_task(name) is False: if self._name_is_group(name): group_name = name - update_config = {k:v for k,v in name_or_config.items() if k != "task"} + update_config = { + k: v for k, v in name_or_config.items() if k != "task" + } subtask_list = self._get_tasklist(name) if subtask_list == -1: subtask_list = self._get_config(name)["task"] @@ -207,36 +204,53 @@ def load_task(config, task, group=None, yaml_path=None): # Check if this is a duplicate. if parent_name is not None: name_or_config["group"] = parent_name - num_duplicate = len(list(filter(lambda x: x.startswith(name), self.task_group_map[parent_name]))) + num_duplicate = len( + list( + filter( + lambda x: x.startswith(name), + self.task_group_map[parent_name], + ) + ) + ) if num_duplicate > 0: name = f"{name}-{num_duplicate}" self.task_group_map[parent_name].append(name) - task_config={ - **base_task_config, - **name_or_config, - } + task_config = { + **base_task_config, + **name_or_config, + } else: task_config = name_or_config - return load_task(task_config, task=name, group=parent_name, yaml_path=yaml_path) + return load_task( + task_config, task=name, group=parent_name, yaml_path=yaml_path + ) else: group_name = name_or_config["group"] subtask_list = name_or_config["task"] - # update_config = {k:v for k,v in name_or_config.items() if k != "task"} if set(name_or_config.keys()) > set(["task", "group"]): update_config = { - k:v for k,v in name_or_config.items() if k not in ["task", "group"] + k: v + for k, v in name_or_config.items() + if k not in ["task", "group"] } all_subtasks = {} - if (parent_name is not None): + if parent_name is not None: all_subtasks = {group_name: (parent_name, None)} - fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config, yaml_path=yaml_path) - all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))} + fn = partial( + self._load_individual_task_or_group, + parent_name=group_name, + update_config=update_config, + yaml_path=yaml_path, + ) + all_subtasks = { + **all_subtasks, + **dict(collections.ChainMap(*map(fn, subtask_list))), + } return all_subtasks - def load_task_or_group(self, task_list: Union[str, list] = None) -> dict: """Loads a dictionary of task objects from a list @@ -250,12 +264,7 @@ def load_task_or_group(self, task_list: Union[str, list] = None) -> dict: task_list = [task_list] all_loaded_tasks = dict( - collections.ChainMap( - *map( - self._load_individual_task_or_group, - task_list - ) - ) + collections.ChainMap(*map(self._load_individual_task_or_group, task_list)) ) return all_loaded_tasks @@ -299,11 +308,11 @@ def _get_task_and_group(self, task_dir: str): # This is a group config tasks_and_groups[config["group"]] = { "type": "group", - "task": -1, # This signals that - # we don't need to know - # the task list for indexing - # as it can be loaded - # when called. + "task": -1, # This signals that + # we don't need to know + # the task list for indexing + # as it can be loaded + # when called. "yaml_path": yaml_path, } @@ -322,7 +331,7 @@ def _get_task_and_group(self, task_dir: str): tasks_and_groups[task] = { "type": "task", "yaml_path": yaml_path, - } + } if "group" in config: groups = config["group"] @@ -343,6 +352,7 @@ def _get_task_and_group(self, task_dir: str): return tasks_and_groups + def include_path(task_dir): logger = utils.eval_logger logger.setLevel(getattr(logging, "INFO")) @@ -352,6 +362,7 @@ def include_path(task_dir): ) return 0 + def initialize_tasks(verbosity="INFO"): logger = utils.eval_logger logger.setLevel(getattr(logging, f"{verbosity}")) @@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"): ) return 0 + def get_task_name_from_config(task_config: Dict[str, str]) -> str: if "task" in task_config: return task_config["task"] @@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str: else: return "{dataset_path}".format(**task_config) + def get_task_name_from_object(task_object): if hasattr(task_object, "config"): return task_object._config["task"] @@ -382,7 +395,10 @@ def get_task_name_from_object(task_object): else type(task_object).__name__ ) -def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None): + +def get_task_dict( + task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None +): """Creates a dictionary of task objects from either a name of task, config, or prepared Task object. :param task_name_list: List[Union[str, Dict, Task]] @@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta if task_manager is None: task_manager = TaskManager() - task_name_from_string_dict = task_manager.load_task_or_group(string_task_name_list) + task_name_from_string_dict = task_manager.load_task_or_group( + string_task_name_list + ) for task_element in others_task_name_list: if isinstance(task_element, dict): diff --git a/lm_eval/tasks/bbh/_generate_configs.py b/lm_eval/tasks/bbh/_generate_configs.py index 0d085a1d0a..febee5fcd4 100644 --- a/lm_eval/tasks/bbh/_generate_configs.py +++ b/lm_eval/tasks/bbh/_generate_configs.py @@ -1,13 +1,13 @@ """ Take in a YAML, and output all other splits with this YAML """ +import argparse import os import re -import yaml -import requests -import argparse import datasets +import requests +import yaml from tqdm import tqdm from lm_eval import utils diff --git a/lm_eval/tasks/bbh/cot_zeroshot/utils.py b/lm_eval/tasks/bbh/cot_zeroshot/utils.py index ca411033fe..a3c63df468 100644 --- a/lm_eval/tasks/bbh/cot_zeroshot/utils.py +++ b/lm_eval/tasks/bbh/cot_zeroshot/utils.py @@ -1,19 +1,24 @@ import collections import re import sys - import unicodedata -from lm_eval.filters.extraction import RegexFilter, Filter +from lm_eval.filters.extraction import Filter, RegexFilter class ExtendedRegexFilter(RegexFilter): - punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) - if unicodedata.category(chr(i)).startswith('P')) + punct_tbl = dict.fromkeys( + i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") + ) def __init__( - self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]", - ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, ) -> None: super().__init__(regex_pattern, group_select, fallback) self.ignore_case = ignore_case @@ -47,8 +52,13 @@ def find_match(self, regex, resp, convert_dict={}): class MapRegexFilter(ExtendedRegexFilter): def __init__( - self, regex_pattern_to_value: dict = {}, group_select=0, fallback: str = "[invalid]", - ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, + self, + regex_pattern_to_value: dict = {}, + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, ) -> None: """ regex_pattern_to_value: Match the regex pattern and change the result into the value @@ -57,8 +67,17 @@ def __init__( ignore_punctuation: Remove the punctuation before matching with the given regex regexes_to_ignore: Remove these regexes before matching with the given regex """ - super().__init__('|'.join(list(regex_pattern_to_value.keys())), group_select, fallback, ignore_case, ignore_punctuation, regexes_to_ignore) - self.regex_to_value = {re.compile(r): v for r, v in regex_pattern_to_value.items()} + super().__init__( + "|".join(list(regex_pattern_to_value.keys())), + group_select, + fallback, + ignore_case, + ignore_punctuation, + regexes_to_ignore, + ) + self.regex_to_value = { + re.compile(r): v for r, v in regex_pattern_to_value.items() + } def apply(self, resps, docs): filtered_resps = [] @@ -66,10 +85,15 @@ def apply(self, resps, docs): for r in resps: filtered = [] for resp in r: - whole_match_considering_group_select = self.find_match(self.regex, self.filter_ignores(resp)) + whole_match_considering_group_select = self.find_match( + self.regex, self.filter_ignores(resp) + ) if whole_match_considering_group_select: for regex, mapped_value in self.regex_to_value.items(): - match = self.find_match(regex, self.filter_ignores(whole_match_considering_group_select)) + match = self.find_match( + regex, + self.filter_ignores(whole_match_considering_group_select), + ) if match: match = mapped_value break @@ -91,9 +115,11 @@ def apply(self, resps, docs): filtered_resps = [] import regex from word2number import w2n + # https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words english_number_regex = regex.compile( - "((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))") + "((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))" + ) for r in resps: filtered = [] @@ -118,21 +144,22 @@ def apply(self, resps, docs): filtered_resps = [] for r, doc in zip(resps, docs): - words = doc['input'].split("List:")[1].strip().split() - regex = re.compile('|'.join([f"\\b{w}\\b" for w in words])) + words = doc["input"].split("List:")[1].strip().split() + regex = re.compile("|".join([f"\\b{w}\\b" for w in words])) filtered = [] for resp in r: match = regex.findall(resp) match.reverse() - ordered_words = reversed(collections.OrderedDict(zip(match, [None] * len(match)))) - filtered.append(' '.join(ordered_words)) + ordered_words = reversed( + collections.OrderedDict(zip(match, [None] * len(match))) + ) + filtered.append(" ".join(ordered_words)) filtered_resps.append(filtered) return filtered_resps class MultiChoiceRegexFilter(ExtendedRegexFilter): - def __init__(self, *args, **kwargs): """ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure @@ -156,13 +183,13 @@ def apply(self, resps, docs): for r, doc in zip(resps, docs): fallback_regexes = [] choice_to_alpha = {} - next_alpha = 'A' + next_alpha = "A" without_paren_fallback_regexes = [] without_paren_to_target = {} multiple_choices_regex = re.compile(r"\([A-Z]\)([^\n^(]*)") - match = multiple_choices_regex.findall(doc['input']) + match = multiple_choices_regex.findall(doc["input"]) for m in match: m = self.filter_ignores(m.strip()) fallback_regexes.append(f"{re.escape(m)}") @@ -172,17 +199,23 @@ def apply(self, resps, docs): without_paren_to_target[next_alpha] = f"({next_alpha})" next_alpha = chr(ord(next_alpha) + 1) - fallback_regex = re.compile('|'.join(fallback_regexes)) - without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes) - without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})") + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + f":[\s]*({without_paren_fallback_regex})" + ) filtered = [] for resp in r: match = self.find_match(self.regex, resp) if not match: - match = self.find_match(fallback_regex, self.filter_ignores(resp), choice_to_alpha) + match = self.find_match( + fallback_regex, self.filter_ignores(resp), choice_to_alpha + ) if not match: - match = self.find_match(without_paren_fallback_regex, resp, without_paren_to_target) + match = self.find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) if not match: match = self.fallback filtered.append(match) diff --git a/lm_eval/tasks/bbh/zeroshot/utils.py b/lm_eval/tasks/bbh/zeroshot/utils.py index ca411033fe..a3c63df468 100644 --- a/lm_eval/tasks/bbh/zeroshot/utils.py +++ b/lm_eval/tasks/bbh/zeroshot/utils.py @@ -1,19 +1,24 @@ import collections import re import sys - import unicodedata -from lm_eval.filters.extraction import RegexFilter, Filter +from lm_eval.filters.extraction import Filter, RegexFilter class ExtendedRegexFilter(RegexFilter): - punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) - if unicodedata.category(chr(i)).startswith('P')) + punct_tbl = dict.fromkeys( + i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") + ) def __init__( - self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]", - ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, ) -> None: super().__init__(regex_pattern, group_select, fallback) self.ignore_case = ignore_case @@ -47,8 +52,13 @@ def find_match(self, regex, resp, convert_dict={}): class MapRegexFilter(ExtendedRegexFilter): def __init__( - self, regex_pattern_to_value: dict = {}, group_select=0, fallback: str = "[invalid]", - ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, + self, + regex_pattern_to_value: dict = {}, + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, ) -> None: """ regex_pattern_to_value: Match the regex pattern and change the result into the value @@ -57,8 +67,17 @@ def __init__( ignore_punctuation: Remove the punctuation before matching with the given regex regexes_to_ignore: Remove these regexes before matching with the given regex """ - super().__init__('|'.join(list(regex_pattern_to_value.keys())), group_select, fallback, ignore_case, ignore_punctuation, regexes_to_ignore) - self.regex_to_value = {re.compile(r): v for r, v in regex_pattern_to_value.items()} + super().__init__( + "|".join(list(regex_pattern_to_value.keys())), + group_select, + fallback, + ignore_case, + ignore_punctuation, + regexes_to_ignore, + ) + self.regex_to_value = { + re.compile(r): v for r, v in regex_pattern_to_value.items() + } def apply(self, resps, docs): filtered_resps = [] @@ -66,10 +85,15 @@ def apply(self, resps, docs): for r in resps: filtered = [] for resp in r: - whole_match_considering_group_select = self.find_match(self.regex, self.filter_ignores(resp)) + whole_match_considering_group_select = self.find_match( + self.regex, self.filter_ignores(resp) + ) if whole_match_considering_group_select: for regex, mapped_value in self.regex_to_value.items(): - match = self.find_match(regex, self.filter_ignores(whole_match_considering_group_select)) + match = self.find_match( + regex, + self.filter_ignores(whole_match_considering_group_select), + ) if match: match = mapped_value break @@ -91,9 +115,11 @@ def apply(self, resps, docs): filtered_resps = [] import regex from word2number import w2n + # https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words english_number_regex = regex.compile( - "((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))") + "((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))" + ) for r in resps: filtered = [] @@ -118,21 +144,22 @@ def apply(self, resps, docs): filtered_resps = [] for r, doc in zip(resps, docs): - words = doc['input'].split("List:")[1].strip().split() - regex = re.compile('|'.join([f"\\b{w}\\b" for w in words])) + words = doc["input"].split("List:")[1].strip().split() + regex = re.compile("|".join([f"\\b{w}\\b" for w in words])) filtered = [] for resp in r: match = regex.findall(resp) match.reverse() - ordered_words = reversed(collections.OrderedDict(zip(match, [None] * len(match)))) - filtered.append(' '.join(ordered_words)) + ordered_words = reversed( + collections.OrderedDict(zip(match, [None] * len(match))) + ) + filtered.append(" ".join(ordered_words)) filtered_resps.append(filtered) return filtered_resps class MultiChoiceRegexFilter(ExtendedRegexFilter): - def __init__(self, *args, **kwargs): """ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure @@ -156,13 +183,13 @@ def apply(self, resps, docs): for r, doc in zip(resps, docs): fallback_regexes = [] choice_to_alpha = {} - next_alpha = 'A' + next_alpha = "A" without_paren_fallback_regexes = [] without_paren_to_target = {} multiple_choices_regex = re.compile(r"\([A-Z]\)([^\n^(]*)") - match = multiple_choices_regex.findall(doc['input']) + match = multiple_choices_regex.findall(doc["input"]) for m in match: m = self.filter_ignores(m.strip()) fallback_regexes.append(f"{re.escape(m)}") @@ -172,17 +199,23 @@ def apply(self, resps, docs): without_paren_to_target[next_alpha] = f"({next_alpha})" next_alpha = chr(ord(next_alpha) + 1) - fallback_regex = re.compile('|'.join(fallback_regexes)) - without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes) - without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})") + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + f":[\s]*({without_paren_fallback_regex})" + ) filtered = [] for resp in r: match = self.find_match(self.regex, resp) if not match: - match = self.find_match(fallback_regex, self.filter_ignores(resp), choice_to_alpha) + match = self.find_match( + fallback_regex, self.filter_ignores(resp), choice_to_alpha + ) if not match: - match = self.find_match(without_paren_fallback_regex, resp, without_paren_to_target) + match = self.find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) if not match: match = self.fallback filtered.append(match) diff --git a/lm_eval/tasks/belebele/_generate_configs.py b/lm_eval/tasks/belebele/_generate_configs.py index fd96034afb..af6aa824a6 100644 --- a/lm_eval/tasks/belebele/_generate_configs.py +++ b/lm_eval/tasks/belebele/_generate_configs.py @@ -1,15 +1,16 @@ """ Take in a YAML, and output all other splits with this YAML """ -import os -import yaml import argparse -import requests +import os +import requests +import yaml from tqdm import tqdm from lm_eval.utils import logging + API_URL = "https://datasets-server.huggingface.co/splits?dataset=facebook/belebele" @@ -39,6 +40,7 @@ def parse_args(): def query(): response = requests.get(API_URL) return response.json()["splits"] + print(query()) languages = [split["split"] for split in query()] @@ -49,7 +51,7 @@ def query(): if args.task_prefix != "" else f"belebele_{lang}", "test_split": lang, - "fewshot_split":lang, + "fewshot_split": lang, } file_save_path = args.save_prefix_path + f"_{lang}.yaml" diff --git a/lm_eval/tasks/bigbench/generate_tasks.py b/lm_eval/tasks/bigbench/generate_tasks.py index 08fd0c0a59..169c664655 100644 --- a/lm_eval/tasks/bigbench/generate_tasks.py +++ b/lm_eval/tasks/bigbench/generate_tasks.py @@ -1,6 +1,8 @@ import os + import yaml + all_subtasks = [ "abstract_narrative_understanding", "anachronisms", diff --git a/lm_eval/tasks/bigbench/push_bigbench_dataset.py b/lm_eval/tasks/bigbench/push_bigbench_dataset.py index 7566a66441..44577fa5d4 100644 --- a/lm_eval/tasks/bigbench/push_bigbench_dataset.py +++ b/lm_eval/tasks/bigbench/push_bigbench_dataset.py @@ -8,10 +8,9 @@ `pip install "bigbench @ https://storage.googleapis.com/public_research_data/bigbench/bigbench-0.0.1.tar.gz"` and is included so that the bigbench dependency can be avoided. """ -from tqdm import tqdm -import datasets - import bigbench.api.util as bb_utils +import datasets +from tqdm import tqdm all_task_names = bb_utils.get_all_json_task_names() diff --git a/lm_eval/tasks/blimp/generate_configs.py b/lm_eval/tasks/blimp/generate_configs.py index a768196172..a32c366834 100644 --- a/lm_eval/tasks/blimp/generate_configs.py +++ b/lm_eval/tasks/blimp/generate_configs.py @@ -1,5 +1,6 @@ import yaml + all_subtasks = [ "adjunct_island", "anaphor_gender_agreement", diff --git a/lm_eval/tasks/ceval/_generate_configs.py b/lm_eval/tasks/ceval/_generate_configs.py index 2df8ca31e4..1c6e4fc78a 100644 --- a/lm_eval/tasks/ceval/_generate_configs.py +++ b/lm_eval/tasks/ceval/_generate_configs.py @@ -1,14 +1,15 @@ """ Take in a YAML, and output all other splits with this YAML """ -import os -import yaml import argparse +import os +import yaml from tqdm import tqdm from lm_eval.logger import eval_logger + SUBJECTS = { "computer_network": "计算机网络", "operating_system": "操作系统", diff --git a/lm_eval/tasks/cmmlu/_generate_configs.py b/lm_eval/tasks/cmmlu/_generate_configs.py index 3afb15bf84..81dc4d7d7b 100644 --- a/lm_eval/tasks/cmmlu/_generate_configs.py +++ b/lm_eval/tasks/cmmlu/_generate_configs.py @@ -1,14 +1,15 @@ """ Take in a YAML, and output all other splits with this YAML """ -import os -import yaml import argparse +import os +import yaml from tqdm import tqdm from lm_eval.logger import eval_logger + SUBJECTS = { "agronomy": "农学", "anatomy": "解剖学", diff --git a/lm_eval/tasks/code_x_glue/code-text/bleu.py b/lm_eval/tasks/code_x_glue/code-text/bleu.py index 7f89404649..654a0ae06a 100644 --- a/lm_eval/tasks/code_x_glue/code-text/bleu.py +++ b/lm_eval/tasks/code_x_glue/code-text/bleu.py @@ -1,10 +1,10 @@ #!/usr/bin/python +import math import re import sys -import math import xml.sax.saxutils +from typing import Any, Dict, List, Optional, Pattern, Tuple, Union -from typing import List, Pattern, Tuple, Union, Dict, Any, Optional """ This script was adapted from the original version by hieuhoang1972 which is part of MOSES. @@ -60,7 +60,7 @@ def normalize(s): # Added to bypass NIST-style pre-processing of hyp and ref files -- wade if nonorm: return s.split() - if type(s) is not str: + if not isinstance(s, str): s = " ".join(s) # language-independent part: for pattern, replace in normalize1: diff --git a/lm_eval/tasks/csatqa/_generate_configs.py b/lm_eval/tasks/csatqa/_generate_configs.py index bd849c0ae6..a74b890490 100644 --- a/lm_eval/tasks/csatqa/_generate_configs.py +++ b/lm_eval/tasks/csatqa/_generate_configs.py @@ -1,14 +1,15 @@ """ Take in a YAML, and output all other splits with this YAML """ -import os -import yaml import argparse +import os +import yaml from tqdm import tqdm from lm_eval.logger import eval_logger + SUBSETS = ["WR", "GR", "RCS", "RCSS", "RCH", "LI"] diff --git a/lm_eval/tasks/drop/utils.py b/lm_eval/tasks/drop/utils.py index 03f7218c90..54093bb4d2 100644 --- a/lm_eval/tasks/drop/utils.py +++ b/lm_eval/tasks/drop/utils.py @@ -4,6 +4,7 @@ import numpy as np from scipy.optimize import linear_sum_assignment + _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) diff --git a/lm_eval/tasks/gpqa/n_shot/_generate_configs.py b/lm_eval/tasks/gpqa/n_shot/_generate_configs.py index 977759f159..401fa9413d 100644 --- a/lm_eval/tasks/gpqa/n_shot/_generate_configs.py +++ b/lm_eval/tasks/gpqa/n_shot/_generate_configs.py @@ -1,5 +1,4 @@ import yaml - from tqdm import tqdm @@ -22,5 +21,6 @@ def main() -> None: except FileExistsError: pass + if __name__ == "__main__": main() diff --git a/lm_eval/tasks/gpqa/n_shot/utils.py b/lm_eval/tasks/gpqa/n_shot/utils.py index c1d9d1a5aa..e0b886d287 100644 --- a/lm_eval/tasks/gpqa/n_shot/utils.py +++ b/lm_eval/tasks/gpqa/n_shot/utils.py @@ -1,6 +1,8 @@ -import datasets -import re import random +import re + +import datasets + def preprocess(text): if text is None: @@ -11,8 +13,10 @@ def preprocess(text): text = text.replace(" ", " ") return text + rng = random.Random(42) + def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: def _process_doc(doc): choices = [ @@ -30,7 +34,7 @@ def _process_doc(doc): "choice2": choices[1], "choice3": choices[2], "choice4": choices[3], - "answer": f"({chr(65 + correct_answer_index)})" + "answer": f"({chr(65 + correct_answer_index)})", } return out_doc diff --git a/lm_eval/tasks/gpqa/zeroshot/_generate_configs.py b/lm_eval/tasks/gpqa/zeroshot/_generate_configs.py index f91c9f454f..64929f1b78 100644 --- a/lm_eval/tasks/gpqa/zeroshot/_generate_configs.py +++ b/lm_eval/tasks/gpqa/zeroshot/_generate_configs.py @@ -1,5 +1,4 @@ import yaml - from tqdm import tqdm @@ -22,5 +21,6 @@ def main() -> None: except FileExistsError: pass + if __name__ == "__main__": main() diff --git a/lm_eval/tasks/gpqa/zeroshot/utils.py b/lm_eval/tasks/gpqa/zeroshot/utils.py index f941abf06f..c2317e02ef 100644 --- a/lm_eval/tasks/gpqa/zeroshot/utils.py +++ b/lm_eval/tasks/gpqa/zeroshot/utils.py @@ -1,6 +1,8 @@ -import datasets -import re import random +import re + +import datasets + def preprocess(text): if text is None: @@ -29,7 +31,7 @@ def _process_doc(doc): "choice2": choices[1], "choice3": choices[2], "choice4": choices[3], - "answer": f"({chr(65 + correct_answer_index)})" + "answer": f"({chr(65 + correct_answer_index)})", } return out_doc diff --git a/lm_eval/tasks/hellaswag/utils.py b/lm_eval/tasks/hellaswag/utils.py index 62c0c23bcd..b526a9e930 100644 --- a/lm_eval/tasks/hellaswag/utils.py +++ b/lm_eval/tasks/hellaswag/utils.py @@ -1,6 +1,7 @@ -import datasets import re +import datasets + def preprocess(text): text = text.strip() diff --git a/lm_eval/tasks/ifeval/instructions.py b/lm_eval/tasks/ifeval/instructions.py index a0da474006..31436834b7 100644 --- a/lm_eval/tasks/ifeval/instructions.py +++ b/lm_eval/tasks/ifeval/instructions.py @@ -22,8 +22,10 @@ from typing import Dict, Optional, Sequence, Union import langdetect + from lm_eval.tasks.ifeval import instructions_util + logger = logging.getLogger(__name__) _InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]] diff --git a/lm_eval/tasks/ifeval/instructions_registry.py b/lm_eval/tasks/ifeval/instructions_registry.py index ecb20e9b23..30a092c379 100644 --- a/lm_eval/tasks/ifeval/instructions_registry.py +++ b/lm_eval/tasks/ifeval/instructions_registry.py @@ -15,6 +15,7 @@ """Registry of all instructions.""" from lm_eval.tasks.ifeval import instructions + _KEYWORD = "keywords:" _LANGUAGE = "language:" diff --git a/lm_eval/tasks/kobest/utils.py b/lm_eval/tasks/kobest/utils.py index 0a1789c33e..9799ef038c 100644 --- a/lm_eval/tasks/kobest/utils.py +++ b/lm_eval/tasks/kobest/utils.py @@ -6,32 +6,43 @@ def copa_doc_to_text(doc: dict) -> str: connector = {"원인": " 왜냐하면", "결과": " 그래서"}[doc["question"].strip()] return f"""{doc["premise"]} {connector}""" + def copa_doc_to_target(doc: dict) -> str: correct_choice = doc["alternative_1"] if doc["label"] == 0 else doc["alternative_2"] return f"""{correct_choice}""" + def copa_doc_to_choice(doc: dict) -> list: return [f"""{doc["alternative_1"]}""", f"""{doc["alternative_2"]}"""] + def sentineg_doc_to_text(doc: dict): return f"""문장: {doc["sentence"]} 긍부정:""" + def wic_doc_to_text(doc: dict) -> str: return f"""문장1: {doc["context_1"]} 문장2: {doc["context_2"]} 두 문장에서 {doc["word"]}가 같은 뜻으로 쓰였나?""" + def hellaswag_process_doc(doc: Dataset) -> Dataset: def preprocessor(dataset): return { "query": f"""문장: {dataset["context"]}""", - "choices": [dataset["ending_1"], dataset["ending_2"], dataset["ending_3"], dataset["ending_4"]], + "choices": [ + dataset["ending_1"], + dataset["ending_2"], + dataset["ending_3"], + dataset["ending_4"], + ], "gold": int(dataset["label"]), } return doc.map(preprocessor) + def macro_f1_score(items): unzipped_list = list(zip(*items)) golds = unzipped_list[0] preds = unzipped_list[1] - fscore = f1_score(golds, preds, average='macro') + fscore = f1_score(golds, preds, average="macro") return fscore diff --git a/lm_eval/tasks/medmcqa/utils_medmcqa.py b/lm_eval/tasks/medmcqa/utils_medmcqa.py index 7e7792d26b..8ce7e6beec 100644 --- a/lm_eval/tasks/medmcqa/utils_medmcqa.py +++ b/lm_eval/tasks/medmcqa/utils_medmcqa.py @@ -10,7 +10,12 @@ def doc_to_text(doc) -> str: Answer: """ choices = [doc["opa"], doc["opb"], doc["opc"], doc["opd"]] - option_choices = {'A': choices[0], 'B': choices[1], 'C': choices[2], 'D': choices[3]} + option_choices = { + "A": choices[0], + "B": choices[1], + "C": choices[2], + "D": choices[3], + } prompt = "Question: " + doc["question"] + "\nChoices:\n" for choice, option in option_choices.items(): diff --git a/lm_eval/tasks/medqa/preprocess_medqa.py b/lm_eval/tasks/medqa/preprocess_medqa.py index 8745f9481d..6ec3585145 100644 --- a/lm_eval/tasks/medqa/preprocess_medqa.py +++ b/lm_eval/tasks/medqa/preprocess_medqa.py @@ -1,5 +1,10 @@ def doc_to_text(doc) -> str: - option_choices = {'A': doc["ending0"], 'B': doc["ending1"], 'C': doc["ending2"], 'D': doc["ending3"]} + option_choices = { + "A": doc["ending0"], + "B": doc["ending1"], + "C": doc["ending2"], + "D": doc["ending3"], + } answers = "".join((f"{k}. {v}\n") for k, v in option_choices.items()) return f"Question: {doc['sent1']}\n{answers}Answer:" diff --git a/lm_eval/tasks/mgsm/utils.py b/lm_eval/tasks/mgsm/utils.py index 3a6547b2e2..116214f9f4 100644 --- a/lm_eval/tasks/mgsm/utils.py +++ b/lm_eval/tasks/mgsm/utils.py @@ -1,6 +1,7 @@ -import yaml import argparse +import yaml + LANGUAGES = { "bn": { # Bengali @@ -126,6 +127,7 @@ def add_regex_pattern(regex_pattern): ], } + def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: """ Generate a yaml file for each language. @@ -158,7 +160,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: task_name = f"mgsm_en_cot_{lang}" file_name = f"{task_name}.yaml" - ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"])+1 + ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1 with open( f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" ) as f: @@ -181,7 +183,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: **filter_list, "generation_kwargs": { "until": [QUESTION, "", "<|im_end|>"], - "do_sample": False + "do_sample": False, }, **({"target_delimiter": DELIMITER} if DELIMITER else {}), }, diff --git a/lm_eval/tasks/minerva_math/utils.py b/lm_eval/tasks/minerva_math/utils.py index bde5801c56..0de9bcafa1 100644 --- a/lm_eval/tasks/minerva_math/utils.py +++ b/lm_eval/tasks/minerva_math/utils.py @@ -1,14 +1,17 @@ -import datasets import re import signal +from typing import Dict, List, Optional + +import datasets + from lm_eval.utils import eval_logger -from typing import Optional, List, Dict + try: import sympy from sympy.parsing.latex import parse_latex except ModuleNotFoundError: - raise Exception( + raise ModuleNotFoundError( "`sympy` is required for generating translation task prompt templates. \ please install sympy via pip install lm-eval[math] or pip install -e .[math]", ) diff --git a/lm_eval/tasks/mmlu/_generate_configs.py b/lm_eval/tasks/mmlu/_generate_configs.py index 1424814e7d..05c67e00e4 100644 --- a/lm_eval/tasks/mmlu/_generate_configs.py +++ b/lm_eval/tasks/mmlu/_generate_configs.py @@ -1,14 +1,15 @@ """ Take in a YAML, and output all "other" splits with this YAML """ -import os -import yaml import argparse +import os +import yaml from tqdm import tqdm from lm_eval.logger import eval_logger + SUBJECTS = { "abstract_algebra": "stem", "anatomy": "stem", @@ -124,7 +125,6 @@ def parse_args(): yaml.dump( yaml_dict, yaml_file, - # width=float("inf"), allow_unicode=True, default_style='"', ) diff --git a/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py b/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py index 0ef6b1e8f0..72246935de 100644 --- a/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py +++ b/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py @@ -1,6 +1,5 @@ import re import sys - import unicodedata from lm_eval.filters.extraction import RegexFilter @@ -10,8 +9,13 @@ class MultiChoiceRegexFilter(RegexFilter): """ """ def __init__( - self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]", - ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, ) -> None: """ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure @@ -44,8 +48,11 @@ def find_match(regex, resp, convert_dict={}): match = convert_dict[match] return match - punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) - if unicodedata.category(chr(i)).startswith('P')) + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) def filter_ignores(st): if self.regexes_to_ignore is not None: @@ -65,12 +72,12 @@ def filter_ignores(st): for r, doc in zip(resps, docs): fallback_regexes = [] choice_to_alpha = {} - next_alpha = 'A' + next_alpha = "A" without_paren_fallback_regexes = [] without_paren_to_target = {} - choices = doc['choices'] + choices = doc["choices"] for c in choices: m = filter_ignores(c.strip()) fallback_regexes.append(f"{re.escape(m)}") @@ -80,17 +87,23 @@ def filter_ignores(st): without_paren_to_target[next_alpha] = f"({next_alpha})" next_alpha = chr(ord(next_alpha) + 1) - fallback_regex = re.compile('|'.join(fallback_regexes)) - without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes) - without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})") + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + f":[\s]*({without_paren_fallback_regex})" + ) filtered = [] for resp in r: match = find_match(self.regex, resp) if not match: - match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha) + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) if not match: - match = find_match(without_paren_fallback_regex, resp, without_paren_to_target) + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) if not match: match = self.fallback filtered.append(match) diff --git a/lm_eval/tasks/mmlu/flan_n_shot/generative/utils.py b/lm_eval/tasks/mmlu/flan_n_shot/generative/utils.py index 0ef6b1e8f0..72246935de 100644 --- a/lm_eval/tasks/mmlu/flan_n_shot/generative/utils.py +++ b/lm_eval/tasks/mmlu/flan_n_shot/generative/utils.py @@ -1,6 +1,5 @@ import re import sys - import unicodedata from lm_eval.filters.extraction import RegexFilter @@ -10,8 +9,13 @@ class MultiChoiceRegexFilter(RegexFilter): """ """ def __init__( - self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]", - ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, ) -> None: """ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure @@ -44,8 +48,11 @@ def find_match(regex, resp, convert_dict={}): match = convert_dict[match] return match - punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) - if unicodedata.category(chr(i)).startswith('P')) + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) def filter_ignores(st): if self.regexes_to_ignore is not None: @@ -65,12 +72,12 @@ def filter_ignores(st): for r, doc in zip(resps, docs): fallback_regexes = [] choice_to_alpha = {} - next_alpha = 'A' + next_alpha = "A" without_paren_fallback_regexes = [] without_paren_to_target = {} - choices = doc['choices'] + choices = doc["choices"] for c in choices: m = filter_ignores(c.strip()) fallback_regexes.append(f"{re.escape(m)}") @@ -80,17 +87,23 @@ def filter_ignores(st): without_paren_to_target[next_alpha] = f"({next_alpha})" next_alpha = chr(ord(next_alpha) + 1) - fallback_regex = re.compile('|'.join(fallback_regexes)) - without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes) - without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})") + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + f":[\s]*({without_paren_fallback_regex})" + ) filtered = [] for resp in r: match = find_match(self.regex, resp) if not match: - match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha) + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) if not match: - match = find_match(without_paren_fallback_regex, resp, without_paren_to_target) + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) if not match: match = self.fallback filtered.append(match) diff --git a/lm_eval/tasks/model_written_evals/advanced_ai_risk/_generate_configs.py b/lm_eval/tasks/model_written_evals/advanced_ai_risk/_generate_configs.py index 3a2bac5923..fa4e30ba16 100644 --- a/lm_eval/tasks/model_written_evals/advanced_ai_risk/_generate_configs.py +++ b/lm_eval/tasks/model_written_evals/advanced_ai_risk/_generate_configs.py @@ -1,6 +1,5 @@ -import yaml import datasets - +import yaml from tqdm import tqdm diff --git a/lm_eval/tasks/model_written_evals/persona/_generate_configs.py b/lm_eval/tasks/model_written_evals/persona/_generate_configs.py index 811e0b1b62..1378dee265 100644 --- a/lm_eval/tasks/model_written_evals/persona/_generate_configs.py +++ b/lm_eval/tasks/model_written_evals/persona/_generate_configs.py @@ -1,6 +1,5 @@ -import yaml import datasets - +import yaml from tqdm import tqdm diff --git a/lm_eval/tasks/okapi/arc_multilingual/utils.py b/lm_eval/tasks/okapi/arc_multilingual/utils.py index 43cccc5672..b47621a760 100644 --- a/lm_eval/tasks/okapi/arc_multilingual/utils.py +++ b/lm_eval/tasks/okapi/arc_multilingual/utils.py @@ -1,6 +1,7 @@ -import datasets import re +import datasets + def preprocess(text): if text is None: @@ -18,7 +19,13 @@ def _process_doc(doc): out_doc = { "id": doc["id"], "query": "Question: " + preprocess(doc["instruction"]) + "\nAnswer:", - "choices": [preprocess(doc['option_a']), preprocess(doc['option_b']), preprocess(doc['option_c']), preprocess(doc['option_d']), preprocess(doc['option_e'])], + "choices": [ + preprocess(doc["option_a"]), + preprocess(doc["option_b"]), + preprocess(doc["option_c"]), + preprocess(doc["option_d"]), + preprocess(doc["option_e"]), + ], "gold": ["A", "B", "C", "D", "E"].index(doc["answer"]), } return out_doc diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/utils.py b/lm_eval/tasks/okapi/hellaswag_multilingual/utils.py index 62c0c23bcd..b526a9e930 100644 --- a/lm_eval/tasks/okapi/hellaswag_multilingual/utils.py +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/utils.py @@ -1,6 +1,7 @@ -import datasets import re +import datasets + def preprocess(text): text = text.strip() diff --git a/lm_eval/tasks/okapi/mmlu_multilingual/_generate_configs.py b/lm_eval/tasks/okapi/mmlu_multilingual/_generate_configs.py index 241779148c..1114255fa7 100644 --- a/lm_eval/tasks/okapi/mmlu_multilingual/_generate_configs.py +++ b/lm_eval/tasks/okapi/mmlu_multilingual/_generate_configs.py @@ -1,6 +1,5 @@ -import yaml import datasets - +import yaml from tqdm import tqdm @@ -10,8 +9,12 @@ def main() -> None: # Removed hy and sk subdataset because the original dataset is broken # I created this PR https://huggingface.co/datasets/alexandrainst/m_mmlu/discussions/3 # on the dataset for the authors, in case it will be accepeted the filter can be removed - keys_without_hy_sk = list(filter(lambda k: ('hy' not in k and 'sk' not in k), - datasets.get_dataset_infos(dataset_path).keys())) + keys_without_hy_sk = list( # noqa: F841 + filter( + lambda k: ("hy" not in k and "sk" not in k), + datasets.get_dataset_infos(dataset_path).keys(), + ) + ) for task in tqdm(): file_name = f"m_mmlu_{task}.yaml" @@ -29,5 +32,6 @@ def main() -> None: except FileExistsError: pass + if __name__ == "__main__": main() diff --git a/lm_eval/tasks/okapi/truthfulqa_multilingual/utils.py b/lm_eval/tasks/okapi/truthfulqa_multilingual/utils.py index 4e24db9f51..2dd5433f09 100644 --- a/lm_eval/tasks/okapi/truthfulqa_multilingual/utils.py +++ b/lm_eval/tasks/okapi/truthfulqa_multilingual/utils.py @@ -1,7 +1,9 @@ -import datasets import re + +import datasets import numpy as np + QA_PROMPT = ( "Q: What is human life expectancy in the United States?\n" "A: Human life expectancy in the United States is 78 years.\n\n" @@ -17,6 +19,7 @@ "A: The 1992 Olympics were held in Barcelona, Spain." ) + def preprocess(text): if text is None: return " " diff --git a/lm_eval/tasks/qasper/metrics.py b/lm_eval/tasks/qasper/metrics.py index f4de05e0a4..cc83291225 100644 --- a/lm_eval/tasks/qasper/metrics.py +++ b/lm_eval/tasks/qasper/metrics.py @@ -1,6 +1,5 @@ import re import string - from collections import Counter diff --git a/lm_eval/tasks/qasper/utils.py b/lm_eval/tasks/qasper/utils.py index 21f5be2435..fb3d4c55cf 100644 --- a/lm_eval/tasks/qasper/utils.py +++ b/lm_eval/tasks/qasper/utils.py @@ -1,6 +1,7 @@ -from datasets import Dataset from functools import partial +from datasets import Dataset + def process_docs(dataset, set_answer_type="bool"): FEATURES = ["title", "abstract", "question", "answer", "answer_type"] diff --git a/lm_eval/tasks/realtoxicityprompts/metric.py b/lm_eval/tasks/realtoxicityprompts/metric.py index 072f561d74..b92f837608 100644 --- a/lm_eval/tasks/realtoxicityprompts/metric.py +++ b/lm_eval/tasks/realtoxicityprompts/metric.py @@ -1,7 +1,8 @@ -import os import json -import requests +import os + import numpy as np +import requests from lm_eval.utils import eval_logger diff --git a/lm_eval/tasks/scrolls/task.py b/lm_eval/tasks/scrolls/task.py index e403fd5e2d..5b604e15d9 100644 --- a/lm_eval/tasks/scrolls/task.py +++ b/lm_eval/tasks/scrolls/task.py @@ -1,16 +1,16 @@ import re +from abc import abstractmethod +from functools import reduce + import numpy as np import transformers.data.metrics.squad_metrics as squad_metrics - -from abc import abstractmethod from datasets import load_metric from transformers import AutoTokenizer -from functools import reduce -from lm_eval.api.task import Task -from lm_eval.api.metrics import mean from lm_eval.api.instance import Instance -from lm_eval.api.registry import register_task +from lm_eval.api.metrics import mean +from lm_eval.api.task import Task + _CITATION = """ @inproceedings{shaham-etal-2022-scrolls, @@ -44,6 +44,7 @@ def _download_metric(): import os import shutil + from huggingface_hub import hf_hub_download scrolls_metric_path = hf_hub_download( @@ -148,7 +149,7 @@ def download(self, *args, **kwargs): del self.dataset["test"] for split in self.dataset: self.dataset[split] = _drop_duplicates_in_input(self.dataset[split]) - if self.PRUNE_TOKENIZERS is not None and self.PRUNE_TOKENIZERS is not None: + if self.PRUNE_TOKENIZERS is not None: self.prune() def _get_prune_text(self, sample): diff --git a/lm_eval/tasks/squadv2/task.py b/lm_eval/tasks/squadv2/task.py index 8af87e7537..ef6be3e1fe 100644 --- a/lm_eval/tasks/squadv2/task.py +++ b/lm_eval/tasks/squadv2/task.py @@ -13,14 +13,15 @@ Homepage: https://rajpurkar.github.io/SQuAD-explorer/ """ -import datasets - -from math import exp from functools import partial +from math import exp + +import datasets from packaging import version -from lm_eval.api.task import ConfigurableTask from lm_eval.api.instance import Instance +from lm_eval.api.task import ConfigurableTask + _CITATION = """ @misc{rajpurkar2018know, @@ -35,7 +36,6 @@ def _squad_metric(predictions, references): - # squad_metric = load("squad_v2") squad_metric = datasets.load_metric("squad_v2") return squad_metric.compute(predictions=predictions, references=references) @@ -52,7 +52,7 @@ class SQuAD2(ConfigurableTask): DATASET_NAME = None def __init__(self): - super().__init__(config={'metadata': {'version': self.VERSION}}) + super().__init__(config={"metadata": {"version": self.VERSION}}) # HF changed squad on us so we have to make sure we aren't running the old one assert version.parse(datasets.__version__) >= version.parse( diff --git a/lm_eval/tasks/super_glue/cb/aggregate.py b/lm_eval/tasks/super_glue/cb/aggregate.py index ef095dfc68..4b99849f9b 100644 --- a/lm_eval/tasks/super_glue/cb/aggregate.py +++ b/lm_eval/tasks/super_glue/cb/aggregate.py @@ -1,5 +1,5 @@ -import sklearn import numpy as np +import sklearn def cb_multi_fi(items): diff --git a/lm_eval/tasks/super_glue/record/t5_utils.py b/lm_eval/tasks/super_glue/record/t5_utils.py index 68301b18b3..e1a29a9498 100644 --- a/lm_eval/tasks/super_glue/record/t5_utils.py +++ b/lm_eval/tasks/super_glue/record/t5_utils.py @@ -1,8 +1,8 @@ +import collections import re import string -import collections -import numpy as np +import numpy as np from datasets import Dataset from lm_eval.api.metrics import metric_max_over_ground_truths diff --git a/lm_eval/tasks/super_glue/wsc/t5_utils.py b/lm_eval/tasks/super_glue/wsc/t5_utils.py index 6570abc732..2860a2a903 100644 --- a/lm_eval/tasks/super_glue/wsc/t5_utils.py +++ b/lm_eval/tasks/super_glue/wsc/t5_utils.py @@ -1,6 +1,7 @@ import re from typing import List + def doc_to_text(x): text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x)) return "wsc: " + text @@ -23,14 +24,14 @@ def create_input(): [ " ".join(words[:pronoun_index]), "X", - " ".join(words[pronoun_index + 1:]), + " ".join(words[pronoun_index + 1 :]), ] ) # Handle some special cases. if ( - x["text"] - == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. ' + x["text"] + == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. ' ): return ( "The boy continued to whip the pony , and eventually the pony threw " @@ -39,8 +40,8 @@ def create_input(): # Using the span2_index, we get 'use' instead of 'it'. if ( - x["text"] - == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?" + x["text"] + == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?" ): return ( "When they had eventually calmed down a bit , and had gotten home, " diff --git a/lm_eval/tasks/truthfulqa/utils.py b/lm_eval/tasks/truthfulqa/utils.py index 8e2ab43fe8..399969ca5c 100644 --- a/lm_eval/tasks/truthfulqa/utils.py +++ b/lm_eval/tasks/truthfulqa/utils.py @@ -1,7 +1,6 @@ import datasets -import sacrebleu import numpy as np - +import sacrebleu from rouge_score import rouge_scorer, scoring diff --git a/lm_eval/tasks/xwinograd/utils.py b/lm_eval/tasks/xwinograd/utils.py index 97c93c7072..5e350d6e9f 100644 --- a/lm_eval/tasks/xwinograd/utils.py +++ b/lm_eval/tasks/xwinograd/utils.py @@ -51,7 +51,9 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: for lang in LANGUAGES: file_name = f"xwinograd_{lang}.yaml" try: - with open(f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf-8") as f: + with open( + f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf-8" + ) as f: f.write("# Generated by utils.py\n") yaml.dump( { diff --git a/pyproject.toml b/pyproject.toml index 63fd49be67..0a6db2161a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,9 +90,6 @@ all = [ "lm_eval[wandb]", ] -[tool.ruff] -extend-exclude = ["lm_eval/tasks/*.py"] - [tool.ruff.lint] extend-select = ["I"] @@ -101,5 +98,4 @@ lines-after-imports = 2 known-first-party = ["lm_eval"] [tool.ruff.extend-per-file-ignores] -"__init__.py" = ["F401","F402","F403","I"] -"lm_eval/tasks/*"= ["E721"] +"__init__.py" = ["F401","F402","F403"]