Skip to content

Commit

Permalink
Improve re-usability of subcomponents (#116)
Browse files Browse the repository at this point in the history
* Move auxiliary functions and types to separate files
* Refactor shared logic for dataset access for users
* Move some queries to db module
* remove level of nestedness in api response
  • Loading branch information
PGijsbers authored Nov 29, 2023
1 parent 64d0be5 commit 319fe5e
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 145 deletions.
15 changes: 15 additions & 0 deletions src/core/access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any

from database.users import User, UserGroup
from schemas.datasets.openml import Visibility


def _user_has_access(
dataset: dict[str, Any],
user: User | None = None,
) -> bool:
"""Determine if `user` has the right to view `dataset`."""
is_public = dataset["visibility"] == Visibility.PUBLIC
return is_public or (
user is not None and (user.user_id == dataset["uploader"] or UserGroup.ADMIN in user.groups)
)
7 changes: 7 additions & 0 deletions src/core/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import IntEnum


class DatasetError(IntEnum):
NOT_FOUND = 111
NO_ACCESS = 112
NO_DATA_FILE = 113
40 changes: 40 additions & 0 deletions src/core/formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import html
from typing import Any

from schemas.datasets.openml import DatasetFileFormat

from core.errors import DatasetError


def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:
"""Formatter for JSON bodies of OpenML error codes."""
return {"code": str(code), "message": message}


def _format_parquet_url(dataset: dict[str, Any]) -> str | None:
if dataset["format"].lower() != DatasetFileFormat.ARFF:
return None

minio_base_url = "https://openml1.win.tue.nl"
return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq"


def _format_dataset_url(dataset: dict[str, Any]) -> str:
base_url = "https://test.openml.org"
filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}"
return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}"


def _safe_unquote(text: str | None) -> str | None:
"""Remove any open and closing quotes and return the remainder if non-empty."""
if not text:
return None
return text.strip("'\"") or None


def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
"""Return comma-separated values in `text` as list, optionally remove quotes."""
if not text:
return []
chars_to_strip = "'\"\t " if unquote_items else "\t "
return [item.strip(chars_to_strip) for item in text.split(",")]
40 changes: 39 additions & 1 deletion src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,49 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
from typing import Any
from collections import defaultdict
from typing import Any, Iterable

from schemas.datasets.openml import Quality
from sqlalchemy import Connection, text

from database.meta import get_column_names


def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
rows = connection.execute(
text(
"""
SELECT `quality`,`value`
FROM data_quality
WHERE `data`=:dataset_id
""",
),
parameters={"dataset_id": dataset_id},
)
return [Quality(name=row.quality, value=row.value) for row in rows]


def get_qualities_for_datasets(
dataset_ids: Iterable[int],
qualities: Iterable[str],
connection: Connection,
) -> dict[int, list[Quality]]:
qualities_filter = ",".join(f"'{q}'" for q in qualities)
dids = ",".join(str(did) for did in dataset_ids)
qualities_query = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - similar to above, no user input
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
for did, quality, value in rows:
if value is not None:
qualities_by_id[did].append(Quality(name=quality, value=value))
return dict(qualities_by_id)


def list_all_qualities(connection: Connection) -> list[str]:
# The current implementation only fetches *used* qualities, otherwise you should
# query: SELECT `name` FROM `quality` WHERE `type`='DataQuality'
Expand Down
129 changes: 35 additions & 94 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,35 @@
We add separate endpoints for old-style JSON responses, so they don't clutter the schema of the
new API, and are easily removed later.
"""
import html
import http.client
from collections import namedtuple
from enum import IntEnum, StrEnum
from typing import Annotated, Any, Literal

import re
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Any, NamedTuple

from core.access import _user_has_access
from core.errors import DatasetError
from core.formatting import (
_csv_as_list,
_format_dataset_url,
_format_error,
_format_parquet_url,
_safe_unquote,
)
from database.datasets import get_dataset as db_get_dataset
from database.datasets import (
get_file,
get_latest_dataset_description,
get_latest_processing_update,
get_latest_status_update,
get_qualities_for_datasets,
get_tags,
)
from database.datasets import tag_dataset as db_tag_dataset
from database.users import APIKey, User, UserGroup, get_user_groups_for, get_user_id_for
from database.users import User, UserGroup
from fastapi import APIRouter, Body, Depends, HTTPException
from schemas.datasets.openml import DatasetFileFormat, DatasetMetadata, DatasetStatus, Visibility
from sqlalchemy import Connection
from schemas.datasets.openml import DatasetMetadata, DatasetStatus
from sqlalchemy import Connection, text

from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection
from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex
Expand Down Expand Up @@ -97,9 +107,7 @@ def list_datasets(
status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE,
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
) -> dict[Literal["data"], dict[Literal["dataset"], list[dict[str, Any]]]]:
from sqlalchemy import text

) -> list[dict[str, Any]]:
current_status = text(
"""
SELECT ds1.`did`, ds1.`status`
Expand Down Expand Up @@ -150,8 +158,6 @@ def list_datasets(
else ""
)

import re

def quality_clause(quality: str, range_: str | None) -> str:
if not range_:
return ""
Expand Down Expand Up @@ -230,96 +236,31 @@ def quality_clause(quality: str, range_: str | None) -> str:
"NumberOfNumericFeatures",
"NumberOfSymbolicFeatures",
]
qualities_filter = ",".join(f"'{q}'" for q in qualities_to_show)
dids = ",".join(str(did) for did in datasets)
qualities = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - similar to above, no user input
qualities_by_dataset = get_qualities_for_datasets(
dataset_ids=datasets.keys(),
qualities=qualities_to_show,
connection=expdb_db,
)
qualities = expdb_db.execute(qualities)
for did, quality, value in qualities:
if value is not None:
datasets[did]["quality"].append({"name": quality, "value": str(value)})
return {"data": {"dataset": list(datasets.values())}}

