Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor db layer to access results in a consistent way #137

Merged
merged 8 commits into from
Jan 3, 2024
9 changes: 4 additions & 5 deletions src/core/access.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Any

from database.users import User, UserGroup
from schemas.datasets.openml import Visibility
from sqlalchemy.engine import Row


def _user_has_access(
dataset: dict[str, Any],
dataset: Row,
user: User | None = None,
) -> bool:
"""Determine if `user` has the right to view `dataset`."""
is_public = dataset["visibility"] == Visibility.PUBLIC
is_public = dataset.visibility == Visibility.PUBLIC
return is_public or (
user is not None and (user.user_id == dataset["uploader"] or UserGroup.ADMIN in user.groups)
user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups)
)
14 changes: 7 additions & 7 deletions src/core/formatting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import html
from typing import Any

from schemas.datasets.openml import DatasetFileFormat
from sqlalchemy.engine import Row

from core.errors import DatasetError

Expand All @@ -20,18 +20,18 @@ def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:
return {"code": str(code), "message": message}


def _format_parquet_url(dataset: dict[str, Any]) -> str | None:
if dataset["format"].lower() != DatasetFileFormat.ARFF:
def _format_parquet_url(dataset: Row) -> str | None:
if dataset.format.lower() != DatasetFileFormat.ARFF:
return None

minio_base_url = "https://openml1.win.tue.nl"
return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq"
return f"{minio_base_url}/dataset{dataset.did}/dataset_{dataset.did}.pq"


def _format_dataset_url(dataset: dict[str, Any]) -> str:
def _format_dataset_url(dataset: Row) -> str:
base_url = "https://test.openml.org"
filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}"
return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}"
filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}"
return f"{base_url}/data/v1/download/{dataset.file_id}/{filename}"


def _safe_unquote(text: str | None) -> str | None:
Expand Down
39 changes: 16 additions & 23 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
import datetime
from collections import defaultdict
from typing import Any, Iterable
from typing import Iterable

from schemas.datasets.openml import Feature, Quality
from sqlalchemy import Connection, text

from database.meta import get_column_names
from sqlalchemy.engine import Row


def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
Expand All @@ -23,19 +22,20 @@ def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Q
return [Quality(name=row.quality, value=row.value) for row in rows]


def get_qualities_for_datasets(
def _get_qualities_for_datasets(
dataset_ids: Iterable[int],
qualities: Iterable[str],
connection: Connection,
) -> dict[int, list[Quality]]:
"""Don't call with user-provided input, as query is not parameterized."""
qualities_filter = ",".join(f"'{q}'" for q in qualities)
dids = ",".join(str(did) for did in dataset_ids)
qualities_query = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - similar to above, no user input
""", # nosec - dids and qualities are not user-provided
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
Expand All @@ -59,8 +59,7 @@ def list_all_qualities(connection: Connection) -> list[str]:
return [quality.quality for quality in qualities]


def get_dataset(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
columns = get_column_names(connection, "dataset")
def get_dataset(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -71,11 +70,10 @@ def get_dataset(dataset_id: int, connection: Connection) -> dict[str, Any] | Non
),
parameters={"dataset_id": dataset_id},
)
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
return row.one_or_none()


def get_file(file_id: int, connection: Connection) -> dict[str, Any] | None:
columns = get_column_names(connection, "file")
def get_file(file_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -86,11 +84,10 @@ def get_file(file_id: int, connection: Connection) -> dict[str, Any] | None:
),
parameters={"file_id": file_id},
)
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
return row.one_or_none()


def get_tags(dataset_id: int, connection: Connection) -> list[str]:
columns = get_column_names(connection, "dataset_tag")
rows = connection.execute(
text(
"""
Expand All @@ -101,7 +98,7 @@ def get_tags(dataset_id: int, connection: Connection) -> list[str]:
),
parameters={"dataset_id": dataset_id},
)
return [dict(zip(columns, row, strict=True))["tag"] for row in rows]
return [row.tag for row in rows]


def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection) -> None:
Expand All @@ -123,8 +120,7 @@ def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection)
def get_latest_dataset_description(
dataset_id: int,
connection: Connection,
) -> dict[str, Any] | None:
columns = get_column_names(connection, "dataset_description")
) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -136,10 +132,10 @@ def get_latest_dataset_description(
),
parameters={"dataset_id": dataset_id},
)
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
return row.one_or_none()


