Skip to content

Commit

Permalink
Add POST end point to create study
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Dec 19, 2023
1 parent 486173e commit 37a25c8
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 5 deletions.
57 changes: 56 additions & 1 deletion src/database/studies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from datetime import datetime
from typing import cast

from schemas.study import StudyType
from schemas.study import CreateStudy, StudyType
from sqlalchemy import Connection, Row, text

from database.users import User


def get_study_by_id(study_id: int, connection: Connection) -> Row:
return connection.execute(
Expand Down Expand Up @@ -66,3 +69,55 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]:
parameters={"study_id": study.id},
).fetchall(),
)


def create_study(study: CreateStudy, user: User, expdb: Connection) -> int:
expdb.execute(
text(
"""
INSERT INTO study (
name, alias, benchmark_suite, main_entity_type, description,
creator, legacy, creation_date
)
VALUES (
:name, :alias, :benchmark_suite, :main_entity_type, :description,
:creator, 'n', :creation_date
)
""",
),
parameters={
"name": study.name,
"alias": study.alias,
"main_entity_type": study.main_entity_type,
"description": study.description,
"creator": user.user_id,
"creation_date": datetime.now(),
"benchmark_suite": study.benchmark_suite,
},
)
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).fetchone()
return cast(int, study_id)


def attach_task_to_study(task_id: int, study_id: int, user: User, expdb: Connection) -> None:
expdb.execute(
text(
"""
INSERT INTO task_study (study_id, task_id, uploader)
VALUES (:study_id, :task_id, :user_id)
""",
),
parameters={"study_id": study_id, "task_id": task_id, "user_id": user.user_id},
)


def attach_run_to_study(run_id: int, study_id: int, user: User, expdb: Connection) -> None:
expdb.execute(
text(
"""
INSERT INTO run_study (study_id, run_id, uploader)
VALUES (:study_id, :run_id, :user_id)
""",
),
parameters={"study_id": study_id, "run_id": run_id, "user_id": user.user_id},
)
51 changes: 47 additions & 4 deletions src/routers/openml/study.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import http.client
from typing import Annotated
from typing import Annotated, Literal

from core.formatting import _str_to_bool
from database.studies import get_study_by_alias, get_study_by_id, get_study_data
from database.studies import (
attach_run_to_study,
attach_task_to_study,
get_study_by_alias,
get_study_by_id,
get_study_data,
)
from database.studies import create_study as db_create_study
from database.users import User, UserGroup
from fastapi import APIRouter, Depends, HTTPException
from schemas.core import Visibility
from schemas.study import Study, StudyType
from schemas.study import CreateStudy, Study, StudyType
from sqlalchemy import Connection, Row

from routers.dependencies import expdb_connection, fetch_user
Expand All @@ -32,10 +39,46 @@ def _get_study_raise_otherwise(id_or_alias: int | str, user: User | None, expdb:
status_code=http.client.GONE,
detail="Legacy studies are no longer supported",
)

return study


@router.post("/")
def create_study(
study: CreateStudy,
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb: Annotated[Connection, Depends(expdb_connection)] = None,
) -> dict[Literal["study_id"], int]:
if user is None:
raise HTTPException(
status_code=http.client.UNAUTHORIZED,
detail="Creating a study requires authentication.",
)
if study.main_entity_type == StudyType.RUN and study.tasks:
raise HTTPException(
status_code=http.client.BAD_REQUEST,
detail="Cannot create a run study with tasks.",
)
if study.main_entity_type == StudyType.TASK and study.runs:
raise HTTPException(
status_code=http.client.BAD_REQUEST,
detail="Cannot create a task study with runs.",
)
if study.alias and get_study_by_alias(study.alias, expdb):
raise HTTPException(
status_code=http.client.CONFLICT,
detail="Study alias already exists.",
)
study_id = db_create_study(study, user, expdb)
if study.main_entity_type == StudyType.TASK:
for task_id in study.tasks:
attach_task_to_study(task_id, study_id, user, expdb)
if study.main_entity_type == StudyType.RUN:
for run_id in study.runs:
attach_run_to_study(run_id, study_id, user, expdb)
# Make sure that invalid fields raise an error (e.g., "task_ids")
return {"study_id": study_id}


@router.get("/{alias_or_id}")
def get_study(
alias_or_id: int | str,
Expand Down
59 changes: 59 additions & 0 deletions src/schemas/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,62 @@ class Study(BaseModel):
data_ids: list[int]
setup_ids: list[int]
flow_ids: list[int]


class CreateStudy(BaseModel):
"""Study, exposing only those fields that should be provided by the user on creation."""

name: str = Field(
description="Full name of the study.",
examples=["The OpenML 100 Benchmarking Suite"],
max_length=256,
)
alias: str | None = Field(
default=None,
description="Short alternative name for the study, which may be used to fetch it.",
examples=["OpenML100"],
max_length=32,
)
main_entity_type: StudyType = Field(
default=StudyType.TASK,
description="Whether it is a collection of runs (study) or tasks (benchmarking suite).",
examples=[StudyType.TASK],
)
benchmark_suite: int | None = Field(
# For study, refers to the benchmarking suite
default=None,
description="The benchmarking suite this study is based on, if any.",
)
description: str = Field(
description=(
"A good study description specifies why the study was created, what it is about, and "
"how it should be used. It may include information about a related publication or"
"website."
),
examples=[
(
"A collection of tasks with simple datasets to benchmark machine learning methods."
"Selected tasks are small classification problems that are not too imbalanced."
"We advise the use of OpenML-CC18 instead, because OpenML100 suffers from some"
"issues. If you do use OpeNML100, please cite ..."
),
],
min_length=1,
max_length=4096,
)
tasks: list[int] = Field(
default_factory=list,
description=(
"Tasks to include in the study, can only be specified if `runs` is empty."
"Can be modified later with `studies/{id}/attach` and `studies/{id}/detach`."
),
examples=[[1, 2, 3]],
)
runs: list[int] = Field(
default_factory=list,
description=(
"Runs to include in the study, can only be specified if `tasks` is empty."
"Can be modified later with `studies/{id}/attach` and `studies/{id}/detach`."
),
examples=[[]],
)
14 changes: 14 additions & 0 deletions tests/routers/openml/study_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,17 @@ def test_get_task_study_by_alias(py_api: TestClient) -> None:
"setup_ids": [],
}
assert response.json() == expected


def test_create_task_study(py_api: TestClient) -> None:
py_api.post(
"/studies?api_key=00000000000000000000000000000000",
json={
"name": "Test Study",
"alias": "test-study",
"main_entity_type": "task",
"description": "A test study",
"tasks": [1, 2, 3],
"runs": [],
},
)

0 comments on commit 37a25c8

Please sign in to comment.