Skip to content

Commit

Permalink
Fix SuperGlue's ReCoRD task following regression in v0.4 refactoring (E…
Browse files Browse the repository at this point in the history
  • Loading branch information
orsharir authored Mar 28, 2024
1 parent 0dffdbb commit ab7cc6b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
7 changes: 4 additions & 3 deletions lm_eval/tasks/super_glue/record/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: !function util.doc_to_text
doc_to_target: "{{answers}}"
doc_to_choice: "{{entities}}"
doc_to_target: !function util.doc_to_target
doc_to_choice: !function util.doc_to_choice
process_docs: !function util.process_docs
process_results: !function util.process_results
metric_list:
- metric: f1
Expand All @@ -17,4 +18,4 @@ metric_list:
higher_is_better: True
aggregation: mean
metadata:
version: 1.0
version: 2.0
16 changes: 16 additions & 0 deletions lm_eval/tasks/super_glue/record/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datasets
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics

Expand All @@ -21,6 +22,21 @@ def doc_to_target(doc):
return format_answer(query=doc["query"], entity=doc["answers"][0])


def doc_to_choice(doc):
return [format_answer(query=doc["query"], entity=ans) for ans in doc["entities"]]


def process_docs(dataset: datasets.Dataset):
def _process_doc(doc):
return {
"passage": doc["passage"],
"query": doc["query"],
"entities": sorted(list(set(doc["entities"]))),
"answers": sorted(list(set(doc["answers"]))),
}
return dataset.map(_process_doc)


def process_results(doc, results):
# ReCoRD's evaluation is actually deceptively simple:
# - Pick the maximum likelihood prediction entity
Expand Down

0 comments on commit ab7cc6b

Please sign in to comment.