From db6f73b49754e8200b2ff8565f3f51910de41153 Mon Sep 17 00:00:00 2001
From: Pieter Gijsbers
Date: Thu, 18 Jul 2024 11:54:37 +0200
Subject: [PATCH] Add/31 (#167) GET only
* First implementation of flow exists
* Rename function and favor verbose usage
This makes it more clear where the data is actually coming from.
* Make return value depend on database response
* Test database layer is called correctly
* Add database level test
* Separate out database tests
* Separate out api vs function tests
* Add happy path migration test
* Add migration test for flow_exists
* Only rollback the transaction if one is active
* Add migration test for flow_exist but no match
* Fix type issue
* Add pytest-mock dependency
---
pyproject.toml | 1 +
src/database/flows.py | 24 +++++--
src/routers/openml/flows.py | 30 +++++---
tests/conftest.py | 49 ++++++++++++-
tests/database/flows_test.py | 20 ++++++
tests/routers/openml/flows_test.py | 69 +++++++++++++++++++
.../openml/migration/flows_migration_test.py | 39 +++++++++++
7 files changed, 216 insertions(+), 16 deletions(-)
create mode 100644 tests/database/flows_test.py
diff --git a/pyproject.toml b/pyproject.toml
index baed303..7ba4969 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,7 @@ dependencies = [
dev = [
"pre-commit",
"pytest",
+ "pytest-mock",
"httpx",
"hypothesis",
"deepdiff",
diff --git a/src/database/flows.py b/src/database/flows.py
index 4889f13..c6c8807 100644
--- a/src/database/flows.py
+++ b/src/database/flows.py
@@ -3,7 +3,7 @@
from sqlalchemy import Connection, Row, text
-def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
+def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
@@ -14,12 +14,12 @@ def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
WHERE parent = :flow_id
""",
),
- parameters={"flow_id": flow_id},
+ parameters={"flow_id": for_flow},
),
)
-def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
+def get_tags(flow_id: int, expdb: Connection) -> list[str]:
tag_rows = expdb.execute(
text(
"""
@@ -33,7 +33,7 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
return [tag.tag for tag in tag_rows]
-def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
+def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
@@ -49,7 +49,21 @@ def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
)
-def get_flow(flow_id: int, expdb: Connection) -> Row | None:
+def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None:
+ """Gets flow by name and external version."""
+ return expdb.execute(
+ text(
+ """
+ SELECT *, uploadDate as upload_date
+ FROM implementation
+ WHERE name = :name AND external_version = :external_version
+ """,
+ ),
+ parameters={"name": name, "external_version": external_version},
+ ).one_or_none()
+
+
+def get_by_id(flow_id: int, expdb: Connection) -> Row | None:
return expdb.execute(
text(
"""
diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py
index ca95102..9b73084 100644
--- a/src/routers/openml/flows.py
+++ b/src/routers/openml/flows.py
@@ -1,9 +1,8 @@
import http.client
-from typing import Annotated
+from typing import Annotated, Literal
+import database.flows
from core.conversions import _str_to_num
-from database.flows import get_flow as db_get_flow
-from database.flows import get_flow_parameters, get_flow_subflows, get_flow_tags
from fastapi import APIRouter, Depends, HTTPException
from schemas.flows import Flow, Parameter
from sqlalchemy import Connection
@@ -13,13 +12,29 @@
router = APIRouter(prefix="/flows", tags=["flows"])
+@router.get("/exists/{name}/{external_version}")
+def flow_exists(
+ name: str,
+ external_version: str,
+ expdb: Annotated[Connection, Depends(expdb_connection)],
+) -> dict[Literal["flow_id"], int]:
+ """Check if a Flow with the name and version exists, if so, return the flow id."""
+ flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb)
+ if flow is None:
+ raise HTTPException(
+ status_code=http.client.NOT_FOUND,
+ detail="Flow not found.",
+ )
+ return {"flow_id": flow.id}
+
+
@router.get("/{flow_id}")
def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow:
- flow = db_get_flow(flow_id, expdb)
+ flow = database.flows.get_by_id(flow_id, expdb)
if not flow:
raise HTTPException(status_code=http.client.NOT_FOUND, detail="Flow not found")
- parameter_rows = get_flow_parameters(flow_id, expdb)
+ parameter_rows = database.flows.get_parameters(flow_id, expdb)
parameters = [
Parameter(
name=parameter.name,
@@ -33,9 +48,8 @@ def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection
for parameter in parameter_rows
]
- tags = get_flow_tags(flow_id, expdb)
-
- flow_rows = get_flow_subflows(flow_id, expdb)
+ tags = database.flows.get_tags(flow_id, expdb)
+ flow_rows = database.flows.get_subflows(for_flow=flow_id, expdb=expdb)
subflows = [
{
"identifier": flow.identifier,
diff --git a/tests/conftest.py b/tests/conftest.py
index 6a80b4f..47f0f39 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2,7 +2,7 @@
import json
from enum import StrEnum
from pathlib import Path
-from typing import Any, Iterator
+from typing import Any, Iterator, NamedTuple
import httpx
import pytest
@@ -10,7 +10,7 @@
from fastapi.testclient import TestClient
from main import create_api
from routers.dependencies import expdb_connection, userdb_connection
-from sqlalchemy import Connection, Engine
+from sqlalchemy import Connection, Engine, text
class ApiKey(StrEnum):
@@ -25,7 +25,8 @@ def automatic_rollback(engine: Engine) -> Iterator[Connection]:
with engine.connect() as connection:
transaction = connection.begin()
yield connection
- transaction.rollback()
+ if transaction.is_active:
+ transaction.rollback()
@pytest.fixture()
@@ -65,3 +66,45 @@ def dataset_130() -> Iterator[dict[str, Any]]:
@pytest.fixture()
def default_configuration_file() -> Path:
return Path().parent.parent / "src" / "config.toml"
+
+
+class Flow(NamedTuple):
+ """To be replaced by an actual ORM class."""
+
+ id: int
+ name: str
+ external_version: str
+
+
+@pytest.fixture()
+def flow(expdb_test: Connection) -> Flow:
+ expdb_test.execute(
+ text(
+ """
+ INSERT INTO implementation(fullname,name,version,external_version,uploadDate)
+ VALUES ('a','name',2,'external_version','2024-02-02 02:23:23');
+ """,
+ ),
+ )
+ (flow_id,) = expdb_test.execute(text("""SELECT LAST_INSERT_ID();""")).one()
+ return Flow(id=flow_id, name="name", external_version="external_version")
+
+
+@pytest.fixture()
+def persisted_flow(flow: Flow, expdb_test: Connection) -> Iterator[Flow]:
+ expdb_test.commit()
+ yield flow
+ # We want to ensure the commit below does not accidentally persist new
+ # data to the database.
+ expdb_test.rollback()
+
+ expdb_test.execute(
+ text(
+ """
+ DELETE FROM implementation
+ WHERE id = :flow_id
+ """,
+ ),
+ parameters={"flow_id": flow.id},
+ )
+ expdb_test.commit()
diff --git a/tests/database/flows_test.py b/tests/database/flows_test.py
new file mode 100644
index 0000000..7438430
--- /dev/null
+++ b/tests/database/flows_test.py
@@ -0,0 +1,20 @@
+import database.flows
+from sqlalchemy import Connection
+
+from tests.conftest import Flow
+
+
+def test_database_flow_exists(flow: Flow, expdb_test: Connection) -> None:
+ retrieved_flow = database.flows.get_by_name(flow.name, flow.external_version, expdb_test)
+ assert retrieved_flow is not None
+ assert retrieved_flow.id == flow.id
+ # when using actual ORM, can instead ensure _all_ fields match.
+
+
+def test_database_flow_exists_returns_none_if_no_match(expdb_test: Connection) -> None:
+ retrieved_flow = database.flows.get_by_name(
+ name="foo",
+ external_version="bar",
+ expdb=expdb_test,
+ )
+ assert retrieved_flow is None
diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py
index b957060..60fe559 100644
--- a/tests/routers/openml/flows_test.py
+++ b/tests/routers/openml/flows_test.py
@@ -1,6 +1,75 @@
+import http.client
+
import deepdiff.diff
+import pytest
+from fastapi import HTTPException
+from pytest_mock import MockerFixture
+from routers.openml.flows import flow_exists
+from sqlalchemy import Connection
from starlette.testclient import TestClient
+from tests.conftest import Flow
+
+
+@pytest.mark.parametrize(
+ ("name", "external_version"),
+ [
+ ("a", "b"),
+ ("c", "d"),
+ ],
+)
+def test_flow_exists_calls_db_correctly(
+ name: str,
+ external_version: str,
+ expdb_test: Connection,
+ mocker: MockerFixture,
+) -> None:
+ mocked_db = mocker.patch("database.flows.get_by_name")
+ flow_exists(name, external_version, expdb_test)
+ mocked_db.assert_called_once_with(
+ name=name,
+ external_version=external_version,
+ expdb=mocker.ANY,
+ )
+
+
+@pytest.mark.parametrize(
+ "flow_id",
+ [1, 2],
+)
+def test_flow_exists_processes_found(
+ flow_id: int,
+ mocker: MockerFixture,
+ expdb_test: Connection,
+) -> None:
+ fake_flow = mocker.MagicMock(id=flow_id)
+ mocker.patch(
+ "database.flows.get_by_name",
+ return_value=fake_flow,
+ )
+ response = flow_exists("name", "external_version", expdb_test)
+ assert response == {"flow_id": fake_flow.id}
+
+
+def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None:
+ mocker.patch("database.flows.get_by_name", return_value=None)
+ with pytest.raises(HTTPException) as error:
+ flow_exists("foo", "bar", expdb_test)
+ assert error.value.status_code == http.client.NOT_FOUND
+ assert error.value.detail == "Flow not found."
+
+
+def test_flow_exists(flow: Flow, py_api: TestClient) -> None:
+ response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
+ assert response.status_code == http.client.OK
+ assert response.json() == {"flow_id": flow.id}
+
+
+def test_flow_exists_not_exists(py_api: TestClient) -> None:
+ response = py_api.get("/flows/exists/foo/bar")
+ assert response.status_code == http.client.NOT_FOUND
+ assert response.json()["detail"] == "Flow not found."
+
def test_get_flow_no_subflow(py_api: TestClient) -> None:
response = py_api.get("/flows/1")
diff --git a/tests/routers/openml/migration/flows_migration_test.py b/tests/routers/openml/migration/flows_migration_test.py
index e1cb5e8..70f80d8 100644
--- a/tests/routers/openml/migration/flows_migration_test.py
+++ b/tests/routers/openml/migration/flows_migration_test.py
@@ -1,3 +1,4 @@
+import http.client
from typing import Any
import deepdiff
@@ -9,6 +10,44 @@
)
from starlette.testclient import TestClient
+from tests.conftest import Flow
+
+
+@pytest.mark.mut()
+@pytest.mark.php()
+def test_flow_exists_not(
+ py_api: TestClient,
+ php_api: TestClient,
+) -> None:
+ path = "exists/foo/bar"
+ py_response = py_api.get(f"/flows/{path}")
+ php_response = php_api.get(f"/flow/{path}")
+
+ assert py_response.status_code == http.client.NOT_FOUND
+ assert php_response.status_code == http.client.OK
+
+ expect_php = {"flow_exists": {"exists": "false", "id": str(-1)}}
+ assert php_response.json() == expect_php
+ assert py_response.json() == {"detail": "Flow not found."}
+
+
+@pytest.mark.mut()
+@pytest.mark.php()
+def test_flow_exists(
+ persisted_flow: Flow,
+ py_api: TestClient,
+ php_api: TestClient,
+) -> None:
+ path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}"
+ py_response = py_api.get(f"/flows/{path}")
+ php_response = php_api.get(f"/flow/{path}")
+
+ assert py_response.status_code == php_response.status_code, php_response.content
+
+ expect_php = {"flow_exists": {"exists": "true", "id": str(persisted_flow.id)}}
+ assert php_response.json() == expect_php
+ assert py_response.json() == {"flow_id": persisted_flow.id}
+
@pytest.mark.php()
@pytest.mark.parametrize(