Skip to content

Commit

Permalink
move fixtures to conftest
Browse files Browse the repository at this point in the history
  • Loading branch information
jfredrickson committed Mar 21, 2023
1 parent 55890be commit 1feb0e2
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 80 deletions.
52 changes: 51 additions & 1 deletion training/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import List
import pytest
from training.database import engine
from sqlalchemy.orm import sessionmaker
from training import models
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy import event
from training.repositories import AgencyRepository


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -33,3 +36,50 @@ def mark_rolled_back(_):
if not already_rolled_back:
transaction.rollback()
connection.close()


@pytest.fixture
def valid_agency_names() -> List[str]:
'''
Provides a list of valid agency names.
'''
return [
"Department of Mysteries",
"Department of Magical Law Enforcement",
"Department of Magical Accidents and Catastrophes",
]


@pytest.fixture
def valid_agency_name(valid_agency_names: List[str]) -> str:
'''
Provides a valid agency name.
'''
return valid_agency_names[0]


@pytest.fixture
def db_with_data(db: Session, valid_agency_names: List[str]):
'''
Provides a populated database.
'''
for name in valid_agency_names:
db.add(models.Agency(name=name))
db.commit()
yield db


@pytest.fixture
def agency_repo(db: Session) -> AgencyRepository:
'''
Provides an AgencyRepository injected with an empty database.
'''
yield AgencyRepository(session=db)


@pytest.fixture
def agency_repo_with_data(db_with_data: Session) -> AgencyRepository:
'''
Provides an AgencyRepository injected with a populated database.
'''
yield AgencyRepository(session=db_with_data)
109 changes: 30 additions & 79 deletions training/tests/test_agency_repository.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,55 @@
from typing import List
import pytest
from training import schemas, models
from training.repositories.agency import AgencyRepository
from sqlalchemy.orm import Session


@pytest.fixture
def valid_names() -> List[str]:
'''
Provides a list of valid agency names.
'''
return [
"Department of Mysteries",
"Department of Magical Law Enforcement",
"Department of Magical Accidents and Catastrophes",
]


@pytest.fixture
def valid_name(valid_names: List[str]) -> str:
'''
Provides a valid agency name.
'''
return valid_names[0]


@pytest.fixture
def db_with_data(db: Session, valid_names: List[str]):
'''
Provides a populated database.
'''
for name in valid_names:
db.add(models.Agency(name=name))
db.commit()
yield db


@pytest.fixture
def repo(db: Session) -> AgencyRepository:
'''
Provides an AgencyRepository injected with an empty database.
'''
yield AgencyRepository(session=db)


@pytest.fixture
def repo_with_data(db_with_data: Session) -> AgencyRepository:
'''
Provides an AgencyRepository injected with a populated database.
'''
yield AgencyRepository(session=db_with_data)


def test_create(repo: AgencyRepository, valid_name):
result = repo.create(schemas.AgencyCreate(name=valid_name))
from training.repositories import AgencyRepository


def test_create(agency_repo: AgencyRepository, valid_agency_name):
result = agency_repo.create(schemas.AgencyCreate(name=valid_agency_name))
assert result.id
assert result.name == valid_name
assert result.name == valid_agency_name


def test_create_duplicate(repo_with_data: AgencyRepository, valid_name):
def test_create_duplicate(agency_repo_with_data: AgencyRepository, valid_agency_name):
with pytest.raises(Exception):
repo_with_data.create(schemas.AgencyCreate(name=valid_name))
agency_repo_with_data.create(schemas.AgencyCreate(name=valid_agency_name))


def test_find_by_name(repo_with_data: AgencyRepository, valid_name):
result = repo_with_data.find_by_name(valid_name)
assert result.name == valid_name
def test_find_by_name(agency_repo_with_data: AgencyRepository, valid_agency_name):
result = agency_repo_with_data.find_by_name(valid_agency_name)
assert result.name == valid_agency_name


def test_find_by_nonexistent_name(repo: AgencyRepository):
result = repo.find_by_name("Nonexistent Agency")
def test_find_by_nonexistent_name(agency_repo: AgencyRepository):
result = agency_repo.find_by_name("Nonexistent Agency")
assert result is None


def test_save(repo: AgencyRepository, valid_name):
result = repo.save(models.Agency(name=valid_name))
def test_save(agency_repo: AgencyRepository, valid_agency_name):
result = agency_repo.save(models.Agency(name=valid_agency_name))
assert result.id
assert repo.find_by_id(result.id).name == valid_name
assert agency_repo.find_by_id(result.id).name == valid_agency_name


def test_find_by_id(repo: AgencyRepository, valid_name):
agency_id = repo.save(models.Agency(name=valid_name)).id
result = repo.find_by_id(agency_id)
assert result.name == valid_name
def test_find_by_id(agency_repo: AgencyRepository, valid_agency_name):
agency_id = agency_repo.save(models.Agency(name=valid_agency_name)).id
result = agency_repo.find_by_id(agency_id)
assert result.name == valid_agency_name


def test_find_by_nonexistent_id(repo: AgencyRepository, valid_name):
agency_id = repo.save(models.Agency(name=valid_name)).id
result = repo.find_by_id(agency_id + 1)
def test_find_by_nonexistent_id(agency_repo: AgencyRepository, valid_agency_name):
agency_id = agency_repo.save(models.Agency(name=valid_agency_name)).id
result = agency_repo.find_by_id(agency_id + 1)
assert result is None


def test_find_all(repo_with_data: AgencyRepository, valid_names):
result = repo_with_data.find_all()
def test_find_all(agency_repo_with_data: AgencyRepository, valid_agency_names):
result = agency_repo_with_data.find_all()
names = list(map(lambda r: r.name, result))
for name in valid_names:
for name in valid_agency_names:
assert name in names


def test_delete_by_id(repo: AgencyRepository, valid_name):
agency_id = repo.save(models.Agency(name=valid_name)).id
repo.delete_by_id(agency_id)
assert repo.find_by_id(agency_id) is None
def test_delete_by_id(agency_repo: AgencyRepository, valid_agency_name):
agency_id = agency_repo.save(models.Agency(name=valid_agency_name)).id
agency_repo.delete_by_id(agency_id)
assert agency_repo.find_by_id(agency_id) is None

0 comments on commit 1feb0e2

Please sign in to comment.