Skip to content

Commit

Permalink
fix averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 21, 2025
1 parent db0b1cd commit 3abf313
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
eval_harness:
# task_spec: ["hellaswag"]
task_spec:
- mmlu
# - mmlu
- task: mmlu
num_fewshot: 1
task_alias: mmlu_1shot
Expand Down
31 changes: 15 additions & 16 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,6 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
result_greedy = np.zeros(len(requests))
covered_points = np.zeros(len(requests), dtype=bool)

return [ (0.0, False) ] * len(requests)

total_padding = 0
total_tokens = 0
pbar = tqdm(total=len(requests), desc="Loglikelihood", unit="req")
Expand Down Expand Up @@ -294,11 +292,11 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
result_greedy[out_ids[valid_indices]] = out_correct[valid_indices]
covered_points[out_ids[valid_indices]] = True

total_padding += padding_count
total_tokens += batch_tokens

pbar.set_postfix(
padding=(
f"{total_padding + padding_count}/{total_tokens + batch_tokens} ="
f" {(total_padding + padding_count) / (total_tokens + batch_tokens):.2f}"
),
padding=f"{total_padding}/{total_tokens} = {(total_padding) / (total_tokens):.2f}",
this_padding=f"{padding_count}/{batch_tokens}= {padding_count / batch_tokens:.2f}",
)
pbar.update(len(segments_this_batch))
Expand Down Expand Up @@ -455,6 +453,7 @@ def _get_task_and_rename(self, manager, our_name, task: dict | str):

def _rename_tasks_for_eval_harness(self, this_task, lm_eval_task_name, our_name):
import lm_eval.tasks as tasks

# hacky, but this allows us to run multiple instances of the same task with different fewshot settings
if isinstance(this_task, dict):
out = {}
Expand All @@ -467,13 +466,11 @@ def _rename_tasks_for_eval_harness(self, this_task, lm_eval_task_name, our_name)
elif isinstance(k, str):
k = self._replace_name_with_our_name(k, lm_eval_task_name, our_name)
if isinstance(v, dict):
subtask_list = self._get_child_tasks(v)
# ok so inexplicably, lm_eval_harness doesn't wrap the key in a ConfigurableGroup when you pass
# in a task dict (it seems like a mistake), so we need to do that here
# subtask is the name of all of the child tasks in v
subtask_list = self._get_child_tasks(v)
group = tasks.ConfigurableGroup(
config={"group": k, "task": subtask_list}
)
group = tasks.ConfigurableGroup(config={"group": k, "task": subtask_list})
out[group] = v
else:
out[k] = v
Expand All @@ -483,7 +480,9 @@ def _rename_tasks_for_eval_harness(self, this_task, lm_eval_task_name, our_name)
return out

elif isinstance(this_task, tasks.ConfigurableTask):
this_task.config.task = self._replace_name_with_our_name(this_task.config.task, lm_eval_task_name, our_name)
this_task.config.task = self._replace_name_with_our_name(
this_task.config.task, lm_eval_task_name, our_name
)
return this_task
else:
raise ValueError(f"Unknown task type: {this_task}")
Expand All @@ -503,6 +502,7 @@ def _replace_name_with_our_name(self, lm_eval_name, lm_eval_prefix, our_name_pre

def _get_child_tasks(self, task_group):
import lm_eval.tasks as tasks

out = []
for k, v in task_group.items():
if isinstance(k, tasks.ConfigurableGroup):
Expand Down Expand Up @@ -642,15 +642,14 @@ def _compute_averages(outputs):
for task_results in outputs["results"].values():
metric_keys.update(k for k in task_results.keys() if "stderr" not in k and k != "alias")

examples_per_task = [task_samples["effective"] for task_samples in outputs["n-samples"].values()]

# Compute macro and micro averages
for metric in metric_keys:
# Collect valid tasks for this metric
# We iterate over the n-samples because real tasks (as opposed to aggregates like "mmlu") have counts
valid_tasks = [
(task_results.get(metric), examples_per_task[i])
for i, (task_name, task_results) in enumerate(outputs["results"].items())
if metric in task_results
(outputs["results"][task_name].get(metric), outputs["n-samples"][task_name]["effective"])
for task_name in outputs["n-samples"]
if outputs["results"][task_name].get(metric, None) is not None
]

if not valid_tasks:
Expand Down

0 comments on commit 3abf313

Please sign in to comment.