for did, qualities in qualities_by_dataset.items():
datasets[did]["quality"] = qualities
return list(datasets.values())

class DatasetError(IntEnum):
NOT_FOUND = 111
NO_ACCESS = 112
NO_DATA_FILE = 113

class ProcessingInformation(NamedTuple):
date: datetime | None
warning: str | None
error: str | None

processing_info = namedtuple("processing_info", ["date", "warning", "error"])


def _get_processing_information(dataset_id: int, connection: Connection) -> processing_info:
def _get_processing_information(dataset_id: int, connection: Connection) -> ProcessingInformation:
"""Return processing information, if any. Otherwise, all fields `None`."""
if not (data_processed := get_latest_processing_update(dataset_id, connection)):
return processing_info(date=None, warning=None, error=None)
return ProcessingInformation(date=None, warning=None, error=None)

date_processed = data_processed["processing_date"]
warning = data_processed["warning"].strip() if data_processed["warning"] else None
error = data_processed["error"].strip() if data_processed["error"] else None
return processing_info(date=date_processed, warning=warning, error=error)


def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:
"""Formatter for JSON bodies of OpenML error codes."""
return {"code": str(code), "message": message}


def _user_has_access(
dataset: dict[str, Any],
connection: Connection,
api_key: APIKey | None = None,
) -> bool:
"""Determine if user of `api_key` has the right to view `dataset`."""
if dataset["visibility"] == Visibility.PUBLIC:
return True
if not api_key:
return False

if not (user_id := get_user_id_for(api_key=api_key, connection=connection)):
return False

if user_id == dataset["uploader"]:
return True

user_groups = get_user_groups_for(user_id=user_id, connection=connection)
ADMIN_GROUP = 1
return ADMIN_GROUP in user_groups


def _format_parquet_url(dataset: dict[str, Any]) -> str | None:
if dataset["format"].lower() != DatasetFileFormat.ARFF:
return None

minio_base_url = "https://openml1.win.tue.nl"
return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq"


def _format_dataset_url(dataset: dict[str, Any]) -> str:
base_url = "https://test.openml.org"
filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}"
return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}"


def _safe_unquote(text: str | None) -> str | None:
"""Remove any open and closing quotes and return the remainder if non-empty."""
if not text:
return None
return text.strip("'\"") or None


def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
"""Return comma-separated values in `text` as list, optionally remove quotes."""
if not text:
return []
chars_to_strip = "'\"\t " if unquote_items else "\t "
return [item.strip(chars_to_strip) for item in text.split(",")]
return ProcessingInformation(date=date_processed, warning=warning, error=error)


@router.get(
Expand All @@ -328,15 +269,15 @@ def _csv_as_list(text: str | None, *, unquote_items: bool = True) -> list[str]:
)
def get_dataset(
dataset_id: int,
api_key: APIKey | None = None,
user: Annotated[User | None, Depends(fetch_user)] = None,
user_db: Annotated[Connection, Depends(userdb_connection)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
) -> DatasetMetadata:
if not (dataset := db_get_dataset(dataset_id, expdb_db)):
error = _format_error(code=DatasetError.NOT_FOUND, message="Unknown dataset")
raise HTTPException(status_code=http.client.NOT_FOUND, detail=error)

if not _user_has_access(dataset=dataset, connection=user_db, api_key=api_key):
if not _user_has_access(dataset=dataset, user=user):
error = _format_error(code=DatasetError.NO_ACCESS, message="No access granted")
raise HTTPException(status_code=http.client.FORBIDDEN, detail=error)

Expand Down
33 changes: 8 additions & 25 deletions src/routers/openml/qualities.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import http.client
from typing import Annotated, Any, Literal
from typing import Annotated, Literal

from database.datasets import get_dataset, list_all_qualities
from database.users import User, UserGroup
from core.access import _user_has_access
from core.errors import DatasetError
from database.datasets import get_dataset, get_qualities_for_dataset, list_all_qualities
from database.users import User
from fastapi import APIRouter, Depends, HTTPException
from schemas.datasets.openml import Quality
from sqlalchemy import Connection, text
from sqlalchemy import Connection

from routers.dependencies import expdb_connection, fetch_user
from routers.openml.datasets import DatasetError

router = APIRouter(prefix="/datasets", tags=["datasets"])

Expand All @@ -25,37 +26,19 @@ def list_qualities(
}


def _user_can_see_dataset(dataset: dict[str, Any], user: User) -> bool:
if dataset["visibility"] == "public":
return True
return user is not None and (
dataset["uploader"] == user.user_id or UserGroup.ADMIN in user.groups
)


@router.get("/qualities/{dataset_id}")
def get_qualities(
dataset_id: int,
user: Annotated[User, Depends(fetch_user)],
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> list[Quality]:
dataset = get_dataset(dataset_id, expdb)
if not dataset or not _user_can_see_dataset(dataset, user):
if not dataset or not _user_has_access(dataset, user):
raise HTTPException(
status_code=http.client.PRECONDITION_FAILED,
detail={"code": DatasetError.NO_DATA_FILE, "message": "Unknown dataset"},
) from None
rows = expdb.execute(
text(
"""
SELECT `quality`,`value`
FROM data_quality
WHERE `data`=:dataset_id
""",
),
parameters={"dataset_id": dataset_id},
)
return [Quality(name=row.quality, value=row.value) for row in rows]
return get_qualities_for_dataset(dataset_id, expdb)
# The PHP API provided (sometime) helpful error messages
# if not qualities:
# check if dataset exists: error 360
Expand Down
Loading

0 comments on commit 319fe5e

Please sign in to comment.