Skip to content

Commit

Permalink
Add basic authorization roles and enforce them.
Browse files Browse the repository at this point in the history
  • Loading branch information
uittenbroekrobbert committed Jan 16, 2025
1 parent 747edf4 commit 7072012
Show file tree
Hide file tree
Showing 37 changed files with 468 additions and 172 deletions.
6 changes: 2 additions & 4 deletions amt/api/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import HTTPException, Request

from amt.api.editable import SafeDict
from amt.core.exceptions import AMTPermissionDenied


Expand All @@ -12,14 +13,11 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
request = kwargs.get("request")
organization_id = kwargs.get("organization_id")
algoritme_id = kwargs.get("algoritme_id")
if not isinstance(request, Request): # todo: change exception to custom exception
raise HTTPException(status_code=400, detail="Request object is missing")

for permission, verbs in permissions.items():
permission = permission.format(organization_id=organization_id)
permission = permission.format(algoritme_id=algoritme_id)
permission = permission.format_map(SafeDict(kwargs))
request_permissions: dict[str, list[str]] = (
request.state.permissions if hasattr(request.state, "permissions") else {}
)
Expand Down
32 changes: 25 additions & 7 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi.responses import FileResponse, HTMLResponse
from ulid import ULID

from amt.api.decorators import add_permissions
from amt.api.deps import templates
from amt.api.editable import (
EditModes,
Expand All @@ -27,7 +28,7 @@
resolve_navigation_items,
)
from amt.api.routes.shared import UpdateFieldModel, get_filters_and_sort_by, replace_none_with_empty_string_inplace
from amt.core.authorization import get_user
from amt.core.authorization import AuthorizationResource, AuthorizationVerb, get_user
from amt.core.exceptions import AMTError, AMTNotFound, AMTRepositoryError
from amt.core.internationalization import get_current_translation
from amt.enums.status import Status
Expand Down Expand Up @@ -215,6 +216,7 @@ async def get_algorithm_context(


@router.get("/{algorithm_id}/details")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_algorithm_details(
request: Request, algorithm_id: int, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)]
) -> HTMLResponse:
Expand All @@ -236,6 +238,7 @@ async def get_algorithm_details(


@router.get("/{algorithm_id}/edit")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]})
async def get_algorithm_edit(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -271,6 +274,7 @@ async def get_algorithm_edit(


@router.get("/{algorithm_id}/cancel")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]})
async def get_algorithm_cancel(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -305,6 +309,7 @@ async def get_algorithm_cancel(


@router.put("/{algorithm_id}/update")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]})
async def get_algorithm_update(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -356,6 +361,7 @@ async def get_algorithm_update(


@router.get("/{algorithm_id}/details/system_card")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_system_card(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -398,6 +404,7 @@ async def get_system_card(


@router.get("/{algorithm_id}/details/system_card/compliance")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_system_card_requirements(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -453,7 +460,9 @@ async def get_system_card_requirements(
extended_linked_measures.append(ext_measure_task)
requirements_and_measures.append((requirement, completed_measures_count, extended_linked_measures)) # pyright: ignore [reportUnknownMemberType]

measure_task_functions = await get_measure_task_functions(measure_tasks, users_repository, sort_by, filters)
measure_task_functions: dict[str, list[User]] = await get_measure_task_functions(
measure_tasks, users_repository, sort_by, filters
)

context = {
"instrument_state": instrument_state,
Expand All @@ -473,7 +482,7 @@ async def _fetch_members(
users_repository: UsersRepository,
search_name: str,
sort_by: dict[str, str],
filters: dict[str, str],
filters: dict[str, str | list[str | int]],
) -> User | None:
members = await users_repository.find_all(search=search_name, sort=sort_by, filters=filters)
return members[0] if members else None
Expand All @@ -483,9 +492,9 @@ async def get_measure_task_functions(
measure_tasks: list[MeasureTask],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
sort_by: dict[str, str],
filters: dict[str, str],
) -> dict[str, list[Any]]:
measure_task_functions: dict[str, list[Any]] = defaultdict(list)
filters: dict[str, str | list[str | int]],
) -> dict[str, list[User]]:
measure_task_functions: dict[str, list[User]] = defaultdict(list)

for measure_task in measure_tasks:
person_types = ["accountable_persons", "reviewer_persons", "responsible_persons"]
Expand Down Expand Up @@ -533,6 +542,7 @@ async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure


@router.delete("/{algorithm_id}")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.DELETE]})
async def delete_algorithm(
request: Request,
algorithm_id: int,
Expand All @@ -543,6 +553,7 @@ async def delete_algorithm(


@router.get("/{algorithm_id}/measure/{measure_urn}")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_measure(
request: Request,
organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)],
Expand Down Expand Up @@ -598,7 +609,7 @@ async def get_users_from_function_name(
measure_responsible: Annotated[str | None, Form()],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
sort_by: dict[str, str],
filters: dict[str, str],
filters: dict[str, str | list[str | int]],
) -> tuple[list[Person], list[Person], list[Person]]:
accountable_persons, reviewer_persons, responsible_persons = [], [], []
if measure_accountable:
Expand All @@ -614,6 +625,7 @@ async def get_users_from_function_name(


@router.post("/{algorithm_id}/measure/{measure_urn}")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def update_measure_value(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -679,6 +691,7 @@ async def update_measure_value(


@router.get("/{algorithm_id}/members")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_algorithm_members(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -712,6 +725,7 @@ async def get_algorithm_members(


@router.get("/{algorithm_id}/details/system_card/assessments/{assessment_card}")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_assessment_card(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -765,6 +779,7 @@ async def get_assessment_card(


@router.get("/{algorithm_id}/details/system_card/models/{model_card}")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_model_card(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -819,6 +834,7 @@ async def get_model_card(


@router.get("/{algorithm_id}/details/system_card/download")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def download_algorithm_system_card_as_yaml(
algorithm_id: int, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], request: Request
) -> FileResponse:
Expand All @@ -833,6 +849,7 @@ async def download_algorithm_system_card_as_yaml(


@router.get("/{algorithm_id}/file/{ulid}")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_file(
request: Request,
algorithm_id: int,
Expand All @@ -854,6 +871,7 @@ async def get_file(


@router.delete("/{algorithm_id}/file/{ulid}")
@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def delete_file(
request: Request,
algorithm_id: int,
Expand Down
23 changes: 18 additions & 5 deletions amt/api/routes/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from typing import Annotated, Any
from typing import Annotated, Any, cast

from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import HTMLResponse

from amt.api.ai_act_profile import get_ai_act_profile_selector
from amt.api.decorators import add_permissions
from amt.api.deps import templates
from amt.api.forms.algorithm import get_algorithm_form
from amt.api.group_by_category import get_localized_group_by_categories
Expand All @@ -14,7 +15,7 @@
get_localized_risk_groups,
)
from amt.api.routes.shared import get_filters_and_sort_by
from amt.core.authorization import get_user
from amt.core.authorization import AuthorizationResource, AuthorizationVerb, get_user
from amt.core.exceptions import AMTAuthorizationError
from amt.core.internationalization import get_current_translation
from amt.models import Algorithm
Expand All @@ -29,16 +30,25 @@


@router.get("/")
@add_permissions(permissions={AuthorizationResource.ALGORITHMS: [AuthorizationVerb.LIST]})
async def get_root(
request: Request,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)],
skip: int = Query(0, ge=0),
limit: int = Query(5000, ge=1), # todo: fix infinite scroll
search: str = Query(""),
display_type: str = Query(""),
) -> HTMLResponse:
filters, drop_filters, localized_filters, sort_by = get_filters_and_sort_by(request)

session_user = get_user(request)
user_id: str | None = session_user["sub"] if session_user else None # pyright: ignore[reportUnknownVariableType]

filters["organization-id"] = [
organization.id for organization in await organizations_service.get_organizations_for_user(user_id)
]

algorithms, amount_algorithm_systems = await get_algorithms(
algorithms_service, display_type, filters, limit, request, search, skip, sort_by
)
Expand Down Expand Up @@ -76,25 +86,26 @@ async def get_root(
async def get_algorithms(
algorithms_service: AlgorithmsService,
display_type: str,
filters: dict[str, str],
filters: dict[str, str | list[str | int]],
limit: int,
request: Request,
search: str,
skip: int,
sort_by: dict[str, str],
) -> tuple[dict[str, list[Algorithm]], int | Any]:
amount_algorithm_systems: int = 0

if display_type == "LIFECYCLE":
algorithms: dict[str, list[Algorithm]] = {}

# When the lifecycle filter is active, only show these algorithms
if "lifecycle" in filters:
for lifecycle in Lifecycles:
algorithms[lifecycle.name] = []
algorithms[filters["lifecycle"]] = await algorithms_service.paginate(
algorithms[cast(str, filters["lifecycle"])] = await algorithms_service.paginate(
skip=skip, limit=limit, search=search, filters=filters, sort=sort_by
)
amount_algorithm_systems += len(algorithms[filters["lifecycle"]])
amount_algorithm_systems += len(algorithms[cast(str, filters["lifecycle"])])
else:
for lifecycle in Lifecycles:
filters["lifecycle"] = lifecycle.name
Expand All @@ -114,6 +125,7 @@ async def get_algorithms(


@router.get("/new")
@add_permissions(permissions={AuthorizationResource.ALGORITHMS: [AuthorizationVerb.CREATE]})
async def get_new(
request: Request,
instrument_service: Annotated[InstrumentsService, Depends(create_instrument_service)],
Expand Down Expand Up @@ -156,6 +168,7 @@ async def get_new(


@router.post("/new", response_class=HTMLResponse)
@add_permissions(permissions={AuthorizationResource.ALGORITHMS: [AuthorizationVerb.CREATE]})
async def post_new(
request: Request,
algorithm_new: AlgorithmNew,
Expand Down
Loading

0 comments on commit 7072012

Please sign in to comment.