From 9aa7479112998356349870533f0a6ea001f63304 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Wed, 29 Nov 2023 12:11:30 +0100 Subject: [PATCH] Move get qualities query to db module, remove level of nestedness --- src/database/datasets.py | 25 ++++++++++- src/routers/openml/datasets.py | 25 +++++------ .../openml/datasets_list_datasets_test.py | 45 +++++++++---------- 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index 110fc70..ffabb7b 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,5 +1,6 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" -from typing import Any +from collections import defaultdict +from typing import Any, Iterable from schemas.datasets.openml import Quality from sqlalchemy import Connection, text @@ -21,6 +22,28 @@ def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Q return [Quality(name=row.quality, value=row.value) for row in rows] +def get_qualities_for_datasets( + dataset_ids: Iterable[int], + qualities: Iterable[str], + connection: Connection, +) -> dict[int, list[Quality]]: + qualities_filter = ",".join(f"'{q}'" for q in qualities) + dids = ",".join(str(did) for did in dataset_ids) + qualities_query = text( + f""" + SELECT `data`, `quality`, `value` + FROM data_quality + WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) + """, # nosec - similar to above, no user input + ) + rows = connection.execute(qualities_query) + qualities_by_id = defaultdict(list) + for did, quality, value in rows: + if value is not None: + qualities_by_id[did].append(Quality(name=quality, value=value)) + return dict(qualities_by_id) + + def list_all_qualities(connection: Connection) -> list[str]: # The current implementation only fetches *used* qualities, otherwise you should # query: SELECT `name` FROM `quality` WHERE `type`='DataQuality' diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index e435b67..08db9a1 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -6,7 +6,7 @@ import re from datetime import datetime from enum import StrEnum -from typing import Annotated, Any, Literal, NamedTuple +from typing import Annotated, Any, NamedTuple from core.access import _user_has_access from core.errors import DatasetError @@ -23,6 +23,7 @@ get_latest_dataset_description, get_latest_processing_update, get_latest_status_update, + get_qualities_for_datasets, get_tags, ) from database.datasets import tag_dataset as db_tag_dataset @@ -106,7 +107,7 @@ def list_datasets( status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, -) -> dict[Literal["data"], dict[Literal["dataset"], list[dict[str, Any]]]]: +) -> list[dict[str, Any]]: current_status = text( """ SELECT ds1.`did`, ds1.`status` @@ -235,20 +236,14 @@ def quality_clause(quality: str, range_: str | None) -> str: "NumberOfNumericFeatures", "NumberOfSymbolicFeatures", ] - qualities_filter = ",".join(f"'{q}'" for q in qualities_to_show) - dids = ",".join(str(did) for did in datasets) - qualities = text( - f""" - SELECT `data`, `quality`, `value` - FROM data_quality - WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) - """, # nosec - similar to above, no user input + qualities_by_dataset = get_qualities_for_datasets( + dataset_ids=datasets.keys(), + qualities=qualities_to_show, + connection=expdb_db, ) - qualities = expdb_db.execute(qualities) - for did, quality, value in qualities: - if value is not None: - datasets[did]["quality"].append({"name": quality, "value": str(value)}) - return {"data": {"dataset": list(datasets.values())}} + for did, qualities in qualities_by_dataset.items(): + datasets[did]["quality"] = qualities + return list(datasets.values()) class ProcessingInformation(NamedTuple): diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index 92f7055..c91960a 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -22,11 +22,7 @@ def _assert_empty_result( def test_list(py_api: TestClient) -> None: response = py_api.get("/datasets/list/") assert response.status_code == http.client.OK - assert "data" in response.json() - assert "dataset" in response.json()["data"] - - datasets = response.json()["data"]["dataset"] - assert len(datasets) >= 1 + assert len(response.json()) >= 1 @pytest.mark.parametrize( @@ -44,8 +40,7 @@ def test_list_filter_active(status: str, amount: int, py_api: TestClient) -> Non json={"status": status, "pagination": {"limit": constants.NUMBER_OF_DATASETS}}, ) assert response.status_code == http.client.OK, response.json() - datasets = response.json()["data"]["dataset"] - assert len(datasets) == amount + assert len(response.json()) == amount @pytest.mark.parametrize( @@ -64,8 +59,7 @@ def test_list_accounts_privacy(api_key: ApiKey | None, amount: int, py_api: Test json={"status": "all", "pagination": {"limit": 1000}}, ) assert response.status_code == http.client.OK, response.json() - datasets = response.json()["data"]["dataset"] - assert len(datasets) == amount + assert len(response.json()) == amount @pytest.mark.parametrize( @@ -79,7 +73,7 @@ def test_list_data_name_present(name: str, count: int, py_api: TestClient) -> No json={"status": "all", "data_name": name}, ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] + datasets = response.json() assert len(datasets) == count assert all(dataset["name"] == name for dataset in datasets) @@ -96,10 +90,6 @@ def test_list_data_name_absent(name: str, py_api: TestClient) -> None: _assert_empty_result(response) -def test_list_quality_filers() -> None: - pytest.skip("Not implemented") - - @pytest.mark.parametrize("limit", [None, 5, 10, 200]) @pytest.mark.parametrize("offset", [None, 0, 5, 129, 130, 200]) def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClient) -> None: @@ -123,7 +113,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie return assert response.status_code == http.client.OK - reported_ids = {dataset["did"] for dataset in response.json()["data"]["dataset"]} + reported_ids = {dataset["did"] for dataset in response.json()} assert reported_ids == set(expected_ids) @@ -137,7 +127,7 @@ def test_list_data_version(version: int, count: int, py_api: TestClient) -> None json={"status": "all", "data_version": version}, ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] + datasets = response.json() assert len(datasets) == count assert {dataset["version"] for dataset in datasets} == {version} @@ -169,8 +159,7 @@ def test_list_uploader(user_id: int, count: int, key: str, py_api: TestClient) - return assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] - assert len(datasets) == count + assert len(response.json()) == count @pytest.mark.parametrize( @@ -184,9 +173,8 @@ def test_list_data_id(data_id: list[int], py_api: TestClient) -> None: ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] private_or_not_exist = {130, 3000} - assert len(datasets) == len(set(data_id) - private_or_not_exist) + assert len(response.json()) == len(set(data_id) - private_or_not_exist) @pytest.mark.parametrize( @@ -201,8 +189,7 @@ def test_list_data_tag(tag: str, count: int, py_api: TestClient) -> None: json={"status": "all", "tag": tag, "pagination": {"limit": 101}}, ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] - assert len(datasets) == count + assert len(response.json()) == count def test_list_data_tag_empty(py_api: TestClient) -> None: @@ -232,7 +219,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl json={"status": "all", quality: range_}, ) assert response.status_code == http.client.OK, response.json() - assert len(response.json()["data"]["dataset"]) == count + assert len(response.json()) == count @pytest.mark.php() @@ -297,6 +284,14 @@ def test_list_data_identical( if original.status_code == http.client.PRECONDITION_FAILED: assert original.json()["error"] == response.json()["detail"] return None - assert len(original.json()["data"]["dataset"]) == len(response.json()["data"]["dataset"]) - assert original.json()["data"]["dataset"] == response.json()["data"]["dataset"] + new_json = response.json() + # Qualities in new response are typed + for dataset in new_json: + for quality in dataset["quality"]: + quality["value"] = str(quality["value"]) + + # PHP API has a double nested dictionary that never has other entries + php_json = original.json()["data"]["dataset"] + assert len(php_json) == len(new_json) + assert php_json == new_json return None