Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add role and rule model
Browse files Browse the repository at this point in the history
berrydenhartog committed Jan 3, 2025
1 parent 8015d9a commit eea0702
Showing 23 changed files with 624 additions and 7 deletions.
37 changes: 37 additions & 0 deletions amt/api/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections.abc import Callable
from functools import wraps
from typing import Any

from fastapi import HTTPException, Request

from amt.core.exceptions import AMTPermissionDenied


# note: propably needs to change to understand all the api context to be usefull
def add_permissions(permissions: dict[str, list[str]]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
request = kwargs.get("request")
organization_id = kwargs.get("organization_id")
algoritme_id = kwargs.get("algoritme_id")
if not isinstance(request, Request): # todo: change exception to custom exception
raise HTTPException(status_code=400, detail="Request object is missing")

for permission, verbs in permissions.items():
permission = permission.format(organization_id=organization_id)
permission = permission.format(algoritme_id=algoritme_id)
request_permissions: dict[str, list[str]] = (
request.state.permissions if hasattr(request.state, "permissions") else {}
)
if permission not in request_permissions:
raise AMTPermissionDenied()
for verb in verbs:
if verb not in request.state.permissions[permission]:
raise AMTPermissionDenied()

return await func(*args, **kwargs)

return wrapper

return decorator
15 changes: 14 additions & 1 deletion amt/api/deps.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from amt.api.http_browser_caching import url_for_cache
from amt.api.localizable import LocalizableEnum
from amt.api.navigation import NavigationItem, get_main_menu
from amt.core.authorization import get_user
from amt.core.authorization import AuthorizationVerb, get_user
from amt.core.config import VERSION, get_settings
from amt.core.internationalization import (
format_datetime,
@@ -43,13 +43,16 @@ def custom_context_processor(
) -> dict[str, str | None | list[str] | dict[str, str] | list[NavigationItem] | type[WebFormFieldType]]:
lang = get_requested_language(request)
translations = get_current_translation(request)
permissions = getattr(request.state, "permissions", {})

return {
"version": VERSION,
"available_translations": list(supported_translations),
"language": lang,
"translations": get_dynamic_field_translations(lang),
"main_menu_items": get_main_menu(request, translations),
"user": get_user(request),
"permissions": permissions,
"WebFormFieldType": WebFormFieldType,
}

@@ -95,6 +98,15 @@ def nested_enum_value(obj: Any, attr_path: str, language: str) -> Any: # noqa:
return get_nested(obj, attr_path).localize(language)


def permission(permission: str, verb: AuthorizationVerb, permissions: dict[str, list[AuthorizationVerb]]) -> bool:
authorized = False

if permission in permissions and verb in permissions[permission]:
authorized = True

return authorized


# we use a custom override so we can add the translation per request, which is parsed in the Request object in kwargs
class LocaleJinja2Templates(Jinja2Templates):
def _create_env(
@@ -166,4 +178,5 @@ def instance(obj: Class, type_string: str) -> bool:
templates.env.globals.update(nested_enum=nested_enum) # pyright: ignore [reportUnknownMemberType]
templates.env.globals.update(nested_enum_value=nested_enum_value) # pyright: ignore [reportUnknownMemberType]
templates.env.globals.update(isinstance=instance) # pyright: ignore [reportUnknownMemberType]
templates.env.globals.update(permission=permission) # pyright: ignore [reportUnknownMemberType]
templates.env.add_extension("jinja2_base64_filters.Base64Filters") # pyright: ignore [reportUnknownMemberType]
23 changes: 23 additions & 0 deletions amt/core/authorization.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
from collections.abc import Iterable
from enum import StrEnum
from typing import Any

from starlette.requests import Request

from amt.core.internationalization import get_requested_language


class AuthorizationVerb(StrEnum):
LIST = "List"
READ = "Read"
CREATE = "Create"
UPDATE = "Update"
DELETE = "Delete"


class AuthorizationType(StrEnum):
ALGORITHM = "Algorithm"
ORGANIZATION = "Organization"


class AuthorizationResource(StrEnum):
ORGANIZATION_INFO = "organization/{organization_id}"
ORGANIZATION_ALGORITHM = "organization/{organization_id}/algorithm"
ORGANIZATION_MEMBER = "organization/{organization_id}/member"
ALGORITHM = "algoritme/{algoritme_id}"
ALGORITHM_SYSTEMCARD = "algoritme/{algoritme_id}/systemcard"
ALGORITHM_MEMBER = "algoritme/{algoritme_id}/user"


def get_user(request: Request) -> dict[str, Any] | None:
user = None
if isinstance(request.scope, Iterable) and "session" in request.scope:
6 changes: 6 additions & 0 deletions amt/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -79,6 +79,12 @@ def __init__(self) -> None:
super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail)


class AMTPermissionDenied(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("You do not have the correct permissions to access this resource.")
super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail)


class AMTStorageError(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("Something went wrong storing your file. PLease try again later.")
6 changes: 6 additions & 0 deletions amt/middleware/authorization.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from starlette.responses import RedirectResponse, Response

from amt.core.authorization import get_user
from amt.services.authorization import AuthorizationService

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]

@@ -18,7 +19,12 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
if request.url.path.startswith("/static/"):
return await call_next(request)

authorization_service = AuthorizationService()

user = get_user(request)

request.state.permissions = await authorization_service.find_by_user(user)

if user: # pragma: no cover
return await call_next(request)

113 changes: 113 additions & 0 deletions amt/migrations/versions/e16bb3d53cd6_authorization_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""authorization system
Revision ID: e16bb3d53cd6
Revises: 5de977ad946f
Create Date: 2024-12-23 08:32:15.194858
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op
from amt.core.authorization import AuthorizationResource, AuthorizationVerb, AuthorizationType
from sqlalchemy.orm.session import Session
from amt.models import User, Organization

# revision identifiers, used by Alembic.
revision: str = "e16bb3d53cd6"
down_revision: str | None = "5de977ad946f"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None



def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
role_table = op.create_table(
"role",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_role")),
)
rule_table = op.create_table(
"rule",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("resource", sa.String(), nullable=False),
sa.Column("verbs", sa.JSON(), nullable=False),
sa.Column("role_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_rule_role_id_role")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_rule")),
)

authorization_table = op.create_table(
"authorization",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("role_id", sa.Integer(), nullable=False),
sa.Column("type", sa.String(), nullable=False),
sa.Column("type_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_authorization_role_id_role")),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], name=op.f("fk_authorization_user_id_user")),
sa.PrimaryKeyConstraint("id", name=op.f("pk_authorization")),
)

