From 7664bfdd578db872f37e4b24e0504da3adb224cd Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Sun, 13 Oct 2024 20:29:12 +0200 Subject: [PATCH] Add verified user replacements --- tests/conftest.py | 8 ---- tests/routers/openml/dataset_tag_test.py | 2 +- .../openml/datasets_list_datasets_test.py | 2 +- tests/routers/openml/datasets_test.py | 47 +++++++++++-------- .../migration/datasets_migration_test.py | 2 +- tests/routers/openml/users_test.py | 39 +++++++++++++++ 6 files changed, 69 insertions(+), 31 deletions(-) create mode 100644 tests/routers/openml/users_test.py diff --git a/tests/conftest.py b/tests/conftest.py index 14e027a..4d2c2c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import contextlib import json from collections.abc import Iterator -from enum import StrEnum from pathlib import Path from typing import Any, NamedTuple @@ -18,13 +17,6 @@ from routers.dependencies import expdb_connection, userdb_connection -class ApiKey(StrEnum): - ADMIN: str = "AD000000000000000000000000000000" - REGULAR_USER: str = "00000000000000000000000000000000" - OWNER_USER: str = "DA1A0000000000000000000000000000" - INVALID: str = "11111111111111111111111111111111" - - @contextlib.contextmanager def automatic_rollback(engine: Engine) -> Iterator[Connection]: with engine.connect() as connection: diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index c23aa3a..6334469 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -6,7 +6,7 @@ from database.datasets import get_tags_for from tests import constants -from tests.conftest import ApiKey +from tests.routers.openml.users_test import ApiKey @pytest.mark.parametrize( diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index c5b1a96..a6c4ac6 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -9,7 +9,7 @@ from starlette.testclient import TestClient from tests import constants -from tests.conftest import ApiKey +from tests.routers.openml.users_test import ApiKey def _assert_empty_result( diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 820d52a..99a30e3 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -1,11 +1,14 @@ from http import HTTPStatus -from typing import Any import pytest +from fastapi import HTTPException +from sqlalchemy import Connection from starlette.testclient import TestClient +from database.users import User +from routers.openml.datasets import get_dataset from schemas.datasets.openml import DatasetStatus -from tests.conftest import ApiKey +from tests.routers.openml.users_test import NO_USER, SOME_USER, ApiKey @pytest.mark.parametrize( @@ -66,32 +69,36 @@ def test_get_dataset(py_api: TestClient) -> None: @pytest.mark.parametrize( - ("api_key", "response_code"), + "user", [ - (None, HTTPStatus.FORBIDDEN), - ("a" * 32, HTTPStatus.FORBIDDEN), + NO_USER, + SOME_USER, ], ) -def test_private_dataset_no_user_no_access( - py_api: TestClient, - api_key: str | None, - response_code: int, +def test_private_dataset_no_owner_no_access( + user: User | None, + expdb_test: Connection, ) -> None: - query = f"?api_key={api_key}" if api_key else "" - response = py_api.get(f"/datasets/130{query}") - - assert response.status_code == response_code - assert response.json()["detail"] == {"code": "112", "message": "No access granted"} + with pytest.raises(HTTPException) as e: + get_dataset( + dataset_id=130, + user=user, + user_db=None, + expdb_db=expdb_test, + ) + assert e.value.status_code == HTTPStatus.FORBIDDEN + assert e.value.detail == {"code": "112", "message": "No access granted"} # type: ignore[comparison-overlap] -@pytest.mark.skip("Not sure how to include apikey in test yet.") def test_private_dataset_owner_access( - py_api: TestClient, - dataset_130: dict[str, Any], + owner: User, expdb_test: Connection, user_test: Connection ) -> None: - response = py_api.get("/v2/datasets/130?api_key=...") - assert response.status_code == HTTPStatus.OK - assert dataset_130 == response.json() + get_dataset( + dataset_id=130, + user=owner, + user_db=user_test, + expdb_db=expdb_test, + ) @pytest.mark.skip("Not sure how to include apikey in test yet.") diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index 5a67105..69c45ef 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -7,7 +7,7 @@ from starlette.testclient import TestClient from core.conversions import nested_remove_single_element_list -from tests.conftest import ApiKey +from tests.routers.openml.users_test import ApiKey @pytest.mark.parametrize( diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py new file mode 100644 index 0000000..3f8bb4d --- /dev/null +++ b/tests/routers/openml/users_test.py @@ -0,0 +1,39 @@ +from enum import StrEnum + +import pytest +from sqlalchemy import Connection + +from database.users import User, UserGroup +from routers.dependencies import fetch_user + +NO_USER = None +SOME_USER = User(user_id=2, _database=None, _groups=[UserGroup.READ_WRITE]) +OWNER_USER = User(user_id=16, _database=None, _groups=[UserGroup.READ_WRITE]) +ADMIN_USER = User(user_id=1, _database=None, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE]) + + +class ApiKey(StrEnum): + ADMIN: str = "AD000000000000000000000000000000" + REGULAR_USER: str = "00000000000000000000000000000000" + OWNER_USER: str = "DA1A0000000000000000000000000000" + INVALID: str = "11111111111111111111111111111111" + + +@pytest.mark.parametrize( + ("api_key", "user"), + [ + (ApiKey.ADMIN, ADMIN_USER), + (ApiKey.OWNER_USER, OWNER_USER), + (ApiKey.REGULAR_USER, SOME_USER), + ], +) +def test_fetch_user(api_key: str, user: User, user_test: Connection) -> None: + db_user = fetch_user(api_key, user_data=user_test) + assert user.user_id == db_user.user_id + assert user.groups == db_user.groups + + +def test_fetch_user_invalid_key_returns_none(user_test: Connection) -> None: + assert fetch_user(api_key=None, user_data=user_test) is None + invalid_key = "f" * 32 + assert fetch_user(api_key=invalid_key, user_data=user_test) is None