From db6f73b49754e8200b2ff8565f3f51910de41153 Mon Sep 17 00:00:00 2001 From: Pieter Gijsbers Date: Thu, 18 Jul 2024 11:54:37 +0200 Subject: [PATCH] Add/31 (#167) GET only * First implementation of flow exists * Rename function and favor verbose usage This makes it more clear where the data is actually coming from. * Make return value depend on database response * Test database layer is called correctly * Add database level test * Separate out database tests * Separate out api vs function tests * Add happy path migration test * Add migration test for flow_exists * Only rollback the transaction if one is active * Add migration test for flow_exist but no match * Fix type issue * Add pytest-mock dependency --- pyproject.toml | 1 + src/database/flows.py | 24 +++++-- src/routers/openml/flows.py | 30 +++++--- tests/conftest.py | 49 ++++++++++++- tests/database/flows_test.py | 20 ++++++ tests/routers/openml/flows_test.py | 69 +++++++++++++++++++ .../openml/migration/flows_migration_test.py | 39 +++++++++++ 7 files changed, 216 insertions(+), 16 deletions(-) create mode 100644 tests/database/flows_test.py diff --git a/pyproject.toml b/pyproject.toml index baed303..7ba4969 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ dev = [ "pre-commit", "pytest", + "pytest-mock", "httpx", "hypothesis", "deepdiff", diff --git a/src/database/flows.py b/src/database/flows.py index 4889f13..c6c8807 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -3,7 +3,7 @@ from sqlalchemy import Connection, Row, text -def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]: +def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]: return cast( Sequence[Row], expdb.execute( @@ -14,12 +14,12 @@ def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]: WHERE parent = :flow_id """, ), - parameters={"flow_id": flow_id}, + parameters={"flow_id": for_flow}, ), ) -def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]: +def get_tags(flow_id: int, expdb: Connection) -> list[str]: tag_rows = expdb.execute( text( """ @@ -33,7 +33,7 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]: return [tag.tag for tag in tag_rows] -def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: +def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: return cast( Sequence[Row], expdb.execute( @@ -49,7 +49,21 @@ def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: ) -def get_flow(flow_id: int, expdb: Connection) -> Row | None: +def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None: + """Gets flow by name and external version.""" + return expdb.execute( + text( + """ + SELECT *, uploadDate as upload_date + FROM implementation + WHERE name = :name AND external_version = :external_version + """, + ), + parameters={"name": name, "external_version": external_version}, + ).one_or_none() + + +def get_by_id(flow_id: int, expdb: Connection) -> Row | None: return expdb.execute( text( """ diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index ca95102..9b73084 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -1,9 +1,8 @@ import http.client -from typing import Annotated +from typing import Annotated, Literal +import database.flows from core.conversions import _str_to_num -from database.flows import get_flow as db_get_flow -from database.flows import get_flow_parameters, get_flow_subflows, get_flow_tags from fastapi import APIRouter, Depends, HTTPException from schemas.flows import Flow, Parameter from sqlalchemy import Connection @@ -13,13 +12,29 @@ router = APIRouter(prefix="/flows", tags=["flows"]) +@router.get("/exists/{name}/{external_version}") +def flow_exists( + name: str, + external_version: str, + expdb: Annotated[Connection, Depends(expdb_connection)], +) -> dict[Literal["flow_id"], int]: + """Check if a Flow with the name and version exists, if so, return the flow id.""" + flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb) + if flow is None: + raise HTTPException( + status_code=http.client.NOT_FOUND, + detail="Flow not found.", + ) + return {"flow_id": flow.id} + + @router.get("/{flow_id}") def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow: - flow = db_get_flow(flow_id, expdb) + flow = database.flows.get_by_id(flow_id, expdb) if not flow: raise HTTPException(status_code=http.client.NOT_FOUND, detail="Flow not found") - parameter_rows = get_flow_parameters(flow_id, expdb) + parameter_rows = database.flows.get_parameters(flow_id, expdb) parameters = [ Parameter( name=parameter.name, @@ -33,9 +48,8 @@ def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection for parameter in parameter_rows ] - tags = get_flow_tags(flow_id, expdb) - - flow_rows = get_flow_subflows(flow_id, expdb) + tags = database.flows.get_tags(flow_id, expdb) + flow_rows = database.flows.get_subflows(for_flow=flow_id, expdb=expdb) subflows = [ { "identifier": flow.identifier, diff --git a/tests/conftest.py b/tests/conftest.py index 6a80b4f..47f0f39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import json from enum import StrEnum from pathlib import Path -from typing import Any, Iterator +from typing import Any, Iterator, NamedTuple import httpx import pytest @@ -10,7 +10,7 @@ from fastapi.testclient import TestClient from main import create_api from routers.dependencies import expdb_connection, userdb_connection -from sqlalchemy import Connection, Engine +from sqlalchemy import Connection, Engine, text class ApiKey(StrEnum): @@ -25,7 +25,8 @@ def automatic_rollback(engine: Engine) -> Iterator[Connection]: with engine.connect() as connection: transaction = connection.begin() yield connection - transaction.rollback() + if transaction.is_active: + transaction.rollback() @pytest.fixture() @@ -65,3 +66,45 @@ def dataset_130() -> Iterator[dict[str, Any]]: @pytest.fixture() def default_configuration_file() -> Path: return Path().parent.parent / "src" / "config.toml" + + +class Flow(NamedTuple): + """To be replaced by an actual ORM class.""" + + id: int + name: str + external_version: str + + +@pytest.fixture() +def flow(expdb_test: Connection) -> Flow: + expdb_test.execute( + text( + """ + INSERT INTO implementation(fullname,name,version,external_version,uploadDate) + VALUES ('a','name',2,'external_version','2024-02-02 02:23:23'); + """, + ), + ) + (flow_id,) = expdb_test.execute(text("""SELECT LAST_INSERT_ID();""")).one() + return Flow(id=flow_id, name="name", external_version="external_version") + + +@pytest.fixture() +def persisted_flow(flow: Flow, expdb_test: Connection) -> Iterator[Flow]: + expdb_test.commit() + yield flow + # We want to ensure the commit below does not accidentally persist new + # data to the database. + expdb_test.rollback() + + expdb_test.execute( + text( + """ + DELETE FROM implementation + WHERE id = :flow_id + """, + ), + parameters={"flow_id": flow.id}, + ) + expdb_test.commit() diff --git a/tests/database/flows_test.py b/tests/database/flows_test.py new file mode 100644 index 0000000..7438430 --- /dev/null +++ b/tests/database/flows_test.py @@ -0,0 +1,20 @@ +import database.flows +from sqlalchemy import Connection + +from tests.conftest import Flow + + +def test_database_flow_exists(flow: Flow, expdb_test: Connection) -> None: + retrieved_flow = database.flows.get_by_name(flow.name, flow.external_version, expdb_test) + assert retrieved_flow is not None + assert retrieved_flow.id == flow.id + # when using actual ORM, can instead ensure _all_ fields match. + + +def test_database_flow_exists_returns_none_if_no_match(expdb_test: Connection) -> None: + retrieved_flow = database.flows.get_by_name( + name="foo", + external_version="bar", + expdb=expdb_test, + ) + assert retrieved_flow is None diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index b957060..60fe559 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -1,6 +1,75 @@ +import http.client + import deepdiff.diff +import pytest +from fastapi import HTTPException +from pytest_mock import MockerFixture +from routers.openml.flows import flow_exists +from sqlalchemy import Connection from starlette.testclient import TestClient +from tests.conftest import Flow + + +@pytest.mark.parametrize( + ("name", "external_version"), + [ + ("a", "b"), + ("c", "d"), + ], +) +def test_flow_exists_calls_db_correctly( + name: str, + external_version: str, + expdb_test: Connection, + mocker: MockerFixture, +) -> None: + mocked_db = mocker.patch("database.flows.get_by_name") + flow_exists(name, external_version, expdb_test) + mocked_db.assert_called_once_with( + name=name, + external_version=external_version, + expdb=mocker.ANY, + ) + + +@pytest.mark.parametrize( + "flow_id", + [1, 2], +) +def test_flow_exists_processes_found( + flow_id: int, + mocker: MockerFixture, + expdb_test: Connection, +) -> None: + fake_flow = mocker.MagicMock(id=flow_id) + mocker.patch( + "database.flows.get_by_name", + return_value=fake_flow, + ) + response = flow_exists("name", "external_version", expdb_test) + assert response == {"flow_id": fake_flow.id} + + +def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None: + mocker.patch("database.flows.get_by_name", return_value=None) + with pytest.raises(HTTPException) as error: + flow_exists("foo", "bar", expdb_test) + assert error.value.status_code == http.client.NOT_FOUND + assert error.value.detail == "Flow not found." + + +def test_flow_exists(flow: Flow, py_api: TestClient) -> None: + response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + assert response.status_code == http.client.OK + assert response.json() == {"flow_id": flow.id} + + +def test_flow_exists_not_exists(py_api: TestClient) -> None: + response = py_api.get("/flows/exists/foo/bar") + assert response.status_code == http.client.NOT_FOUND + assert response.json()["detail"] == "Flow not found." + def test_get_flow_no_subflow(py_api: TestClient) -> None: response = py_api.get("/flows/1") diff --git a/tests/routers/openml/migration/flows_migration_test.py b/tests/routers/openml/migration/flows_migration_test.py index e1cb5e8..70f80d8 100644 --- a/tests/routers/openml/migration/flows_migration_test.py +++ b/tests/routers/openml/migration/flows_migration_test.py @@ -1,3 +1,4 @@ +import http.client from typing import Any import deepdiff @@ -9,6 +10,44 @@ ) from starlette.testclient import TestClient +from tests.conftest import Flow + + +@pytest.mark.mut() +@pytest.mark.php() +def test_flow_exists_not( + py_api: TestClient, + php_api: TestClient, +) -> None: + path = "exists/foo/bar" + py_response = py_api.get(f"/flows/{path}") + php_response = php_api.get(f"/flow/{path}") + + assert py_response.status_code == http.client.NOT_FOUND + assert php_response.status_code == http.client.OK + + expect_php = {"flow_exists": {"exists": "false", "id": str(-1)}} + assert php_response.json() == expect_php + assert py_response.json() == {"detail": "Flow not found."} + + +@pytest.mark.mut() +@pytest.mark.php() +def test_flow_exists( + persisted_flow: Flow, + py_api: TestClient, + php_api: TestClient, +) -> None: + path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}" + py_response = py_api.get(f"/flows/{path}") + php_response = php_api.get(f"/flow/{path}") + + assert py_response.status_code == php_response.status_code, php_response.content + + expect_php = {"flow_exists": {"exists": "true", "id": str(persisted_flow.id)}} + assert php_response.json() == expect_php + assert py_response.json() == {"flow_id": persisted_flow.id} + @pytest.mark.php() @pytest.mark.parametrize(