Skip to content

Commit

Permalink
Add/study attach detach (#135)
Browse files Browse the repository at this point in the history
* Towards attach/detach

* Add study_attach endpoint

* Add study_attach function that supports attaching tasks
  • Loading branch information
PGijsbers authored Jan 3, 2024
1 parent bac0823 commit 78073cd
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 3 deletions.
41 changes: 41 additions & 0 deletions src/database/studies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from datetime import datetime
from typing import cast

Expand Down Expand Up @@ -121,3 +122,43 @@ def attach_run_to_study(run_id: int, study_id: int, user: User, expdb: Connectio
),
parameters={"study_id": study_id, "run_id": run_id, "user_id": user.user_id},
)


def attach_tasks_to_study(
study_id: int,
task_ids: list[int],
user: User,
connection: Connection,
) -> None:
to_link = [(study_id, task_id, user.user_id) for task_id in task_ids]
try:
connection.execute(
text(
"""
INSERT INTO task_study (study_id, task_id, uploader)
VALUES (:study_id, :task_id, :user_id)
""",
),
parameters=[{"study_id": s, "task_id": t, "user_id": u} for s, t, u in to_link],
)
except Exception as e:
(msg,) = e.args
if match := re.search(r"Duplicate entry '(\d+)-(\d+)' for key 'task_study.PRIMARY'", msg):
msg = f"Task {match.group(2)} is already attached to study {match.group(1)}."
elif "a foreign key constraint fails" in msg:
# The message and exception have no information about which task is invalid.
msg = "One or more of the tasks do not exist."
elif "Out of range value for column 'task_id'" in msg:
msg = "One specified ids is not in the valid range of task ids."
else:
raise
raise ValueError(msg) from e


def attach_runs_to_study(
study_id: int, # noqa: ARG001
task_ids: list[int], # noqa: ARG001
user: User, # noqa: ARG001
connection: Connection, # noqa: ARG001
) -> None:
raise NotImplementedError
53 changes: 50 additions & 3 deletions src/routers/openml/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from core.formatting import _str_to_bool
from database.studies import (
attach_run_to_study,
attach_runs_to_study,
attach_task_to_study,
attach_tasks_to_study,
get_study_by_alias,
get_study_by_id,
get_study_data,
)
from database.studies import create_study as db_create_study
from database.users import User, UserGroup
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel
from schemas.core import Visibility
from schemas.study import CreateStudy, Study, StudyType
from schemas.study import CreateStudy, Study, StudyStatus, StudyType
from sqlalchemy import Connection, Row

from routers.dependencies import expdb_connection, fetch_user
Expand All @@ -31,7 +34,10 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb:
raise HTTPException(status_code=http.client.NOT_FOUND, detail="Study not found.")
if study.visibility == Visibility.PRIVATE:
if user is None:
raise HTTPException(status_code=http.client.UNAUTHORIZED, detail="Study is private.")
raise HTTPException(
status_code=http.client.UNAUTHORIZED,
detail="Must authenticate for private study.",
)
if study.creator != user.user_id and UserGroup.ADMIN not in user.groups:
raise HTTPException(status_code=http.client.FORBIDDEN, detail="Study is private.")
if _str_to_bool(study.legacy):
Expand All @@ -42,6 +48,46 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb:
return study


class AttachDetachResponse(BaseModel):
study_id: int
main_entity_type: StudyType


@router.post("/attach")
def attach_to_study(
study_id: Annotated[int, Body()],
entity_ids: Annotated[list[int], Body()],
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb: Annotated[Connection, Depends(expdb_connection)] = None,
) -> AttachDetachResponse:
if user is None:
raise HTTPException(status_code=http.client.UNAUTHORIZED, detail="User not found.")
study = _get_study_raise_otherwise(study_id, user, expdb)
# PHP lets *anyone* edit *any* study. We're not going to do that.
if study.creator != user.user_id and UserGroup.ADMIN not in user.groups:
raise HTTPException(
status_code=http.client.FORBIDDEN,
detail="Study can only be edited by its creator.",
)
if study.status != StudyStatus.IN_PREPARATION:
raise HTTPException(
status_code=http.client.FORBIDDEN,
detail="Study can only be edited while in preparation.",
)