op.bulk_insert(
role_table,
[
{'id': 1, 'name': 'Organization Maintainer'},
{'id': 2, 'name': 'Organization Member'},
{'id': 3, 'name': 'Organization Viewer'},
{'id': 4, 'name': 'Algorithm Maintainer'},
{'id': 5, 'name': 'Algorithm Member'},
{'id': 6, 'name': 'Algorithm Viewer'},
]
)

op.bulk_insert(
rule_table,
[
{'id': 1, 'resource': AuthorizationResource.ORGANIZATION_INFO, 'verbs': [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.UPDATE], 'role_id': 1},
{'id': 2, 'resource': AuthorizationResource.ORGANIZATION_ALGORITHM, 'verbs': [AuthorizationVerb.LIST, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE, AuthorizationVerb.DELETE], 'role_id': 1},
{'id': 3, 'resource': AuthorizationResource.ORGANIZATION_MEMBER, 'verbs': [AuthorizationVerb.LIST, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE, AuthorizationVerb.DELETE], 'role_id': 1},
{'id': 4, 'resource': AuthorizationResource.ORGANIZATION_INFO, 'verbs': [AuthorizationVerb.READ], 'role_id': 2},
{'id': 5, 'resource': AuthorizationResource.ORGANIZATION_ALGORITHM, 'verbs': [AuthorizationVerb.LIST, AuthorizationVerb.CREATE], 'role_id': 2},
{'id': 6, 'resource': AuthorizationResource.ORGANIZATION_MEMBER, 'verbs': [AuthorizationVerb.LIST], 'role_id': 2},
{'id': 7, 'resource': AuthorizationResource.ORGANIZATION_INFO, 'verbs': [AuthorizationVerb.READ], 'role_id': 3},
{'id': 8, 'resource': AuthorizationResource.ORGANIZATION_ALGORITHM, 'verbs': [AuthorizationVerb.LIST], 'role_id': 3},
{'id': 9, 'resource': AuthorizationResource.ORGANIZATION_MEMBER, 'verbs': [AuthorizationVerb.LIST], 'role_id': 3},
{'id': 10, 'resource': AuthorizationResource.ALGORITHM, 'verbs': [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.DELETE], 'role_id': 4},
{'id': 11, 'resource': AuthorizationResource.ALGORITHM_SYSTEMCARD, 'verbs': [AuthorizationVerb.READ, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE], 'role_id': 4},
{'id': 12, 'resource': AuthorizationResource.ALGORITHM_MEMBER, 'verbs': [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.UPDATE, AuthorizationVerb.DELETE], 'role_id': 4},
{'id': 13, 'resource': AuthorizationResource.ALGORITHM, 'verbs': [AuthorizationVerb.READ, AuthorizationVerb.CREATE], 'role_id': 5},
{'id': 14, 'resource': AuthorizationResource.ALGORITHM_SYSTEMCARD, 'verbs': [AuthorizationVerb.READ, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE], 'role_id': 5},
{'id': 15, 'resource': AuthorizationResource.ALGORITHM_MEMBER, 'verbs': [AuthorizationVerb.READ], 'role_id': 5},
{'id': 16, 'resource': AuthorizationResource.ALGORITHM, 'verbs': [AuthorizationVerb.READ], 'role_id': 6},
{'id': 17, 'resource': AuthorizationResource.ALGORITHM_SYSTEMCARD, 'verbs': [AuthorizationVerb.READ], 'role_id': 6},
{'id': 18, 'resource': AuthorizationResource.ALGORITHM_MEMBER, 'verbs': [AuthorizationVerb.READ], 'role_id': 6},
]
)

