Skip to content

Commit

Permalink
Move get qualities query to db module, remove level of nestedness
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Nov 29, 2023
1 parent 15f1615 commit 9aa7479
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 41 deletions.
25 changes: 24 additions & 1 deletion src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
from typing import Any
from collections import defaultdict
from typing import Any, Iterable

from schemas.datasets.openml import Quality
from sqlalchemy import Connection, text
Expand All @@ -21,6 +22,28 @@ def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Q
return [Quality(name=row.quality, value=row.value) for row in rows]


def get_qualities_for_datasets(
dataset_ids: Iterable[int],
qualities: Iterable[str],
connection: Connection,
) -> dict[int, list[Quality]]:
qualities_filter = ",".join(f"'{q}'" for q in qualities)
dids = ",".join(str(did) for did in dataset_ids)
qualities_query = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - similar to above, no user input
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
for did, quality, value in rows:
if value is not None:
qualities_by_id[did].append(Quality(name=quality, value=value))
return dict(qualities_by_id)


def list_all_qualities(connection: Connection) -> list[str]:
# The current implementation only fetches *used* qualities, otherwise you should
# query: SELECT `name` FROM `quality` WHERE `type`='DataQuality'
Expand Down
25 changes: 10 additions & 15 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Any, Literal, NamedTuple
from typing import Annotated, Any, NamedTuple

