Skip to content

Commit

Permalink
Add/31 (#167) GET only
Browse files Browse the repository at this point in the history
* First implementation of flow exists

* Rename function and favor verbose usage

This makes it more clear where the data is actually coming from.

* Make return value depend on database response

* Test database layer is called correctly

* Add database level test

* Separate out database tests

* Separate out api vs function tests

* Add happy path migration test

* Add migration test for flow_exists

* Only rollback the transaction if one is active

* Add migration test for flow_exist but no match

* Fix type issue

* Add pytest-mock dependency
  • Loading branch information
PGijsbers authored Jul 18, 2024
1 parent 4c0965d commit db6f73b
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 16 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
dev = [
"pre-commit",
"pytest",
"pytest-mock",
"httpx",
"hypothesis",
"deepdiff",
Expand Down
24 changes: 19 additions & 5 deletions src/database/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy import Connection, Row, text


def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
Expand All @@ -14,12 +14,12 @@ def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
WHERE parent = :flow_id
""",
),
parameters={"flow_id": flow_id},
parameters={"flow_id": for_flow},
),
)


def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
def get_tags(flow_id: int, expdb: Connection) -> list[str]:
tag_rows = expdb.execute(
text(
"""
Expand All @@ -33,7 +33,7 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
return [tag.tag for tag in tag_rows]


def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
Expand All @@ -49,7 +49,21 @@ def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
)


def get_flow(flow_id: int, expdb: Connection) -> Row | None:
def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None:
"""Gets flow by name and external version."""
return expdb.execute(
text(
"""
SELECT *, uploadDate as upload_date
FROM implementation
WHERE name = :name AND external_version = :external_version
""",
),
parameters={"name": name, "external_version": external_version},
).one_or_none()


def get_by_id(flow_id: int, expdb: Connection) -> Row | None:
return expdb.execute(
text(
"""
Expand Down
30 changes: 22 additions & 8 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import http.client
from typing import Annotated
from typing import Annotated, Literal

import database.flows
from core.conversions import _str_to_num
from database.flows import get_flow as db_get_flow
from database.flows import get_flow_parameters, get_flow_subflows, get_flow_tags
from fastapi import APIRouter, Depends, HTTPException
from schemas.flows import Flow, Parameter
from sqlalchemy import Connection
Expand All @@ -13,13 +12,29 @@
router = APIRouter(prefix="/flows", tags=["flows"])


@router.get("/exists/{name}/{external_version}")
def flow_exists(
name: str,
external_version: str,
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Check if a Flow with the name and version exists, if so, return the flow id."""
flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb)
if flow is None:
raise HTTPException(
status_code=http.client.NOT_FOUND,
detail="Flow not found.",
)
return {"flow_id": flow.id}


@router.get("/{flow_id}")
def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow:
flow = db_get_flow(flow_id, expdb)
flow = database.flows.get_by_id(flow_id, expdb)
if not flow:
raise HTTPException(status_code=http.client.NOT_FOUND, detail="Flow not found")

parameter_rows = get_flow_parameters(flow_id, expdb)
parameter_rows = database.flows.get_parameters(flow_id, expdb)
parameters = [
Parameter(
name=parameter.name,
Expand All @@ -33,9 +48,8 @@ def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection
for parameter in parameter_rows
]

tags = get_flow_tags(flow_id, expdb)

flow_rows = get_flow_subflows(flow_id, expdb)
tags = database.flows.get_tags(flow_id, expdb)
flow_rows = database.flows.get_subflows(for_flow=flow_id, expdb=expdb)
subflows = [
{
"identifier": flow.identifier,
Expand Down
49 changes: 46 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import json
from enum import StrEnum
from pathlib import Path
from typing import Any, Iterator
from typing import Any, Iterator, NamedTuple

import httpx
import pytest
from database.setup import expdb_database, user_database
from fastapi.testclient import TestClient
from main import create_api
from routers.dependencies import expdb_connection, userdb_connection
from sqlalchemy import Connection, Engine
from sqlalchemy import Connection, Engine, text


class ApiKey(StrEnum):
Expand All @@ -25,7 +25,8 @@ def automatic_rollback(engine: Engine) -> Iterator[Connection]:
with engine.connect() as connection:
transaction = connection.begin()
yield connection
transaction.rollback()
if transaction.is_active:
transaction.rollback()


@pytest.fixture()
Expand Down Expand Up @@ -65,3 +66,45 @@ def dataset_130() -> Iterator[dict[str, Any]]:
@pytest.fixture()
def default_configuration_file() -> Path:
return Path().parent.parent / "src" / "config.toml"


class Flow(NamedTuple):
"""To be replaced by an actual ORM class."""

id: int
name: str
external_version: str


@pytest.fixture()
def flow(expdb_test: Connection) -> Flow:
expdb_test.execute(
text(
"""
INSERT INTO implementation(fullname,name,version,external_version,uploadDate)
VALUES ('a','name',2,'external_version','2024-02-02 02:23:23');
""",
),
)
(flow_id,) = expdb_test.execute(text("""SELECT LAST_INSERT_ID();""")).one()
return Flow(id=flow_id, name="name", external_version="external_version")


@pytest.fixture()
def persisted_flow(flow: Flow, expdb_test: Connection) -> Iterator[Flow]:
expdb_test.commit()
yield flow
# We want to ensure the commit below does not accidentally persist new
# data to the database.
expdb_test.rollback()

expdb_test.execute(
text(
"""
DELETE FROM implementation
WHERE id = :flow_id
""",
),
parameters={"flow_id": flow.id},
)
expdb_test.commit()
20 changes: 20 additions & 0 deletions tests/database/flows_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import database.flows
from sqlalchemy import Connection

from tests.conftest import Flow


def test_database_flow_exists(flow: Flow, expdb_test: Connection) -> None:
retrieved_flow = database.flows.get_by_name(flow.name, flow.external_version, expdb_test)
assert retrieved_flow is not None
assert retrieved_flow.id == flow.id
# when using actual ORM, can instead ensure _all_ fields match.


def test_database_flow_exists_returns_none_if_no_match(expdb_test: Connection) -> None:
retrieved_flow = database.flows.get_by_name(
name="foo",
external_version="bar",
expdb=expdb_test,
)
assert retrieved_flow is None
69 changes: 69 additions & 0 deletions tests/routers/openml/flows_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,75 @@
import http.client

import deepdiff.diff
import pytest
from fastapi import HTTPException
from pytest_mock import MockerFixture
from routers.openml.flows import flow_exists
from sqlalchemy import Connection
from starlette.testclient import TestClient

from tests.conftest import Flow


@pytest.mark.parametrize(
("name", "external_version"),
[
("a", "b"),
("c", "d"),
],
)
def test_flow_exists_calls_db_correctly(
name: str,
external_version: str,
expdb_test: Connection,
mocker: MockerFixture,
) -> None:
mocked_db = mocker.patch("database.flows.get_by_name")
flow_exists(name, external_version, expdb_test)
mocked_db.assert_called_once_with(
name=name,
external_version=external_version,
expdb=mocker.ANY,
)


@pytest.mark.parametrize(
"flow_id",
[1, 2],
)
def test_flow_exists_processes_found(
flow_id: int,
mocker: MockerFixture,
expdb_test: Connection,
) -> None:
fake_flow = mocker.MagicMock(id=flow_id)
mocker.patch(
"database.flows.get_by_name",
return_value=fake_flow,
)
response = flow_exists("name", "external_version", expdb_test)
assert response == {"flow_id": fake_flow.id}


def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None:
mocker.patch("database.flows.get_by_name", return_value=None)
with pytest.raises(HTTPException) as error:
flow_exists("foo", "bar", expdb_test)
assert error.value.status_code == http.client.NOT_FOUND
assert error.value.detail == "Flow not found."


def test_flow_exists(flow: Flow, py_api: TestClient) -> None:
response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
assert response.status_code == http.client.OK
assert response.json() == {"flow_id": flow.id}


def test_flow_exists_not_exists(py_api: TestClient) -> None:
response = py_api.get("/flows/exists/foo/bar")
assert response.status_code == http.client.NOT_FOUND
assert response.json()["detail"] == "Flow not found."


def test_get_flow_no_subflow(py_api: TestClient) -> None:
response = py_api.get("/flows/1")
Expand Down
39 changes: 39 additions & 0 deletions tests/routers/openml/migration/flows_migration_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import http.client
from typing import Any

import deepdiff
Expand All @@ -9,6 +10,44 @@
)
from starlette.testclient import TestClient

from tests.conftest import Flow


@pytest.mark.mut()
@pytest.mark.php()
def test_flow_exists_not(
py_api: TestClient,
php_api: TestClient,
) -> None:
path = "exists/foo/bar"
py_response = py_api.get(f"/flows/{path}")
php_response = php_api.get(f"/flow/{path}")

assert py_response.status_code == http.client.NOT_FOUND
assert php_response.status_code == http.client.OK

expect_php = {"flow_exists": {"exists": "false", "id": str(-1)}}
assert php_response.json() == expect_php
assert py_response.json() == {"detail": "Flow not found."}


@pytest.mark.mut()
@pytest.mark.php()
def test_flow_exists(
persisted_flow: Flow,
py_api: TestClient,
php_api: TestClient,
) -> None:
path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}"
py_response = py_api.get(f"/flows/{path}")
php_response = php_api.get(f"/flow/{path}")

assert py_response.status_code == php_response.status_code, php_response.content

expect_php = {"flow_exists": {"exists": "true", "id": str(persisted_flow.id)}}
assert php_response.json() == expect_php
assert py_response.json() == {"flow_id": persisted_flow.id}


@pytest.mark.php()
@pytest.mark.parametrize(
Expand Down

0 comments on commit db6f73b

Please sign in to comment.