From 9aa7479112998356349870533f0a6ea001f63304 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Wed, 29 Nov 2023 12:11:30 +0100
Subject: [PATCH] Move get qualities query to db module, remove level of
nestedness
---
src/database/datasets.py | 25 ++++++++++-
src/routers/openml/datasets.py | 25 +++++------
.../openml/datasets_list_datasets_test.py | 45 +++++++++----------
3 files changed, 54 insertions(+), 41 deletions(-)
diff --git a/src/database/datasets.py b/src/database/datasets.py
index 110fc70..ffabb7b 100644
--- a/src/database/datasets.py
+++ b/src/database/datasets.py
@@ -1,5 +1,6 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
-from typing import Any
+from collections import defaultdict
+from typing import Any, Iterable
from schemas.datasets.openml import Quality
from sqlalchemy import Connection, text
@@ -21,6 +22,28 @@ def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Q
return [Quality(name=row.quality, value=row.value) for row in rows]
+def get_qualities_for_datasets(
+ dataset_ids: Iterable[int],
+ qualities: Iterable[str],
+ connection: Connection,
+) -> dict[int, list[Quality]]:
+ qualities_filter = ",".join(f"'{q}'" for q in qualities)
+ dids = ",".join(str(did) for did in dataset_ids)
+ qualities_query = text(
+ f"""
+ SELECT `data`, `quality`, `value`
+ FROM data_quality
+ WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
+ """, # nosec - similar to above, no user input
+ )
+ rows = connection.execute(qualities_query)
+ qualities_by_id = defaultdict(list)
+ for did, quality, value in rows:
+ if value is not None:
+ qualities_by_id[did].append(Quality(name=quality, value=value))
+ return dict(qualities_by_id)
+
+
def list_all_qualities(connection: Connection) -> list[str]:
# The current implementation only fetches *used* qualities, otherwise you should
# query: SELECT `name` FROM `quality` WHERE `type`='DataQuality'
diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py
index e435b67..08db9a1 100644
--- a/src/routers/openml/datasets.py
+++ b/src/routers/openml/datasets.py
@@ -6,7 +6,7 @@
import re
from datetime import datetime
from enum import StrEnum
-from typing import Annotated, Any, Literal, NamedTuple
+from typing import Annotated, Any, NamedTuple
from core.access import _user_has_access
from core.errors import DatasetError
@@ -23,6 +23,7 @@
get_latest_dataset_description,
get_latest_processing_update,
get_latest_status_update,
+ get_qualities_for_datasets,
get_tags,
)
from database.datasets import tag_dataset as db_tag_dataset
@@ -106,7 +107,7 @@ def list_datasets(
status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE,
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
-) -> dict[Literal["data"], dict[Literal["dataset"], list[dict[str, Any]]]]:
+) -> list[dict[str, Any]]:
current_status = text(
"""
SELECT ds1.`did`, ds1.`status`
@@ -235,20 +236,14 @@ def quality_clause(quality: str, range_: str | None) -> str:
"NumberOfNumericFeatures",
"NumberOfSymbolicFeatures",
]
- qualities_filter = ",".join(f"'{q}'" for q in qualities_to_show)
- dids = ",".join(str(did) for did in datasets)
- qualities = text(
- f"""
- SELECT `data`, `quality`, `value`
- FROM data_quality
- WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
- """, # nosec - similar to above, no user input
+ qualities_by_dataset = get_qualities_for_datasets(
+ dataset_ids=datasets.keys(),
+ qualities=qualities_to_show,
+ connection=expdb_db,
)
- qualities = expdb_db.execute(qualities)
- for did, quality, value in qualities:
- if value is not None:
- datasets[did]["quality"].append({"name": quality, "value": str(value)})
- return {"data": {"dataset": list(datasets.values())}}
+ for did, qualities in qualities_by_dataset.items():
+ datasets[did]["quality"] = qualities
+ return list(datasets.values())
class ProcessingInformation(NamedTuple):
diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py
index 92f7055..c91960a 100644
--- a/tests/routers/openml/datasets_list_datasets_test.py
+++ b/tests/routers/openml/datasets_list_datasets_test.py
@@ -22,11 +22,7 @@ def _assert_empty_result(
def test_list(py_api: TestClient) -> None:
response = py_api.get("/datasets/list/")
assert response.status_code == http.client.OK
- assert "data" in response.json()
- assert "dataset" in response.json()["data"]
-
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) >= 1
+ assert len(response.json()) >= 1
@pytest.mark.parametrize(
@@ -44,8 +40,7 @@ def test_list_filter_active(status: str, amount: int, py_api: TestClient) -> Non
json={"status": status, "pagination": {"limit": constants.NUMBER_OF_DATASETS}},
)
assert response.status_code == http.client.OK, response.json()
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == amount
+ assert len(response.json()) == amount
@pytest.mark.parametrize(
@@ -64,8 +59,7 @@ def test_list_accounts_privacy(api_key: ApiKey | None, amount: int, py_api: Test
json={"status": "all", "pagination": {"limit": 1000}},
)
assert response.status_code == http.client.OK, response.json()
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == amount
+ assert len(response.json()) == amount
@pytest.mark.parametrize(
@@ -79,7 +73,7 @@ def test_list_data_name_present(name: str, count: int, py_api: TestClient) -> No
json={"status": "all", "data_name": name},
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
+ datasets = response.json()
assert len(datasets) == count
assert all(dataset["name"] == name for dataset in datasets)
@@ -96,10 +90,6 @@ def test_list_data_name_absent(name: str, py_api: TestClient) -> None:
_assert_empty_result(response)
-def test_list_quality_filers() -> None:
- pytest.skip("Not implemented")
-
-
@pytest.mark.parametrize("limit", [None, 5, 10, 200])
@pytest.mark.parametrize("offset", [None, 0, 5, 129, 130, 200])
def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClient) -> None:
@@ -123,7 +113,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie
return
assert response.status_code == http.client.OK
- reported_ids = {dataset["did"] for dataset in response.json()["data"]["dataset"]}
+ reported_ids = {dataset["did"] for dataset in response.json()}
assert reported_ids == set(expected_ids)
@@ -137,7 +127,7 @@ def test_list_data_version(version: int, count: int, py_api: TestClient) -> None
json={"status": "all", "data_version": version},
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
+ datasets = response.json()
assert len(datasets) == count
assert {dataset["version"] for dataset in datasets} == {version}
@@ -169,8 +159,7 @@ def test_list_uploader(user_id: int, count: int, key: str, py_api: TestClient) -
return
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == count
+ assert len(response.json()) == count
@pytest.mark.parametrize(
@@ -184,9 +173,8 @@ def test_list_data_id(data_id: list[int], py_api: TestClient) -> None:
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
private_or_not_exist = {130, 3000}
- assert len(datasets) == len(set(data_id) - private_or_not_exist)
+ assert len(response.json()) == len(set(data_id) - private_or_not_exist)
@pytest.mark.parametrize(
@@ -201,8 +189,7 @@ def test_list_data_tag(tag: str, count: int, py_api: TestClient) -> None:
json={"status": "all", "tag": tag, "pagination": {"limit": 101}},
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == count
+ assert len(response.json()) == count
def test_list_data_tag_empty(py_api: TestClient) -> None:
@@ -232,7 +219,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl
json={"status": "all", quality: range_},
)
assert response.status_code == http.client.OK, response.json()
- assert len(response.json()["data"]["dataset"]) == count
+ assert len(response.json()) == count
@pytest.mark.php()
@@ -297,6 +284,14 @@ def test_list_data_identical(
if original.status_code == http.client.PRECONDITION_FAILED:
assert original.json()["error"] == response.json()["detail"]
return None
- assert len(original.json()["data"]["dataset"]) == len(response.json()["data"]["dataset"])
- assert original.json()["data"]["dataset"] == response.json()["data"]["dataset"]
+ new_json = response.json()
+ # Qualities in new response are typed
+ for dataset in new_json:
+ for quality in dataset["quality"]:
+ quality["value"] = str(quality["value"])
+
+ # PHP API has a double nested dictionary that never has other entries
+ php_json = original.json()["data"]["dataset"]
+ assert len(php_json) == len(new_json)
+ assert php_json == new_json
return None