From 319fe5e6ab7612a8f34840cffc10ee9051747ac3 Mon Sep 17 00:00:00 2001
From: Pieter Gijsbers
Date: Wed, 29 Nov 2023 13:55:15 +0100
Subject: [PATCH] Improve re-usability of subcomponents (#116)
* Move auxiliary functions and types to separate files
* Refactor shared logic for dataset access for users
* Move some queries to db module
* remove level of nestedness in api response
---
src/core/access.py | 15 ++
src/core/errors.py | 7 +
src/core/formatting.py | 40 ++++++
src/database/datasets.py | 40 +++++-
src/routers/openml/datasets.py | 129 +++++-------------
src/routers/openml/qualities.py | 33 ++---
.../openml/datasets_list_datasets_test.py | 45 +++---
7 files changed, 164 insertions(+), 145 deletions(-)
create mode 100644 src/core/access.py
create mode 100644 src/core/errors.py
create mode 100644 src/core/formatting.py
diff --git a/src/core/access.py b/src/core/access.py
new file mode 100644
index 0000000..90f1676
--- /dev/null
+++ b/src/core/access.py
@@ -0,0 +1,15 @@
+from typing import Any
+
+from database.users import User, UserGroup
+from schemas.datasets.openml import Visibility
+
+
+def _user_has_access(
+ dataset: dict[str, Any],
+ user: User | None = None,
+) -> bool:
+ """Determine if `user` has the right to view `dataset`."""
+ is_public = dataset["visibility"] == Visibility.PUBLIC
+ return is_public or (
+ user is not None and (user.user_id == dataset["uploader"] or UserGroup.ADMIN in user.groups)
+ )
diff --git a/src/core/errors.py b/src/core/errors.py
new file mode 100644
index 0000000..840cd75
--- /dev/null
+++ b/src/core/errors.py
@@ -0,0 +1,7 @@
+from enum import IntEnum
+
+
+class DatasetError(IntEnum):
+ NOT_FOUND = 111
+ NO_ACCESS = 112
+ NO_DATA_FILE = 113
diff --git a/src/core/formatting.py b/src/core/formatting.py
new file mode 100644
index 0000000..819080c
--- /dev/null
+++ b/src/core/formatting.py
@@ -0,0 +1,40 @@
+import html
+from typing import Any
+
+from schemas.datasets.openml import DatasetFileFormat
+
+from core.errors import DatasetError
+
+
+def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:
+ """Formatter for JSON bodies of OpenML error codes."""
+ return {"code": str(code), "message": message}
+
+
+def _format_parquet_url(dataset: dict[str, Any]) -> str | None:
+ if dataset["format"].lower() != DatasetFileFormat.ARFF:
+ return None
+
+ minio_base_url = "https://openml1.win.tue.nl"
+ return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq"
+
+
+def _format_dataset_url(dataset: dict[str, Any]) -> str:
+ base_url = "https://test.openml.org"
+ filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}"
+ return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}"
+
+
+def _safe_unquote(text: str | None) -> str | None:
+ """Remove any open and closing quotes and return the remainder if non-empty."""
+ if not text:
+ return None
+ return text.strip("'\"") or None
+
+
+def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
+ """Return comma-separated values in `text` as list, optionally remove quotes."""
+ if not text:
+ return []
+ chars_to_strip = "'\"\t " if unquote_items else "\t "
+ return [item.strip(chars_to_strip) for item in text.split(",")]
diff --git a/src/database/datasets.py b/src/database/datasets.py
index 62e7df9..ffabb7b 100644
--- a/src/database/datasets.py
+++ b/src/database/datasets.py
@@ -1,11 +1,49 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
-from typing import Any
+from collections import defaultdict
+from typing import Any, Iterable
+from schemas.datasets.openml import Quality
from sqlalchemy import Connection, text
from database.meta import get_column_names
+def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
+ rows = connection.execute(
+ text(
+ """
+ SELECT `quality`,`value`
+ FROM data_quality
+ WHERE `data`=:dataset_id
+ """,
+ ),
+ parameters={"dataset_id": dataset_id},
+ )
+ return [Quality(name=row.quality, value=row.value) for row in rows]
+
+
+def get_qualities_for_datasets(
+ dataset_ids: Iterable[int],
+ qualities: Iterable[str],
+ connection: Connection,
+) -> dict[int, list[Quality]]:
+ qualities_filter = ",".join(f"'{q}'" for q in qualities)
+ dids = ",".join(str(did) for did in dataset_ids)
+ qualities_query = text(
+ f"""
+ SELECT `data`, `quality`, `value`
+ FROM data_quality
+ WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
+ """, # nosec - similar to above, no user input
+ )
+ rows = connection.execute(qualities_query)
+ qualities_by_id = defaultdict(list)
+ for did, quality, value in rows:
+ if value is not None:
+ qualities_by_id[did].append(Quality(name=quality, value=value))
+ return dict(qualities_by_id)
+
+
def list_all_qualities(connection: Connection) -> list[str]:
# The current implementation only fetches *used* qualities, otherwise you should
# query: SELECT `name` FROM `quality` WHERE `type`='DataQuality'
diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py
index 65c18a6..08db9a1 100644
--- a/src/routers/openml/datasets.py
+++ b/src/routers/openml/datasets.py
@@ -2,25 +2,35 @@
We add separate endpoints for old-style JSON responses, so they don't clutter the schema of the
new API, and are easily removed later.
"""
-import html
import http.client
-from collections import namedtuple
-from enum import IntEnum, StrEnum
-from typing import Annotated, Any, Literal
-
+import re
+from datetime import datetime
+from enum import StrEnum
+from typing import Annotated, Any, NamedTuple
+
+from core.access import _user_has_access
+from core.errors import DatasetError
+from core.formatting import (
+ _csv_as_list,
+ _format_dataset_url,
+ _format_error,
+ _format_parquet_url,
+ _safe_unquote,
+)
from database.datasets import get_dataset as db_get_dataset
from database.datasets import (
get_file,
get_latest_dataset_description,
get_latest_processing_update,
get_latest_status_update,
+ get_qualities_for_datasets,
get_tags,
)
from database.datasets import tag_dataset as db_tag_dataset
-from database.users import APIKey, User, UserGroup, get_user_groups_for, get_user_id_for
+from database.users import User, UserGroup
from fastapi import APIRouter, Body, Depends, HTTPException
-from schemas.datasets.openml import DatasetFileFormat, DatasetMetadata, DatasetStatus, Visibility
-from sqlalchemy import Connection
+from schemas.datasets.openml import DatasetMetadata, DatasetStatus
+from sqlalchemy import Connection, text
from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection
from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex
@@ -97,9 +107,7 @@ def list_datasets(
status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE,
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
-) -> dict[Literal["data"], dict[Literal["dataset"], list[dict[str, Any]]]]:
- from sqlalchemy import text
-
+) -> list[dict[str, Any]]:
current_status = text(
"""
SELECT ds1.`did`, ds1.`status`
@@ -150,8 +158,6 @@ def list_datasets(
else ""
)
- import re
-
def quality_clause(quality: str, range_: str | None) -> str:
if not range_:
return ""
@@ -230,96 +236,31 @@ def quality_clause(quality: str, range_: str | None) -> str:
"NumberOfNumericFeatures",
"NumberOfSymbolicFeatures",
]
- qualities_filter = ",".join(f"'{q}'" for q in qualities_to_show)
- dids = ",".join(str(did) for did in datasets)
- qualities = text(
- f"""
- SELECT `data`, `quality`, `value`
- FROM data_quality
- WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
- """, # nosec - similar to above, no user input
+ qualities_by_dataset = get_qualities_for_datasets(
+ dataset_ids=datasets.keys(),
+ qualities=qualities_to_show,
+ connection=expdb_db,
)
- qualities = expdb_db.execute(qualities)
- for did, quality, value in qualities:
- if value is not None:
- datasets[did]["quality"].append({"name": quality, "value": str(value)})
- return {"data": {"dataset": list(datasets.values())}}
-
+ for did, qualities in qualities_by_dataset.items():
+ datasets[did]["quality"] = qualities
+ return list(datasets.values())
-class DatasetError(IntEnum):
- NOT_FOUND = 111
- NO_ACCESS = 112
- NO_DATA_FILE = 113
+class ProcessingInformation(NamedTuple):
+ date: datetime | None
+ warning: str | None
+ error: str | None
-processing_info = namedtuple("processing_info", ["date", "warning", "error"])
-
-def _get_processing_information(dataset_id: int, connection: Connection) -> processing_info:
+def _get_processing_information(dataset_id: int, connection: Connection) -> ProcessingInformation:
"""Return processing information, if any. Otherwise, all fields `None`."""
if not (data_processed := get_latest_processing_update(dataset_id, connection)):
- return processing_info(date=None, warning=None, error=None)
+ return ProcessingInformation(date=None, warning=None, error=None)
date_processed = data_processed["processing_date"]
warning = data_processed["warning"].strip() if data_processed["warning"] else None
error = data_processed["error"].strip() if data_processed["error"] else None
- return processing_info(date=date_processed, warning=warning, error=error)
-
-
-def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:
- """Formatter for JSON bodies of OpenML error codes."""
- return {"code": str(code), "message": message}
-
-
-def _user_has_access(
- dataset: dict[str, Any],
- connection: Connection,
- api_key: APIKey | None = None,
-) -> bool:
- """Determine if user of `api_key` has the right to view `dataset`."""
- if dataset["visibility"] == Visibility.PUBLIC:
- return True
- if not api_key:
- return False
-
- if not (user_id := get_user_id_for(api_key=api_key, connection=connection)):
- return False
-
- if user_id == dataset["uploader"]:
- return True
-
- user_groups = get_user_groups_for(user_id=user_id, connection=connection)
- ADMIN_GROUP = 1
- return ADMIN_GROUP in user_groups
-
-
-def _format_parquet_url(dataset: dict[str, Any]) -> str | None:
- if dataset["format"].lower() != DatasetFileFormat.ARFF:
- return None
-
- minio_base_url = "https://openml1.win.tue.nl"
- return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq"
-
-
-def _format_dataset_url(dataset: dict[str, Any]) -> str:
- base_url = "https://test.openml.org"
- filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}"
- return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}"
-
-
-def _safe_unquote(text: str | None) -> str | None:
- """Remove any open and closing quotes and return the remainder if non-empty."""
- if not text:
- return None
- return text.strip("'\"") or None
-
-
-def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
- """Return comma-separated values in `text` as list, optionally remove quotes."""
- if not text:
- return []
- chars_to_strip = "'\"\t " if unquote_items else "\t "
- return [item.strip(chars_to_strip) for item in text.split(",")]
+ return ProcessingInformation(date=date_processed, warning=warning, error=error)
@router.get(
@@ -328,7 +269,7 @@ def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
)
def get_dataset(
dataset_id: int,
- api_key: APIKey | None = None,
+ user: Annotated[User | None, Depends(fetch_user)] = None,
user_db: Annotated[Connection, Depends(userdb_connection)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
) -> DatasetMetadata:
@@ -336,7 +277,7 @@ def get_dataset(
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)
- if not _user_has_access(dataset=dataset, connection=user_db, api_key=api_key):
+ if not _user_has_access(dataset=dataset, user=user):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)
diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py
index 569f5d0..4231d5a 100644
--- a/src/routers/openml/qualities.py
+++ b/src/routers/openml/qualities.py
@@ -1,14 +1,15 @@
import http.client
-from typing import Annotated, Any, Literal
+from typing import Annotated, Literal
-from database.datasets import get_dataset, list_all_qualities
-from database.users import User, UserGroup
+from core.access import _user_has_access
+from core.errors import DatasetError
+from database.datasets import get_dataset, get_qualities_for_dataset, list_all_qualities
+from database.users import User
from fastapi import APIRouter, Depends, HTTPException
from schemas.datasets.openml import Quality
-from sqlalchemy import Connection, text
+from sqlalchemy import Connection
from routers.dependencies import expdb_connection, fetch_user
-from routers.openml.datasets import DatasetError
router = APIRouter(prefix="/datasets", tags=["datasets"])
@@ -25,14 +26,6 @@ def list_qualities(
}
-def _user_can_see_dataset(dataset: dict[str, Any], user: User) -> bool:
- if dataset["visibility"] == "public":
- return True
- return user is not None and (
- dataset["uploader"] == user.user_id or UserGroup.ADMIN in user.groups
- )
-
-
@router.get("/qualities/{dataset_id}")
def get_qualities(
dataset_id: int,
@@ -40,22 +33,12 @@ def get_qualities(
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> list[Quality]:
dataset = get_dataset(dataset_id, expdb)
- if not dataset or not _user_can_see_dataset(dataset, user):
+ if not dataset or not _user_has_access(dataset, user):
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"},
) from None
- rows = expdb.execute(
- text(
- """
- SELECT `quality`,`value`
- FROM data_quality
- WHERE `data`=:dataset_id
- """,
- ),
- parameters={"dataset_id": dataset_id},
- )
- return [Quality(name=row.quality, value=row.value) for row in rows]
+ return get_qualities_for_dataset(dataset_id, expdb)
# The PHP API provided (sometime) helpful error messages
# if not qualities:
# check if dataset exists: error 360
diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py
index 92f7055..c91960a 100644
--- a/tests/routers/openml/datasets_list_datasets_test.py
+++ b/tests/routers/openml/datasets_list_datasets_test.py
@@ -22,11 +22,7 @@ def _assert_empty_result(
def test_list(py_api: TestClient) -> None:
response = py_api.get("/datasets/list/")
assert response.status_code == http.client.OK
- assert "data" in response.json()
- assert "dataset" in response.json()["data"]
-
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) >= 1
+ assert len(response.json()) >= 1
@pytest.mark.parametrize(
@@ -44,8 +40,7 @@ def test_list_filter_active(status: str, amount: int, py_api: TestClient) -> Non
json={"status": status, "pagination": {"limit": constants.NUMBER_OF_DATASETS}},
)
assert response.status_code == http.client.OK, response.json()
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == amount
+ assert len(response.json()) == amount
@pytest.mark.parametrize(
@@ -64,8 +59,7 @@ def test_list_accounts_privacy(api_key: ApiKey | None, amount: int, py_api: Test
json={"status": "all", "pagination": {"limit": 1000}},
)
assert response.status_code == http.client.OK, response.json()
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == amount
+ assert len(response.json()) == amount
@pytest.mark.parametrize(
@@ -79,7 +73,7 @@ def test_list_data_name_present(name: str, count: int, py_api: TestClient) -> No
json={"status": "all", "data_name": name},
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
+ datasets = response.json()
assert len(datasets) == count
assert all(dataset["name"] == name for dataset in datasets)
@@ -96,10 +90,6 @@ def test_list_data_name_absent(name: str, py_api: TestClient) -> None:
_assert_empty_result(response)
-def test_list_quality_filers() -> None:
- pytest.skip("Not implemented")
-
-
@pytest.mark.parametrize("limit", [None, 5, 10, 200])
@pytest.mark.parametrize("offset", [None, 0, 5, 129, 130, 200])
def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClient) -> None:
@@ -123,7 +113,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie
return
assert response.status_code == http.client.OK
- reported_ids = {dataset["did"] for dataset in response.json()["data"]["dataset"]}
+ reported_ids = {dataset["did"] for dataset in response.json()}
assert reported_ids == set(expected_ids)
@@ -137,7 +127,7 @@ def test_list_data_version(version: int, count: int, py_api: TestClient) -> None
json={"status": "all", "data_version": version},
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
+ datasets = response.json()
assert len(datasets) == count
assert {dataset["version"] for dataset in datasets} == {version}
@@ -169,8 +159,7 @@ def test_list_uploader(user_id: int, count: int, key: str, py_api: TestClient) -
return
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == count
+ assert len(response.json()) == count
@pytest.mark.parametrize(
@@ -184,9 +173,8 @@ def test_list_data_id(data_id: list[int], py_api: TestClient) -> None:
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
private_or_not_exist = {130, 3000}
- assert len(datasets) == len(set(data_id) - private_or_not_exist)
+ assert len(response.json()) == len(set(data_id) - private_or_not_exist)
@pytest.mark.parametrize(
@@ -201,8 +189,7 @@ def test_list_data_tag(tag: str, count: int, py_api: TestClient) -> None:
json={"status": "all", "tag": tag, "pagination": {"limit": 101}},
)
assert response.status_code == http.client.OK
- datasets = response.json()["data"]["dataset"]
- assert len(datasets) == count
+ assert len(response.json()) == count
def test_list_data_tag_empty(py_api: TestClient) -> None:
@@ -232,7 +219,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl
json={"status": "all", quality: range_},
)
assert response.status_code == http.client.OK, response.json()
- assert len(response.json()["data"]["dataset"]) == count
+ assert len(response.json()) == count
@pytest.mark.php()
@@ -297,6 +284,14 @@ def test_list_data_identical(
if original.status_code == http.client.PRECONDITION_FAILED:
assert original.json()["error"] == response.json()["detail"]
return None
- assert len(original.json()["data"]["dataset"]) == len(response.json()["data"]["dataset"])
- assert original.json()["data"]["dataset"] == response.json()["data"]["dataset"]
+ new_json = response.json()
+ # Qualities in new response are typed
+ for dataset in new_json:
+ for quality in dataset["quality"]:
+ quality["value"] = str(quality["value"])
+
+ # PHP API has a double nested dictionary that never has other entries
+ php_json = original.json()["data"]["dataset"]
+ assert len(php_json) == len(new_json)
+ assert php_json == new_json
return None