session = Session(bind=op.get_bind())

first_user = session.query(User).first() # first user is always present due to other migration
organizations = session.query(Organization).all()

authorizations = []
# lets add user 1 to all organizations bij default
for organization in organizations:
authorizations.append(
{'id': 1, 'user_id': first_user.id, 'role_id': 1, 'type': AuthorizationType.ORGANIZATION, 'type_id': organization.id},
)

op.bulk_insert(
authorization_table,
authorizations
)



def downgrade() -> None:
op.drop_table("rule")
op.drop_table("authorization")
op.drop_table("role")
5 changes: 4 additions & 1 deletion amt/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .algorithm import Algorithm
from .authorization import Authorization
from .organization import Organization
from .role import Role
from .rule import Rule
from .task import Task
from .user import User

__all__ = ["Algorithm", "Organization", "Task", "User"]
__all__ = ["Algorithm", "Authorization", "Organization", "Role", "Rule", "Task", "User"]
15 changes: 15 additions & 0 deletions amt/models/authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship

from amt.models.base import Base


class Authorization(Base):
__tablename__ = "authorization"

id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
user: Mapped["User"] = relationship(back_populates="authorizations") # pyright: ignore [reportUndefinedVariable, reportUnknownVariableType] #noqa
role_id: Mapped[int] = mapped_column(ForeignKey("role.id"))
type: Mapped[str] # type [Organization or Algorithm]
type_id: Mapped[int] # ID of the organization or algorithm
15 changes: 15 additions & 0 deletions amt/models/role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from amt.models import Authorization
from amt.models.base import Base
from amt.models.rule import Rule


class Role(Base):
__tablename__ = "role"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, nullable=False)
rules: Mapped[list["Rule"]] = relationship()
authorizations: Mapped[list["Authorization"]] = relationship()
14 changes: 14 additions & 0 deletions amt/models/rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.types import JSON

from amt.models.base import Base


class Rule(Base):
__tablename__ = "rule"

id: Mapped[int] = mapped_column(primary_key=True)
resource: Mapped[str] = mapped_column(String, nullable=False)
verbs: Mapped[list[str]] = mapped_column(JSON, default=list)
role_id: Mapped[int] = mapped_column(ForeignKey("role.id"))
3 changes: 2 additions & 1 deletion amt/models/user.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from sqlalchemy import UUID as SQLAlchemyUUID
from sqlalchemy.orm import Mapped, mapped_column, relationship

from amt.models import Organization
from amt.models import Authorization, Organization
from amt.models.base import Base


@@ -19,3 +19,4 @@ class User(Base):
"Organization", secondary="users_and_organizations", back_populates="users", lazy="selectin"
)
organizations_created: Mapped[list["Organization"]] = relationship(back_populates="created_by", lazy="selectin")
authorizations: Mapped[list["Authorization"]] = relationship("Authorization", back_populates="user")
50 changes: 50 additions & 0 deletions amt/repositories/authorizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
from uuid import UUID

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from amt.core.authorization import AuthorizationVerb
from amt.models import Authorization, Role, Rule
from amt.repositories.deps import get_session_non_generator

