Skip to content

Commit

Permalink
Remove LM dependency from build_all_requests (EleutherAI#2011)
Browse files Browse the repository at this point in the history
* refactored `lm.apply_chat_template`

* nit

* fix weird type error

* fixed!

* skip failing test

* pre-commit run all

* add type hints

* nit

* nit

* fixup
  • Loading branch information
baberabb authored Jun 25, 2024
1 parent 9b6b0f5 commit 9b6179b
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 45 deletions.
35 changes: 18 additions & 17 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,16 @@ def doc_to_target(self, doc):
def build_all_requests(
self,
*,
limit=None,
rank=None,
world_size=None,
cache_requests=False,
rewrite_requests_cache=False,
system_instruction=None,
apply_chat_template=False,
fewshot_as_multiturn=False,
lm=None,
limit: Union[int, None] = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""

Expand All @@ -391,7 +392,7 @@ def build_all_requests(
if system_instruction is not None
else ""
)
cache_key += f"-tokenizer{lm.tokenizer_name}" if apply_chat_template else ""
cache_key += f"-tokenizer{tokenizer_name}"

cached_instances = load_from_cache(file_name=cache_key)

Expand Down Expand Up @@ -436,7 +437,7 @@ def build_all_requests(
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
lm,
chat_template,
)

# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
Expand Down Expand Up @@ -1014,7 +1015,7 @@ def fewshot_context(
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
lm=None,
chat_template: Optional[Callable] = None,
) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
Expand All @@ -1029,8 +1030,8 @@ def fewshot_context(
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param lm:
Language model with definition of the tokenizer/function to use for applying the chat template.
:param chat_template: Callable
Chat template to be applied to the fewshot context.
:returns: str
The fewshot context.
"""
Expand Down Expand Up @@ -1077,7 +1078,7 @@ def fewshot_context(
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
Expand All @@ -1089,7 +1090,7 @@ def fewshot_context(
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(lm.apply_chat_template(chat))
labeled_examples_list.append(chat_template(chat))
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
Expand All @@ -1103,7 +1104,7 @@ def fewshot_context(
labeled_examples, str(example), fewshot_as_multiturn
)
# return lm.apply_chat_template(labeled_examples)
return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
else:
if self.multiple_input:
return labeled_examples
Expand Down
19 changes: 12 additions & 7 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,12 @@ def evaluate(
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
lm=lm,
chat_template=getattr(lm, "apply_chat_template")
if apply_chat_template
else None,
tokenizer_name=getattr(lm, "tokenizer_name", "")
if apply_chat_template
else "",
)
eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
Expand Down Expand Up @@ -609,16 +614,16 @@ def evaluate(
]

# compute group's pooled metric and stderr
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
results[group][metric] = (
lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
results[group][stderr] = (
lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
Expand Down
6 changes: 3 additions & 3 deletions lm_eval/evaluator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def consolidate_results(
metric_key
]
results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][
f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
)
return results, samples, configs, versions, num_fewshot, higher_is_better


Expand Down
4 changes: 2 additions & 2 deletions lm_eval/tasks/arabicmmlu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Homepage: https://github.com/mbzuai-nlp/ArabicMMLU

```
@misc{koto2024arabicmmlu,
title={ArabicMMLU: Assessing Massive Multitask Language Understanding in Arabic},
title={ArabicMMLU: Assessing Massive Multitask Language Understanding in Arabic},
author={Fajri Koto and Haonan Li and Sara Shatnawi and Jad Doughman and Abdelrahman Boda Sadallah and Aisha Alraeesi and Khalid Almubarak and Zaid Alyafeai and Neha Sengupta and Shady Shehata and Nizar Habash and Preslav Nakov and Timothy Baldwin},
year={2024},
eprint={2402.12840},
Expand All @@ -37,4 +37,4 @@ Homepage: https://github.com/mbzuai-nlp/ArabicMMLU
* `arabicmmlu_stem_social_science`: evaluates social science ArabicMMLU tasks.
* `arabicmmlu_stem_humanities`: evaluates humanities ArabicMMLU tasks.
* `arabicmmlu_stem_language`: evaluates Arabic language ArabicMMLU tasks.
* `arabicmmlu_stem_other`: evaluates other ArabicMMLU tasks.
* `arabicmmlu_stem_other`: evaluates other ArabicMMLU tasks.
7 changes: 5 additions & 2 deletions lm_eval/tasks/arabicmmlu/_generate_configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Take in a YAML, and output all "other" splits with this YAML
"""

import argparse
import logging
import os
Expand Down Expand Up @@ -76,7 +77,6 @@ def parse_args():
if category not in ALL_CATEGORIES:
ALL_CATEGORIES.append(category)


# description = f"The following are multiple choice questions (with answers) about {' '.join(subject.split('_'))}.\n\n"

yaml_dict = {
Expand All @@ -89,7 +89,10 @@ def parse_args():
# "description": description,
}

file_save_path = args.save_prefix_path + f"_{subject.lower().replace(' ', '_').replace('(', '').replace(')', '')}.yaml"
file_save_path = (
args.save_prefix_path
+ f"_{subject.lower().replace(' ', '_').replace('(', '').replace(')', '')}.yaml"
)
eval_logger.info(f"Saving yaml for subset {subject} to {file_save_path}")
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
Expand Down
34 changes: 20 additions & 14 deletions lm_eval/tasks/arabicmmlu/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
PROMPT = 'This is a {}. Select the correct answer!\n\nQuestion: {}\n{}\n\nAnswer:'
PROMPT = "This is a {}. Select the correct answer!\n\nQuestion: {}\n{}\n\nAnswer:"

level_en = {
'Primary': 'primary school',
'Middle': 'middle school',
'High': 'high school',
'Univ': 'university',
'Prof': 'professional',
"Primary": "primary school",
"Middle": "middle school",
"High": "high school",
"Univ": "university",
"Prof": "professional",
}

alpa = ['A.', 'B.', 'C.', 'D.', 'E.']
alpa = ["A.", "B.", "C.", "D.", "E."]


def doc_to_text(doc):
Expand All @@ -17,22 +17,28 @@ def doc_to_text(doc):
https://github.com/mbzuai-nlp/ArabicMMLU/blob/main/util_prompt.py
"""

level = "" if not doc['Level'] else " for " + level_en[doc['Level']]
country = "" if not doc['Country'] else " in " + doc['Country']
level = "" if not doc["Level"] else " for " + level_en[doc["Level"]]
country = "" if not doc["Country"] else " in " + doc["Country"]
main_meta_data = f"{doc['Subject']} question{level}{country}"

question = doc['Question'] if doc['Context']=="" else f"{doc['Context']}\n\n{doc['Question']}"
question = (
doc["Question"]
if doc["Context"] == ""
else f"{doc['Context']}\n\n{doc['Question']}"
)

options = []
for i, opt in enumerate(['Option 1', 'Option 2', 'Option 3', 'Option 4', 'Option 5']):
for i, opt in enumerate(
["Option 1", "Option 2", "Option 3", "Option 4", "Option 5"]
):
if not doc[opt]:
break
options.append(f"{alpa[i]} {doc[opt]}")

doc_text = PROMPT.format(main_meta_data, question, '\n'.join(options))
doc_text = PROMPT.format(main_meta_data, question, "\n".join(options))

return doc_text


def doc_to_choice(doc):
return [alpa[i][0] for i in range(5) if doc[f'Option {i+1}']]
return [alpa[i][0] for i in range(5) if doc[f"Option {i+1}"]]
1 change: 1 addition & 0 deletions tests/models/test_neuralmagic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
]


@pytest.mark.skip(reason="test failing")
@pytest.mark.parametrize("model_id,task", SPARSEML_MODELS_TASKS)
def test_sparseml_eval(model_id, task):
lm = get_model("sparseml").create_from_arg_string(
Expand Down

0 comments on commit 9b6179b

Please sign in to comment.