From fb74bde226434b66cc79fae213ca43d0e0c5ec37 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Sun, 1 Dec 2024 14:36:12 +0200 Subject: [PATCH] Start adding more tests for benchmark task --- .../resources/test_benchmark_definition.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/unit/amlb/resources/test_benchmark_definition.py b/tests/unit/amlb/resources/test_benchmark_definition.py index c61d01e25..0cf913169 100644 --- a/tests/unit/amlb/resources/test_benchmark_definition.py +++ b/tests/unit/amlb/resources/test_benchmark_definition.py @@ -112,17 +112,33 @@ def test_add_task_defaults_looks_up_instance_type(amlb_dummy_configuration: Name ), "Should not overwrite explicit configuration" -def test_benchmark_task(load_default_resources: Resources): +def create_benchmark_task(resources: Resources, task: Namespace): benchmark = Benchmark( framework_name="constantpredictor", benchmark_name="test", constraint_name="test", job_history=None, ) - task = Namespace(name="foo") - Resources._add_task_defaults(task, load_default_resources.config) - benchmark_task = BenchmarkTask( + Resources._add_task_defaults(task, resources.config) + return BenchmarkTask( benchmark=benchmark, task_def=task, fold=0, ) + + +def test_benchmark_task_load_data_raises_if_no_dataset(load_default_resources): + task = Namespace(name="foo") + benchmark_task = create_benchmark_task(load_default_resources, task) + + with pytest.raises(ValueError) as excinfo: + benchmark_task.load_data() + assert "should have one property" in excinfo.value.args[0] + + +def test_benchmark_task_load_data(load_default_resources, mocker): + task = Namespace(name="foo", openml_task_id=42) + benchmark_task = create_benchmark_task(load_default_resources, task) + + mocker.patch("amlb.benchmark.Benchmark.data_loader.load", return_value={}) + benchmark_task.load_data()