logger = logging.getLogger(__name__)

PermissionTuple = tuple[str, list[AuthorizationVerb], str, int]
PermissionsList = list[PermissionTuple]


class AuthorizationRepository:
"""
The AuthorizationRepository provides access to the repository layer.
"""

def __init__(self, session: AsyncSession | None = None) -> None:
self.session = session

async def init_session(self) -> None:
if self.session is None:
self.session = await get_session_non_generator()

async def find_by_user(self, user: UUID) -> PermissionsList | None:
"""
Returns all authorization for a user.
:return: all authorization for the user
"""
await self.init_session()

statement = (
select(
Rule.resource,
Rule.verbs,
Authorization.type,
Authorization.type_id,
)
.join(Role, Rule.role_id == Role.id)
.join(Authorization, Rule.role_id == Authorization.role_id)
.filter(Authorization.user_id == user)
)

result = await self.session.execute(statement) # type: ignore
authorizations = result.all()
return authorizations # type: ignore
9 changes: 9 additions & 0 deletions amt/repositories/deps.py
Original file line number Diff line number Diff line change
@@ -13,3 +13,12 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
)
async with async_session_factory() as async_session:
yield async_session


async def get_session_non_generator() -> AsyncSession:
async_session_factory = async_sessionmaker(
get_engine(),
expire_on_commit=False,
class_=AsyncSession,
)
return async_session_factory()
7 changes: 7 additions & 0 deletions amt/schema/permission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from amt.core.authorization import AuthorizationVerb
from amt.schema.shared import BaseModel


class Permission(BaseModel):
resource: str
verb: list[AuthorizationVerb]
43 changes: 43 additions & 0 deletions amt/services/authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import contextlib
from typing import Any
from uuid import UUID

from amt.core.authorization import AuthorizationType, AuthorizationVerb
from amt.repositories.authorizations import AuthorizationRepository
from amt.schema.permission import Permission

PermissionTuple = tuple[str, list[AuthorizationVerb], str, int]
PermissionsList = list[PermissionTuple]


class AuthorizationService:
def __init__(self) -> None:
self.repository = AuthorizationRepository()

async def find_by_user(self, user: dict[str, Any] | None) -> dict[str, list[AuthorizationVerb]]:
if not user:
return {}
else:
permissions: dict[str, list[AuthorizationVerb]] = {}

uuid = UUID(user["sub"])
authorizations: PermissionsList = await self.repository.find_by_user(uuid) # type: ignore
for auth in authorizations:
auth_dict: dict[str, int] = {"organization_id": -1, "algoritme_id": -1}

if auth[2] == AuthorizationType.ORGANIZATION:
auth_dict["organization_id"] = auth[3]

if auth[2] == AuthorizationType.ALGORITHM:
auth_dict["algoritme_id"] = auth[3]

resource: str = auth[0]
verbs: list[AuthorizationVerb] = auth[1]
with contextlib.suppress(Exception):
resource = resource.format(**auth_dict)

permission: Permission = Permission(resource=resource, verb=verbs)

permissions.update({permission.resource: permission.verb})

return permissions
1 change: 1 addition & 0 deletions tests/api/routes/test_deps.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ def test_custom_context_processor(mocker: MockerFixture):
"translations",
"main_menu_items",
"user",
"permissions",
"WebFormFieldType",
]
assert result["version"] == VERSION
91 changes: 91 additions & 0 deletions tests/api/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json
import typing

from amt.api.decorators import add_permissions
from amt.core.authorization import AuthorizationResource, AuthorizationVerb
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from starlette.responses import Response

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]

app = FastAPI()


@app.get("/unauthorized")
@add_permissions(permissions={"algoritme/1": [AuthorizationVerb.CREATE]})
async def unauthorized(request: Request):
return {"message": "Hello World"}


@app.get("/authorized")
@add_permissions(permissions={})
async def authorized(request: Request):
return {"message": "Hello World"}


@app.get("/norequest")
@add_permissions(permissions={})
async def norequest():
return {"message": "Hello World"}


