From 4a7f104f1a300cf5d9a64bba8d3908718301ab88 Mon Sep 17 00:00:00 2001 From: Pieter Gijsbers Date: Fri, 1 Dec 2023 09:50:47 +0100 Subject: [PATCH] Add data/status/update (#119) * Add marker to indicate database mutation in tests * Do not use separate query to identify fieldnames * Add `datasets/status/update` endpoint There is currently no migration test because the test infra- structure makes this hard. Since the PHP calls directly affect the database, we need to ensure we can recover the old database state after the test. We can not generally do this for the dataset status update call, as for that we need to know the initial state of the database. While we do know that tech- nically, we would need to hard-code that information into the tests. A better approach would be to start up a new database container for PHP. A second issue that arises it that if we call an update on the status table from both Python and PHP then the transaction from the first will block the call of the other. This too is mitigated by introducing a second database. The only potential risk you introduce is to be working on different databases. Overall though, the effect of this should be rather minimal as the database would be effectively reset after every individual test. * Update tests to reflect new database state --- pyproject.toml | 6 +- src/database/datasets.py | 40 ++++++++- src/routers/openml/datasets.py | 59 ++++++++++++- tests/constants.py | 19 +++-- tests/routers/openml/dataset_tag_test.py | 6 +- .../openml/datasets_list_datasets_test.py | 4 +- tests/routers/openml/datasets_test.py | 85 +++++++++++++++++++ 7 files changed, 201 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 767ff89..7517ef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,5 +59,9 @@ pythonpath = [ ] markers = [ "php: tests that compare directly to an old PHP endpoint.", - "slow: test or sets of tests which take more than a few seconds to run." + "slow: test or sets of tests which take more than a few seconds to run.", + # While the `mut`ation marker below is not strictly necessary as every change is + # executed within transaction that is rolled back, it can halt other unit tests which + # whose queries may depend on the execution or rollback of the transaction. + "mut: executes a mutation on the database (in a transaction which is rolled back)", ] diff --git a/src/database/datasets.py b/src/database/datasets.py index 02360b8..fc40dc3 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,4 +1,5 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" +import datetime from collections import defaultdict from typing import Any, Iterable @@ -139,7 +140,6 @@ def get_latest_dataset_description( def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None: - columns = get_column_names(connection, "dataset_status") row = connection.execute( text( """ @@ -151,9 +151,7 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[st ), parameters={"dataset_id": dataset_id}, ) - return ( - dict(zip(columns, result[0], strict=True), strict=True) if (result := list(row)) else None - ) + return next(row.mappings(), None) def get_latest_processing_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None: @@ -201,3 +199,37 @@ def get_feature_values(dataset_id: int, feature_index: int, connection: Connecti parameters={"dataset_id": dataset_id, "feature_index": feature_index}, ) return [row.value for row in rows] + + +def insert_status_for_dataset( + dataset_id: int, + user_id: int, + status: str, + connection: Connection, +) -> None: + connection.execute( + text( + """ + INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`) + VALUES (:dataset, :status, :date, :user) + """, + ), + parameters={ + "dataset": dataset_id, + "status": status, + "date": datetime.datetime.now(), + "user": user_id, + }, + ) + + +def remove_deactivated_status(dataset_id: int, connection: Connection) -> None: + connection.execute( + text( + """ + DELETE FROM dataset_status + WHERE `did` = :data AND `status`='deactivated' + """, + ), + parameters={"data": dataset_id}, + ) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index b10a841..5903a28 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -6,7 +6,7 @@ import re from datetime import datetime from enum import StrEnum -from typing import Annotated, Any, NamedTuple +from typing import Annotated, Any, Literal, NamedTuple from core.access import _user_has_access from core.errors import DatasetError @@ -27,6 +27,8 @@ get_latest_status_update, get_qualities_for_datasets, get_tags, + insert_status_for_dataset, + remove_deactivated_status, ) from database.datasets import tag_dataset as db_tag_dataset from database.users import User, UserGroup @@ -317,6 +319,61 @@ def get_dataset_features( return features +@router.post( + path="/status/update", +) +def update_dataset_status( + dataset_id: Annotated[int, Body()], + status: Annotated[Literal[DatasetStatus.ACTIVE] | Literal[DatasetStatus.DEACTIVATED], Body()], + user: Annotated[User | None, Depends(fetch_user)], + expdb: Annotated[Connection, Depends(expdb_connection)], +) -> dict[str, str | int]: + if user is None: + raise HTTPException( + status_code=http.client.UNAUTHORIZED, + detail="Updating dataset status required authorization", + ) + + dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb) + + can_deactivate = dataset["uploader"] == user.user_id or UserGroup.ADMIN in user.groups + if status == DatasetStatus.DEACTIVATED and not can_deactivate: + raise HTTPException( + status_code=http.client.FORBIDDEN, + detail={"code": 693, "message": "Dataset is not owned by you"}, + ) + if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups: + raise HTTPException( + status_code=http.client.FORBIDDEN, + detail={"code": 696, "message": "Only administrators can activate datasets."}, + ) + + current_status = get_latest_status_update(dataset_id, expdb) + if current_status and current_status["status"] == status: + raise HTTPException( + status_code=http.client.PRECONDITION_FAILED, + detail={"code": 694, "message": "Illegal status transition."}, + ) + + # If current status is unknown, it is effectively "in preparation", + # So the following transitions are allowed (first 3 transitions are first clause) + # - in preparation => active (add a row) + # - in preparation => deactivated (add a row) + # - active => deactivated (add a row) + # - deactivated => active (delete a row) + if current_status is None or status == DatasetStatus.DEACTIVATED: + insert_status_for_dataset(dataset_id, user.user_id, status, expdb) + elif current_status["status"] == DatasetStatus.DEACTIVATED: + remove_deactivated_status(dataset_id, expdb) + else: + raise HTTPException( + status_code=http.client.INTERNAL_SERVER_ERROR, + detail={"message": f"Unknown status transition: {current_status} -> {status}"}, + ) + + return {"dataset_id": dataset_id, "status": status} + + @router.get( path="/{dataset_id}", description="Get meta-data for dataset with ID `dataset_id`.", diff --git a/tests/constants.py b/tests/constants.py index 709e255..e471fd5 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,8 +1,13 @@ -NUMBER_OF_DATASETS = 131 -NUMBER_OF_DEACTIVATED_DATASETS = 1 -NUMBER_OF_DATASETS_IN_PREPARATION = 1 -NUMBER_OF_PRIVATE_DATASETS = 1 -NUMBER_OF_ACTIVE_DATASETS = 128 +PRIVATE_DATASET_ID = {130} +IN_PREPARATION_ID = {1, 33} +DEACTIVATED_DATASETS = {2, 131} +DATASETS = set(range(1, 132)) -PRIVATE_DATASET_ID = 130 -IN_PREPARATION_ID = 1 +NUMBER_OF_DATASETS = len(DATASETS) +NUMBER_OF_DEACTIVATED_DATASETS = len(DEACTIVATED_DATASETS) +NUMBER_OF_DATASETS_IN_PREPARATION = len(IN_PREPARATION_ID) +NUMBER_OF_PRIVATE_DATASETS = len(PRIVATE_DATASET_ID) +NUMBER_OF_ACTIVE_DATASETS = ( + NUMBER_OF_DATASETS - NUMBER_OF_DEACTIVATED_DATASETS - NUMBER_OF_DATASETS_IN_PREPARATION +) +NUMBER_OF_PUBLIC_ACTIVE_DATASETS = NUMBER_OF_ACTIVE_DATASETS - NUMBER_OF_PRIVATE_DATASETS diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index f315ffb..3196953 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -22,7 +22,7 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No httpx.Response, py_api.post( f"/datasets/tag{apikey}", - json={"data_id": constants.PRIVATE_DATASET_ID, "tag": "test"}, + json={"data_id": list(constants.PRIVATE_DATASET_ID)[0], "tag": "test"}, ), ) assert response.status_code == http.client.PRECONDITION_FAILED @@ -35,7 +35,7 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No ids=["administrator", "non-owner", "owner"], ) def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> None: - dataset_id, tag = constants.PRIVATE_DATASET_ID, "test" + dataset_id, tag = list(constants.PRIVATE_DATASET_ID)[0], "test" response = cast( httpx.Response, py_api.post( @@ -46,7 +46,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> assert response.status_code == http.client.OK assert {"data_tag": {"id": str(dataset_id), "tag": tag}} == response.json() - tags = get_tags(dataset_id=constants.PRIVATE_DATASET_ID, connection=expdb_test) + tags = get_tags(dataset_id=dataset_id, connection=expdb_test) assert tag in tags diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index c91960a..c05e9f2 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -28,7 +28,7 @@ def test_list(py_api: TestClient) -> None: @pytest.mark.parametrize( ("status", "amount"), [ - ("active", constants.NUMBER_OF_ACTIVE_DATASETS), + ("active", constants.NUMBER_OF_PUBLIC_ACTIVE_DATASETS), ("deactivated", constants.NUMBER_OF_DEACTIVATED_DATASETS), ("in_preparation", constants.NUMBER_OF_DATASETS_IN_PREPARATION), ("all", constants.NUMBER_OF_DATASETS - constants.NUMBER_OF_PRIVATE_DATASETS), @@ -96,7 +96,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie all_ids = [ did for did in range(1, 1 + constants.NUMBER_OF_DATASETS) - if did not in [constants.PRIVATE_DATASET_ID] + if did not in constants.PRIVATE_DATASET_ID ] start = 0 if offset is None else offset diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 52a8e71..37e8242 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -3,6 +3,7 @@ import httpx import pytest +from schemas.datasets.openml import DatasetStatus from starlette.testclient import TestClient from tests.conftest import ApiKey @@ -144,3 +145,87 @@ def test_dataset_features_with_processing_error(py_api: TestClient) -> None: def test_dataset_features_dataset_does_not_exist(py_api: TestClient) -> None: resource = py_api.get("/datasets/features/1000") assert resource.status_code == http.client.NOT_FOUND + + +def _assert_status_update_is_successful( + apikey: ApiKey, + dataset_id: int, + status: str, + py_api: TestClient, +) -> None: + response = py_api.post( + f"/datasets/status/update?api_key={apikey}", + json={"dataset_id": dataset_id, "status": status}, + ) + assert response.status_code == http.client.OK + assert response.json() == { + "dataset_id": dataset_id, + "status": status, + } + + +@pytest.mark.mut() +@pytest.mark.parametrize( + "dataset_id", + [3, 4], +) +def test_dataset_status_update_active_to_deactivated(dataset_id: int, py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=dataset_id, + status=DatasetStatus.DEACTIVATED, + py_api=py_api, + ) + + +@pytest.mark.mut() +def test_dataset_status_update_in_preparation_to_active(py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=1, + status=DatasetStatus.ACTIVE, + py_api=py_api, + ) + + +@pytest.mark.mut() +def test_dataset_status_update_in_preparation_to_deactivated(py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=1, + status=DatasetStatus.DEACTIVATED, + py_api=py_api, + ) + + +@pytest.mark.mut() +def test_dataset_status_update_deactivated_to_active(py_api: TestClient) -> None: + _assert_status_update_is_successful( + apikey=ApiKey.ADMIN, + dataset_id=131, + status=DatasetStatus.ACTIVE, + py_api=py_api, + ) + + +@pytest.mark.parametrize( + ("dataset_id", "api_key", "status"), + [ + (1, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE), + (1, ApiKey.REGULAR_USER, DatasetStatus.DEACTIVATED), + (2, ApiKey.REGULAR_USER, DatasetStatus.DEACTIVATED), + (33, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE), + (131, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE), + ], +) +def test_dataset_status_unauthorized( + dataset_id: int, + api_key: ApiKey, + status: str, + py_api: TestClient, +) -> None: + response = py_api.post( + f"/datasets/status/update?api_key={api_key}", + json={"dataset_id": dataset_id, "status": status}, + ) + assert response.status_code == http.client.FORBIDDEN