from core.access import _user_has_access
from core.errors import DatasetError
Expand All @@ -23,6 +23,7 @@
get_latest_dataset_description,
get_latest_processing_update,
get_latest_status_update,
get_qualities_for_datasets,
get_tags,
)
from database.datasets import tag_dataset as db_tag_dataset
Expand Down Expand Up @@ -106,7 +107,7 @@ def list_datasets(
status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE,
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
) -> dict[Literal["data"], dict[Literal["dataset"], list[dict[str, Any]]]]:
) -> list[dict[str, Any]]:
current_status = text(
"""
SELECT ds1.`did`, ds1.`status`
Expand Down Expand Up @@ -235,20 +236,14 @@ def quality_clause(quality: str, range_: str | None) -> str:
"NumberOfNumericFeatures",
"NumberOfSymbolicFeatures",
]
qualities_filter = ",".join(f"'{q}'" for q in qualities_to_show)
dids = ",".join(str(did) for did in datasets)
qualities = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - similar to above, no user input
qualities_by_dataset = get_qualities_for_datasets(
dataset_ids=datasets.keys(),
qualities=qualities_to_show,
connection=expdb_db,
)
qualities = expdb_db.execute(qualities)
for did, quality, value in qualities:
if value is not None:
datasets[did]["quality"].append({"name": quality, "value": str(value)})
return {"data": {"dataset": list(datasets.values())}}
for did, qualities in qualities_by_dataset.items():
datasets[did]["quality"] = qualities
return list(datasets.values())


class ProcessingInformation(NamedTuple):
Expand Down
45 changes: 20 additions & 25 deletions tests/routers/openml/datasets_list_datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ def _assert_empty_result(
def test_list(py_api: TestClient) -> None:
response = py_api.get("/datasets/list/")
assert response.status_code == http.client.OK
assert "data" in response.json()
assert "dataset" in response.json()["data"]

datasets = response.json()["data"]["dataset"]
assert len(datasets) >= 1
assert len(response.json()) >= 1


@pytest.mark.parametrize(
Expand All @@ -44,8 +40,7 @@ def test_list_filter_active(status: str, amount: int, py_api: TestClient) -> Non
json={"status": status, "pagination": {"limit": constants.NUMBER_OF_DATASETS}},
)
assert response.status_code == http.client.OK, response.json()
datasets = response.json()["data"]["dataset"]
assert len(datasets) == amount
assert len(response.json()) == amount


@pytest.mark.parametrize(
Expand All @@ -64,8 +59,7 @@ def test_list_accounts_privacy(api_key: ApiKey | None, amount: int, py_api: Test
json={"status": "all", "pagination": {"limit": 1000}},
)
assert response.status_code == http.client.OK, response.json()
datasets = response.json()["data"]["dataset"]
assert len(datasets) == amount
assert len(response.json()) == amount


@pytest.mark.parametrize(
Expand All @@ -79,7 +73,7 @@ def test_list_data_name_present(name: str, count: int, py_api: TestClient) -> No
json={"status": "all", "data_name": name},
)
assert response.status_code == http.client.OK
datasets = response.json()["data"]["dataset"]
datasets = response.json()
assert len(datasets) == count
assert all(dataset["name"] == name for dataset in datasets)

Expand All @@ -96,10 +90,6 @@ def test_list_data_name_absent(name: str, py_api: TestClient) -> None:
_assert_empty_result(response)


def test_list_quality_filers() -> None:
pytest.skip("Not implemented")


@pytest.mark.parametrize("limit", [None, 5, 10, 200])
@pytest.mark.parametrize("offset", [None, 0, 5, 129, 130, 200])
def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClient) -> None:
Expand All @@ -123,7 +113,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie
return

assert response.status_code == http.client.OK
reported_ids = {dataset["did"] for dataset in response.json()["data"]["dataset"]}
reported_ids = {dataset["did"] for dataset in response.json()}
assert reported_ids == set(expected_ids)


Expand All @@ -137,7 +127,7 @@ def test_list_data_version(version: int, count: int, py_api: TestClient) -> None
json={"status": "all", "data_version": version},
)
assert response.status_code == http.client.OK
datasets = response.json()["data"]["dataset"]
datasets = response.json()
assert len(datasets) == count
assert {dataset["version"] for dataset in datasets} == {version}

Expand Down Expand Up @@ -169,8 +159,7 @@ def test_list_uploader(user_id: int, count: int, key: str, py_api: TestClient) -
return

assert response.status_code == http.client.OK
datasets = response.json()["data"]["dataset"]
assert len(datasets) == count
assert len(response.json()) == count


@pytest.mark.parametrize(
Expand All @@ -184,9 +173,8 @@ def test_list_data_id(data_id: list[int], py_api: TestClient) -> None:
)

assert response.status_code == http.client.OK
datasets = response.json()["data"]["dataset"]
private_or_not_exist = {130, 3000}
assert len(datasets) == len(set(data_id) - private_or_not_exist)
assert len(response.json()) == len(set(data_id) - private_or_not_exist)


@pytest.mark.parametrize(
Expand All @@ -201,8 +189,7 @@ def test_list_data_tag(tag: str, count: int, py_api: TestClient) -> None:
json={"status": "all", "tag": tag, "pagination": {"limit": 101}},
)
assert response.status_code == http.client.OK
datasets = response.json()["data"]["dataset"]
assert len(datasets) == count
assert len(response.json()) == count


def test_list_data_tag_empty(py_api: TestClient) -> None:
Expand Down Expand Up @@ -232,7 +219,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl
json={"status": "all", quality: range_},
)
assert response.status_code == http.client.OK, response.json()
assert len(response.json()["data"]["dataset"]) == count
assert len(response.json()) == count


@pytest.mark.php()
Expand Down Expand Up @@ -297,6 +284,14 @@ def test_list_data_identical(
if original.status_code == http.client.PRECONDITION_FAILED:
assert original.json()["error"] == response.json()["detail"]
return None
assert len(original.json()["data"]["dataset"]) == len(response.json()["data"]["dataset"])
assert original.json()["data"]["dataset"] == response.json()["data"]["dataset"]
new_json = response.json()
# Qualities in new response are typed
for dataset in new_json:
for quality in dataset["quality"]:
quality["value"] = str(quality["value"])

# PHP API has a double nested dictionary that never has other entries
php_json = original.json()["data"]["dataset"]
assert len(php_json) == len(new_json)
assert php_json == new_json
return None

0 comments on commit 9aa7479

Please sign in to comment.