From 96013d42f838d04f25c9ecba808bdba107d294a0 Mon Sep 17 00:00:00 2001 From: Berry den Hartog <38954346+berrydenhartog@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:21:39 +0100 Subject: [PATCH] Add role and rule model --- amt/api/decorators.py | 36 ++++++ amt/api/deps.py | 16 ++- amt/core/authorization.py | 23 ++++ amt/core/exceptions.py | 6 + amt/middleware/authorization.py | 6 + .../e16bb3d53cd6_authorization_system.py | 113 ++++++++++++++++++ amt/models/__init__.py | 5 +- amt/models/authorization.py | 15 +++ amt/models/role.py | 15 +++ amt/models/rule.py | 14 +++ amt/models/user.py | 3 +- amt/repositories/authorizations.py | 50 ++++++++ amt/repositories/deps.py | 9 ++ amt/schema/permission.py | 7 ++ amt/services/authorization.py | 43 +++++++ tests/api/routes/test_deps.py | 1 + tests/api/test_decorator.py | 91 ++++++++++++++ tests/api/test_deps.py | 45 ++++++- tests/constants.py | 24 +++- tests/core/test_exceptions.py | 24 +++- tests/repositories/test_authorizations.py | 39 ++++++ tests/repositories/test_deps.py | 8 +- tests/services/test_authorization_service.py | 38 ++++++ .../templates/permission_example.html.j2 | 5 + .../templates/test_template_permission.py | 38 ++++++ 25 files changed, 667 insertions(+), 7 deletions(-) create mode 100644 amt/api/decorators.py create mode 100644 amt/migrations/versions/e16bb3d53cd6_authorization_system.py create mode 100644 amt/models/authorization.py create mode 100644 amt/models/role.py create mode 100644 amt/models/rule.py create mode 100644 amt/repositories/authorizations.py create mode 100644 amt/schema/permission.py create mode 100644 amt/services/authorization.py create mode 100644 tests/api/test_decorator.py create mode 100644 tests/repositories/test_authorizations.py create mode 100644 tests/services/test_authorization_service.py create mode 100644 tests/site/static/templates/permission_example.html.j2 create mode 100644 tests/site/static/templates/test_template_permission.py diff --git a/amt/api/decorators.py b/amt/api/decorators.py new file mode 100644 index 00000000..c4fbe22e --- /dev/null +++ b/amt/api/decorators.py @@ -0,0 +1,36 @@ +from collections.abc import Callable +from functools import wraps +from typing import Any + +from fastapi import HTTPException, Request + +from amt.core.exceptions import AMTPermissionDenied + + +def add_permissions(permissions: dict[str, list[str]]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + request = kwargs.get("request") + organization_id = kwargs.get("organization_id") + algoritme_id = kwargs.get("algoritme_id") + if not isinstance(request, Request): # todo: change exception to custom exception + raise HTTPException(status_code=400, detail="Request object is missing") + + for permission, verbs in permissions.items(): + permission = permission.format(organization_id=organization_id) + permission = permission.format(algoritme_id=algoritme_id) + request_permissions: dict[str, list[str]] = ( + request.state.permissions if hasattr(request.state, "permissions") else {} + ) + if permission not in request_permissions: + raise AMTPermissionDenied() + for verb in verbs: + if verb not in request.state.permissions[permission]: + raise AMTPermissionDenied() + + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/amt/api/deps.py b/amt/api/deps.py index 02076c84..67cdae7c 100644 --- a/amt/api/deps.py +++ b/amt/api/deps.py @@ -16,7 +16,7 @@ from amt.api.http_browser_caching import url_for_cache from amt.api.localizable import LocalizableEnum from amt.api.navigation import NavigationItem, get_main_menu -from amt.core.authorization import get_user +from amt.core.authorization import AuthorizationVerb, get_user from amt.core.config import VERSION, get_settings from amt.core.internationalization import ( format_datetime, @@ -43,6 +43,8 @@ def custom_context_processor( ) -> dict[str, str | None | list[str] | dict[str, str] | list[NavigationItem] | type[WebFormFieldType]]: lang = get_requested_language(request) translations = get_current_translation(request) + permissions = getattr(request.state, "permissions", {}) + return { "version": VERSION, "available_translations": list(supported_translations), @@ -50,6 +52,7 @@ def custom_context_processor( "translations": get_dynamic_field_translations(lang), "main_menu_items": get_main_menu(request, translations), "user": get_user(request), + "permissions": permissions, "WebFormFieldType": WebFormFieldType, } @@ -95,6 +98,15 @@ def nested_enum_value(obj: Any, attr_path: str, language: str) -> Any: # noqa: return get_nested(obj, attr_path).localize(language) +def permission(permission: str, verb: AuthorizationVerb, permissions: dict[str, list[AuthorizationVerb]]) -> bool: + authorized = False + + if permission in permissions and verb in permissions[permission]: + authorized = True + + return authorized + + # we use a custom override so we can add the translation per request, which is parsed in the Request object in kwargs class LocaleJinja2Templates(Jinja2Templates): def _create_env( @@ -166,4 +178,6 @@ def instance(obj: Class, type_string: str) -> bool: templates.env.globals.update(nested_enum=nested_enum) # pyright: ignore [reportUnknownMemberType] templates.env.globals.update(nested_enum_value=nested_enum_value) # pyright: ignore [reportUnknownMemberType] templates.env.globals.update(isinstance=instance) # pyright: ignore [reportUnknownMemberType] +templates.env.globals.update(permission=permission) # pyright: ignore [reportUnknownMemberType] +templates.env.tests["permission"] = permission # pyright: ignore [reportUnknownMemberType] templates.env.add_extension("jinja2_base64_filters.Base64Filters") # pyright: ignore [reportUnknownMemberType] diff --git a/amt/core/authorization.py b/amt/core/authorization.py index 15956894..9756361f 100644 --- a/amt/core/authorization.py +++ b/amt/core/authorization.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from enum import StrEnum from typing import Any from starlette.requests import Request @@ -6,6 +7,28 @@ from amt.core.internationalization import get_requested_language +class AuthorizationVerb(StrEnum): + LIST = "List" + READ = "Read" + CREATE = "Create" + UPDATE = "Update" + DELETE = "Delete" + + +class AuthorizationType(StrEnum): + ALGORITHM = "Algorithm" + ORGANIZATION = "Organization" + + +class AuthorizationResource(StrEnum): + ORGANIZATION_INFO = "organization/{organization_id}" + ORGANIZATION_ALGORITHM = "organization/{organization_id}/algorithm" + ORGANIZATION_MEMBER = "organization/{organization_id}/member" + ALGORITHM = "algoritme/{algoritme_id}" + ALGORITHM_SYSTEMCARD = "algoritme/{algoritme_id}/systemcard" + ALGORITHM_MEMBER = "algoritme/{algoritme_id}/user" + + def get_user(request: Request) -> dict[str, Any] | None: user = None if isinstance(request.scope, Iterable) and "session" in request.scope: diff --git a/amt/core/exceptions.py b/amt/core/exceptions.py index 9f5e18fe..1a39cae9 100644 --- a/amt/core/exceptions.py +++ b/amt/core/exceptions.py @@ -79,6 +79,12 @@ def __init__(self) -> None: super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail) +class AMTPermissionDenied(AMTHTTPException): + def __init__(self) -> None: + self.detail: str = _("You do not have the correct permissions to access this resource.") + super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail) + + class AMTStorageError(AMTHTTPException): def __init__(self) -> None: self.detail: str = _("Something went wrong storing your file. PLease try again later.") diff --git a/amt/middleware/authorization.py b/amt/middleware/authorization.py index 58fb06c4..6b666396 100644 --- a/amt/middleware/authorization.py +++ b/amt/middleware/authorization.py @@ -6,6 +6,7 @@ from starlette.responses import RedirectResponse, Response from amt.core.authorization import get_user +from amt.services.authorization import AuthorizationService RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] @@ -18,7 +19,12 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - if request.url.path.startswith("/static/"): return await call_next(request) + authorization_service = AuthorizationService() + user = get_user(request) + + request.state.permissions = await authorization_service.find_by_user(user) + if user: # pragma: no cover return await call_next(request) diff --git a/amt/migrations/versions/e16bb3d53cd6_authorization_system.py b/amt/migrations/versions/e16bb3d53cd6_authorization_system.py new file mode 100644 index 00000000..a32cea96 --- /dev/null +++ b/amt/migrations/versions/e16bb3d53cd6_authorization_system.py @@ -0,0 +1,113 @@ +"""authorization system + +Revision ID: e16bb3d53cd6 +Revises: 5de977ad946f +Create Date: 2024-12-23 08:32:15.194858 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from amt.core.authorization import AuthorizationResource, AuthorizationVerb, AuthorizationType +from sqlalchemy.orm.session import Session +from amt.models import User, Organization + +# revision identifiers, used by Alembic. +revision: str = "e16bb3d53cd6" +down_revision: str | None = "5de977ad946f" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + role_table = op.create_table( + "role", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_role")), + ) + rule_table = op.create_table( + "rule", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("resource", sa.String(), nullable=False), + sa.Column("verbs", sa.JSON(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_rule_role_id_role")), + sa.PrimaryKeyConstraint("id", name=op.f("pk_rule")), + ) + + authorization_table = op.create_table( + "authorization", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.Column("type", sa.String(), nullable=False), + sa.Column("type_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["role_id"], ["role.id"], name=op.f("fk_authorization_role_id_role")), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], name=op.f("fk_authorization_user_id_user")), + sa.PrimaryKeyConstraint("id", name=op.f("pk_authorization")), + ) + + op.bulk_insert( + role_table, + [ + {'id': 1, 'name': 'Organization Maintainer'}, + {'id': 2, 'name': 'Organization Member'}, + {'id': 3, 'name': 'Organization Viewer'}, + {'id': 4, 'name': 'Algorithm Maintainer'}, + {'id': 5, 'name': 'Algorithm Member'}, + {'id': 6, 'name': 'Algorithm Viewer'}, + ] + ) + + op.bulk_insert( + rule_table, + [ + {'id': 1, 'resource': AuthorizationResource.ORGANIZATION_INFO, 'verbs': [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.UPDATE], 'role_id': 1}, + {'id': 2, 'resource': AuthorizationResource.ORGANIZATION_ALGORITHM, 'verbs': [AuthorizationVerb.LIST, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE, AuthorizationVerb.DELETE], 'role_id': 1}, + {'id': 3, 'resource': AuthorizationResource.ORGANIZATION_MEMBER, 'verbs': [AuthorizationVerb.LIST, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE, AuthorizationVerb.DELETE], 'role_id': 1}, + {'id': 4, 'resource': AuthorizationResource.ORGANIZATION_INFO, 'verbs': [AuthorizationVerb.READ], 'role_id': 2}, + {'id': 5, 'resource': AuthorizationResource.ORGANIZATION_ALGORITHM, 'verbs': [AuthorizationVerb.LIST, AuthorizationVerb.CREATE], 'role_id': 2}, + {'id': 6, 'resource': AuthorizationResource.ORGANIZATION_MEMBER, 'verbs': [AuthorizationVerb.LIST], 'role_id': 2}, + {'id': 7, 'resource': AuthorizationResource.ORGANIZATION_INFO, 'verbs': [AuthorizationVerb.READ], 'role_id': 3}, + {'id': 8, 'resource': AuthorizationResource.ORGANIZATION_ALGORITHM, 'verbs': [AuthorizationVerb.LIST], 'role_id': 3}, + {'id': 9, 'resource': AuthorizationResource.ORGANIZATION_MEMBER, 'verbs': [AuthorizationVerb.LIST], 'role_id': 3}, + {'id': 10, 'resource': AuthorizationResource.ALGORITHM, 'verbs': [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.DELETE], 'role_id': 4}, + {'id': 11, 'resource': AuthorizationResource.ALGORITHM_SYSTEMCARD, 'verbs': [AuthorizationVerb.READ, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE], 'role_id': 4}, + {'id': 12, 'resource': AuthorizationResource.ALGORITHM_MEMBER, 'verbs': [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.UPDATE, AuthorizationVerb.DELETE], 'role_id': 4}, + {'id': 13, 'resource': AuthorizationResource.ALGORITHM, 'verbs': [AuthorizationVerb.READ, AuthorizationVerb.CREATE], 'role_id': 5}, + {'id': 14, 'resource': AuthorizationResource.ALGORITHM_SYSTEMCARD, 'verbs': [AuthorizationVerb.READ, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE], 'role_id': 5}, + {'id': 15, 'resource': AuthorizationResource.ALGORITHM_MEMBER, 'verbs': [AuthorizationVerb.READ], 'role_id': 5}, + {'id': 16, 'resource': AuthorizationResource.ALGORITHM, 'verbs': [AuthorizationVerb.READ], 'role_id': 6}, + {'id': 17, 'resource': AuthorizationResource.ALGORITHM_SYSTEMCARD, 'verbs': [AuthorizationVerb.READ], 'role_id': 6}, + {'id': 18, 'resource': AuthorizationResource.ALGORITHM_MEMBER, 'verbs': [AuthorizationVerb.READ], 'role_id': 6}, + ] + ) + + session = Session(bind=op.get_bind()) + + first_user = session.query(User).first() # first user is always present due to other migration + organizations = session.query(Organization).all() + + authorizations = [] + # lets add user 1 to all organizations bij default + for organization in organizations: + authorizations.append( + {'id': 1, 'user_id': first_user.id, 'role_id': 1, 'type': AuthorizationType.ORGANIZATION, 'type_id': organization.id}, + ) + + op.bulk_insert( + authorization_table, + authorizations + ) + + + +def downgrade() -> None: + op.drop_table("rule") + op.drop_table("authorization") + op.drop_table("role") diff --git a/amt/models/__init__.py b/amt/models/__init__.py index 37777d68..254f22c2 100644 --- a/amt/models/__init__.py +++ b/amt/models/__init__.py @@ -1,6 +1,9 @@ from .algorithm import Algorithm +from .authorization import Authorization from .organization import Organization +from .role import Role +from .rule import Rule from .task import Task from .user import User -__all__ = ["Algorithm", "Organization", "Task", "User"] +__all__ = ["Algorithm", "Authorization", "Organization", "Role", "Rule", "Task", "User"] diff --git a/amt/models/authorization.py b/amt/models/authorization.py new file mode 100644 index 00000000..80ae8dc9 --- /dev/null +++ b/amt/models/authorization.py @@ -0,0 +1,15 @@ +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from amt.models.base import Base + + +class Authorization(Base): + __tablename__ = "authorization" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + user: Mapped["User"] = relationship(back_populates="authorizations") # pyright: ignore [reportUndefinedVariable, reportUnknownVariableType] #noqa + role_id: Mapped[int] = mapped_column(ForeignKey("role.id")) + type: Mapped[str] # type [Organization or Algorithm] + type_id: Mapped[int] # ID of the organization or algorithm diff --git a/amt/models/role.py b/amt/models/role.py new file mode 100644 index 00000000..40733b6e --- /dev/null +++ b/amt/models/role.py @@ -0,0 +1,15 @@ +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from amt.models import Authorization +from amt.models.base import Base +from amt.models.rule import Rule + + +class Role(Base): + __tablename__ = "role" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, nullable=False) + rules: Mapped[list["Rule"]] = relationship() + authorizations: Mapped[list["Authorization"]] = relationship() diff --git a/amt/models/rule.py b/amt/models/rule.py new file mode 100644 index 00000000..bf28922a --- /dev/null +++ b/amt/models/rule.py @@ -0,0 +1,14 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.types import JSON + +from amt.models.base import Base + + +class Rule(Base): + __tablename__ = "rule" + + id: Mapped[int] = mapped_column(primary_key=True) + resource: Mapped[str] = mapped_column(String, nullable=False) + verbs: Mapped[list[str]] = mapped_column(JSON, default=list) + role_id: Mapped[int] = mapped_column(ForeignKey("role.id")) diff --git a/amt/models/user.py b/amt/models/user.py index 51612419..ae3c53ec 100644 --- a/amt/models/user.py +++ b/amt/models/user.py @@ -3,7 +3,7 @@ from sqlalchemy import UUID as SQLAlchemyUUID from sqlalchemy.orm import Mapped, mapped_column, relationship -from amt.models import Organization +from amt.models import Authorization, Organization from amt.models.base import Base @@ -19,3 +19,4 @@ class User(Base): "Organization", secondary="users_and_organizations", back_populates="users", lazy="selectin" ) organizations_created: Mapped[list["Organization"]] = relationship(back_populates="created_by", lazy="selectin") + authorizations: Mapped[list["Authorization"]] = relationship("Authorization", back_populates="user") diff --git a/amt/repositories/authorizations.py b/amt/repositories/authorizations.py new file mode 100644 index 00000000..3225a4ca --- /dev/null +++ b/amt/repositories/authorizations.py @@ -0,0 +1,50 @@ +import logging +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from amt.core.authorization import AuthorizationVerb +from amt.models import Authorization, Role, Rule +from amt.repositories.deps import get_session_non_generator + +logger = logging.getLogger(__name__) + +PermissionTuple = tuple[str, list[AuthorizationVerb], str, int] +PermissionsList = list[PermissionTuple] + + +class AuthorizationRepository: + """ + The AuthorizationRepository provides access to the repository layer. + """ + + def __init__(self, session: AsyncSession | None = None) -> None: + self.session = session + + async def init_session(self) -> None: + if self.session is None: + self.session = await get_session_non_generator() + + async def find_by_user(self, user: UUID) -> PermissionsList | None: + """ + Returns all authorization for a user. + :return: all authorization for the user + """ + await self.init_session() + + statement = ( + select( + Rule.resource, + Rule.verbs, + Authorization.type, + Authorization.type_id, + ) + .join(Role, Rule.role_id == Role.id) + .join(Authorization, Rule.role_id == Authorization.role_id) + .filter(Authorization.user_id == user) + ) + + result = await self.session.execute(statement) # type: ignore + authorizations = result.all() + return authorizations # type: ignore diff --git a/amt/repositories/deps.py b/amt/repositories/deps.py index f7bf0db2..90a2f334 100644 --- a/amt/repositories/deps.py +++ b/amt/repositories/deps.py @@ -13,3 +13,12 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]: ) async with async_session_factory() as async_session: yield async_session + + +async def get_session_non_generator() -> AsyncSession: + async_session_factory = async_sessionmaker( + get_engine(), + expire_on_commit=False, + class_=AsyncSession, + ) + return async_session_factory() diff --git a/amt/schema/permission.py b/amt/schema/permission.py new file mode 100644 index 00000000..936d8005 --- /dev/null +++ b/amt/schema/permission.py @@ -0,0 +1,7 @@ +from amt.core.authorization import AuthorizationVerb +from amt.schema.shared import BaseModel + + +class Permission(BaseModel): + resource: str + verb: list[AuthorizationVerb] diff --git a/amt/services/authorization.py b/amt/services/authorization.py new file mode 100644 index 00000000..7a835cbb --- /dev/null +++ b/amt/services/authorization.py @@ -0,0 +1,43 @@ +import contextlib +from typing import Any +from uuid import UUID + +from amt.core.authorization import AuthorizationType, AuthorizationVerb +from amt.repositories.authorizations import AuthorizationRepository +from amt.schema.permission import Permission + +PermissionTuple = tuple[str, list[AuthorizationVerb], str, int] +PermissionsList = list[PermissionTuple] + + +class AuthorizationService: + def __init__(self) -> None: + self.repository = AuthorizationRepository() + + async def find_by_user(self, user: dict[str, Any] | None) -> dict[str, list[AuthorizationVerb]]: + if not user: + return {} + else: + permissions: dict[str, list[AuthorizationVerb]] = {} + + uuid = UUID(user["sub"]) + authorizations: PermissionsList = await self.repository.find_by_user(uuid) # type: ignore + for auth in authorizations: + auth_dict: dict[str, int] = {"organization_id": -1, "algoritme_id": -1} + + if auth[2] == AuthorizationType.ORGANIZATION: + auth_dict["organization_id"] = auth[3] + + if auth[2] == AuthorizationType.ALGORITHM: + auth_dict["algoritme_id"] = auth[3] + + resource: str = auth[0] + verbs: list[AuthorizationVerb] = auth[1] + with contextlib.suppress(Exception): + resource = resource.format(**auth_dict) + + permission: Permission = Permission(resource=resource, verb=verbs) + + permissions.update({permission.resource: permission.verb}) + + return permissions diff --git a/tests/api/routes/test_deps.py b/tests/api/routes/test_deps.py index eb714abe..337328d6 100644 --- a/tests/api/routes/test_deps.py +++ b/tests/api/routes/test_deps.py @@ -30,6 +30,7 @@ def test_custom_context_processor(mocker: MockerFixture): "translations", "main_menu_items", "user", + "permissions", "WebFormFieldType", ] assert result["version"] == VERSION diff --git a/tests/api/test_decorator.py b/tests/api/test_decorator.py new file mode 100644 index 00000000..a5032566 --- /dev/null +++ b/tests/api/test_decorator.py @@ -0,0 +1,91 @@ +import json +import typing + +from amt.api.decorators import add_permissions +from amt.core.authorization import AuthorizationResource, AuthorizationVerb +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from starlette.responses import Response + +RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] + +app = FastAPI() + + +@app.get("/unauthorized") +@add_permissions(permissions={"algoritme/1": [AuthorizationVerb.CREATE]}) +async def unauthorized(request: Request): + return {"message": "Hello World"} + + +@app.get("/authorized") +@add_permissions(permissions={}) +async def authorized(request: Request): + return {"message": "Hello World"} + + +@app.get("/norequest") +@add_permissions(permissions={}) +async def norequest(): + return {"message": "Hello World"} + + +@app.get("/authorizedparameters/{organization_id}") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO: [AuthorizationVerb.CREATE]}) +async def authorizedparameters(request: Request, organization_id: int): + return {"message": "Hello World"} + + +@app.middleware("http") +async def add_authorizations(request: Request, call_next: RequestResponseEndpoint): + if "X-Permissions" in request.headers: + request.state.permissions = json.loads(request.headers["X-Permissions"]) + return await call_next(request) + + +def test_permission_decorator_norequest(): + client = TestClient(app, base_url="https://testserver") + response = client.get("/norequest") + assert response.status_code == 400 + + +def test_permission_decorator_unauthorized(): + client = TestClient(app, base_url="https://testserver") + response = client.get("/unauthorized") + assert response.status_code == 401 + + +def test_permission_decorator_authorized(): + client = TestClient(app, base_url="https://testserver") + response = client.get( + "/authorized", + ) + assert response.status_code == 200 + + +def test_permission_decorator_authorized_permission(): + client = TestClient(app, base_url="https://testserver") + + response = client.get("/unauthorized", headers={"X-Permissions": '{"algoritme/1": ["Create"]}'}) + assert response.status_code == 200 + + +def test_permission_decorator_authorized_permission_missing(): + client = TestClient(app, base_url="https://testserver") + + response = client.get("/unauthorized", headers={"X-Permissions": '{"algoritme/1": ["Read"]}'}) + assert response.status_code == 401 + + +def test_permission_decorator_authorized_permission_variable(): + client = TestClient(app, base_url="https://testserver") + + response = client.get("/authorizedparameters/1", headers={"X-Permissions": '{"organization/1": ["Create"]}'}) + assert response.status_code == 200 + + +def test_permission_decorator_unauthorized_permission_variable(): + client = TestClient(app, base_url="https://testserver") + + response = client.get("/authorizedparameters/4453546", headers={"X-Permissions": '{"organization/1": ["Create"]}'}) + assert response.status_code == 401 diff --git a/tests/api/test_deps.py b/tests/api/test_deps.py index e46688a9..9ad5a69e 100644 --- a/tests/api/test_deps.py +++ b/tests/api/test_deps.py @@ -1,7 +1,32 @@ -from amt.api.deps import custom_context_processor +from amt.api.deps import custom_context_processor, permission +from amt.core.authorization import AuthorizationVerb from tests.constants import default_fastapi_request +example_permissions = { + "organization/1": [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.UPDATE], + "organization/1/algorithm": [ + AuthorizationVerb.LIST, + AuthorizationVerb.CREATE, + AuthorizationVerb.UPDATE, + AuthorizationVerb.DELETE, + ], + "organization/1/member": [ + AuthorizationVerb.LIST, + AuthorizationVerb.CREATE, + AuthorizationVerb.UPDATE, + AuthorizationVerb.DELETE, + ], + "algoritme/1": [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.DELETE], + "algoritme/1/systemcard": [AuthorizationVerb.READ, AuthorizationVerb.CREATE, AuthorizationVerb.UPDATE], + "algoritme/1/user": [ + AuthorizationVerb.CREATE, + AuthorizationVerb.READ, + AuthorizationVerb.UPDATE, + AuthorizationVerb.DELETE, + ], +} + def test_custom_context_processor(): result = custom_context_processor(default_fastapi_request()) @@ -10,3 +35,21 @@ def test_custom_context_processor(): assert result["available_translations"] == ["en", "nl"] assert result["language"] == "en" assert result["translations"] is None + + +def test_permissions_false(): + result = permission("organization/1", AuthorizationVerb.LIST, example_permissions) + + assert result is False + + +def test_permissions_true(): + result = permission("organization/1", AuthorizationVerb.READ, example_permissions) + + assert result is True + + +def test_permissions_none_existing_resource(): + result = permission("badfadfb/1", AuthorizationVerb.READ, example_permissions) + + assert result is False diff --git a/tests/constants.py b/tests/constants.py index 894a486d..8367013e 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -6,7 +6,8 @@ from amt.api.lifecycles import Lifecycles from amt.api.navigation import BaseNavigationItem, DisplayText -from amt.models import Algorithm, Organization, Task, User +from amt.core.authorization import AuthorizationResource, AuthorizationVerb +from amt.models import Algorithm, Authorization, Organization, Role, Rule, Task, User from amt.schema.instrument import Instrument, InstrumentTask, Owner from amt.schema.system_card import SystemCard from fastapi import Request @@ -56,6 +57,27 @@ def default_organization(name: str = "default organization", slug: str = "defaul return Organization(name=name, slug=slug, created_by_id=UUID(default_auth_user()["sub"])) +def default_rule() -> Rule: + return Rule( + resource=AuthorizationResource.ORGANIZATION_INFO, + verbs=[AuthorizationVerb.CREATE, AuthorizationVerb.READ], + role_id=1, + ) + + +def default_role() -> Role: + return Role(name="default role") + + +def default_authorization() -> Authorization: + return Authorization( + user_id=UUID(default_auth_user()["sub"]), + role_id=1, + type="Organization", + type_id=1, + ) + + def default_user( id: str | UUID | None = None, name: str | None = None, diff --git a/tests/core/test_exceptions.py b/tests/core/test_exceptions.py index 3db035b9..8a101960 100644 --- a/tests/core/test_exceptions.py +++ b/tests/core/test_exceptions.py @@ -1,7 +1,15 @@ from gettext import NullTranslations import pytest -from amt.core.exceptions import AMTCSRFProtectError, AMTError, AMTInstrumentError, AMTNotFound, AMTSettingsError +from amt.core.exceptions import ( + AMTAuthorizationFlowError, + AMTCSRFProtectError, + AMTError, + AMTInstrumentError, + AMTNotFound, + AMTPermissionDenied, + AMTSettingsError, +) def test_settings_error(): @@ -44,3 +52,17 @@ def test_AMTCSRFProtectError(): raise AMTCSRFProtectError() assert exc_info.value.detail == "CSRF check failed." + + +def test_AMTPermissionDenied(): + with pytest.raises(AMTPermissionDenied) as exc_info: + raise AMTPermissionDenied() + + assert exc_info.value.detail == "You do not have the correct permissions to access this resource." + + +def test_AMTAuthorizationFlowError(): + with pytest.raises(AMTAuthorizationFlowError) as exc_info: + raise AMTAuthorizationFlowError() + + assert exc_info.value.detail == "Something went wrong during the authorization flow. Please try again later." diff --git a/tests/repositories/test_authorizations.py b/tests/repositories/test_authorizations.py new file mode 100644 index 00000000..7488fef7 --- /dev/null +++ b/tests/repositories/test_authorizations.py @@ -0,0 +1,39 @@ +from uuid import UUID + +import pytest +from amt.core.authorization import AuthorizationResource, AuthorizationType, AuthorizationVerb +from amt.repositories.authorizations import AuthorizationRepository +from tests.constants import ( + default_algorithm, + default_auth_user, + default_authorization, + default_role, + default_rule, + default_user, +) +from tests.database_test_utils import DatabaseTestUtils + + +@pytest.mark.asyncio +async def test_authorization_basic(db: DatabaseTestUtils): + await db.given( + [ + default_user(), + default_algorithm(), + default_role(), + default_rule(), + default_authorization(), + ] + ) + + authorization_repository = AuthorizationRepository(session=db.session) + results = await authorization_repository.find_by_user(UUID(default_auth_user()["sub"])) + + assert results == [ + ( + AuthorizationResource.ORGANIZATION_INFO, + [AuthorizationVerb.CREATE, AuthorizationVerb.READ], + AuthorizationType.ORGANIZATION, + 1, + ) + ] diff --git a/tests/repositories/test_deps.py b/tests/repositories/test_deps.py index f4ff4c85..3524ed96 100644 --- a/tests/repositories/test_deps.py +++ b/tests/repositories/test_deps.py @@ -1,5 +1,5 @@ import pytest -from amt.repositories.deps import get_session +from amt.repositories.deps import get_session, get_session_non_generator from sqlalchemy.ext.asyncio import AsyncSession @@ -14,3 +14,9 @@ async def test_get_session(): await session_generator.aclose() except StopAsyncIteration: pass + + +@pytest.mark.asyncio +async def test_get_session_non_generator(): + session = await get_session_non_generator() + assert isinstance(session, AsyncSession) diff --git a/tests/services/test_authorization_service.py b/tests/services/test_authorization_service.py new file mode 100644 index 00000000..8fa6c49c --- /dev/null +++ b/tests/services/test_authorization_service.py @@ -0,0 +1,38 @@ +import pytest +from amt.core.authorization import AuthorizationResource, AuthorizationType, AuthorizationVerb +from amt.repositories.authorizations import AuthorizationRepository +from amt.services.authorization import AuthorizationService +from pytest_mock import MockFixture +from tests.constants import default_auth_user + + +@pytest.mark.asyncio +async def test_authorization_get_auth_read(mocker: MockFixture): + # Given + + authorization_service = AuthorizationService() + authorization_service.repository = mocker.AsyncMock(spec=AuthorizationRepository) + authorization_service.repository.find_by_user.return_value = [ + (AuthorizationResource.ORGANIZATION_INFO, [AuthorizationVerb.READ], AuthorizationType.ORGANIZATION, 1) + ] + + # When + authorizations = await authorization_service.find_by_user(default_auth_user()) + + # Then + assert authorizations == {"organization/1": [AuthorizationVerb.READ]} + + +@pytest.mark.asyncio +async def test_authorization_get_auth_none(mocker: MockFixture): + # Given + + authorization_service = AuthorizationService() + authorization_service.repository = mocker.AsyncMock(spec=AuthorizationRepository) + authorization_service.repository.find_by_user.return_value = [] + + # When + authorizations = await authorization_service.find_by_user(default_auth_user()) + + # Then + assert authorizations == {} diff --git a/tests/site/static/templates/permission_example.html.j2 b/tests/site/static/templates/permission_example.html.j2 new file mode 100644 index 00000000..ce99f435 --- /dev/null +++ b/tests/site/static/templates/permission_example.html.j2 @@ -0,0 +1,5 @@ +{% if permission('organization/1', 'Create', permissions) == True %} + User Authorized +{% else %} + User UnAuthorized +{% endif %} diff --git a/tests/site/static/templates/test_template_permission.py b/tests/site/static/templates/test_template_permission.py new file mode 100644 index 00000000..c914ca9f --- /dev/null +++ b/tests/site/static/templates/test_template_permission.py @@ -0,0 +1,38 @@ +from amt.api.deps import LocaleJinja2Templates, custom_context_processor, permission +from amt.core.authorization import AuthorizationVerb +from tests.constants import default_fastapi_request + + +def test_template_permission_unauthorized(): + # given + request = default_fastapi_request() + templates = LocaleJinja2Templates( + directory="tests/site/static/templates", context_processors=[custom_context_processor] + ) + templates.env.globals.update(permission=permission) # pyright: ignore [reportUnknownMemberType] + templates.env.tests["permission"] = permission # pyright: ignore [reportUnknownMemberType] + + # when + response = templates.TemplateResponse(request, "permission_example.html.j2") + + # then + assert b"User UnAuthorized" in response.body + + +def test_template_permission_authorized(): + # given + request = default_fastapi_request() + request.state.permissions = { + "organization/1": [AuthorizationVerb.CREATE, AuthorizationVerb.READ, AuthorizationVerb.UPDATE] + } + templates = LocaleJinja2Templates( + directory="tests/site/static/templates", context_processors=[custom_context_processor] + ) + templates.env.globals.update(permission=permission) # pyright: ignore [reportUnknownMemberType] + templates.env.tests["permission"] = permission # pyright: ignore [reportUnknownMemberType] + + # when + response = templates.TemplateResponse(request, "permission_example.html.j2") + + # then + assert b"User Authorized" in response.body