From 4a7f104f1a300cf5d9a64bba8d3908718301ab88 Mon Sep 17 00:00:00 2001
From: Pieter Gijsbers
Date: Fri, 1 Dec 2023 09:50:47 +0100
Subject: [PATCH] Add data/status/update (#119)
* Add marker to indicate database mutation in tests
* Do not use separate query to identify fieldnames
* Add `datasets/status/update` endpoint
There is currently no migration test because the test infra-
structure makes this hard. Since the PHP calls directly
affect the database, we need to ensure we can recover the old
database state after the test. We can not generally do this
for the dataset status update call, as for that we need to know
the initial state of the database. While we do know that tech-
nically, we would need to hard-code that information into the
tests. A better approach would be to start up a new database
container for PHP. A second issue that arises it that if we
call an update on the status table from both Python and PHP
then the transaction from the first will block the call of
the other. This too is mitigated by introducing a second
database. The only potential risk you introduce is to
be working on different databases. Overall though, the effect
of this should be rather minimal as the database would be
effectively reset after every individual test.
* Update tests to reflect new database state
---
pyproject.toml | 6 +-
src/database/datasets.py | 40 ++++++++-
src/routers/openml/datasets.py | 59 ++++++++++++-
tests/constants.py | 19 +++--
tests/routers/openml/dataset_tag_test.py | 6 +-
.../openml/datasets_list_datasets_test.py | 4 +-
tests/routers/openml/datasets_test.py | 85 +++++++++++++++++++
7 files changed, 201 insertions(+), 18 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 767ff89..7517ef2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -59,5 +59,9 @@ pythonpath = [
]
markers = [
"php: tests that compare directly to an old PHP endpoint.",
- "slow: test or sets of tests which take more than a few seconds to run."
+ "slow: test or sets of tests which take more than a few seconds to run.",
+ # While the `mut`ation marker below is not strictly necessary as every change is
+ # executed within transaction that is rolled back, it can halt other unit tests which
+ # whose queries may depend on the execution or rollback of the transaction.
+ "mut: executes a mutation on the database (in a transaction which is rolled back)",
]
diff --git a/src/database/datasets.py b/src/database/datasets.py
index 02360b8..fc40dc3 100644
--- a/src/database/datasets.py
+++ b/src/database/datasets.py
@@ -1,4 +1,5 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
+import datetime
from collections import defaultdict
from typing import Any, Iterable
@@ -139,7 +140,6 @@ def get_latest_dataset_description(
def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
- columns = get_column_names(connection, "dataset_status")
row = connection.execute(
text(
"""
@@ -151,9 +151,7 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[st
),
parameters={"dataset_id": dataset_id},
)
- return (
- dict(zip(columns, result[0], strict=True), strict=True) if (result := list(row)) else None
- )
+ return next(row.mappings(), None)
def get_latest_processing_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
@@ -201,3 +199,37 @@ def get_feature_values(dataset_id: int, feature_index: int, connection: Connecti
parameters={"dataset_id": dataset_id, "feature_index": feature_index},
)
return [row.value for row in rows]
+
+
+def insert_status_for_dataset(
+ dataset_id: int,
+ user_id: int,
+ status: str,
+ connection: Connection,
+) -> None:
+ connection.execute(
+ text(
+ """
+ INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`)
+ VALUES (:dataset, :status, :date, :user)
+ """,
+ ),
+ parameters={
+ "dataset": dataset_id,
+ "status": status,
+ "date": datetime.datetime.now(),
+ "user": user_id,
+ },
+ )
+
+
+def remove_deactivated_status(dataset_id: int, connection: Connection) -> None:
+ connection.execute(
+ text(
+ """
+ DELETE FROM dataset_status
+ WHERE `did` = :data AND `status`='deactivated'
+ """,
+ ),
+ parameters={"data": dataset_id},
+ )
diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py
index b10a841..5903a28 100644
--- a/src/routers/openml/datasets.py
+++ b/src/routers/openml/datasets.py
@@ -6,7 +6,7 @@
import re
from datetime import datetime
from enum import StrEnum
-from typing import Annotated, Any, NamedTuple
+from typing import Annotated, Any, Literal, NamedTuple
from core.access import _user_has_access
from core.errors import DatasetError
@@ -27,6 +27,8 @@
get_latest_status_update,
get_qualities_for_datasets,
get_tags,
+ insert_status_for_dataset,
+ remove_deactivated_status,
)
from database.datasets import tag_dataset as db_tag_dataset
from database.users import User, UserGroup
@@ -317,6 +319,61 @@ def get_dataset_features(
return features
+@router.post(
+ path="/status/update",
+)
+def update_dataset_status(
+ dataset_id: Annotated[int, Body()],
+ status: Annotated[Literal[DatasetStatus.ACTIVE] | Literal[DatasetStatus.DEACTIVATED], Body()],
+ user: Annotated[User | None, Depends(fetch_user)],
+ expdb: Annotated[Connection, Depends(expdb_connection)],
+) -> dict[str, str | int]:
+ if user is None:
+ raise HTTPException(
+ status_code=http.client.UNAUTHORIZED,
+ detail="Updating dataset status required authorization",
+ )
+
+ dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb)
+
+ can_deactivate = dataset["uploader"] == user.user_id or UserGroup.ADMIN in user.groups
+ if status == DatasetStatus.DEACTIVATED and not can_deactivate:
+ raise HTTPException(
+ status_code=http.client.FORBIDDEN,
+ detail={"code": 693, "message": "Dataset is not owned by you"},
+ )
+ if status == DatasetStatus.ACTIVE and UserGroup.ADMIN not in user.groups:
+ raise HTTPException(
+ status_code=http.client.FORBIDDEN,
+ detail={"code": 696, "message": "Only administrators can activate datasets."},
+ )
+
+ current_status = get_latest_status_update(dataset_id, expdb)
+ if current_status and current_status["status"] == status:
+ raise HTTPException(
+ status_code=http.client.PRECONDITION_FAILED,
+ detail={"code": 694, "message": "Illegal status transition."},
+ )
+
+ # If current status is unknown, it is effectively "in preparation",
+ # So the following transitions are allowed (first 3 transitions are first clause)
+ # - in preparation => active (add a row)
+ # - in preparation => deactivated (add a row)
+ # - active => deactivated (add a row)
+ # - deactivated => active (delete a row)
+ if current_status is None or status == DatasetStatus.DEACTIVATED:
+ insert_status_for_dataset(dataset_id, user.user_id, status, expdb)
+ elif current_status["status"] == DatasetStatus.DEACTIVATED:
+ remove_deactivated_status(dataset_id, expdb)
+ else:
+ raise HTTPException(
+ status_code=http.client.INTERNAL_SERVER_ERROR,
+ detail={"message": f"Unknown status transition: {current_status} -> {status}"},
+ )
+
+ return {"dataset_id": dataset_id, "status": status}
+
+
@router.get(
path="/{dataset_id}",
description="Get meta-data for dataset with ID `dataset_id`.",
diff --git a/tests/constants.py b/tests/constants.py
index 709e255..e471fd5 100644
--- a/tests/constants.py
+++ b/tests/constants.py
@@ -1,8 +1,13 @@
-NUMBER_OF_DATASETS = 131
-NUMBER_OF_DEACTIVATED_DATASETS = 1
-NUMBER_OF_DATASETS_IN_PREPARATION = 1
-NUMBER_OF_PRIVATE_DATASETS = 1
-NUMBER_OF_ACTIVE_DATASETS = 128
+PRIVATE_DATASET_ID = {130}
+IN_PREPARATION_ID = {1, 33}
+DEACTIVATED_DATASETS = {2, 131}
+DATASETS = set(range(1, 132))
-PRIVATE_DATASET_ID = 130
-IN_PREPARATION_ID = 1
+NUMBER_OF_DATASETS = len(DATASETS)
+NUMBER_OF_DEACTIVATED_DATASETS = len(DEACTIVATED_DATASETS)
+NUMBER_OF_DATASETS_IN_PREPARATION = len(IN_PREPARATION_ID)
+NUMBER_OF_PRIVATE_DATASETS = len(PRIVATE_DATASET_ID)
+NUMBER_OF_ACTIVE_DATASETS = (
+ NUMBER_OF_DATASETS - NUMBER_OF_DEACTIVATED_DATASETS - NUMBER_OF_DATASETS_IN_PREPARATION
+)
+NUMBER_OF_PUBLIC_ACTIVE_DATASETS = NUMBER_OF_ACTIVE_DATASETS - NUMBER_OF_PRIVATE_DATASETS
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index f315ffb..3196953 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -22,7 +22,7 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No
httpx.Response,
py_api.post(
f"/datasets/tag{apikey}",
- json={"data_id": constants.PRIVATE_DATASET_ID, "tag": "test"},
+ json={"data_id": list(constants.PRIVATE_DATASET_ID)[0], "tag": "test"},
),
)
assert response.status_code == http.client.PRECONDITION_FAILED
@@ -35,7 +35,7 @@ def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> No
ids=["administrator", "non-owner", "owner"],
)
def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> None:
- dataset_id, tag = constants.PRIVATE_DATASET_ID, "test"
+ dataset_id, tag = list(constants.PRIVATE_DATASET_ID)[0], "test"
response = cast(
httpx.Response,
py_api.post(
@@ -46,7 +46,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) ->
assert response.status_code == http.client.OK
assert {"data_tag": {"id": str(dataset_id), "tag": tag}} == response.json()
- tags = get_tags(dataset_id=constants.PRIVATE_DATASET_ID, connection=expdb_test)
+ tags = get_tags(dataset_id=dataset_id, connection=expdb_test)
assert tag in tags
diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py
index c91960a..c05e9f2 100644
--- a/tests/routers/openml/datasets_list_datasets_test.py
+++ b/tests/routers/openml/datasets_list_datasets_test.py
@@ -28,7 +28,7 @@ def test_list(py_api: TestClient) -> None:
@pytest.mark.parametrize(
("status", "amount"),
[
- ("active", constants.NUMBER_OF_ACTIVE_DATASETS),
+ ("active", constants.NUMBER_OF_PUBLIC_ACTIVE_DATASETS),
("deactivated", constants.NUMBER_OF_DEACTIVATED_DATASETS),
("in_preparation", constants.NUMBER_OF_DATASETS_IN_PREPARATION),
("all", constants.NUMBER_OF_DATASETS - constants.NUMBER_OF_PRIVATE_DATASETS),
@@ -96,7 +96,7 @@ def test_list_pagination(limit: int | None, offset: int | None, py_api: TestClie
all_ids = [
did
for did in range(1, 1 + constants.NUMBER_OF_DATASETS)
- if did not in [constants.PRIVATE_DATASET_ID]
+ if did not in constants.PRIVATE_DATASET_ID
]
start = 0 if offset is None else offset
diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py
index 52a8e71..37e8242 100644
--- a/tests/routers/openml/datasets_test.py
+++ b/tests/routers/openml/datasets_test.py
@@ -3,6 +3,7 @@
import httpx
import pytest
+from schemas.datasets.openml import DatasetStatus
from starlette.testclient import TestClient
from tests.conftest import ApiKey
@@ -144,3 +145,87 @@ def test_dataset_features_with_processing_error(py_api: TestClient) -> None:
def test_dataset_features_dataset_does_not_exist(py_api: TestClient) -> None:
resource = py_api.get("/datasets/features/1000")
assert resource.status_code == http.client.NOT_FOUND
+
+
+def _assert_status_update_is_successful(
+ apikey: ApiKey,
+ dataset_id: int,
+ status: str,
+ py_api: TestClient,
+) -> None:
+ response = py_api.post(
+ f"/datasets/status/update?api_key={apikey}",
+ json={"dataset_id": dataset_id, "status": status},
+ )
+ assert response.status_code == http.client.OK
+ assert response.json() == {
+ "dataset_id": dataset_id,
+ "status": status,
+ }
+
+
+@pytest.mark.mut()
+@pytest.mark.parametrize(
+ "dataset_id",
+ [3, 4],
+)
+def test_dataset_status_update_active_to_deactivated(dataset_id: int, py_api: TestClient) -> None:
+ _assert_status_update_is_successful(
+ apikey=ApiKey.ADMIN,
+ dataset_id=dataset_id,
+ status=DatasetStatus.DEACTIVATED,
+ py_api=py_api,
+ )
+
+
+@pytest.mark.mut()
+def test_dataset_status_update_in_preparation_to_active(py_api: TestClient) -> None:
+ _assert_status_update_is_successful(
+ apikey=ApiKey.ADMIN,
+ dataset_id=1,
+ status=DatasetStatus.ACTIVE,
+ py_api=py_api,
+ )
+
+
+@pytest.mark.mut()
+def test_dataset_status_update_in_preparation_to_deactivated(py_api: TestClient) -> None:
+ _assert_status_update_is_successful(
+ apikey=ApiKey.ADMIN,
+ dataset_id=1,
+ status=DatasetStatus.DEACTIVATED,
+ py_api=py_api,
+ )
+
+
+@pytest.mark.mut()
+def test_dataset_status_update_deactivated_to_active(py_api: TestClient) -> None:
+ _assert_status_update_is_successful(
+ apikey=ApiKey.ADMIN,
+ dataset_id=131,
+ status=DatasetStatus.ACTIVE,
+ py_api=py_api,
+ )
+
+
+@pytest.mark.parametrize(
+ ("dataset_id", "api_key", "status"),
+ [
+ (1, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE),
+ (1, ApiKey.REGULAR_USER, DatasetStatus.DEACTIVATED),
+ (2, ApiKey.REGULAR_USER, DatasetStatus.DEACTIVATED),
+ (33, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE),
+ (131, ApiKey.REGULAR_USER, DatasetStatus.ACTIVE),
+ ],
+)
+def test_dataset_status_unauthorized(
+ dataset_id: int,
+ api_key: ApiKey,
+ status: str,
+ py_api: TestClient,
+) -> None:
+ response = py_api.post(
+ f"/datasets/status/update?api_key={api_key}",
+ json={"dataset_id": dataset_id, "status": status},
+ )
+ assert response.status_code == http.client.FORBIDDEN