Skip to content

Commit

Permalink
Use fastapi.Depends for dependency injection of the database (#90)
Browse files Browse the repository at this point in the history
* Add engine as parameter to allow dependency injection

* Add engine parameter to allow dependency injection

* Move database initialization to shared setup

* Use fastapi.Depends for dependency injection at the endpoint

* Move old and new dataset endpoint tests to separate files

* Add database dependencies

* Define auto-injected parameters last

There is a quirck where `None` is a valid `Engine`, which allows
us to put it behind other optional parameters. In principle, I
do not like that it is technical not optional (but provided by
FastAPI) but I do prefer having these parameters last instead of
first.
  • Loading branch information
PGijsbers authored Nov 3, 2023
1 parent bbdae4e commit 3247a11
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 97 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ pythonpath = [
"src"
]
markers = [
"web: uses an internet connection"
"php: tests that compare directly to an old PHP endpoint"
]
57 changes: 22 additions & 35 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
from typing import Any

from config import load_database_configuration
from sqlalchemy import create_engine, text
from sqlalchemy.engine import URL
from sqlalchemy import Engine, text

from database.meta import get_column_names

_database_configuration = load_database_configuration()
expdb_url = URL.create(**_database_configuration["expdb"])
expdb = create_engine(
expdb_url,
echo=True,
pool_recycle=3600,
)
openml_url = URL.create(**_database_configuration["openml"])
openml = create_engine(
openml_url,
echo=True,
pool_recycle=3600,
)


def get_dataset(dataset_id: int) -> dict[str, Any] | None:
columns = get_column_names(expdb, "dataset")
with expdb.connect() as conn:
def get_dataset(dataset_id: int, engine: Engine) -> dict[str, Any] | None:
columns = get_column_names(engine, "dataset")
with engine.connect() as conn:
row = conn.execute(
text(
"""
Expand All @@ -38,9 +22,9 @@ def get_dataset(dataset_id: int) -> dict[str, Any] | None:
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None


def get_file(file_id: int) -> dict[str, Any] | None:
columns = get_column_names(openml, "file")
with openml.connect() as conn:
def get_file(file_id: int, engine: Engine) -> dict[str, Any] | None:
columns = get_column_names(engine, "file")
with engine.connect() as conn:
row = conn.execute(
text(
"""
Expand All @@ -54,9 +38,9 @@ def get_file(file_id: int) -> dict[str, Any] | None:
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None


def get_tags(dataset_id: int) -> list[str]:
columns = get_column_names(expdb, "dataset_tag")
with expdb.connect() as conn:
def get_tags(dataset_id: int, engine: Engine) -> list[str]:
columns = get_column_names(engine, "dataset_tag")
with engine.connect() as conn:
rows = conn.execute(
text(
"""
Expand All @@ -70,9 +54,12 @@ def get_tags(dataset_id: int) -> list[str]:
return [dict(zip(columns, row, strict=True))["tag"] for row in rows]


def get_latest_dataset_description(dataset_id: int) -> dict[str, Any] | None:
columns = get_column_names(expdb, "dataset_description")
with expdb.connect() as conn:
def get_latest_dataset_description(
dataset_id: int,
engine: Engine,
) -> dict[str, Any] | None:
columns = get_column_names(engine, "dataset_description")
with engine.connect() as conn:
row = conn.execute(
text(
"""
Expand All @@ -87,9 +74,9 @@ def get_latest_dataset_description(dataset_id: int) -> dict[str, Any] | None:
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None


def get_latest_status_update(dataset_id: int) -> dict[str, Any] | None:
columns = get_column_names(expdb, "dataset_status")
with expdb.connect() as conn:
def get_latest_status_update(dataset_id: int, engine: Engine) -> dict[str, Any] | None:
columns = get_column_names(engine, "dataset_status")
with engine.connect() as conn:
row = conn.execute(
text(
"""
Expand All @@ -106,9 +93,9 @@ def get_latest_status_update(dataset_id: int) -> dict[str, Any] | None:
)


def get_latest_processing_update(dataset_id: int) -> dict[str, Any] | None:
columns = get_column_names(expdb, "data_processed")
with expdb.connect() as conn:
def get_latest_processing_update(dataset_id: int, engine: Engine) -> dict[str, Any] | None:
columns = get_column_names(engine, "data_processed")
with engine.connect() as conn:
row = conn.execute(
text(
"""
Expand Down
30 changes: 30 additions & 0 deletions src/database/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from config import load_database_configuration
from sqlalchemy import Engine, create_engine
from sqlalchemy.engine import URL

_user_engine = None
_expdb_engine = None


def _create_engine(database_name: str) -> Engine:
database_configuration = load_database_configuration()
db_url = URL.create(**database_configuration[database_name])
return create_engine(
db_url,
echo=True,
pool_recycle=3600,
)


def user_database() -> Engine:
global _user_engine
if _user_engine is None:
_user_engine = _create_engine("openml")
return _user_engine


def expdb_database() -> Engine:
global _expdb_engine
if _expdb_engine is None:
_expdb_engine = _create_engine("expdb")
return _expdb_engine
23 changes: 6 additions & 17 deletions src/database/users.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,17 @@
from typing import Annotated

from config import load_database_configuration
from pydantic import StringConstraints
from sqlalchemy import create_engine, text
from sqlalchemy.engine import URL
from sqlalchemy import Engine, text

from database.meta import get_column_names

_database_configuration = load_database_configuration()

openml_url = URL.create(**_database_configuration["openml"])
openml = create_engine(
openml_url,
echo=True,
pool_recycle=3600,
)

# Enforces str is 32 hexadecimal characters, does not check validity.
APIKey = Annotated[str, StringConstraints(pattern=r"^[0-9a-fA-F]{32}$")]


def get_user_id_for(*, api_key: APIKey) -> int | None:
columns = get_column_names(openml, "users")
with openml.connect() as conn:
def get_user_id_for(*, api_key: APIKey, engine: Engine) -> int | None:
columns = get_column_names(engine, "users")
with engine.connect() as conn:
row = conn.execute(
text(
"""
Expand All @@ -38,8 +27,8 @@ def get_user_id_for(*, api_key: APIKey) -> int | None:
return int(dict(zip(columns, user, strict=True))["id"])


def get_user_groups_for(*, user_id: int) -> list[int]:
with openml.connect() as conn:
def get_user_groups_for(*, user_id: int, engine: Engine) -> list[int]:
with engine.connect() as conn:
row = conn.execute(
text(
"""
Expand Down
31 changes: 18 additions & 13 deletions src/routers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import http.client
from collections import namedtuple
from enum import IntEnum
from typing import Any
from typing import Annotated, Any

from database.datasets import get_dataset as db_get_dataset
from database.datasets import (
Expand All @@ -12,14 +12,16 @@
get_latest_status_update,
get_tags,
)
from database.setup import expdb_database, user_database
from database.users import APIKey, get_user_groups_for, get_user_id_for
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from schemas.datasets.openml import (
DatasetFileFormat,
DatasetMetadata,
DatasetStatus,
Visibility,
)
from sqlalchemy import Engine

router = APIRouter(prefix="/datasets", tags=["datasets"])

Expand All @@ -33,9 +35,9 @@ class DatasetError(IntEnum):
processing_info = namedtuple("processing_info", ["date", "warning", "error"])


def _get_processing_information(dataset_id: int) -> processing_info:
def _get_processing_information(dataset_id: int, engine: Engine) -> processing_info:
"""Return processing information, if any. Otherwise, all fields `None`."""
if not (data_processed := get_latest_processing_update(dataset_id)):
if not (data_processed := get_latest_processing_update(dataset_id, engine)):
return processing_info(date=None, warning=None, error=None)

date_processed = data_processed["processing_date"]
Expand All @@ -51,6 +53,7 @@ def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:

def _user_has_access(
dataset: dict[str, Any],
engine: Engine,
api_key: APIKey | None = None,
) -> bool:
"""Determine if user of `api_key` has the right to view `dataset`."""
Expand All @@ -59,13 +62,13 @@ def _user_has_access(
if not api_key:
return False

if not (user_id := get_user_id_for(api_key=api_key)):
if not (user_id := get_user_id_for(api_key=api_key, engine=engine)):
return False

if user_id == dataset["uploader"]:
return True

user_groups = get_user_groups_for(user_id=user_id)
user_groups = get_user_groups_for(user_id=user_id, engine=engine)
ADMIN_GROUP = 1
return ADMIN_GROUP in user_groups

Expand Down Expand Up @@ -106,26 +109,28 @@ def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
def get_dataset(
dataset_id: int,
api_key: APIKey | None = None,
user_db: Annotated[Engine, Depends(user_database)] = None,
expdb_db: Annotated[Engine, Depends(expdb_database)] = None,
) -> DatasetMetadata:
if not (dataset := db_get_dataset(dataset_id)):
if not (dataset := db_get_dataset(dataset_id, expdb_db)):
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)

if not _user_has_access(dataset, api_key):
if not _user_has_access(dataset=dataset, api_key=api_key, engine=user_db):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)

if not (dataset_file := get_file(dataset["file_id"])):
if not (dataset_file := get_file(dataset["file_id"], user_db)):
error = _format_error(
code=DatasetError.NO_DATA_FILE,
message="No data file found",
)
raise HTTPException(status_code=http.client.PRECONDITION_FAILED, detail=error)

tags = get_tags(dataset_id)
description = get_latest_dataset_description(dataset_id)
processing_result = _get_processing_information(dataset_id)
status = get_latest_status_update(dataset_id)
tags = get_tags(dataset_id, expdb_db)
description = get_latest_dataset_description(dataset_id, expdb_db)
processing_result = _get_processing_information(dataset_id, expdb_db)
status = get_latest_status_update(dataset_id, expdb_db)

status_ = DatasetStatus(status["status"]) if status else DatasetStatus.IN_PREPARATION

Expand Down
18 changes: 15 additions & 3 deletions src/routers/mldcat_ap/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from fastapi import APIRouter
from typing import Annotated

from database.setup import expdb_database, user_database
from fastapi import APIRouter, Depends
from schemas.datasets.mldcat_ap import JsonLDGraph, convert_to_mldcat_ap
from sqlalchemy import Engine

from routers.datasets import get_dataset

Expand All @@ -10,6 +14,14 @@
path="/{dataset_id}",
description="Get meta-data for dataset with ID `dataset_id`.",
)
def get_mldcat_ap_dataset(dataset_id: int) -> JsonLDGraph:
openml_dataset = get_dataset(dataset_id)
def get_mldcat_ap_dataset(
dataset_id: int,
user_db: Annotated[Engine, Depends(user_database)] = None,
expdb_db: Annotated[Engine, Depends(expdb_database)] = None,
) -> JsonLDGraph:
openml_dataset = get_dataset(
dataset_id=dataset_id,
user_db=user_db,
expdb_db=expdb_db,
)
return convert_to_mldcat_ap(openml_dataset)
15 changes: 12 additions & 3 deletions src/routers/old/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
new API, and are easily removed later.
"""
import http.client
from typing import Any
from typing import Annotated, Any

from database.setup import expdb_database, user_database
from database.users import APIKey
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import Engine

from routers.datasets import get_dataset

Expand All @@ -20,9 +22,16 @@
def get_dataset_wrapped(
dataset_id: int,
api_key: APIKey | None = None,
user_db: Annotated[Engine, Depends(user_database)] = None,
expdb_db: Annotated[Engine, Depends(expdb_database)] = None,
) -> dict[str, dict[str, Any]]:
try:
dataset = get_dataset(dataset_id, api_key).model_dump(by_alias=True)
dataset = get_dataset(
user_db=user_db,
expdb_db=expdb_db,
dataset_id=dataset_id,
api_key=api_key,
).model_dump(by_alias=True)
except HTTPException as e:
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
Expand Down
Loading

0 comments on commit 3247a11

Please sign in to comment.