From fb74bde226434b66cc79fae213ca43d0e0c5ec37 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Sun, 1 Dec 2024 14:36:12 +0200
Subject: [PATCH] Start adding more tests for benchmark task
---
.../resources/test_benchmark_definition.py | 24 +++++++++++++++----
1 file changed, 20 insertions(+), 4 deletions(-)
diff --git a/tests/unit/amlb/resources/test_benchmark_definition.py b/tests/unit/amlb/resources/test_benchmark_definition.py
index c61d01e25..0cf913169 100644
--- a/tests/unit/amlb/resources/test_benchmark_definition.py
+++ b/tests/unit/amlb/resources/test_benchmark_definition.py
@@ -112,17 +112,33 @@ def test_add_task_defaults_looks_up_instance_type(amlb_dummy_configuration: Name
), "Should not overwrite explicit configuration"
-def test_benchmark_task(load_default_resources: Resources):
+def create_benchmark_task(resources: Resources, task: Namespace):
benchmark = Benchmark(
framework_name="constantpredictor",
benchmark_name="test",
constraint_name="test",
job_history=None,
)
- task = Namespace(name="foo")
- Resources._add_task_defaults(task, load_default_resources.config)
- benchmark_task = BenchmarkTask(
+ Resources._add_task_defaults(task, resources.config)
+ return BenchmarkTask(
benchmark=benchmark,
task_def=task,
fold=0,
)
+
+
+def test_benchmark_task_load_data_raises_if_no_dataset(load_default_resources):
+ task = Namespace(name="foo")
+ benchmark_task = create_benchmark_task(load_default_resources, task)
+
+ with pytest.raises(ValueError) as excinfo:
+ benchmark_task.load_data()
+ assert "should have one property" in excinfo.value.args[0]
+
+
+def test_benchmark_task_load_data(load_default_resources, mocker):
+ task = Namespace(name="foo", openml_task_id=42)
+ benchmark_task = create_benchmark_task(load_default_resources, task)
+
+ mocker.patch("amlb.benchmark.Benchmark.data_loader.load", return_value={})
+ benchmark_task.load_data()