Skip to content

Commit

Permalink
Tiny bug : oml_id must be an int in load_openml_task (#686)
Browse files Browse the repository at this point in the history
* oml_id must be an int

* Catch error earlier to report better error messages

* Also give early warning for if task is defined as str in file

---------

Co-authored-by: PGijsbers <[email protected]>
  • Loading branch information
SubhadityaMukherjee and PGijsbers authored Jan 2, 2025
1 parent b719142 commit 2c2a93d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
5 changes: 5 additions & 0 deletions amlb/benchmarks/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ def load_file_benchmark(
log.info("Loading benchmark definitions from %s.", benchmark_file)
tasks = config_load(benchmark_file)
benchmark_name, _ = os.path.splitext(os.path.basename(benchmark_file))
for task in tasks:
if task["openml_task_id"] is not None and not isinstance(
task["openml_task_id"], int
):
raise TypeError("OpenML task id for task {task.name!r} must be integer.")
return benchmark_name, benchmark_file, tasks
33 changes: 22 additions & 11 deletions amlb/benchmarks/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import openml
import pandas as pd
from openml import OpenMLTask, OpenMLDataset

from amlb.utils import Namespace, str_sanitize

Expand All @@ -20,7 +21,13 @@ def is_openml_benchmark(benchmark: str) -> bool:

def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]]:
"""Loads benchmark defined by openml suite or task, from openml/s/X or openml/t/Y."""
domain, oml_type, oml_id = benchmark.split("/")
domain, oml_type, oml_id_str = benchmark.split("/")
try:
oml_id = int(oml_id_str)
except ValueError:
raise ValueError(
f"Could not convert OpenML id {oml_id_str!r} in {benchmark!r} to integer."
)

if domain == "test.openml":
log.debug("Setting openml server to the test server.")
Expand All @@ -34,7 +41,7 @@ def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]
openml.config.set_retry_policy("robot")

if oml_type == "t":
tasks = load_openml_task(domain, oml_id)
tasks = load_openml_task_as_definition(domain, oml_id)
elif oml_type == "s":
tasks = load_openml_tasks_from_suite(domain, oml_id)
else:
Expand All @@ -44,7 +51,7 @@ def load_oml_benchmark(benchmark: str) -> tuple[str, str | None, list[Namespace]
return benchmark, None, tasks


def load_openml_tasks_from_suite(domain: str, oml_id: str) -> list[Namespace]:
def load_openml_tasks_from_suite(domain: str, oml_id: int) -> list[Namespace]:
log.info("Loading openml suite %s.", oml_id)
suite = openml.study.get_suite(oml_id)
# Here we know the (task, dataset) pairs so only download dataset meta-data is sufficient
Expand All @@ -66,18 +73,22 @@ def load_openml_tasks_from_suite(domain: str, oml_id: str) -> list[Namespace]:
return tasks


def load_openml_task(domain: str, oml_id: str) -> list[Namespace]:
def load_openml_task_as_definition(domain: str, oml_id: int) -> list[Namespace]:
log.info("Loading openml task %s.", oml_id)
# We first have the retrieve the task because we don't know the dataset id
t = openml.tasks.get_task(oml_id, download_data=False, download_qualities=False)
data = openml.datasets.get_dataset(
t.dataset_id, download_data=False, download_qualities=False
)
task, data = load_openml_task_and_data(oml_id)
return [
Namespace(
name=str_sanitize(data.name),
description=data.description,
openml_task_id=t.id,
id="{}.org/t/{}".format(domain, t.id),
openml_task_id=task.id,
id="{}.org/t/{}".format(domain, task.id),
)
]


def load_openml_task_and_data(task_id: int) -> tuple[OpenMLTask, OpenMLDataset]:
task = openml.tasks.get_task(task_id, download_data=False, download_qualities=False)
data = openml.datasets.get_dataset(
task.dataset_id, download_data=False, download_qualities=False
)
return task, data
6 changes: 2 additions & 4 deletions amlb/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import openml as oml
import xmltodict

from ..benchmarks.openml import load_openml_task_and_data
from ..data import AM, DF, Dataset, DatasetType, Datasplit, Feature
from ..datautils import impute_array
from ..resources import config as rconfig, get as rget
Expand Down Expand Up @@ -71,10 +72,7 @@ def load(self, task_id=None, dataset_id=None, fold=0):
dataset_id, task_id
)
)
task = oml.tasks.get_task(task_id, download_qualities=False)
dataset = oml.datasets.get_dataset(
task.dataset_id, download_qualities=False
)
task, dataset = load_openml_task_and_data(task_id)
_, nfolds, _ = task.get_split_dimensions()
if fold >= nfolds:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/amlb/benchmarks/test_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from amlb.benchmarks.openml import (
is_openml_benchmark,
load_openml_task,
load_openml_task_as_definition,
load_oml_benchmark,
)
from amlb.utils import Namespace
Expand Down Expand Up @@ -35,7 +35,7 @@ def test_load_openml_task(mocker, oml_task, oml_dataset):
mocker.patch(
"openml.datasets.get_dataset", new=mocker.Mock(return_value=oml_dataset)
)
[task] = load_openml_task("openml", oml_task.id)
[task] = load_openml_task_as_definition("openml", oml_task.id)
assert task.name == oml_dataset.name
assert task.description == oml_dataset.description
assert task.openml_task_id == oml_task.id
Expand Down

0 comments on commit 2c2a93d

Please sign in to comment.