From 7664bfdd578db872f37e4b24e0504da3adb224cd Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Sun, 13 Oct 2024 20:29:12 +0200
Subject: [PATCH] Add verified user replacements
---
tests/conftest.py | 8 ----
tests/routers/openml/dataset_tag_test.py | 2 +-
.../openml/datasets_list_datasets_test.py | 2 +-
tests/routers/openml/datasets_test.py | 47 +++++++++++--------
.../migration/datasets_migration_test.py | 2 +-
tests/routers/openml/users_test.py | 39 +++++++++++++++
6 files changed, 69 insertions(+), 31 deletions(-)
create mode 100644 tests/routers/openml/users_test.py
diff --git a/tests/conftest.py b/tests/conftest.py
index 14e027a..4d2c2c9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,6 @@
import contextlib
import json
from collections.abc import Iterator
-from enum import StrEnum
from pathlib import Path
from typing import Any, NamedTuple
@@ -18,13 +17,6 @@
from routers.dependencies import expdb_connection, userdb_connection
-class ApiKey(StrEnum):
- ADMIN: str = "AD000000000000000000000000000000"
- REGULAR_USER: str = "00000000000000000000000000000000"
- OWNER_USER: str = "DA1A0000000000000000000000000000"
- INVALID: str = "11111111111111111111111111111111"
-
-
@contextlib.contextmanager
def automatic_rollback(engine: Engine) -> Iterator[Connection]:
with engine.connect() as connection:
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index c23aa3a..6334469 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -6,7 +6,7 @@
from database.datasets import get_tags_for
from tests import constants
-from tests.conftest import ApiKey
+from tests.routers.openml.users_test import ApiKey
@pytest.mark.parametrize(
diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py
index c5b1a96..a6c4ac6 100644
--- a/tests/routers/openml/datasets_list_datasets_test.py
+++ b/tests/routers/openml/datasets_list_datasets_test.py
@@ -9,7 +9,7 @@
from starlette.testclient import TestClient
from tests import constants
-from tests.conftest import ApiKey
+from tests.routers.openml.users_test import ApiKey
def _assert_empty_result(
diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py
index 820d52a..99a30e3 100644
--- a/tests/routers/openml/datasets_test.py
+++ b/tests/routers/openml/datasets_test.py
@@ -1,11 +1,14 @@
from http import HTTPStatus
-from typing import Any
import pytest
+from fastapi import HTTPException
+from sqlalchemy import Connection
from starlette.testclient import TestClient
+from database.users import User
+from routers.openml.datasets import get_dataset
from schemas.datasets.openml import DatasetStatus
-from tests.conftest import ApiKey
+from tests.routers.openml.users_test import NO_USER, SOME_USER, ApiKey
@pytest.mark.parametrize(
@@ -66,32 +69,36 @@ def test_get_dataset(py_api: TestClient) -> None:
@pytest.mark.parametrize(
- ("api_key", "response_code"),
+ "user",
[
- (None, HTTPStatus.FORBIDDEN),
- ("a" * 32, HTTPStatus.FORBIDDEN),
+ NO_USER,
+ SOME_USER,
],
)
-def test_private_dataset_no_user_no_access(
- py_api: TestClient,
- api_key: str | None,
- response_code: int,
+def test_private_dataset_no_owner_no_access(
+ user: User | None,
+ expdb_test: Connection,
) -> None:
- query = f"?api_key={api_key}" if api_key else ""
- response = py_api.get(f"/datasets/130{query}")
-
- assert response.status_code == response_code
- assert response.json()["detail"] == {"code": "112", "message": "No access granted"}
+ with pytest.raises(HTTPException) as e:
+ get_dataset(
+ dataset_id=130,
+ user=user,
+ user_db=None,
+ expdb_db=expdb_test,
+ )
+ assert e.value.status_code == HTTPStatus.FORBIDDEN
+ assert e.value.detail == {"code": "112", "message": "No access granted"} # type: ignore[comparison-overlap]
-@pytest.mark.skip("Not sure how to include apikey in test yet.")
def test_private_dataset_owner_access(
- py_api: TestClient,
- dataset_130: dict[str, Any],
+ owner: User, expdb_test: Connection, user_test: Connection
) -> None:
- response = py_api.get("/v2/datasets/130?api_key=...")
- assert response.status_code == HTTPStatus.OK
- assert dataset_130 == response.json()
+ get_dataset(
+ dataset_id=130,
+ user=owner,
+ user_db=user_test,
+ expdb_db=expdb_test,
+ )
@pytest.mark.skip("Not sure how to include apikey in test yet.")
diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py
index 5a67105..69c45ef 100644
--- a/tests/routers/openml/migration/datasets_migration_test.py
+++ b/tests/routers/openml/migration/datasets_migration_test.py
@@ -7,7 +7,7 @@
from starlette.testclient import TestClient
from core.conversions import nested_remove_single_element_list
-from tests.conftest import ApiKey
+from tests.routers.openml.users_test import ApiKey
@pytest.mark.parametrize(
diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py
new file mode 100644
index 0000000..3f8bb4d
--- /dev/null
+++ b/tests/routers/openml/users_test.py
@@ -0,0 +1,39 @@
+from enum import StrEnum
+
+import pytest
+from sqlalchemy import Connection
+
+from database.users import User, UserGroup
+from routers.dependencies import fetch_user
+
+NO_USER = None
+SOME_USER = User(user_id=2, _database=None, _groups=[UserGroup.READ_WRITE])
+OWNER_USER = User(user_id=16, _database=None, _groups=[UserGroup.READ_WRITE])
+ADMIN_USER = User(user_id=1, _database=None, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE])
+
+
+class ApiKey(StrEnum):
+ ADMIN: str = "AD000000000000000000000000000000"
+ REGULAR_USER: str = "00000000000000000000000000000000"
+ OWNER_USER: str = "DA1A0000000000000000000000000000"
+ INVALID: str = "11111111111111111111111111111111"
+
+
+@pytest.mark.parametrize(
+ ("api_key", "user"),
+ [
+ (ApiKey.ADMIN, ADMIN_USER),
+ (ApiKey.OWNER_USER, OWNER_USER),
+ (ApiKey.REGULAR_USER, SOME_USER),
+ ],
+)
+def test_fetch_user(api_key: str, user: User, user_test: Connection) -> None:
+ db_user = fetch_user(api_key, user_data=user_test)
+ assert user.user_id == db_user.user_id
+ assert user.groups == db_user.groups
+
+
+def test_fetch_user_invalid_key_returns_none(user_test: Connection) -> None:
+ assert fetch_user(api_key=None, user_data=user_test) is None
+ invalid_key = "f" * 32
+ assert fetch_user(api_key=invalid_key, user_data=user_test) is None