From 319fe5e6ab7612a8f34840cffc10ee9051747ac3 Mon Sep 17 00:00:00 2001 From: Pieter Gijsbers Date: Wed, 29 Nov 2023 13:55:15 +0100 Subject: [PATCH] Improve re-usability of subcomponents (#116) * Move auxiliary functions and types to separate files * Refactor shared logic for dataset access for users * Move some queries to db module * remove level of nestedness in api response --- src/core/access.py | 15 ++ src/core/errors.py | 7 + src/core/formatting.py | 40 ++++++ src/database/datasets.py | 40 +++++- src/routers/openml/datasets.py | 129 +++++------------- src/routers/openml/qualities.py | 33 ++--- .../openml/datasets_list_datasets_test.py | 45 +++--- 7 files changed, 164 insertions(+), 145 deletions(-) create mode 100644 src/core/access.py create mode 100644 src/core/errors.py create mode 100644 src/core/formatting.py diff --git a/src/core/access.py b/src/core/access.py new file mode 100644 index 0000000..90f1676 --- /dev/null +++ b/src/core/access.py @@ -0,0 +1,15 @@ +from typing import Any + +from database.users import User, UserGroup +from schemas.datasets.openml import Visibility + + +def _user_has_access( + dataset: dict[str, Any], + user: User | None = None, +) -> bool: + """Determine if `user` has the right to view `dataset`.""" + is_public = dataset["visibility"] == Visibility.PUBLIC + return is_public or ( + user is not None and (user.user_id == dataset["uploader"] or UserGroup.ADMIN in user.groups) + ) diff --git a/src/core/errors.py b/src/core/errors.py new file mode 100644 index 0000000..840cd75 --- /dev/null +++ b/src/core/errors.py @@ -0,0 +1,7 @@ +from enum import IntEnum + + +class DatasetError(IntEnum): + NOT_FOUND = 111 + NO_ACCESS = 112 + NO_DATA_FILE = 113 diff --git a/src/core/formatting.py b/src/core/formatting.py new file mode 100644 index 0000000..819080c --- /dev/null +++ b/src/core/formatting.py @@ -0,0 +1,40 @@ +import html +from typing import Any + +from schemas.datasets.openml import DatasetFileFormat + +from core.errors import DatasetError + + +def _format_error(*, code: DatasetError, message: str) -> dict[str, str]: + """Formatter for JSON bodies of OpenML error codes.""" + return {"code": str(code), "message": message} + + +def _format_parquet_url(dataset: dict[str, Any]) -> str | None: + if dataset["format"].lower() != DatasetFileFormat.ARFF: + return None + + minio_base_url = "https://openml1.win.tue.nl" + return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq" + + +def _format_dataset_url(dataset: dict[str, Any]) -> str: + base_url = "https://test.openml.org" + filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}" + return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}" + + +def _safe_unquote(text: str | None) -> str | None: + """Remove any open and closing quotes and return the remainder if non-empty.""" + if not text: + return None + return text.strip("'\"") or None + + +def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]: + """Return comma-separated values in `text` as list, optionally remove quotes.""" + if not text: + return [] + chars_to_strip = "'\"\t " if unquote_items else "\t " + return [item.strip(chars_to_strip) for item in text.split(",")] diff --git a/src/database/datasets.py b/src/database/datasets.py index 62e7df9..ffabb7b 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,11 +1,49 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" -from typing import Any +from collections import defaultdict +from typing import Any, Iterable +from schemas.datasets.openml import Quality from sqlalchemy import Connection, text from database.meta import get_column_names +def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: + rows = connection.execute( + text( + """ + SELECT `quality`,`value` + FROM data_quality + WHERE `data`=:dataset_id + """, + ), + parameters={"dataset_id": dataset_id}, + ) + return [Quality(name=row.quality, value=row.value) for row in rows] + + +def get_qualities_for_datasets( + dataset_ids: Iterable[int], + qualities: Iterable[str], + connection: Connection, +) -> dict[int, list[Quality]]: + qualities_filter = ",".join(f"'{q}'" for q in qualities) + dids = ",".join(str(did) for did in dataset_ids) + qualities_query = text( + f""" + SELECT `data`, `quality`, `value` + FROM data_quality + WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) + """, # nosec - similar to above, no user input + ) + rows = connection.execute(qualities_query) + qualities_by_id = defaultdict(list) + for did, quality, value in rows: + if value is not None: + qualities_by_id[did].append(Quality(name=quality, value=value)) + return dict(qualities_by_id) + + def list_all_qualities(connection: Connection) -> list[str]: # The current implementation only fetches *used* qualities, otherwise you should # query: SELECT `name` FROM `quality` WHERE `type`='DataQuality' diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 65c18a6..08db9a1 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -2,25 +2,35 @@ We add separate endpoints for old-style JSON responses, so they don't clutter the schema of the new API, and are easily removed later. """ -import html import http.client -from collections import namedtuple -from enum import IntEnum, StrEnum -from typing import Annotated, Any, Literal - +import re +from datetime import datetime +from enum import StrEnum +from typing import Annotated, Any, NamedTuple + +from core.access import _user_has_access +from core.errors import DatasetError +from core.formatting import ( + _csv_as_list, + _format_dataset_url, + _format_error, + _format_parquet_url, + _safe_unquote, +) from database.datasets import get_dataset as db_get_dataset from database.datasets import ( get_file, get_latest_dataset_description, get_latest_processing_update, get_latest_status_update, + get_qualities_for_datasets, get_tags, ) from database.datasets import tag_dataset as db_tag_dataset -from database.users import APIKey, User, UserGroup, get_user_groups_for, get_user_id_for +from database.users import User, UserGroup from fastapi import APIRouter, Body, Depends, HTTPException -from schemas.datasets.openml import DatasetFileFormat, DatasetMetadata, DatasetStatus, Visibility -from sqlalchemy import Connection +from schemas.datasets.openml import DatasetMetadata, DatasetStatus +from sqlalchemy import Connection, text from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex @@ -97,9 +107,7 @@ def list_datasets( status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, -) -> dict[Literal["data"], dict[Literal["dataset"], list[dict[str, Any]]]]: - from sqlalchemy import text - +) -> list[dict[str, Any]]: current_status = text( """ SELECT ds1.`did`, ds1.`status` @@ -150,8 +158,6 @@ def list_datasets( else "" ) - import re - def quality_clause(quality: str, range_: str | None) -> str: if not range_: return "" @@ -230,96 +236,31 @@ def quality_clause(quality: str, range_: str | None) -> str: "NumberOfNumericFeatures", "NumberOfSymbolicFeatures", ] - qualities_filter = ",".join(f"'{q}'" for q in qualities_to_show) - dids = ",".join(str(did) for did in datasets) - qualities = text( - f""" - SELECT `data`, `quality`, `value` - FROM data_quality - WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) - """, # nosec - similar to above, no user input + qualities_by_dataset = get_qualities_for_datasets( + dataset_ids=datasets.keys(), + qualities=qualities_to_show, + connection=expdb_db, ) - qualities = expdb_db.execute(qualities) - for did, quality, value in qualities: - if value is not None: - datasets[did]["quality"].append({"name": quality, "value": str(value)}) - return {"data": {"dataset": list(datasets.values())}} - + for did, qualities in qualities_by_dataset.items(): + datasets[did]["quality"] = qualities + return list(datasets.values()) -class DatasetError(IntEnum): - NOT_FOUND = 111 - NO_ACCESS = 112 - NO_DATA_FILE = 113 +class ProcessingInformation(NamedTuple): + date: datetime | None + warning: str | None + error: str | None -processing_info = namedtuple("processing_info", ["date", "warning", "error"]) - -def _get_processing_information(dataset_id: int, connection: Connection) -> processing_info: +def _get_processing_information(dataset_id: int, connection: Connection) -> ProcessingInformation: """Return processing information, if any. Otherwise, all fields `None`.""" if not (data_processed := get_latest_processing_update(dataset_id, connection)): - return processing_info(date=None, warning=None, error=None) + return ProcessingInformation(date=None, warning=None, error=None) date_processed = data_processed["processing_date"] warning = data_processed["warning"].strip() if data_processed["warning"] else None error = data_processed["error"].strip() if data_processed["error"] else None - return processing_info(date=date_processed, warning=warning, error=error) - - -def _format_error(*, code: DatasetError, message: str) -> dict[str, str]: - """Formatter for JSON bodies of OpenML error codes.""" - return {"code": str(code), "message": message} - - -def _user_has_access( - dataset: dict[str, Any], - connection: Connection, - api_key: APIKey | None = None, -) -> bool: - """Determine if user of `api_key` has the right to view `dataset`.""" - if dataset["visibility"] == Visibility.PUBLIC: - return True - if not api_key: - return False - - if not (user_id := get_user_id_for(api_key=api_key, connection=connection)): - return False - - if user_id == dataset["uploader"]: - return True - - user_groups = get_user_groups_for(user_id=user_id, connection=connection) - ADMIN_GROUP = 1 - return ADMIN_GROUP in user_groups - - -def _format_parquet_url(dataset: dict[str, Any]) -> str | None: - if dataset["format"].lower() != DatasetFileFormat.ARFF: - return None - - minio_base_url = "https://openml1.win.tue.nl" - return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq" - - -def _format_dataset_url(dataset: dict[str, Any]) -> str: - base_url = "https://test.openml.org" - filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}" - return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}" - - -def _safe_unquote(text: str | None) -> str | None: - """Remove any open and closing quotes and return the remainder if non-empty.""" - if not text: - return None - return text.strip("'\"") or None - - -def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]: - """Return comma-separated values in `text` as list, optionally remove quotes.""" - if not text: - return [] - chars_to_strip = "'\"\t " if unquote_items else "\t " - return [item.strip(chars_to_strip) for item in text.split(",")] + return ProcessingInformation(date=date_processed, warning=warning, error=error) @router.get( @@ -328,7 +269,7 @@ def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]: ) def get_dataset( dataset_id: int, - api_key: APIKey | None = None, + user: Annotated[User | None, Depends(fetch_user)] = None, user_db: Annotated[Connection, Depends(userdb_connection)] = None, expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, ) -> DatasetMetadata: @@ -336,7 +277,7 @@ def get_dataset( error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset") raise HTTPException(status_code=http.client.NOT_FOUND, detail=error) - if not _user_has_access(dataset=dataset, connection=user_db, api_key=api_key): + if not _user_has_access(dataset=dataset, user=user): error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted") raise HTTPException(status_code=http.client.FORBIDDEN, detail=error) diff --git a/src/routers/openml/qualities.py b/src/routers/openml/qualities.py index 569f5d0..4231d5a 100644 --- a/src/routers/openml/qualities.py +++ b/src/routers/openml/qualities.py @@ -1,14 +1,15 @@ import http.client -from typing import Annotated, Any, Literal +from typing import Annotated, Literal -from database.datasets import get_dataset, list_all_qualities -from database.users import User, UserGroup +from core.access import _user_has_access +from core.errors import DatasetError +from database.datasets import get_dataset, get_qualities_for_dataset, list_all_qualities +from database.users import User from fastapi import APIRouter, Depends, HTTPException from schemas.datasets.openml import Quality -from sqlalchemy import Connection, text +from sqlalchemy import Connection from routers.dependencies import expdb_connection, fetch_user -from routers.openml.datasets import DatasetError router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -25,14 +26,6 @@ def list_qualities( } -def _user_can_see_dataset(dataset: dict[str, Any], user: User) -> bool: - if dataset["visibility"] == "public": - return True - return user is not None and ( - dataset["uploader"] == user.user_id or UserGroup.ADMIN in user.groups - ) - - @router.get("/qualities/{dataset_id}") def get_qualities( dataset_id: int, @@ -40,22 +33,12 @@ def get_qualities( expdb: Annotated[Connection, Depends(expdb_connection)], ) -> list[Quality]: dataset = get_dataset(dataset_id, expdb) - if not dataset or not _user_can_see_dataset(dataset, user): + if not dataset or not _user_has_access(dataset, user): raise HTTPException( status_code=http.client.PRECONDITION_FAILED, detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"}, ) from None - rows = expdb.execute( - text( - """ - SELECT `quality`,`value` - FROM data_quality - WHERE `data`=:dataset_id - """, - ), - parameters={"dataset_id": dataset_id}, - ) - return [Quality(name=row.quality, value=row.value) for row in rows] + return get_qualities_for_dataset(dataset_id, expdb) # The PHP API provided (sometime) helpful error messages # if not qualities: # check if dataset exists: error 360 diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index 92f7055..c91960a 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -22,11 +22,7 @@ def _assert_empty_result( def test_list(py_api: TestClient) -> None: response = py_api.get("/datasets/list/") assert response.status_code == http.client.OK - assert "data" in response.json() - assert "dataset" in response.json()["data"] - - datasets = response.json()["data"]["dataset"] - assert len(datasets) >= 1 + assert len(response.json()) >= 1 @pytest.mark.parametrize( @@ -44,8 +40,7 @@ def test_list_filter_active(status: str, amount: int, py_api: TestClient) -> Non json={"status": status, "pagination": {"limit": constants.NUMBER_OF_DATASETS}}, ) assert response.status_code == http.client.OK, response.json() - datasets = response.json()["data"]["dataset"] - assert len(datasets) == amount + assert len(response.json()) == amount @pytest.mark.parametrize( @@ -64,8 +59,7 @@ def test_list_accounts_privacy(api_key: ApiKey | None, amount: int, py_api: Test json={"status": "all", "pagination": {"limit": 1000}}, ) assert response.status_code == http.client.OK, response.json() - datasets = response.json()["data"]["dataset"] - assert len(datasets) == amount + assert len(response.json()) == amount @pytest.mark.parametrize( @@ -79,7 +73,7 @@ def test_list_data_name_present(name: str, count: int, py_api: TestClient) -> No json={"status": "all", "data_name": name}, ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] + datasets = response.json() assert len(datasets) == count assert all(dataset["name"] == name for dataset in datasets) @@ -96,10 +90,6 @@ def test_list_data_name_absent(name: str, py_api: TestClient) -> None: _assert_empty_result(response) -def test_list_quality_filers() -> None: - pytest.skip("Not implemented") - - @pytest.mark.parametrize("limit", [None, 5, 10, 200]) @pytest.mark.parametrize("offset", [None, 0, 5, 129, 130, 200]) def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClient) -> None: @@ -123,7 +113,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie return assert response.status_code == http.client.OK - reported_ids = {dataset["did"] for dataset in response.json()["data"]["dataset"]} + reported_ids = {dataset["did"] for dataset in response.json()} assert reported_ids == set(expected_ids) @@ -137,7 +127,7 @@ def test_list_data_version(version: int, count: int, py_api: TestClient) -> None json={"status": "all", "data_version": version}, ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] + datasets = response.json() assert len(datasets) == count assert {dataset["version"] for dataset in datasets} == {version} @@ -169,8 +159,7 @@ def test_list_uploader(user_id: int, count: int, key: str, py_api: TestClient) - return assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] - assert len(datasets) == count + assert len(response.json()) == count @pytest.mark.parametrize( @@ -184,9 +173,8 @@ def test_list_data_id(data_id: list[int], py_api: TestClient) -> None: ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] private_or_not_exist = {130, 3000} - assert len(datasets) == len(set(data_id) - private_or_not_exist) + assert len(response.json()) == len(set(data_id) - private_or_not_exist) @pytest.mark.parametrize( @@ -201,8 +189,7 @@ def test_list_data_tag(tag: str, count: int, py_api: TestClient) -> None: json={"status": "all", "tag": tag, "pagination": {"limit": 101}}, ) assert response.status_code == http.client.OK - datasets = response.json()["data"]["dataset"] - assert len(datasets) == count + assert len(response.json()) == count def test_list_data_tag_empty(py_api: TestClient) -> None: @@ -232,7 +219,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl json={"status": "all", quality: range_}, ) assert response.status_code == http.client.OK, response.json() - assert len(response.json()["data"]["dataset"]) == count + assert len(response.json()) == count @pytest.mark.php() @@ -297,6 +284,14 @@ def test_list_data_identical( if original.status_code == http.client.PRECONDITION_FAILED: assert original.json()["error"] == response.json()["detail"] return None - assert len(original.json()["data"]["dataset"]) == len(response.json()["data"]["dataset"]) - assert original.json()["data"]["dataset"] == response.json()["data"]["dataset"] + new_json = response.json() + # Qualities in new response are typed + for dataset in new_json: + for quality in dataset["quality"]: + quality["value"] = str(quality["value"]) + + # PHP API has a double nested dictionary that never has other entries + php_json = original.json()["data"]["dataset"] + assert len(php_json) == len(new_json) + assert php_json == new_json return None