diff --git a/pyproject.toml b/pyproject.toml index 767ff89..7517ef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,5 +59,9 @@ pythonpath = [ ] markers = [ "php: tests that compare directly to an old PHP endpoint.", - "slow: test or sets of tests which take more than a few seconds to run." + "slow: test or sets of tests which take more than a few seconds to run.", + # While the `mut`ation marker below is not strictly necessary as every change is + # executed within transaction that is rolled back, it can halt other unit tests which + # whose queries may depend on the execution or rollback of the transaction. + "mut: executes a mutation on the database (in a transaction which is rolled back)", ] diff --git a/src/database/datasets.py b/src/database/datasets.py index 02360b8..fc40dc3 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,4 +1,5 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" +import datetime from collections import defaultdict from typing import Any, Iterable @@ -139,7 +140,6 @@ def get_latest_dataset_description( def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None: - columns = get_column_names(connection, "dataset_status") row = connection.execute( text( """ @@ -151,9 +151,7 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[st ), parameters={"dataset_id": dataset_id}, ) - return ( - dict(zip(columns, result[0], strict=True), strict=True) if (result := list(row)) else None - ) + return next(row.mappings(), None) def get_latest_processing_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None: @@ -201,3 +199,37 @@ def get_feature_values(dataset_id: int, feature_index: int, connection: Connecti parameters={"dataset_id": dataset_id, "feature_index": feature_index}, ) return [row.value for row in rows] + + +def insert_status_for_dataset( + dataset_id: int, + user_id: int, + status: str, + connection: Connection, +) -> None: + connection.execute( + text( + """ + INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`) + VALUES (:dataset, :status, :date, :user) + """, + ), + parameters={ + "dataset": dataset_id, + "status": status, + "date": datetime.datetime.now(), + "user": user_id, + }, + ) + + +def remove_deactivated_status(dataset_id: int, connection: Connection) -> None: + connection.execute( + text( + """ + DELETE FROM dataset_status + WHERE `did` = :data AND `status`='deactivated' + """, + ), + parameters={"data": dataset_id}, + ) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index b10a841..5903a28 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -6,7 +6,7 @@ import re from datetime import datetime from enum import StrEnum -from typing import Annotated, Any, NamedTuple +from typing import Annotated, Any, Literal, NamedTuple from core.access import _user_has_access from core.errors import DatasetError @@ -27,6 +27,8 @@ get_latest_status_update, get_qualities_for_datasets, get_tags, + insert_status_for_dataset, + remove_deactivated_status, ) from database.datasets import tag_dataset as db_tag_dataset from database.users import User, UserGroup @@ -317,6 +319,61 @@ def get_dataset_features( return features +@router.post( + path="/status/update", +) +def update_dataset_status( + dataset_id: Annotated[int, Body()], + status: Annotated[Literal[DatasetStatus.ACTIVE] | Literal[DatasetStatus.DEACTIVATED], Body()], + user: Annotated[User | None, Depends(fetch_user)], + expdb: Annotated[Connection, Depends(expdb_connection)], +) -> dict[str, str | int]: + if user is None: + raise HTTPException( + status_code=http.client.UNAUTHORIZED, + detail="Updating dataset status required authorization", + ) + + dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb) + + can_deactivate = dataset["uploader"] == user.user_id or UserGroup.ADMIN in user.groups + if status == DatasetStatus.DEACTIVATED and not can_deactivate: + raise HTTPException( + status_code=http.client.FORBIDDEN, + detail={"code": 693, "message": "Dataset is not owned by you"}, + ) + if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups: + raise HTTPException( + status_code=http.client.FORBIDDEN, + detail={"code": 696, "message": "Only administrators can activate datasets."}, + ) + + current_status = get_latest_status_update(dataset_id, expdb) + if current_status and current_status["status"] == status: + raise HTTPException( + status_code=http.client.PRECONDITION_FAILED, + detail={"code": 694, "message": "Illegal status transition."}, + ) + + # If current status is unknown, it is effectively "in preparation", + # So the following transitions are allowed (first 3 transitions are first clause) + # - in preparation => active (add a row) + # - in preparation => deactivated (add a row) + # - active => deactivated (add a row) + # - deactivated => active (delete a row) + if current_status is None or status == DatasetStatus.DEACTIVATED: + insert_status_for_dataset(dataset_id, user.user_id, status, expdb) + elif current_status["status"] == DatasetStatus.DEACTIVATED: + remove_deactivated_status(dataset_id, expdb) + else: + raise HTTPException( + status_code=http.client.INTERNAL_SERVER_ERROR, + detail={"message": f"Unknown status transition: {current_status} -> {status}"}, + ) + + return {"dataset_id": dataset_id, "status": status} + + @router.get( path="/{dataset_id}", description="Get meta-data for dataset with ID `dataset_id`.", diff --git a/tests/constants.py b/tests/constants.py index 709e255..e471fd5 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,8 +1,13 @@ -NUMBER_OF_DATASETS = 131 -NUMBER_OF_DEACTIVATED_DATASETS = 1 -NUMBER_OF_DATASETS_IN_PREPARATION = 1 -NUMBER_OF_PRIVATE_DATASETS = 1 -NUMBER_OF_ACTIVE_DATASETS = 128 +PRIVATE_DATASET_ID = {130} +IN_PREPARATION_ID = {1, 33} +DEACTIVATED_DATASETS = {2, 131} +DATASETS = set(range(1, 132)) -PRIVATE_DATASET_ID = 130 -IN_PREPARATION_ID = 1 +NUMBER_OF_DATASETS = len(DATASETS) +NUMBER_OF_DEACTIVATED_DATASETS = len(DEACTIVATED_DATASETS) +NUMBER_OF_DATASETS_IN_PREPARATION = len(IN_PREPARATION_ID) +NUMBER_OF_PRIVATE_DATASETS = len(PRIVATE_DATASET_ID) +NUMBER_OF_ACTIVE_DATASETS = ( + NUMBER_OF_DATASETS - NUMBER_OF_DEACTIVATED_DATASETS - NUMBER_OF_DATASETS_IN_PREPARATION +) +NUMBER_OF_PUBLIC_ACTIVE_DATASETS = NUMBER_OF_ACTIVE_DATASETS - NUMBER_OF_PRIVATE_DATASETS diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index f315ffb..3196953 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -22,7 +22,7 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No httpx.Response, py_api.post( f"/datasets/tag{apikey}", - json={"data_id": constants.PRIVATE_DATASET_ID, "tag": "test"}, + json={"data_id": list(constants.PRIVATE_DATASET_ID)[0], "tag": "test"}, ), ) assert response.status_code == http.client.PRECONDITION_FAILED @@ -35,7 +35,7 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No ids=["administrator", "non-owner", "owner"], ) def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> None: - dataset_id, tag = constants.PRIVATE_DATASET_ID, "test" + dataset_id, tag = list(constants.PRIVATE_DATASET_ID)[0], "test" response = cast( httpx.Response, py_api.post( @@ -46,7 +46,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> assert response.status_code == http.client.OK assert {"data_tag": {"id": str(dataset_id), "tag": tag}} == response.json() - tags = get_tags(dataset_id=constants.PRIVATE_DATASET_ID, connection=expdb_test) + tags = get_tags(dataset_id=dataset_id, connection=expdb_test) assert tag in tags diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index c91960a..c05e9f2 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -28,7 +28,7 @@ def test_list(py_api: TestClient) -> None: @pytest.mark.parametrize( ("status", "amount"), [ - ("active", constants.NUMBER_OF_ACTIVE_DATASETS), + ("active", constants.NUMBER_OF_PUBLIC_ACTIVE_DATASETS), ("deactivated", constants.NUMBER_OF_DEACTIVATED_DATASETS), ("in_preparation", constants.NUMBER_OF_DATASETS_IN_PREPARATION), ("all", constants.NUMBER_OF_DATASETS - constants.NUMBER_OF_PRIVATE_DATASETS), @@ -96,7 +96,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie all_ids = [ did for did in range(1, 1 + constants.NUMBER_OF_DATASETS) - if did not in [constants.PRIVATE_DATASET_ID] + if did not in constants.PRIVATE_DATASET_ID ] start = 0 if offset is None else offset diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 52a8e71..37e8242 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -3,6 +3,7 @@ import httpx import pytest +from schemas.datasets.openml import DatasetStatus from starlette.testclient import TestClient from tests.conftest import ApiKey @@ -144,3 +145,87 @@ def test_dataset_features_with_processing_error(py_api: TestClient) -> None: def test_dataset_features_dataset_does_not_exist(py_api: TestClient) -> None: resource = py_api.get("/datasets/features/1000") assert resource.status_code == http.client.NOT_FOUND + + +def _assert_status_update_is_successful( + apikey: ApiKey, + dataset_id: int, + status: str, + py_api: TestClient, +) -> None: + response = py_api.post( + f"/datasets/status/update?api_key={apikey}", + json={"dataset_id": dataset_id, "status": status}, + ) + assert response.status_code == http.client.OK + assert response.json() == { + "dataset_id": dataset_id, + "status": status, + } + + +@pytest.mark.mut() +@pytest.mark.parametrize( + "dataset_id", + [3, 4], +) +def test_dataset_status_update_active_to_deactivated(dataset_id: int, py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=dataset_id, + status=DatasetStatus.DEACTIVATED, + py_api=py_api, + ) + + +@pytest.mark.mut() +def test_dataset_status_update_in_preparation_to_active(py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=1, + status=DatasetStatus.ACTIVE, + py_api=py_api, + ) + + +@pytest.mark.mut() +def test_dataset_status_update_in_preparation_to_deactivated(py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=1, + status=DatasetStatus.DEACTIVATED, + py_api=py_api, + ) + + +@pytest.mark.mut() +def test_dataset_status_update_deactivated_to_active(py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=131, + status=DatasetStatus.ACTIVE, + py_api=py_api, + ) + + +@pytest.mark.parametrize( + ("dataset_id", "api_key", "status"), + [ + (1, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE), + (1, ApiKey.REGULAR_USER, DatasetStatus.DEACTIVATED), + (2, ApiKey.REGULAR_USER, DatasetStatus.DEACTIVATED), + (33, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE), + (131, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE), + ], +) +def test_dataset_status_unauthorized( + dataset_id: int, + api_key: ApiKey, + status: str, + py_api: TestClient, +) -> None: + response = py_api.post( + f"/datasets/status/update?api_key={api_key}", + json={"dataset_id": dataset_id, "status": status}, + ) + assert response.status_code == http.client.FORBIDDEN