From 3247a111174633b1bc696a0ba8087a2bdad91c83 Mon Sep 17 00:00:00 2001
From: Pieter Gijsbers
Date: Fri, 3 Nov 2023 14:56:30 +0100
Subject: [PATCH] Use `fastapi.Depends` for dependency injection of the
database (#90)
* Add engine as parameter to allow dependency injection
* Add engine parameter to allow dependency injection
* Move database initialization to shared setup
* Use fastapi.Depends for dependency injection at the endpoint
* Move old and new dataset endpoint tests to separate files
* Add database dependencies
* Define auto-injected parameters last
There is a quirck where `None` is a valid `Engine`, which allows
us to put it behind other optional parameters. In principle, I
do not like that it is technical not optional (but provided by
FastAPI) but I do prefer having these parameters last instead of
first.
---
pyproject.toml | 2 +-
src/database/datasets.py | 57 +++++++----------
src/database/setup.py | 30 +++++++++
src/database/users.py | 23 ++-----
src/routers/datasets.py | 31 ++++++----
src/routers/mldcat_ap/dataset.py | 18 +++++-
src/routers/old/datasets.py | 15 ++++-
tests/routers/datasets_test.py | 62 +++++++++++++++++++
.../old/datasets_old_test.py} | 34 +++-------
9 files changed, 175 insertions(+), 97 deletions(-)
create mode 100644 src/database/setup.py
create mode 100644 tests/routers/datasets_test.py
rename tests/{identical_test.py => routers/old/datasets_old_test.py} (72%)
diff --git a/pyproject.toml b/pyproject.toml
index a750eb0..36e506a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,5 +56,5 @@ pythonpath = [
"src"
]
markers = [
- "web: uses an internet connection"
+ "php: tests that compare directly to an old PHP endpoint"
]
diff --git a/src/database/datasets.py b/src/database/datasets.py
index 1e53439..efc4649 100644
--- a/src/database/datasets.py
+++ b/src/database/datasets.py
@@ -1,30 +1,14 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
from typing import Any
-from config import load_database_configuration
-from sqlalchemy import create_engine, text
-from sqlalchemy.engine import URL
+from sqlalchemy import Engine, text
from database.meta import get_column_names
-_database_configuration = load_database_configuration()
-expdb_url = URL.create(**_database_configuration["expdb"])
-expdb = create_engine(
- expdb_url,
- echo=True,
- pool_recycle=3600,
-)
-openml_url = URL.create(**_database_configuration["openml"])
-openml = create_engine(
- openml_url,
- echo=True,
- pool_recycle=3600,
-)
-
-def get_dataset(dataset_id: int) -> dict[str, Any] | None:
- columns = get_column_names(expdb, "dataset")
- with expdb.connect() as conn:
+def get_dataset(dataset_id: int, engine: Engine) -> dict[str, Any] | None:
+ columns = get_column_names(engine, "dataset")
+ with engine.connect() as conn:
row = conn.execute(
text(
"""
@@ -38,9 +22,9 @@ def get_dataset(dataset_id: int) -> dict[str, Any] | None:
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
-def get_file(file_id: int) -> dict[str, Any] | None:
- columns = get_column_names(openml, "file")
- with openml.connect() as conn:
+def get_file(file_id: int, engine: Engine) -> dict[str, Any] | None:
+ columns = get_column_names(engine, "file")
+ with engine.connect() as conn:
row = conn.execute(
text(
"""
@@ -54,9 +38,9 @@ def get_file(file_id: int) -> dict[str, Any] | None:
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
-def get_tags(dataset_id: int) -> list[str]:
- columns = get_column_names(expdb, "dataset_tag")
- with expdb.connect() as conn:
+def get_tags(dataset_id: int, engine: Engine) -> list[str]:
+ columns = get_column_names(engine, "dataset_tag")
+ with engine.connect() as conn:
rows = conn.execute(
text(
"""
@@ -70,9 +54,12 @@ def get_tags(dataset_id: int) -> list[str]:
return [dict(zip(columns, row, strict=True))["tag"] for row in rows]
-def get_latest_dataset_description(dataset_id: int) -> dict[str, Any] | None:
- columns = get_column_names(expdb, "dataset_description")
- with expdb.connect() as conn:
+def get_latest_dataset_description(
+ dataset_id: int,
+ engine: Engine,
+) -> dict[str, Any] | None:
+ columns = get_column_names(engine, "dataset_description")
+ with engine.connect() as conn:
row = conn.execute(
text(
"""
@@ -87,9 +74,9 @@ def get_latest_dataset_description(dataset_id: int) -> dict[str, Any] | None:
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
-def get_latest_status_update(dataset_id: int) -> dict[str, Any] | None:
- columns = get_column_names(expdb, "dataset_status")
- with expdb.connect() as conn:
+def get_latest_status_update(dataset_id: int, engine: Engine) -> dict[str, Any] | None:
+ columns = get_column_names(engine, "dataset_status")
+ with engine.connect() as conn:
row = conn.execute(
text(
"""
@@ -106,9 +93,9 @@ def get_latest_status_update(dataset_id: int) -> dict[str, Any] | None:
)
-def get_latest_processing_update(dataset_id: int) -> dict[str, Any] | None:
- columns = get_column_names(expdb, "data_processed")
- with expdb.connect() as conn:
+def get_latest_processing_update(dataset_id: int, engine: Engine) -> dict[str, Any] | None:
+ columns = get_column_names(engine, "data_processed")
+ with engine.connect() as conn:
row = conn.execute(
text(
"""
diff --git a/src/database/setup.py b/src/database/setup.py
new file mode 100644
index 0000000..6f6a1e9
--- /dev/null
+++ b/src/database/setup.py
@@ -0,0 +1,30 @@
+from config import load_database_configuration
+from sqlalchemy import Engine, create_engine
+from sqlalchemy.engine import URL
+
+_user_engine = None
+_expdb_engine = None
+
+
+def _create_engine(database_name: str) -> Engine:
+ database_configuration = load_database_configuration()
+ db_url = URL.create(**database_configuration[database_name])
+ return create_engine(
+ db_url,
+ echo=True,
+ pool_recycle=3600,
+ )
+
+
+def user_database() -> Engine:
+ global _user_engine
+ if _user_engine is None:
+ _user_engine = _create_engine("openml")
+ return _user_engine
+
+
+def expdb_database() -> Engine:
+ global _expdb_engine
+ if _expdb_engine is None:
+ _expdb_engine = _create_engine("expdb")
+ return _expdb_engine
diff --git a/src/database/users.py b/src/database/users.py
index f8f58d5..4b1fc30 100644
--- a/src/database/users.py
+++ b/src/database/users.py
@@ -1,28 +1,17 @@
from typing import Annotated
-from config import load_database_configuration
from pydantic import StringConstraints
-from sqlalchemy import create_engine, text
-from sqlalchemy.engine import URL
+from sqlalchemy import Engine, text
from database.meta import get_column_names
-_database_configuration = load_database_configuration()
-
-openml_url = URL.create(**_database_configuration["openml"])
-openml = create_engine(
- openml_url,
- echo=True,
- pool_recycle=3600,
-)
-
# Enforces str is 32 hexadecimal characters, does not check validity.
APIKey = Annotated[str, StringConstraints(pattern=r"^[0-9a-fA-F]{32}$")]
-def get_user_id_for(*, api_key: APIKey) -> int | None:
- columns = get_column_names(openml, "users")
- with openml.connect() as conn:
+def get_user_id_for(*, api_key: APIKey, engine: Engine) -> int | None:
+ columns = get_column_names(engine, "users")
+ with engine.connect() as conn:
row = conn.execute(
text(
"""
@@ -38,8 +27,8 @@ def get_user_id_for(*, api_key: APIKey) -> int | None:
return int(dict(zip(columns, user, strict=True))["id"])
-def get_user_groups_for(*, user_id: int) -> list[int]:
- with openml.connect() as conn:
+def get_user_groups_for(*, user_id: int, engine: Engine) -> list[int]:
+ with engine.connect() as conn:
row = conn.execute(
text(
"""
diff --git a/src/routers/datasets.py b/src/routers/datasets.py
index 860412f..486f639 100644
--- a/src/routers/datasets.py
+++ b/src/routers/datasets.py
@@ -2,7 +2,7 @@
import http.client
from collections import namedtuple
from enum import IntEnum
-from typing import Any
+from typing import Annotated, Any
from database.datasets import get_dataset as db_get_dataset
from database.datasets import (
@@ -12,14 +12,16 @@
get_latest_status_update,
get_tags,
)
+from database.setup import expdb_database, user_database
from database.users import APIKey, get_user_groups_for, get_user_id_for
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
from schemas.datasets.openml import (
DatasetFileFormat,
DatasetMetadata,
DatasetStatus,
Visibility,
)
+from sqlalchemy import Engine
router = APIRouter(prefix="/datasets", tags=["datasets"])
@@ -33,9 +35,9 @@ class DatasetError(IntEnum):
processing_info = namedtuple("processing_info", ["date", "warning", "error"])
-def _get_processing_information(dataset_id: int) -> processing_info:
+def _get_processing_information(dataset_id: int, engine: Engine) -> processing_info:
"""Return processing information, if any. Otherwise, all fields `None`."""
- if not (data_processed := get_latest_processing_update(dataset_id)):
+ if not (data_processed := get_latest_processing_update(dataset_id, engine)):
return processing_info(date=None, warning=None, error=None)
date_processed = data_processed["processing_date"]
@@ -51,6 +53,7 @@ def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:
def _user_has_access(
dataset: dict[str, Any],
+ engine: Engine,
api_key: APIKey | None = None,
) -> bool:
"""Determine if user of `api_key` has the right to view `dataset`."""
@@ -59,13 +62,13 @@ def _user_has_access(
if not api_key:
return False
- if not (user_id := get_user_id_for(api_key=api_key)):
+ if not (user_id := get_user_id_for(api_key=api_key, engine=engine)):
return False
if user_id == dataset["uploader"]:
return True
- user_groups = get_user_groups_for(user_id=user_id)
+ user_groups = get_user_groups_for(user_id=user_id, engine=engine)
ADMIN_GROUP = 1
return ADMIN_GROUP in user_groups
@@ -106,26 +109,28 @@ def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
def get_dataset(
dataset_id: int,
api_key: APIKey | None = None,
+ user_db: Annotated[Engine, Depends(user_database)] = None,
+ expdb_db: Annotated[Engine, Depends(expdb_database)] = None,
) -> DatasetMetadata:
- if not (dataset := db_get_dataset(dataset_id)):
+ if not (dataset := db_get_dataset(dataset_id, expdb_db)):
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)
- if not _user_has_access(dataset, api_key):
+ if not _user_has_access(dataset=dataset, api_key=api_key, engine=user_db):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)
- if not (dataset_file := get_file(dataset["file_id"])):
+ if not (dataset_file := get_file(dataset["file_id"], user_db)):
error = _format_error(
code=DatasetError.NO_DATA_FILE,
message="No data file found",
)
raise HTTPException(status_code=http.client.PRECONDITION_FAILED, detail=error)
- tags = get_tags(dataset_id)
- description = get_latest_dataset_description(dataset_id)
- processing_result = _get_processing_information(dataset_id)
- status = get_latest_status_update(dataset_id)
+ tags = get_tags(dataset_id, expdb_db)
+ description = get_latest_dataset_description(dataset_id, expdb_db)
+ processing_result = _get_processing_information(dataset_id, expdb_db)
+ status = get_latest_status_update(dataset_id, expdb_db)
status_ = DatasetStatus(status["status"]) if status else DatasetStatus.IN_PREPARATION
diff --git a/src/routers/mldcat_ap/dataset.py b/src/routers/mldcat_ap/dataset.py
index 95bfebe..1e1be28 100644
--- a/src/routers/mldcat_ap/dataset.py
+++ b/src/routers/mldcat_ap/dataset.py
@@ -1,5 +1,9 @@
-from fastapi import APIRouter
+from typing import Annotated
+
+from database.setup import expdb_database, user_database
+from fastapi import APIRouter, Depends
from schemas.datasets.mldcat_ap import JsonLDGraph, convert_to_mldcat_ap
+from sqlalchemy import Engine
from routers.datasets import get_dataset
@@ -10,6 +14,14 @@
path="/{dataset_id}",
description="Get meta-data for dataset with ID `dataset_id`.",
)
-def get_mldcat_ap_dataset(dataset_id: int) -> JsonLDGraph:
- openml_dataset = get_dataset(dataset_id)
+def get_mldcat_ap_dataset(
+ dataset_id: int,
+ user_db: Annotated[Engine, Depends(user_database)] = None,
+ expdb_db: Annotated[Engine, Depends(expdb_database)] = None,
+) -> JsonLDGraph:
+ openml_dataset = get_dataset(
+ dataset_id=dataset_id,
+ user_db=user_db,
+ expdb_db=expdb_db,
+ )
return convert_to_mldcat_ap(openml_dataset)
diff --git a/src/routers/old/datasets.py b/src/routers/old/datasets.py
index df58cf3..efc7575 100644
--- a/src/routers/old/datasets.py
+++ b/src/routers/old/datasets.py
@@ -3,10 +3,12 @@
new API, and are easily removed later.
"""
import http.client
-from typing import Any
+from typing import Annotated, Any
+from database.setup import expdb_database, user_database
from database.users import APIKey
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
+from sqlalchemy import Engine
from routers.datasets import get_dataset
@@ -20,9 +22,16 @@
def get_dataset_wrapped(
dataset_id: int,
api_key: APIKey | None = None,
+ user_db: Annotated[Engine, Depends(user_database)] = None,
+ expdb_db: Annotated[Engine, Depends(expdb_database)] = None,
) -> dict[str, dict[str, Any]]:
try:
- dataset = get_dataset(dataset_id, api_key).model_dump(by_alias=True)
+ dataset = get_dataset(
+ user_db=user_db,
+ expdb_db=expdb_db,
+ dataset_id=dataset_id,
+ api_key=api_key,
+ ).model_dump(by_alias=True)
except HTTPException as e:
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
diff --git a/tests/routers/datasets_test.py b/tests/routers/datasets_test.py
new file mode 100644
index 0000000..987b6e9
--- /dev/null
+++ b/tests/routers/datasets_test.py
@@ -0,0 +1,62 @@
+import http.client
+from typing import Any, cast
+
+import httpx
+import pytest
+from fastapi import FastAPI
+
+
+@pytest.mark.parametrize(
+ ("endpoint", "dataset_id", "response_code"),
+ [
+ ("datasets/", -1, http.client.NOT_FOUND),
+ ("datasets/", 138, http.client.NOT_FOUND),
+ ("datasets/", 100_000, http.client.NOT_FOUND),
+ ],
+)
+def test_error_unknown_dataset(
+ endpoint: str,
+ dataset_id: int,
+ response_code: int,
+ api_client: FastAPI,
+) -> None:
+ response = cast(httpx.Response, api_client.get(f"{endpoint}/{dataset_id}"))
+
+ assert response.status_code == response_code
+ assert {"code": "111", "message": "Unknown dataset"} == response.json()["detail"]
+
+
+@pytest.mark.parametrize(
+ ("endpoint", "api_key", "response_code"),
+ [
+ ("datasets", None, http.client.FORBIDDEN),
+ ("datasets", "a" * 32, http.client.FORBIDDEN),
+ ],
+)
+def test_private_dataset_no_user_no_access(
+ api_client: FastAPI,
+ endpoint: str,
+ api_key: str | None,
+ response_code: int,
+) -> None:
+ query = f"?api_key={api_key}" if api_key else ""
+ response = cast(httpx.Response, api_client.get(f"{endpoint}/130{query}"))
+
+ assert response.status_code == response_code
+ assert {"code": "112", "message": "No access granted"} == response.json()["detail"]
+
+
+@pytest.mark.skip("Not sure how to include apikey in test yet.")
+def test_private_dataset_owner_access(
+ api_client: FastAPI,
+ dataset_130: dict[str, Any],
+) -> None:
+ response = cast(httpx.Response, api_client.get("/datasets/130?api_key=..."))
+ assert response.status_code == http.client.OK
+ assert dataset_130 == response.json()
+
+
+@pytest.mark.skip("Not sure how to include apikey in test yet.")
+def test_private_dataset_admin_access(api_client: FastAPI) -> None:
+ cast(httpx.Response, api_client.get("/datasets/130?api_key=..."))
+ # test against cached response
diff --git a/tests/identical_test.py b/tests/routers/old/datasets_old_test.py
similarity index 72%
rename from tests/identical_test.py
rename to tests/routers/old/datasets_old_test.py
index 1ad4935..8cfef7b 100644
--- a/tests/identical_test.py
+++ b/tests/routers/old/datasets_old_test.py
@@ -7,7 +7,7 @@
from fastapi import FastAPI
-@pytest.mark.web()
+@pytest.mark.php()
@pytest.mark.parametrize(
"dataset_id",
range(1, 132),
@@ -54,47 +54,31 @@ def test_dataset_response_is_identical(dataset_id: int, api_client: FastAPI) ->
@pytest.mark.parametrize(
- ("endpoint", "dataset_id", "response_code"),
- [
- ("old/datasets/", -1, http.client.PRECONDITION_FAILED),
- ("old/datasets/", 138, http.client.PRECONDITION_FAILED),
- ("old/datasets/", 100_000, http.client.PRECONDITION_FAILED),
- ("datasets/", -1, http.client.NOT_FOUND),
- ("datasets/", 138, http.client.NOT_FOUND),
- ("datasets/", 100_000, http.client.NOT_FOUND),
- ],
+ "dataset_id",
+ [-1, 138, 100_000],
)
def test_error_unknown_dataset(
- endpoint: str,
dataset_id: int,
- response_code: int,
api_client: FastAPI,
) -> None:
- response = cast(httpx.Response, api_client.get(f"{endpoint}/{dataset_id}"))
+ response = cast(httpx.Response, api_client.get(f"old/datasets/{dataset_id}"))
- assert response.status_code == response_code
+ assert response.status_code == http.client.PRECONDITION_FAILED
assert {"code": "111", "message": "Unknown dataset"} == response.json()["detail"]
@pytest.mark.parametrize(
- ("endpoint", "api_key", "response_code"),
- [
- ("old/datasets", None, http.client.PRECONDITION_FAILED),
- ("old/datasets", "a" * 32, http.client.PRECONDITION_FAILED),
- ("datasets", None, http.client.FORBIDDEN),
- ("datasets", "a" * 32, http.client.FORBIDDEN),
- ],
+ "api_key",
+ [None, "a" * 32],
)
def test_private_dataset_no_user_no_access(
api_client: FastAPI,
- endpoint: str,
api_key: str | None,
- response_code: int,
) -> None:
query = f"?api_key={api_key}" if api_key else ""
- response = cast(httpx.Response, api_client.get(f"{endpoint}/130{query}"))
+ response = cast(httpx.Response, api_client.get(f"old/datasets/130{query}"))
- assert response.status_code == response_code
+ assert response.status_code == http.client.PRECONDITION_FAILED
assert {"code": "112", "message": "No access granted"} == response.json()["detail"]