# We let the database handle the constraints on whether
# the entity is already attached or if it even exists.
attach = attach_tasks_to_study if study.type_ == StudyType.TASK else attach_runs_to_study
try:
attach(study_id, entity_ids, user, expdb)
except ValueError as e:
raise HTTPException(
status_code=http.client.CONFLICT,
detail=str(e),
) from None
return AttachDetachResponse(study_id=study_id, main_entity_type=study.type_)


@router.post("/")
def create_study(
study: CreateStudy,
Expand Down Expand Up @@ -88,6 +134,7 @@ def get_study(
study = _get_study_raise_otherwise(alias_or_id, user, expdb)
study_data = get_study_data(study, expdb)
return Study(
_legacy=_str_to_bool(study.legacy),
id_=study.id,
name=study.name,
alias=study.alias,
Expand Down
1 change: 1 addition & 0 deletions src/schemas/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class StudyStatus(StrEnum):


class Study(BaseModel):
legacy: bool = Field(default=False, exclude=True)
id_: int = Field(serialization_alias="id")
name: str
alias: str | None
Expand Down
76 changes: 76 additions & 0 deletions tests/routers/openml/study_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import http.client
from datetime import datetime

import httpx
from schemas.study import StudyType
from sqlalchemy import Connection, text
from starlette.testclient import TestClient


Expand Down Expand Up @@ -494,3 +498,75 @@ def test_create_task_study(py_api: TestClient) -> None:
)
assert creation_date.date() == datetime.now().date()
assert new_study == expected


def _attach_tasks_to_study(
study_id: int,
task_ids: list[int],
api_key: str,
py_api: TestClient,
expdb_test: Connection,
) -> httpx.Response:
# Adding requires the study to be in preparation,
# but the current snapshot has no in-preparation studies.
expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1"))
return py_api.post(
f"/studies/attach?api_key={api_key}",
json={"study_id": study_id, "entity_ids": task_ids},
)


def test_attach_task_to_study(py_api: TestClient, expdb_test: Connection) -> None:
response = _attach_tasks_to_study(
study_id=1,
task_ids=[2, 3, 4],
api_key="AD000000000000000000000000000000",
py_api=py_api,
expdb_test=expdb_test,
)
assert response.status_code == http.client.OK
assert response.json() == {"study_id": 1, "main_entity_type": StudyType.TASK}


def test_attach_task_to_study_needs_owner(py_api: TestClient, expdb_test: Connection) -> None:
expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1"))
response = _attach_tasks_to_study(
study_id=1,
task_ids=[2, 3, 4],
api_key="00000000000000000000000000000000",
py_api=py_api,
expdb_test=expdb_test,
)
assert response.status_code == http.client.FORBIDDEN


def test_attach_task_to_study_already_linked_raises(
py_api: TestClient,
expdb_test: Connection,
) -> None:
expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1"))
response = _attach_tasks_to_study(
study_id=1,
task_ids=[1, 3, 4],
api_key="AD000000000000000000000000000000",
py_api=py_api,
expdb_test=expdb_test,
)
assert response.status_code == http.client.CONFLICT
assert response.json() == {"detail": "Task 1 is already attached to study 1."}


def test_attach_task_to_study_but_task_not_exist_raises(
py_api: TestClient,
expdb_test: Connection,
) -> None:
expdb_test.execute(text("UPDATE study SET status = 'in_preparation' WHERE id = 1"))
response = _attach_tasks_to_study(
study_id=1,
task_ids=[80123, 78914],
api_key="AD000000000000000000000000000000",
py_api=py_api,
expdb_test=expdb_test,
)
assert response.status_code == http.client.CONFLICT
assert response.json() == {"detail": "One or more of the tasks do not exist."}

0 comments on commit 78073cd

Please sign in to comment.