@app.get("/authorizedparameters/{organization_id}")
@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO: [AuthorizationVerb.CREATE]})
async def authorizedparameters(request: Request, organization_id: int):
return {"message": "Hello World"}


@app.middleware("http")
async def add_authorizations(request: Request, call_next: RequestResponseEndpoint):
if "X-Permissions" in request.headers:
request.state.permissions = json.loads(request.headers["X-Permissions"])
return await call_next(request)


def test_permission_decorator_norequest():
client = TestClient(app, base_url="https://testserver")
response = client.get("/norequest")
assert response.status_code == 400


def test_permission_decorator_unauthorized():
client = TestClient(app, base_url="https://testserver")
response = client.get("/unauthorized")
assert response.status_code == 401


def test_permission_decorator_authorized():
client = TestClient(app, base_url="https://testserver")
response = client.get(
"/authorized",
)
assert response.status_code == 200


def test_permission_decorator_authorized_permission():
client = TestClient(app, base_url="https://testserver")

response = client.get("/unauthorized", headers={"X-Permissions": '{"algoritme/1": ["Create"]}'})
assert response.status_code == 200


def test_permission_decorator_authorized_permission_missing():
client = TestClient(app, base_url="https://testserver")

response = client.get("/unauthorized", headers={"X-Permissions": '{"algoritme/1": ["Read"]}'})
assert response.status_code == 401


def test_permission_decorator_authorized_permission_variable():
client = TestClient(app, base_url="https://testserver")

response = client.get("/authorizedparameters/1", headers={"X-Permissions": '{"organization/1": ["Create"]}'})
assert response.status_code == 200


def test_permission_decorator_unauthorized_permission_variable():
client = TestClient(app, base_url="https://testserver")

response = client.get("/authorizedparameters/4453546", headers={"X-Permissions": '{"organization/1": ["Create"]}'})
assert response.status_code == 401
45 changes: 44 additions & 1 deletion tests/api/test_deps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
from amt.api.deps import custom_context_processor
from amt.api.deps import custom_context_processor, permission
from amt.core.authorization import AuthorizationVerb

from tests.constants import default_fastapi_request

example_permissions = {
"organization/1": [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.UPDATE],
"organization/1/algorithm": [
AuthorizationVerb.LIST,
AuthorizationVerb.CREATE,
AuthorizationVerb.UPDATE,
AuthorizationVerb.DELETE,
],
"organization/1/member": [
AuthorizationVerb.LIST,
AuthorizationVerb.CREATE,
AuthorizationVerb.UPDATE,
AuthorizationVerb.DELETE,
],
"algoritme/1": [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.DELETE],
"algoritme/1/systemcard": [AuthorizationVerb.READ, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE],
"algoritme/1/user": [
AuthorizationVerb.CREATE,
AuthorizationVerb.READ,
AuthorizationVerb.UPDATE,
AuthorizationVerb.DELETE,
],
}


def test_custom_context_processor():
result = custom_context_processor(default_fastapi_request())
@@ -10,3 +35,21 @@ def test_custom_context_processor():
assert result["available_translations"] == ["en", "nl"]
assert result["language"] == "en"
assert result["translations"] is None


def test_permissions_false():
result = permission("organization/1", AuthorizationVerb.LIST, example_permissions)

assert result is False


def test_permissions_true():
result = permission("organization/1", AuthorizationVerb.READ, example_permissions)

assert result is True


def test_permissions_none_existing_resource():
result = permission("badfadfb/1", AuthorizationVerb.READ, example_permissions)

assert result is False
24 changes: 23 additions & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,8 @@