def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
def get_latest_status_update(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -151,11 +147,10 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[st
),
parameters={"dataset_id": dataset_id},
)
return next(row.mappings(), None)
return row.first()


def get_latest_processing_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
columns = get_column_names(connection, "data_processed")
def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -167,9 +162,7 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> dic
),
parameters={"dataset_id": dataset_id},
)
return (
dict(zip(columns, result[0], strict=True), strict=True) if (result := list(row)) else None
)
return row.one_or_none()


def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Feature]:
Expand Down
21 changes: 12 additions & 9 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from typing import Any, Iterable
from typing import Sequence, cast

from core.formatting import _str_to_bool
from schemas.datasets.openml import EstimationProcedure
from sqlalchemy import Connection, CursorResult, text
from sqlalchemy import Connection, Row, text


def get_math_functions(function_type: str, connection: Connection) -> CursorResult[Any]:
return connection.execute(
text(
"""
def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
connection.execute(
text(
"""
SELECT *
FROM math_function
WHERE `functionType` = :function_type
""",
),
parameters={"function_type": function_type},
),
parameters={"function_type": function_type},
).all(),
)


def get_estimation_procedures(connection: Connection) -> Iterable[EstimationProcedure]:
def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]:
rows = connection.execute(
text(
"""
Expand Down
36 changes: 21 additions & 15 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from typing import Any
from typing import Sequence, cast

from sqlalchemy import Connection, CursorResult, text
from sqlalchemy import Connection, Row, text


def get_flow_subflows(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT child as child_id, identifier
FROM implementation_component
WHERE parent = :flow_id
""",
),
parameters={"flow_id": flow_id},
),
parameters={"flow_id": flow_id},
)


def get_flow_tags(flow_id: int, expdb: Connection) -> CursorResult[Any]:
def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
tag_rows = expdb.execute(
text(
"""
Expand All @@ -30,20 +33,23 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return [tag.tag for tag in tag_rows]


def get_flow_parameters(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT *, defaultValue as default_value, dataType as data_type
FROM input
WHERE implementation_id = :flow_id
""",
),
parameters={"flow_id": flow_id},
),
parameters={"flow_id": flow_id},
)


def get_flow(flow_id: int, expdb: Connection) -> CursorResult[Any]:
def get_flow(flow_id: int, expdb: Connection) -> Row | None:
return expdb.execute(
text(
"""
Expand All @@ -53,4 +59,4 @@ def get_flow(flow_id: int, expdb: Connection) -> CursorResult[Any]:
""",
),
parameters={"flow_id": flow_id},
)
).one_or_none()
17 changes: 0 additions & 17 deletions src/database/meta.py

This file was deleted.

22 changes: 11 additions & 11 deletions src/database/studies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import re
from datetime import datetime
from typing import cast
from typing import Sequence, cast

from schemas.study import CreateStudy, StudyType
from sqlalchemy import Connection, Row, text

from database.users import User


def get_study_by_id(study_id: int, connection: Connection) -> Row:
def get_study_by_id(study_id: int, connection: Connection) -> Row | None:
return connection.execute(
text(
"""
Expand All @@ -18,10 +18,10 @@ def get_study_by_id(study_id: int, connection: Connection) -> Row:
""",
),
parameters={"study_id": study_id},
).fetchone()
).one_or_none()


def get_study_by_alias(alias: str, connection: Connection) -> Row:
def get_study_by_alias(alias: str, connection: Connection) -> Row | None:
return connection.execute(
text(
"""
Expand All @@ -31,13 +31,13 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row:
""",
),
parameters={"study_id": alias},
).fetchone()
).one_or_none()


def get_study_data(study: Row, expdb: Connection) -> list[Row]:
def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
if study.type_ == StudyType.TASK:
return cast(
list[Row],
Sequence[Row],
expdb.execute(
text(
"""
Expand All @@ -47,10 +47,10 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]:
""",
),
parameters={"study_id": study.id},
).fetchall(),
).all(),
)
return cast(
list[Row],
Sequence[Row],
expdb.execute(
text(
"""
Expand All @@ -68,7 +68,7 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]:
""",
),
parameters={"study_id": study.id},
).fetchall(),
).all(),
)


Expand Down Expand Up @@ -96,7 +96,7 @@ def create_study(study: CreateStudy, user: User, expdb: Connection) -> int:
"benchmark_suite": study.benchmark_suite,
},
)
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).fetchone()
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one()
return cast(int, study_id)


Expand Down
Loading