Skip to content

Commit

Permalink
Apply code autoformatting with Ruff to tasks/*.py an *__init__.py (El…
Browse files Browse the repository at this point in the history
  • Loading branch information
LSinev authored Feb 26, 2024
1 parent f78e2da commit d27c0c0
Show file tree
Hide file tree
Showing 48 changed files with 404 additions and 233 deletions.
7 changes: 3 additions & 4 deletions lm_eval/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List, Union
from functools import partial
from typing import List, Union

from lm_eval.api.filter import FilterEnsemble
from . import selection
from . import extraction
from . import transformation

from . import extraction, selection, transformation


FILTER_REGISTRY = {
Expand Down
24 changes: 14 additions & 10 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from . import huggingface
from . import openai_completions
from . import textsynth
from . import dummy
from . import anthropic_llms
from . import gguf
from . import vllm_causallms
from . import mamba_lm
from . import optimum_lm
from . import neuron_optimum
from . import (
anthropic_llms,
dummy,
gguf,
huggingface,
mamba_lm,
neuron_optimum,
openai_completions,
optimum_lm,
textsynth,
vllm_causallms,
)


# TODO: implement __all__


Expand Down
5 changes: 3 additions & 2 deletions lm_eval/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import ast

import os
from typing import Dict

from lm_eval import utils
from lm_eval.utils import eval_logger


# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
Expand Down
132 changes: 75 additions & 57 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
import os
import abc
import collections

import logging
import os
from functools import partial
from typing import List, Union, Dict
from typing import Dict, List, Union

from lm_eval import utils
from lm_eval.api.task import Task, ConfigurableTask

import logging
from lm_eval.api.task import ConfigurableTask, Task


class TaskManager:
"""TaskManager indexes all tasks from the default `lm_eval/tasks/`
and an optional directory if provided.
"""
def __init__(
self,
verbosity="INFO",
include_path=None
) -> None:

def __init__(self, verbosity="INFO", include_path=None) -> None:
self.verbosity = verbosity
self.include_path = include_path
self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}"))

self._task_index = self.initialize_tasks(
include_path=include_path
)
self._task_index = self.initialize_tasks(include_path=include_path)
self._all_tasks = sorted(list(self._task_index.keys()))

self.task_group_map = collections.defaultdict(list)
Expand Down Expand Up @@ -65,27 +57,29 @@ def task_index(self):
return self._task_index

def match_tasks(self, task_list):
return utils.pattern_match(
task_list, self.all_tasks
)
return utils.pattern_match(task_list, self.all_tasks)

def _name_is_registered(self, name):
if name in self.all_tasks:
return True
return False

def _name_is_task(self, name):
def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
return True
return False

def _name_is_group(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "group"):
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True
return False

def _name_is_python_task(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "python_task"):
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True
return False

Expand Down Expand Up @@ -117,7 +111,7 @@ def _get_config(self, name):
return utils.load_yaml_config(yaml_path, mode="full")

def _get_tasklist(self, name):
assert self._name_is_task(name) == False
assert self._name_is_task(name) is False
return self.task_index[name]["task"]

def _process_alias(self, config, group=None):
Expand All @@ -130,12 +124,12 @@ def _process_alias(self, config, group=None):
return config

def _load_individual_task_or_group(
self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None,
yaml_path: str = None,
) -> ConfigurableTask:
self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None,
yaml_path: str = None,
) -> ConfigurableTask:
def load_task(config, task, group=None, yaml_path=None):
if "include" in config:
assert yaml_path is not None
Expand Down Expand Up @@ -174,7 +168,9 @@ def load_task(config, task, group=None, yaml_path=None):
group_config = self._get_config(name_or_config)
if set(group_config.keys()) > set(["task", "group"]):
update_config = {
k:v for k,v in group_config.items() if k not in ["task", "group"]
k: v
for k, v in group_config.items()
if k not in ["task", "group"]
}
yaml_path = self._get_yaml_path(group_name)

Expand All @@ -183,9 +179,8 @@ def load_task(config, task, group=None, yaml_path=None):
update_config.pop("group_alias")

if isinstance(name_or_config, dict):

if update_config is not None:
name_or_config={
name_or_config = {
**name_or_config,
**update_config,
}
Expand All @@ -196,7 +191,9 @@ def load_task(config, task, group=None, yaml_path=None):
# if self._name_is_task(name) is False:
if self._name_is_group(name):
group_name = name
update_config = {k:v for k,v in name_or_config.items() if k != "task"}
update_config = {
k: v for k, v in name_or_config.items() if k != "task"
}
subtask_list = self._get_tasklist(name)
if subtask_list == -1:
subtask_list = self._get_config(name)["task"]
Expand All @@ -207,36 +204,53 @@ def load_task(config, task, group=None, yaml_path=None):
# Check if this is a duplicate.
if parent_name is not None:
name_or_config["group"] = parent_name
num_duplicate = len(list(filter(lambda x: x.startswith(name), self.task_group_map[parent_name])))
num_duplicate = len(
list(
filter(
lambda x: x.startswith(name),
self.task_group_map[parent_name],
)
)
)
if num_duplicate > 0:
name = f"{name}-{num_duplicate}"
self.task_group_map[parent_name].append(name)

task_config={
**base_task_config,
**name_or_config,
}
task_config = {
**base_task_config,
**name_or_config,
}
else:
task_config = name_or_config
return load_task(task_config, task=name, group=parent_name, yaml_path=yaml_path)
return load_task(
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else:
group_name = name_or_config["group"]
subtask_list = name_or_config["task"]
# update_config = {k:v for k,v in name_or_config.items() if k != "task"}
if set(name_or_config.keys()) > set(["task", "group"]):
update_config = {
k:v for k,v in name_or_config.items() if k not in ["task", "group"]
k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
}

all_subtasks = {}
if (parent_name is not None):
if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)}

fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config, yaml_path=yaml_path)
all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))}
fn = partial(
self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
yaml_path=yaml_path,
)
all_subtasks = {
**all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
}
return all_subtasks


def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
"""Loads a dictionary of task objects from a list
Expand All @@ -250,12 +264,7 @@ def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
task_list = [task_list]

all_loaded_tasks = dict(
collections.ChainMap(
*map(
self._load_individual_task_or_group,
task_list
)
)
collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
)
return all_loaded_tasks

Expand Down Expand Up @@ -299,11 +308,11 @@ def _get_task_and_group(self, task_dir: str):
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}

Expand All @@ -322,7 +331,7 @@ def _get_task_and_group(self, task_dir: str):
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
}

if "group" in config:
groups = config["group"]
Expand All @@ -343,6 +352,7 @@ def _get_task_and_group(self, task_dir: str):

return tasks_and_groups


def include_path(task_dir):
logger = utils.eval_logger
logger.setLevel(getattr(logging, "INFO"))
Expand All @@ -352,6 +362,7 @@ def include_path(task_dir):
)
return 0


def initialize_tasks(verbosity="INFO"):
logger = utils.eval_logger
logger.setLevel(getattr(logging, f"{verbosity}"))
Expand All @@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"):
)
return 0


def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "task" in task_config:
return task_config["task"]
Expand All @@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
else:
return "{dataset_path}".format(**task_config)


def get_task_name_from_object(task_object):
if hasattr(task_object, "config"):
return task_object._config["task"]
Expand All @@ -382,7 +395,10 @@ def get_task_name_from_object(task_object):
else type(task_object).__name__
)

def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None):

def get_task_dict(
task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
Expand All @@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
if task_manager is None:
task_manager = TaskManager()

task_name_from_string_dict = task_manager.load_task_or_group(string_task_name_list)
task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list
)

for task_element in others_task_name_list:
if isinstance(task_element, dict):
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/tasks/bbh/_generate_configs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
Take in a YAML, and output all other splits with this YAML
"""
import argparse
import os
import re
import yaml
import requests
import argparse

import datasets
import requests
import yaml
from tqdm import tqdm

from lm_eval import utils
Expand Down
Loading

0 comments on commit d27c0c0

Please sign in to comment.