from amt.api.lifecycles import Lifecycles
from amt.api.navigation import BaseNavigationItem, DisplayText
from amt.models import Algorithm, Organization, Task, User
from amt.core.authorization import AuthorizationResource, AuthorizationVerb
from amt.models import Algorithm, Authorization, Organization, Role, Rule, Task, User
from amt.schema.instrument import Instrument, InstrumentTask, Owner
from amt.schema.system_card import SystemCard
from fastapi import Request
@@ -56,6 +57,27 @@ def default_organization(name: str = "default organization", slug: str = "defaul
return Organization(name=name, slug=slug, created_by_id=UUID(default_auth_user()["sub"]))


def default_rule() -> Rule:
return Rule(
resource=AuthorizationResource.ORGANIZATION_INFO,
verbs=[AuthorizationVerb.CREATE, AuthorizationVerb.READ],
role_id=1,
)


def default_role() -> Role:
return Role(name="default role")


def default_authorization() -> Authorization:
return Authorization(
user_id=UUID(default_auth_user()["sub"]),
role_id=1,
type="Organization",
type_id=1,
)


def default_user(
id: str | UUID | None = None,
name: str | None = None,
24 changes: 23 additions & 1 deletion tests/core/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from gettext import NullTranslations

import pytest
from amt.core.exceptions import AMTCSRFProtectError, AMTError, AMTInstrumentError, AMTNotFound, AMTSettingsError
from amt.core.exceptions import (
AMTAuthorizationFlowError,
AMTCSRFProtectError,
AMTError,
AMTInstrumentError,
AMTNotFound,
AMTPermissionDenied,
AMTSettingsError,
)


def test_settings_error():
@@ -44,3 +52,17 @@ def test_AMTCSRFProtectError():
raise AMTCSRFProtectError()

assert exc_info.value.detail == "CSRF check failed."


def test_AMTPermissionDenied():
with pytest.raises(AMTPermissionDenied) as exc_info:
raise AMTPermissionDenied()

assert exc_info.value.detail == "You do not have the correct permissions to access this resource."


def test_AMTAuthorizationFlowError():
with pytest.raises(AMTAuthorizationFlowError) as exc_info:
raise AMTAuthorizationFlowError()

assert exc_info.value.detail == "Something went wrong during the authorization flow. Please try again later."
39 changes: 39 additions & 0 deletions tests/repositories/test_authorizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from uuid import UUID

import pytest
from amt.core.authorization import AuthorizationResource, AuthorizationType, AuthorizationVerb
from amt.repositories.authorizations import AuthorizationRepository
from tests.constants import (
default_algorithm,
default_auth_user,
default_authorization,
default_role,
default_rule,
default_user,
)
from tests.database_test_utils import DatabaseTestUtils


@pytest.mark.asyncio
async def test_authorization_basic(db: DatabaseTestUtils):
await db.given(
[
default_user(),
default_algorithm(),
default_role(),
default_rule(),
default_authorization(),
]
)

authorization_repository = AuthorizationRepository(session=db.session)
results = await authorization_repository.find_by_user(UUID(default_auth_user()["sub"]))

assert results == [
(
AuthorizationResource.ORGANIZATION_INFO,
[AuthorizationVerb.CREATE, AuthorizationVerb.READ],
AuthorizationType.ORGANIZATION,
1,
)
]
8 changes: 7 additions & 1 deletion tests/repositories/test_deps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from amt.repositories.deps import get_session
from amt.repositories.deps import get_session, get_session_non_generator
from sqlalchemy.ext.asyncio import AsyncSession


@@ -14,3 +14,9 @@ async def test_get_session():
await session_generator.aclose()
except StopAsyncIteration:
pass


@pytest.mark.asyncio
async def test_get_session_non_generator():
session = await get_session_non_generator()
assert isinstance(session, AsyncSession)
38 changes: 38 additions & 0 deletions tests/services/test_authorization_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from amt.core.authorization import AuthorizationResource, AuthorizationType, AuthorizationVerb
from amt.repositories.authorizations import AuthorizationRepository
from amt.services.authorization import AuthorizationService
from pytest_mock import MockFixture
from tests.constants import default_auth_user


@pytest.mark.asyncio
async def test_authorization_get_auth_read(mocker: MockFixture):
# Given

authorization_service = AuthorizationService()
authorization_service.repository = mocker.AsyncMock(spec=AuthorizationRepository)
authorization_service.repository.find_by_user.return_value = [
(AuthorizationResource.ORGANIZATION_INFO, [AuthorizationVerb.READ], AuthorizationType.ORGANIZATION, 1)
]

# When
authorizations = await authorization_service.find_by_user(default_auth_user())

# Then
assert authorizations == {"organization/1": [AuthorizationVerb.READ]}


@pytest.mark.asyncio
async def test_authorization_get_auth_none(mocker: MockFixture):
# Given

authorization_service = AuthorizationService()
authorization_service.repository = mocker.AsyncMock(spec=AuthorizationRepository)
authorization_service.repository.find_by_user.return_value = []

# When
authorizations = await authorization_service.find_by_user(default_auth_user())

# Then
assert authorizations == {}

0 comments on commit eea0702

Please sign in to comment.