Skip to content

Commit

Permalink
151 add pytest mock for amt (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurensWe authored Oct 23, 2024
2 parents 93543a5 + 0d7fcbc commit dba3cf7
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 170 deletions.
7 changes: 3 additions & 4 deletions amt/cli/check_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import sys
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -44,8 +43,8 @@ def get_tasks_by_priority(urns: list[str], system_card_path: Path) -> None:
click.echo(task.urn)
except yaml.YAMLError as error:
click.echo(f"Sorry, an error occurred; yaml could not be parsed: {error}", err=True)
sys.exit(1)
click.Abort()
except Exception as error:
click.echo(f"Sorry, an error occurred. {error}", err=True)
sys.exit(1)
sys.exit(0)
click.Abort()
return
184 changes: 108 additions & 76 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ sqlalchemy = "^2.0.36"
sqlalchemy-utils = "^0.41.2"
liccheck = "^0.9.2"
authlib = "^1.3.2"
pytest-mock = "^3.14.0"


[tool.poetry.group.test.dependencies]
Expand Down
6 changes: 3 additions & 3 deletions tests/api/routes/test_deps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Callable
from enum import Enum
from unittest.mock import Mock

import pytest
from amt.api.deps import (
Expand All @@ -16,10 +15,11 @@
from amt.core.internationalization import supported_translations
from amt.schema.localized_value_item import LocalizedValueItem
from fastapi import Request
from pytest_mock import MockerFixture


def test_custom_context_processor():
request: Request = Mock()
def test_custom_context_processor(mocker: MockerFixture):
request: Request = mocker.Mock()
request.cookies.get.return_value = "nl"
request.headers.get.return_value = "nl"
result = custom_context_processor(request)
Expand Down
5 changes: 3 additions & 2 deletions tests/api/routes/test_project.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections.abc import Generator
from typing import Any

import pytest
from amt.api.routes.project import set_path
from amt.models import Project
from fastapi.testclient import TestClient
from pytest_mock import MockFixture

from tests.constants import default_project, default_task
from tests.database_test_utils import DatabaseTestUtils
Expand Down Expand Up @@ -232,10 +232,11 @@ def test_get_project_cancel(client: TestClient, db: DatabaseTestUtils) -> None:
assert b"lifecycle" in response.content


def test_get_project_update(client: TestClient, mock_csrf: Generator[None, None, None], db: DatabaseTestUtils) -> None:
def test_get_project_update(client: TestClient, mocker: MockFixture, db: DatabaseTestUtils) -> None:
# given
db.given([default_project("testproject1")])
client.cookies["fastapi-csrf-token"] = "1"
mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock)

# when
response = client.put("/algorithm-system/1/update/name", json={"value": "Test Name"}, headers={"X-CSRF-Token": "1"})
Expand Down
47 changes: 25 additions & 22 deletions tests/api/routes/test_projects.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
from collections.abc import Generator
from typing import cast
from unittest.mock import Mock

import pytest
from amt.api.routes.projects import get_localized_value
from amt.models import Project
from amt.models.base import Base
from amt.schema.ai_act_profile import AiActProfile
from amt.schema.project import ProjectNew
from amt.schema.system_card import SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.task_registry import get_requirements_and_measures
from fastapi.requests import Request
from fastapi.testclient import TestClient
from fastapi_csrf_protect import CsrfProtect # type: ignore # noqa
from pytest_mock import MockFixture

from tests.constants import default_instrument
from tests.database_test_utils import DatabaseTestUtils


@pytest.fixture
def init_instruments() -> Generator[None, None, None]: # noqa: PT004
origin = InstrumentsService.fetch_instruments
InstrumentsService.fetch_instruments = Mock(
return_value=[default_instrument(urn="urn1", name="name1"), default_instrument(urn="urn2", name="name2")]
)
yield
InstrumentsService.fetch_instruments = origin


def test_projects_get_root(client: TestClient) -> None:
response = client.get("/algorithm-systems/")

Expand All @@ -50,7 +36,13 @@ def test_projects_get_root_htmx(client: TestClient) -> None:
assert b'<table id="search-results-table" class="rvo-table margin-top-large">' not in response.content


def test_get_new_projects(client: TestClient, init_instruments: Generator[None, None, None]) -> None:
def test_get_new_projects(client: TestClient, mocker: MockFixture) -> None:
# given
mocker.patch(
"amt.services.instruments.InstrumentsService.fetch_instruments",
return_value=[default_instrument(urn="urn1", name="name1"), default_instrument(urn="urn2", name="name2")],
)

# when
response = client.get("/algorithm-systems/new")
assert response.status_code == 200
Expand All @@ -67,7 +59,10 @@ def test_get_new_projects(client: TestClient, init_instruments: Generator[None,
)


def test_post_new_projects_bad_request(client: TestClient, mock_csrf: Generator[None, None, None]) -> None:
def test_post_new_projects_bad_request(client: TestClient, mocker: MockFixture) -> None:
# given
mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock)

# when
client.cookies["fastapi-csrf-token"] = "1"
response = client.post("/algorithm-systems/new", json={}, headers={"X-CSRF-Token": "1"})
Expand All @@ -78,9 +73,7 @@ def test_post_new_projects_bad_request(client: TestClient, mock_csrf: Generator[
assert b"name: Field required" in response.content


def test_post_new_projects(
client: TestClient, mock_csrf: Generator[None, None, None], init_instruments: Generator[None, None, None]
) -> None:
def test_post_new_projects(client: TestClient, mocker: MockFixture) -> None:
client.cookies["fastapi-csrf-token"] = "1"
new_project = ProjectNew(
name="default project",
Expand All @@ -92,6 +85,12 @@ def test_post_new_projects(
transparency_obligations="geen transparantieverplichtingen",
role="gebruiksverantwoordelijke",
)
# given
mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock)
mocker.patch(
"amt.services.instruments.InstrumentsService.fetch_instruments",
return_value=[default_instrument(urn="urn1", name="name1"), default_instrument(urn="urn2", name="name2")],
)

# when
response = client.post("/algorithm-systems/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"})
Expand All @@ -104,12 +103,16 @@ def test_post_new_projects(

def test_post_new_projects_write_system_card(
client: TestClient,
mock_csrf: Generator[None, None, None],
init_instruments: Generator[None, None, None],
mocker: MockFixture,
db: DatabaseTestUtils,
) -> None:
# Given
client.cookies["fastapi-csrf-token"] = "1"
mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock)
mocker.patch(
"amt.services.instruments.InstrumentsService.fetch_instruments",
return_value=[default_instrument(urn="urn1", name="name1"), default_instrument(urn="urn2", name="name2")],
)

name = "name1"
project_new = ProjectNew(
Expand Down
9 changes: 3 additions & 6 deletions tests/api/test_http_browser_caching.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from pathlib import Path
from typing import NamedTuple
from unittest.mock import Mock

import pytest
from amt.api import http_browser_caching
from amt.core.exceptions import AMTNotFound, AMTOnlyStatic
from pytest_mock import MockerFixture
from starlette.responses import Response


Expand All @@ -24,23 +24,20 @@ def test_url_for_cache_file_not_found():
http_browser_caching.url_for_cache("static", path="this/does/not/exist")


def test_url_for_cache_file_happy_flow(tmp_path: Path):
def test_url_for_cache_file_happy_flow(tmp_path: Path, mocker: MockerFixture):
class MockStatResult(NamedTuple):
st_mtime: int
st_size: int

lookup_path_orig = http_browser_caching.static_files.lookup_path
(tmp_path / "testfile").write_text("This is a test", encoding="utf-8")
http_browser_caching.static_files = http_browser_caching.StaticFilesCache(directory=Path(tmp_path))
http_browser_caching.static_files.lookup_path = Mock(os.stat_result, return_value=(None, MockStatResult(1, 2)))
mocker.patch("amt.api.http_browser_caching.static_files.lookup_path", return_value=(None, MockStatResult(1, 2)))
result = http_browser_caching.url_for_cache("static", path="testfile")
assert result == "/static/testfile?etag=98c6f2c2287f4c73cea3d40ae7ec3ff2"
# also test with a query param
result = http_browser_caching.url_for_cache("static", path="testfile?queryparam1=true")
assert result == "/static/testfile?queryparam1=true&etag=98c6f2c2287f4c73cea3d40ae7ec3ff2"

http_browser_caching.static_files.lookup_path = lookup_path_orig


def test_static_files_class_immutable(tmp_path: Path):
testfile = tmp_path / "testfile"
Expand Down
43 changes: 14 additions & 29 deletions tests/cli/test_check_state.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportAttributeAccessIssue=false
from pathlib import Path
from typing import Any
from unittest.mock import Mock

import amt.services.instruments_and_requirements_state
import pytest
from amt.cli.check_state import get_requested_instruments, get_tasks_by_priority
from amt.core.exceptions import AMTInstrumentError
from amt.schema.instrument import InstrumentTask
from amt.schema.system_card import SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.storage import FileSystemStorageService, StorageFactory
from click.testing import CliRunner
from pytest_mock import MockerFixture
from tests.constants import default_instrument
from yaml import YAMLError

Expand All @@ -37,17 +35,13 @@ def system_card(system_card_data: dict[str, Any]) -> SystemCard:
return system_card


def test_get_system_card(system_card: SystemCard, system_card_data: dict[str, Any]):
init_orig = StorageFactory.init

storage_mock = Mock()
storage_mock.read.return_value = system_card_data
StorageFactory.init = Mock(return_value=storage_mock)

def test_get_system_card(system_card: SystemCard, system_card_data: dict[str, Any], mocker: MockerFixture):
mocker.patch(
"amt.services.storage.StorageFactory.init",
return_value=mocker.Mock(read=mocker.Mock(return_value=system_card_data)),
)
assert system_card == amt.cli.check_state.get_system_card(Path("dummy"))

StorageFactory.init = init_orig


def test_get_requested_instruments():
instrument0 = default_instrument(urn="instrument0")
Expand All @@ -58,9 +52,7 @@ def test_get_requested_instruments():
assert expected == get_requested_instruments(all_instruments_cards, ["instrument0", "instrument2"])


def test_cli(capsys: pytest.CaptureFixture[str], system_card: SystemCard):
fetch_instruments_orig = InstrumentsService.fetch_instruments
get_system_card_orig = amt.cli.check_state.get_system_card
def test_cli(capsys: pytest.CaptureFixture[str], system_card: SystemCard, mocker: MockerFixture):
instrument = default_instrument(
urn="urn:instrument:assessment",
tasks=[
Expand All @@ -69,35 +61,28 @@ def test_cli(capsys: pytest.CaptureFixture[str], system_card: SystemCard):
],
)

return_value = [instrument]
InstrumentsService.fetch_instruments = Mock(return_value=return_value)
amt.cli.check_state.get_system_card = Mock(return_value=system_card)
mocker.patch("amt.services.instruments.InstrumentsService.fetch_instruments", return_value=[instrument])
mocker.patch("amt.cli.check_state.get_system_card", return_value=system_card)
runner = CliRunner()
# workaround for https://github.com/pallets/click/issues/824
with capsys.disabled() as _:
result = runner.invoke(get_tasks_by_priority, ["urn:instrument:assessment", "example/system_test_card.yaml"]) # type: ignore
assert "urn:instrument:assessment:task2" in result.output
InstrumentsService.fetch_instruments = fetch_instruments_orig
amt.cli.check_state.get_system_card = get_system_card_orig


def test_cli_with_exception(capsys: pytest.CaptureFixture[str], system_card: SystemCard):
fetch_instruments_orig = InstrumentsService.fetch_instruments
InstrumentsService.fetch_instruments = Mock(side_effect=AMTInstrumentError())
def test_cli_with_exception(capsys: pytest.CaptureFixture[str], system_card: SystemCard, mocker: MockerFixture):
mocker.patch("amt.services.instruments.InstrumentsService.fetch_instruments", side_effect=AMTInstrumentError())
runner = CliRunner()
# workaround for https://github.com/pallets/click/issues/824
with capsys.disabled() as _:
result = runner.invoke(get_tasks_by_priority, ["urn:instrument:assessment", "example/system_test_card.yaml"]) # type: ignore
assert "Sorry, an error occurre" in result.output
InstrumentsService.fetch_instruments = fetch_instruments_orig
assert "Sorry, an error occurred" in result.output


def test_cli_with_exception_yaml(capsys: pytest.CaptureFixture[str], system_card: SystemCard):
read_orig = FileSystemStorageService.read
FileSystemStorageService.read = Mock(side_effect=YAMLError("Test error message"))
def test_cli_with_exception_yaml(capsys: pytest.CaptureFixture[str], system_card: SystemCard, mocker: MockerFixture):
mocker.patch("amt.services.storage.FileSystemStorageService.read", side_effect=YAMLError("Test error message"))
runner = CliRunner()
# workaround for https://github.com/pallets/click/issues/824
with capsys.disabled() as _:
result = runner.invoke(get_tasks_by_priority, ["urn:instrument:assessment", "example/system_test_card.yaml"]) # type: ignore
assert "Sorry, an error occurred; yaml could not be parsed:" in result.output
FileSystemStorageService.read = read_orig
10 changes: 0 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from multiprocessing import Process
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock

import httpx
import pytest
import uvicorn
from amt.models.base import Base
from amt.server import create_app
from fastapi.testclient import TestClient
from fastapi_csrf_protect import CsrfProtect # type: ignore
from playwright.sync_api import Browser
from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -191,11 +189,3 @@ def db(

with Session(engine, expire_on_commit=False) as session:
yield DatabaseTestUtils(session, database_file)


@pytest.fixture
def mock_csrf() -> Generator[None, None, None]: # noqa: PT004
original = CsrfProtect.validate_csrf
CsrfProtect.validate_csrf = AsyncMock()
yield
CsrfProtect.validate_csrf = original
6 changes: 3 additions & 3 deletions tests/core/test_db.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import logging
from pathlib import Path
from unittest.mock import MagicMock

import pytest
from amt.core.db import (
check_db,
)
from pytest_mock import MockFixture
from sqlalchemy import select
from sqlalchemy.orm import Session

logger = logging.getLogger(__name__)


def test_check_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
def test_check_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, mocker: MockFixture):
database_file = tmp_path / "database.sqlite3"
monkeypatch.setenv("APP_DATABASE_FILE", str(database_file))
org_exec = Session.execute
Session.execute = MagicMock()
Session.execute = mocker.MagicMock()
check_db()

assert Session.execute.call_args is not None
Expand Down
Loading

0 comments on commit dba3cf7

Please sign in to comment.