From 3247a111174633b1bc696a0ba8087a2bdad91c83 Mon Sep 17 00:00:00 2001 From: Pieter Gijsbers Date: Fri, 3 Nov 2023 14:56:30 +0100 Subject: [PATCH] Use `fastapi.Depends` for dependency injection of the database (#90) * Add engine as parameter to allow dependency injection * Add engine parameter to allow dependency injection * Move database initialization to shared setup * Use fastapi.Depends for dependency injection at the endpoint * Move old and new dataset endpoint tests to separate files * Add database dependencies * Define auto-injected parameters last There is a quirck where `None` is a valid `Engine`, which allows us to put it behind other optional parameters. In principle, I do not like that it is technical not optional (but provided by FastAPI) but I do prefer having these parameters last instead of first. --- pyproject.toml | 2 +- src/database/datasets.py | 57 +++++++---------- src/database/setup.py | 30 +++++++++ src/database/users.py | 23 ++----- src/routers/datasets.py | 31 ++++++---- src/routers/mldcat_ap/dataset.py | 18 +++++- src/routers/old/datasets.py | 15 ++++- tests/routers/datasets_test.py | 62 +++++++++++++++++++ .../old/datasets_old_test.py} | 34 +++------- 9 files changed, 175 insertions(+), 97 deletions(-) create mode 100644 src/database/setup.py create mode 100644 tests/routers/datasets_test.py rename tests/{identical_test.py => routers/old/datasets_old_test.py} (72%) diff --git a/pyproject.toml b/pyproject.toml index a750eb0..36e506a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,5 +56,5 @@ pythonpath = [ "src" ] markers = [ - "web: uses an internet connection" + "php: tests that compare directly to an old PHP endpoint" ] diff --git a/src/database/datasets.py b/src/database/datasets.py index 1e53439..efc4649 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,30 +1,14 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" from typing import Any -from config import load_database_configuration -from sqlalchemy import create_engine, text -from sqlalchemy.engine import URL +from sqlalchemy import Engine, text from database.meta import get_column_names -_database_configuration = load_database_configuration() -expdb_url = URL.create(**_database_configuration["expdb"]) -expdb = create_engine( - expdb_url, - echo=True, - pool_recycle=3600, -) -openml_url = URL.create(**_database_configuration["openml"]) -openml = create_engine( - openml_url, - echo=True, - pool_recycle=3600, -) - -def get_dataset(dataset_id: int) -> dict[str, Any] | None: - columns = get_column_names(expdb, "dataset") - with expdb.connect() as conn: +def get_dataset(dataset_id: int, engine: Engine) -> dict[str, Any] | None: + columns = get_column_names(engine, "dataset") + with engine.connect() as conn: row = conn.execute( text( """ @@ -38,9 +22,9 @@ def get_dataset(dataset_id: int) -> dict[str, Any] | None: return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None -def get_file(file_id: int) -> dict[str, Any] | None: - columns = get_column_names(openml, "file") - with openml.connect() as conn: +def get_file(file_id: int, engine: Engine) -> dict[str, Any] | None: + columns = get_column_names(engine, "file") + with engine.connect() as conn: row = conn.execute( text( """ @@ -54,9 +38,9 @@ def get_file(file_id: int) -> dict[str, Any] | None: return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None -def get_tags(dataset_id: int) -> list[str]: - columns = get_column_names(expdb, "dataset_tag") - with expdb.connect() as conn: +def get_tags(dataset_id: int, engine: Engine) -> list[str]: + columns = get_column_names(engine, "dataset_tag") + with engine.connect() as conn: rows = conn.execute( text( """ @@ -70,9 +54,12 @@ def get_tags(dataset_id: int) -> list[str]: return [dict(zip(columns, row, strict=True))["tag"] for row in rows] -def get_latest_dataset_description(dataset_id: int) -> dict[str, Any] | None: - columns = get_column_names(expdb, "dataset_description") - with expdb.connect() as conn: +def get_latest_dataset_description( + dataset_id: int, + engine: Engine, +) -> dict[str, Any] | None: + columns = get_column_names(engine, "dataset_description") + with engine.connect() as conn: row = conn.execute( text( """ @@ -87,9 +74,9 @@ def get_latest_dataset_description(dataset_id: int) -> dict[str, Any] | None: return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None -def get_latest_status_update(dataset_id: int) -> dict[str, Any] | None: - columns = get_column_names(expdb, "dataset_status") - with expdb.connect() as conn: +def get_latest_status_update(dataset_id: int, engine: Engine) -> dict[str, Any] | None: + columns = get_column_names(engine, "dataset_status") + with engine.connect() as conn: row = conn.execute( text( """ @@ -106,9 +93,9 @@ def get_latest_status_update(dataset_id: int) -> dict[str, Any] | None: ) -def get_latest_processing_update(dataset_id: int) -> dict[str, Any] | None: - columns = get_column_names(expdb, "data_processed") - with expdb.connect() as conn: +def get_latest_processing_update(dataset_id: int, engine: Engine) -> dict[str, Any] | None: + columns = get_column_names(engine, "data_processed") + with engine.connect() as conn: row = conn.execute( text( """ diff --git a/src/database/setup.py b/src/database/setup.py new file mode 100644 index 0000000..6f6a1e9 --- /dev/null +++ b/src/database/setup.py @@ -0,0 +1,30 @@ +from config import load_database_configuration +from sqlalchemy import Engine, create_engine +from sqlalchemy.engine import URL + +_user_engine = None +_expdb_engine = None + + +def _create_engine(database_name: str) -> Engine: + database_configuration = load_database_configuration() + db_url = URL.create(**database_configuration[database_name]) + return create_engine( + db_url, + echo=True, + pool_recycle=3600, + ) + + +def user_database() -> Engine: + global _user_engine + if _user_engine is None: + _user_engine = _create_engine("openml") + return _user_engine + + +def expdb_database() -> Engine: + global _expdb_engine + if _expdb_engine is None: + _expdb_engine = _create_engine("expdb") + return _expdb_engine diff --git a/src/database/users.py b/src/database/users.py index f8f58d5..4b1fc30 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -1,28 +1,17 @@ from typing import Annotated -from config import load_database_configuration from pydantic import StringConstraints -from sqlalchemy import create_engine, text -from sqlalchemy.engine import URL +from sqlalchemy import Engine, text from database.meta import get_column_names -_database_configuration = load_database_configuration() - -openml_url = URL.create(**_database_configuration["openml"]) -openml = create_engine( - openml_url, - echo=True, - pool_recycle=3600, -) - # Enforces str is 32 hexadecimal characters, does not check validity. APIKey = Annotated[str, StringConstraints(pattern=r"^[0-9a-fA-F]{32}$")] -def get_user_id_for(*, api_key: APIKey) -> int | None: - columns = get_column_names(openml, "users") - with openml.connect() as conn: +def get_user_id_for(*, api_key: APIKey, engine: Engine) -> int | None: + columns = get_column_names(engine, "users") + with engine.connect() as conn: row = conn.execute( text( """ @@ -38,8 +27,8 @@ def get_user_id_for(*, api_key: APIKey) -> int | None: return int(dict(zip(columns, user, strict=True))["id"]) -def get_user_groups_for(*, user_id: int) -> list[int]: - with openml.connect() as conn: +def get_user_groups_for(*, user_id: int, engine: Engine) -> list[int]: + with engine.connect() as conn: row = conn.execute( text( """ diff --git a/src/routers/datasets.py b/src/routers/datasets.py index 860412f..486f639 100644 --- a/src/routers/datasets.py +++ b/src/routers/datasets.py @@ -2,7 +2,7 @@ import http.client from collections import namedtuple from enum import IntEnum -from typing import Any +from typing import Annotated, Any from database.datasets import get_dataset as db_get_dataset from database.datasets import ( @@ -12,14 +12,16 @@ get_latest_status_update, get_tags, ) +from database.setup import expdb_database, user_database from database.users import APIKey, get_user_groups_for, get_user_id_for -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException from schemas.datasets.openml import ( DatasetFileFormat, DatasetMetadata, DatasetStatus, Visibility, ) +from sqlalchemy import Engine router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -33,9 +35,9 @@ class DatasetError(IntEnum): processing_info = namedtuple("processing_info", ["date", "warning", "error"]) -def _get_processing_information(dataset_id: int) -> processing_info: +def _get_processing_information(dataset_id: int, engine: Engine) -> processing_info: """Return processing information, if any. Otherwise, all fields `None`.""" - if not (data_processed := get_latest_processing_update(dataset_id)): + if not (data_processed := get_latest_processing_update(dataset_id, engine)): return processing_info(date=None, warning=None, error=None) date_processed = data_processed["processing_date"] @@ -51,6 +53,7 @@ def _format_error(*, code: DatasetError, message: str) -> dict[str, str]: def _user_has_access( dataset: dict[str, Any], + engine: Engine, api_key: APIKey | None = None, ) -> bool: """Determine if user of `api_key` has the right to view `dataset`.""" @@ -59,13 +62,13 @@ def _user_has_access( if not api_key: return False - if not (user_id := get_user_id_for(api_key=api_key)): + if not (user_id := get_user_id_for(api_key=api_key, engine=engine)): return False if user_id == dataset["uploader"]: return True - user_groups = get_user_groups_for(user_id=user_id) + user_groups = get_user_groups_for(user_id=user_id, engine=engine) ADMIN_GROUP = 1 return ADMIN_GROUP in user_groups @@ -106,26 +109,28 @@ def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]: def get_dataset( dataset_id: int, api_key: APIKey | None = None, + user_db: Annotated[Engine, Depends(user_database)] = None, + expdb_db: Annotated[Engine, Depends(expdb_database)] = None, ) -> DatasetMetadata: - if not (dataset := db_get_dataset(dataset_id)): + if not (dataset := db_get_dataset(dataset_id, expdb_db)): error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset") raise HTTPException(status_code=http.client.NOT_FOUND, detail=error) - if not _user_has_access(dataset, api_key): + if not _user_has_access(dataset=dataset, api_key=api_key, engine=user_db): error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted") raise HTTPException(status_code=http.client.FORBIDDEN, detail=error) - if not (dataset_file := get_file(dataset["file_id"])): + if not (dataset_file := get_file(dataset["file_id"], user_db)): error = _format_error( code=DatasetError.NO_DATA_FILE, message="No data file found", ) raise HTTPException(status_code=http.client.PRECONDITION_FAILED, detail=error) - tags = get_tags(dataset_id) - description = get_latest_dataset_description(dataset_id) - processing_result = _get_processing_information(dataset_id) - status = get_latest_status_update(dataset_id) + tags = get_tags(dataset_id, expdb_db) + description = get_latest_dataset_description(dataset_id, expdb_db) + processing_result = _get_processing_information(dataset_id, expdb_db) + status = get_latest_status_update(dataset_id, expdb_db) status_ = DatasetStatus(status["status"]) if status else DatasetStatus.IN_PREPARATION diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py index 95bfebe..1e1be28 100644 --- a/src/routers/mldcat_ap/dataset.py +++ b/src/routers/mldcat_ap/dataset.py @@ -1,5 +1,9 @@ -from fastapi import APIRouter +from typing import Annotated + +from database.setup import expdb_database, user_database +from fastapi import APIRouter, Depends from schemas.datasets.mldcat_ap import JsonLDGraph, convert_to_mldcat_ap +from sqlalchemy import Engine from routers.datasets import get_dataset @@ -10,6 +14,14 @@ path="/{dataset_id}", description="Get meta-data for dataset with ID `dataset_id`.", ) -def get_mldcat_ap_dataset(dataset_id: int) -> JsonLDGraph: - openml_dataset = get_dataset(dataset_id) +def get_mldcat_ap_dataset( + dataset_id: int, + user_db: Annotated[Engine, Depends(user_database)] = None, + expdb_db: Annotated[Engine, Depends(expdb_database)] = None, +) -> JsonLDGraph: + openml_dataset = get_dataset( + dataset_id=dataset_id, + user_db=user_db, + expdb_db=expdb_db, + ) return convert_to_mldcat_ap(openml_dataset) diff --git a/src/routers/old/datasets.py b/src/routers/old/datasets.py index df58cf3..efc7575 100644 --- a/src/routers/old/datasets.py +++ b/src/routers/old/datasets.py @@ -3,10 +3,12 @@ new API, and are easily removed later. """ import http.client -from typing import Any +from typing import Annotated, Any +from database.setup import expdb_database, user_database from database.users import APIKey -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import Engine from routers.datasets import get_dataset @@ -20,9 +22,16 @@ def get_dataset_wrapped( dataset_id: int, api_key: APIKey | None = None, + user_db: Annotated[Engine, Depends(user_database)] = None, + expdb_db: Annotated[Engine, Depends(expdb_database)] = None, ) -> dict[str, dict[str, Any]]: try: - dataset = get_dataset(dataset_id, api_key).model_dump(by_alias=True) + dataset = get_dataset( + user_db=user_db, + expdb_db=expdb_db, + dataset_id=dataset_id, + api_key=api_key, + ).model_dump(by_alias=True) except HTTPException as e: raise HTTPException( status_code=http.client.PRECONDITION_FAILED, diff --git a/tests/routers/datasets_test.py b/tests/routers/datasets_test.py new file mode 100644 index 0000000..987b6e9 --- /dev/null +++ b/tests/routers/datasets_test.py @@ -0,0 +1,62 @@ +import http.client +from typing import Any, cast + +import httpx +import pytest +from fastapi import FastAPI + + +@pytest.mark.parametrize( + ("endpoint", "dataset_id", "response_code"), + [ + ("datasets/", -1, http.client.NOT_FOUND), + ("datasets/", 138, http.client.NOT_FOUND), + ("datasets/", 100_000, http.client.NOT_FOUND), + ], +) +def test_error_unknown_dataset( + endpoint: str, + dataset_id: int, + response_code: int, + api_client: FastAPI, +) -> None: + response = cast(httpx.Response, api_client.get(f"{endpoint}/{dataset_id}")) + + assert response.status_code == response_code + assert {"code": "111", "message": "Unknown dataset"} == response.json()["detail"] + + +@pytest.mark.parametrize( + ("endpoint", "api_key", "response_code"), + [ + ("datasets", None, http.client.FORBIDDEN), + ("datasets", "a" * 32, http.client.FORBIDDEN), + ], +) +def test_private_dataset_no_user_no_access( + api_client: FastAPI, + endpoint: str, + api_key: str | None, + response_code: int, +) -> None: + query = f"?api_key={api_key}" if api_key else "" + response = cast(httpx.Response, api_client.get(f"{endpoint}/130{query}")) + + assert response.status_code == response_code + assert {"code": "112", "message": "No access granted"} == response.json()["detail"] + + +@pytest.mark.skip("Not sure how to include apikey in test yet.") +def test_private_dataset_owner_access( + api_client: FastAPI, + dataset_130: dict[str, Any], +) -> None: + response = cast(httpx.Response, api_client.get("/datasets/130?api_key=...")) + assert response.status_code == http.client.OK + assert dataset_130 == response.json() + + +@pytest.mark.skip("Not sure how to include apikey in test yet.") +def test_private_dataset_admin_access(api_client: FastAPI) -> None: + cast(httpx.Response, api_client.get("/datasets/130?api_key=...")) + # test against cached response diff --git a/tests/identical_test.py b/tests/routers/old/datasets_old_test.py similarity index 72% rename from tests/identical_test.py rename to tests/routers/old/datasets_old_test.py index 1ad4935..8cfef7b 100644 --- a/tests/identical_test.py +++ b/tests/routers/old/datasets_old_test.py @@ -7,7 +7,7 @@ from fastapi import FastAPI -@pytest.mark.web() +@pytest.mark.php() @pytest.mark.parametrize( "dataset_id", range(1, 132), @@ -54,47 +54,31 @@ def test_dataset_response_is_identical(dataset_id: int, api_client: FastAPI) -> @pytest.mark.parametrize( - ("endpoint", "dataset_id", "response_code"), - [ - ("old/datasets/", -1, http.client.PRECONDITION_FAILED), - ("old/datasets/", 138, http.client.PRECONDITION_FAILED), - ("old/datasets/", 100_000, http.client.PRECONDITION_FAILED), - ("datasets/", -1, http.client.NOT_FOUND), - ("datasets/", 138, http.client.NOT_FOUND), - ("datasets/", 100_000, http.client.NOT_FOUND), - ], + "dataset_id", + [-1, 138, 100_000], ) def test_error_unknown_dataset( - endpoint: str, dataset_id: int, - response_code: int, api_client: FastAPI, ) -> None: - response = cast(httpx.Response, api_client.get(f"{endpoint}/{dataset_id}")) + response = cast(httpx.Response, api_client.get(f"old/datasets/{dataset_id}")) - assert response.status_code == response_code + assert response.status_code == http.client.PRECONDITION_FAILED assert {"code": "111", "message": "Unknown dataset"} == response.json()["detail"] @pytest.mark.parametrize( - ("endpoint", "api_key", "response_code"), - [ - ("old/datasets", None, http.client.PRECONDITION_FAILED), - ("old/datasets", "a" * 32, http.client.PRECONDITION_FAILED), - ("datasets", None, http.client.FORBIDDEN), - ("datasets", "a" * 32, http.client.FORBIDDEN), - ], + "api_key", + [None, "a" * 32], ) def test_private_dataset_no_user_no_access( api_client: FastAPI, - endpoint: str, api_key: str | None, - response_code: int, ) -> None: query = f"?api_key={api_key}" if api_key else "" - response = cast(httpx.Response, api_client.get(f"{endpoint}/130{query}")) + response = cast(httpx.Response, api_client.get(f"old/datasets/130{query}")) - assert response.status_code == response_code + assert response.status_code == http.client.PRECONDITION_FAILED assert {"code": "112", "message": "No access granted"} == response.json()["detail"]