Skip to content

Commit

Permalink
Easier unitxt tasks loading and removal of unitxt library dependancy (E…
Browse files Browse the repository at this point in the history
…leutherAI#1933)

* Updated unitxt loading

Signed-off-by: Elron Bandel <[email protected]>

* Revert change to general Readme

Signed-off-by: Elron Bandel <[email protected]>

* Adjust fda,squadv2,squad_completion and swde to work accept config in the constructor

Signed-off-by: Elron Bandel <[email protected]>

* Fix scrolls

Signed-off-by: elronbandel <[email protected]>

* Update documentation

Signed-off-by: elronbandel <[email protected]>

* Enforce backward compatability

Signed-off-by: elronbandel <[email protected]>

* Format unitxt class

Signed-off-by: elronbandel <[email protected]>

---------

Signed-off-by: Elron Bandel <[email protected]>
Signed-off-by: elronbandel <[email protected]>
Co-authored-by: haileyschoelkopf <[email protected]>
  • Loading branch information
elronbandel and haileyschoelkopf authored Jul 8, 2024
1 parent cb43ad4 commit ad80f55
Show file tree
Hide file tree
Showing 36 changed files with 208 additions and 391 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"`
| sentencepiece | For using the sentencepiece tokenizer |
| sparseml | For using NM's SparseML models |
| testing | For running library test suite |
| unitxt | For IBM's unitxt dataset tasks |
| vllm | For loading models with vLLM |
| zeno | For visualizing results with Zeno |
|---------------|---------------------------------------|
Expand Down
11 changes: 11 additions & 0 deletions docs/new_task_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,17 @@ task:
...
```

You can also pass a custom argument to your class by accepting `config` in the custom class constructor.
Here's how to do it:

```yaml
task: 20_newsgroups
class: !function task.Unitxt
recipe: card=cards.20_newsgroups,template=templates.classification.multi_class.title
```

In this example, `recipe` is the custom argument for the `Unitxt` class.

## Beautifying Table Display

To avoid conflict, each task needs to be registered with a unique name. Because of this, slight variations of task are still counted as unique tasks and need to be named uniquely. This could be done by appending an additional naming that may refer to the variation such as in MMLU where the template used to evaluated for flan are differentiated from the default by the prefix `mmlu_flan_*`. Printing the full task names can easily clutter the results table at the end of the evaluation especially when you have a long list of tasks or are using a benchmark that comprises of many tasks. To make it more legible, you can use `task_alias` and `group_alias` to provide an alternative task name and group name that will be printed. For example in `mmlu_abstract_algebra.yaml` we set `task_alias` to `abstract_algebra`. In group configs, a `group_alias` for a group can also be set.
Expand Down
23 changes: 16 additions & 7 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import inspect
import logging
import os
from functools import partial
Expand Down Expand Up @@ -151,6 +152,14 @@ def _process_alias(self, config, group=None):
config["group_alias"] = None
return config

def _class_has_config_in_constructor(self, cls):
constructor = getattr(cls, "__init__", None)
return (
"config" in inspect.signature(constructor).parameters
if constructor
else False
)

def _load_individual_task_or_group(
self,
name_or_config: Optional[Union[str, dict]] = None,
Expand All @@ -168,13 +177,13 @@ def _load_task(config, task):
**config,
}
if self._config_is_python_task(config):
task_object = (
config["class"](config=config)
if issubclass(config["class"], ConfigurableTask)
else config["class"]()
)
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
if self._class_has_config_in_constructor(config["class"]):
task_object = config["class"](config=config)
else:
task_object = config["class"]()
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
else:
task_object = ConfigurableTask(config=config)

Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/scrolls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class _SCROLLSTask(ConfigurableTask):
PRUNE_MAX_TOKENS = None
PRUNE_NUM_PROC = None

def __init__(self):
def __init__(self, config=None):
super().__init__(config={"metadata": {"version": self.VERSION}})
if self.DATASET_NAME is not None:
self.metric = load_metric(_download_metric(), config_name=self.DATASET_NAME)
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/tasks/squadv2/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SQuAD2(ConfigurableTask):
DATASET_PATH = "squad_v2"
DATASET_NAME = None

def __init__(self):
def __init__(self, config=None):
super().__init__(config={"metadata": {"version": self.VERSION}})

# HF changed squad on us so we have to make sure we aren't running the old one
Expand Down
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/20_newsgroups.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: 20_newsgroups
dataset_name: card=cards.20_newsgroups,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.20_newsgroups,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/ag_news.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: ag_news
dataset_name: card=cards.ag_news,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.ag_news,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/argument_topic.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: argument_topic
dataset_name: card=cards.argument_topic,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.argument_topic,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/atis.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.span_labeling.extraction
task: atis
dataset_name: card=cards.atis,template=templates.span_labeling.extraction.title
include: unitxt
recipe: card=cards.atis,template=templates.span_labeling.extraction.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/banking77.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: banking77
dataset_name: card=cards.banking77,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.banking77,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/claim_stance_topic.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: claim_stance_topic
dataset_name: card=cards.claim_stance_topic,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.claim_stance_topic,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/cnn_dailymail.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.summarization.abstractive
task: cnn_dailymail
dataset_name: card=cards.cnn_dailymail,template=templates.summarization.abstractive.full
include: unitxt
recipe: card=cards.cnn_dailymail,template=templates.summarization.abstractive.full
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/coedit_gec.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.grammatical_error_correction
task: coedit_gec
dataset_name: card=cards.coedit_gec,template=templates.grammatical_error_correction.simple
include: unitxt
recipe: card=cards.coedit_gec,template=templates.grammatical_error_correction.simple
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/dbpedia_14.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: dbpedia_14
dataset_name: card=cards.dbpedia_14,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.dbpedia_14,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/ethos_binary.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: ethos_binary
dataset_name: card=cards.ethos_binary,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.ethos_binary,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/financial_tweets.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: financial_tweets
dataset_name: card=cards.financial_tweets,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.financial_tweets,template=templates.classification.multi_class.title
135 changes: 0 additions & 135 deletions lm_eval/tasks/unitxt/generate_yamls.py

This file was deleted.

4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/law_stack_exchange.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: law_stack_exchange
dataset_name: card=cards.law_stack_exchange,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.law_stack_exchange,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/ledgar.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: ledgar
dataset_name: card=cards.ledgar,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.ledgar,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/medical_abstracts.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.classification.multi_class
task: medical_abstracts
dataset_name: card=cards.medical_abstracts,template=templates.classification.multi_class.title
include: unitxt
recipe: card=cards.medical_abstracts,template=templates.classification.multi_class.title
4 changes: 2 additions & 2 deletions lm_eval/tasks/unitxt/stsb.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
include: unitxt_tasks.regression.two_texts
task: stsb
dataset_name: card=cards.stsb,template=templates.regression.two_texts.simple
include: unitxt
recipe: card=cards.stsb,template=templates.regression.two_texts.simple
Loading

0 comments on commit ad80f55

Please sign in to comment.