-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add boilerplate for `studies` endpoints * Support modern-style task studies * Add run study support, but it is untested It is untested because the test database currently does not have any runs or run studies. * Add get study by alias * Move SQL queries to database submodule * Add migration test * Add information on the lack of support for legacy studies * Add flows and setups are now also always returned
- Loading branch information
Showing
9 changed files
with
698 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import cast | ||
|
||
from schemas.study import StudyType | ||
from sqlalchemy import Connection, Row, text | ||
|
||
|
||
def get_study_by_id(study_id: int, connection: Connection) -> Row: | ||
return connection.execute( | ||
text( | ||
""" | ||
SELECT *, main_entity_type as type_ | ||
FROM study | ||
WHERE id = :study_id | ||
""", | ||
), | ||
parameters={"study_id": study_id}, | ||
).fetchone() | ||
|
||
|
||
def get_study_by_alias(alias: str, connection: Connection) -> Row: | ||
return connection.execute( | ||
text( | ||
""" | ||
SELECT *, main_entity_type as type_ | ||
FROM study | ||
WHERE alias = :study_id | ||
""", | ||
), | ||
parameters={"study_id": alias}, | ||
).fetchone() | ||
|
||
|
||
def get_study_data(study: Row, expdb: Connection) -> list[Row]: | ||
if study.type_ == StudyType.TASK: | ||
return cast( | ||
list[Row], | ||
expdb.execute( | ||
text( | ||
""" | ||
SELECT ts.task_id as task_id, ti.value as data_id | ||
FROM task_study as ts LEFT JOIN task_inputs ti ON ts.task_id = ti.task_id | ||
WHERE ts.study_id = :study_id AND ti.input = 'source_data' | ||
""", | ||
), | ||
parameters={"study_id": study.id}, | ||
).fetchall(), | ||
) | ||
return cast( | ||
list[Row], | ||
expdb.execute( | ||
text( | ||
""" | ||
SELECT | ||
rs.run_id as run_id, | ||
run.task_id as task_id, | ||
run.setup as setup_id, | ||
ti.value as data_id, | ||
setup.implementation_id as flow_id | ||
FROM run_study as rs | ||
JOIN run ON run.rid = rs.run_id | ||
JOIN algorithm_setup as setup ON setup.sid = run.setup | ||
JOIN task_inputs as ti ON ti.task_id = run.task_id | ||
WHERE rs.study_id = :study_id AND ti.input = 'source_data' | ||
""", | ||
), | ||
parameters={"study_id": study.id}, | ||
).fetchall(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import http.client | ||
from typing import Annotated | ||
|
||
from core.formatting import _str_to_bool | ||
from database.studies import get_study_by_alias, get_study_by_id, get_study_data | ||
from database.users import User, UserGroup | ||
from fastapi import APIRouter, Depends, HTTPException | ||
from schemas.core import Visibility | ||
from schemas.study import Study, StudyType | ||
from sqlalchemy import Connection, Row | ||
|
||
from routers.dependencies import expdb_connection, fetch_user | ||
|
||
router = APIRouter(prefix="/studies", tags=["studies"]) | ||
|
||
|
||
def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb: Connection) -> Row: | ||
if isinstance(id_or_alias, int) or id_or_alias.isdigit(): | ||
study = get_study_by_id(int(id_or_alias), expdb) | ||
else: | ||
study = get_study_by_alias(id_or_alias, expdb) | ||
|
||
if study is None: | ||
raise HTTPException(status_code=http.client.NOT_FOUND, detail="Study not found.") | ||
if study.visibility == Visibility.PRIVATE: | ||
if user is None: | ||
raise HTTPException(status_code=http.client.UNAUTHORIZED, detail="Study is private.") | ||
if study.creator != user.user_id and UserGroup.ADMIN not in user.groups: | ||
raise HTTPException(status_code=http.client.FORBIDDEN, detail="Study is private.") | ||
if _str_to_bool(study.legacy): | ||
raise HTTPException( | ||
status_code=http.client.GONE, | ||
detail="Legacy studies are no longer supported", | ||
) | ||
|
||
return study | ||
|
||
|
||
@router.get("/{alias_or_id}") | ||
def get_study( | ||
alias_or_id: int | str, | ||
user: Annotated[User | None, Depends(fetch_user)] = None, | ||
expdb: Annotated[Connection, Depends(expdb_connection)] = None, | ||
) -> Study: | ||
study = _get_study_raise_otherwise(alias_or_id, user, expdb) | ||
study_data = get_study_data(study, expdb) | ||
return Study( | ||
id_=study.id, | ||
name=study.name, | ||
alias=study.alias, | ||
main_entity_type=study.type_, | ||
description=study.description, | ||
visibility=study.visibility, | ||
status=study.status, | ||
creation_date=study.creation_date, | ||
creator=study.creator, | ||
data_ids=[row.data_id for row in study_data], | ||
task_ids=[row.task_id for row in study_data], | ||
run_ids=[row.run_id for row in study_data] if study.type_ == StudyType.RUN else [], | ||
flow_ids=[row.flow_id for row in study_data] if study.type_ == StudyType.RUN else [], | ||
setup_ids=[row.setup_id for row in study_data] if study.type_ == StudyType.RUN else [], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from enum import StrEnum, auto | ||
|
||
|
||
class Visibility(StrEnum): | ||
PUBLIC = auto() | ||
PRIVATE = auto() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from datetime import datetime | ||
from enum import StrEnum, auto | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
from schemas.core import Visibility | ||
|
||
|
||
class StudyType(StrEnum): | ||
RUN = auto() | ||
TASK = auto() | ||
|
||
|
||
class StudyStatus(StrEnum): | ||
ACTIVE = auto() | ||
DEACTIVATED = auto() | ||
IN_PREPARATION = auto() | ||
|
||
|
||
class Study(BaseModel): | ||
id_: int = Field(serialization_alias="id") | ||
name: str | ||
alias: str | None | ||
main_entity_type: StudyType | ||
description: str | ||
visibility: Visibility | ||
status: StudyStatus | ||
creation_date: datetime | ||
creator: int | ||
task_ids: list[int] | ||
run_ids: list[int] | ||
data_ids: list[int] | ||
setup_ids: list[int] | ||
flow_ids: list[int] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import deepdiff | ||
import httpx | ||
import pytest | ||
from core.conversions import nested_num_to_str, nested_remove_nones | ||
from starlette.testclient import TestClient | ||
|
||
|
||
@pytest.mark.php() | ||
def test_get_study_equal(py_api: TestClient, php_api: httpx.Client) -> None: | ||
new = py_api.get("/studies/1") | ||
old = php_api.get("/study/1") | ||
assert new.status_code == old.status_code | ||
|
||
new = new.json() | ||
# New implementation is typed | ||
new = nested_num_to_str(new) | ||
# New implementation has same fields even if empty | ||
new = nested_remove_nones(new) | ||
new["tasks"] = {"task_id": new.pop("task_ids")} | ||
new["data"] = {"data_id": new.pop("data_ids")} | ||
if runs := new.pop("run_ids", None): | ||
new["runs"] = {"run_id": runs} | ||
if flows := new.pop("flow_ids", None): | ||
new["flows"] = {"flow_id": flows} | ||
if setups := new.pop("setup_ids", None): | ||
new["setup"] = {"setup_id": setups} | ||
|
||
# New implementation is not nested | ||
new = {"study": new} | ||
difference = deepdiff.diff.DeepDiff( | ||
new, | ||
old.json(), | ||
ignore_order=True, | ||
ignore_numeric_type_changes=True, | ||
) | ||
assert not difference |
Oops, something went wrong.