diff --git a/amt/api/decorators.py b/amt/api/decorators.py index c4fbe22e..99870f53 100644 --- a/amt/api/decorators.py +++ b/amt/api/decorators.py @@ -4,6 +4,7 @@ from fastapi import HTTPException, Request +from amt.api.editable import SafeDict from amt.core.exceptions import AMTPermissionDenied @@ -12,14 +13,11 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 request = kwargs.get("request") - organization_id = kwargs.get("organization_id") - algoritme_id = kwargs.get("algoritme_id") if not isinstance(request, Request): # todo: change exception to custom exception raise HTTPException(status_code=400, detail="Request object is missing") for permission, verbs in permissions.items(): - permission = permission.format(organization_id=organization_id) - permission = permission.format(algoritme_id=algoritme_id) + permission = permission.format_map(SafeDict(kwargs)) request_permissions: dict[str, list[str]] = ( request.state.permissions if hasattr(request.state, "permissions") else {} ) diff --git a/amt/api/routes/algorithm.py b/amt/api/routes/algorithm.py index a73bc70b..fdfed271 100644 --- a/amt/api/routes/algorithm.py +++ b/amt/api/routes/algorithm.py @@ -10,6 +10,7 @@ from fastapi.responses import FileResponse, HTMLResponse from ulid import ULID +from amt.api.decorators import add_permissions from amt.api.deps import templates from amt.api.editable import ( EditModes, @@ -27,7 +28,7 @@ resolve_navigation_items, ) from amt.api.routes.shared import UpdateFieldModel, get_filters_and_sort_by, replace_none_with_empty_string_inplace -from amt.core.authorization import get_user +from amt.core.authorization import AuthorizationResource, AuthorizationVerb, get_user from amt.core.exceptions import AMTError, AMTNotFound, AMTRepositoryError from amt.core.internationalization import get_current_translation from amt.enums.status import Status @@ -215,6 +216,7 @@ async def get_algorithm_context( @router.get("/{algorithm_id}/details") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_algorithm_details( request: Request, algorithm_id: int, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)] ) -> HTMLResponse: @@ -236,6 +238,7 @@ async def get_algorithm_details( @router.get("/{algorithm_id}/edit") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]}) async def get_algorithm_edit( request: Request, algorithm_id: int, @@ -271,6 +274,7 @@ async def get_algorithm_edit( @router.get("/{algorithm_id}/cancel") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]}) async def get_algorithm_cancel( request: Request, algorithm_id: int, @@ -305,6 +309,7 @@ async def get_algorithm_cancel( @router.put("/{algorithm_id}/update") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]}) async def get_algorithm_update( request: Request, algorithm_id: int, @@ -356,6 +361,7 @@ async def get_algorithm_update( @router.get("/{algorithm_id}/details/system_card") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_system_card( request: Request, algorithm_id: int, @@ -398,6 +404,7 @@ async def get_system_card( @router.get("/{algorithm_id}/details/system_card/compliance") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_system_card_requirements( request: Request, algorithm_id: int, @@ -453,7 +460,9 @@ async def get_system_card_requirements( extended_linked_measures.append(ext_measure_task) requirements_and_measures.append((requirement, completed_measures_count, extended_linked_measures)) # pyright: ignore [reportUnknownMemberType] - measure_task_functions = await get_measure_task_functions(measure_tasks, users_repository, sort_by, filters) + measure_task_functions: dict[str, list[User]] = await get_measure_task_functions( + measure_tasks, users_repository, sort_by, filters + ) context = { "instrument_state": instrument_state, @@ -473,7 +482,7 @@ async def _fetch_members( users_repository: UsersRepository, search_name: str, sort_by: dict[str, str], - filters: dict[str, str], + filters: dict[str, str | list[str | int]], ) -> User | None: members = await users_repository.find_all(search=search_name, sort=sort_by, filters=filters) return members[0] if members else None @@ -483,9 +492,9 @@ async def get_measure_task_functions( measure_tasks: list[MeasureTask], users_repository: Annotated[UsersRepository, Depends(UsersRepository)], sort_by: dict[str, str], - filters: dict[str, str], -) -> dict[str, list[Any]]: - measure_task_functions: dict[str, list[Any]] = defaultdict(list) + filters: dict[str, str | list[str | int]], +) -> dict[str, list[User]]: + measure_task_functions: dict[str, list[User]] = defaultdict(list) for measure_task in measure_tasks: person_types = ["accountable_persons", "reviewer_persons", "responsible_persons"] @@ -533,6 +542,7 @@ async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure @router.delete("/{algorithm_id}") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.DELETE]}) async def delete_algorithm( request: Request, algorithm_id: int, @@ -543,6 +553,7 @@ async def delete_algorithm( @router.get("/{algorithm_id}/measure/{measure_urn}") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_measure( request: Request, organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)], @@ -598,7 +609,7 @@ async def get_users_from_function_name( measure_responsible: Annotated[str | None, Form()], users_repository: Annotated[UsersRepository, Depends(UsersRepository)], sort_by: dict[str, str], - filters: dict[str, str], + filters: dict[str, str | list[str | int]], ) -> tuple[list[Person], list[Person], list[Person]]: accountable_persons, reviewer_persons, responsible_persons = [], [], [] if measure_accountable: @@ -614,6 +625,7 @@ async def get_users_from_function_name( @router.post("/{algorithm_id}/measure/{measure_urn}") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def update_measure_value( request: Request, algorithm_id: int, @@ -679,6 +691,7 @@ async def update_measure_value( @router.get("/{algorithm_id}/members") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_algorithm_members( request: Request, algorithm_id: int, @@ -712,6 +725,7 @@ async def get_algorithm_members( @router.get("/{algorithm_id}/details/system_card/assessments/{assessment_card}") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_assessment_card( request: Request, algorithm_id: int, @@ -765,6 +779,7 @@ async def get_assessment_card( @router.get("/{algorithm_id}/details/system_card/models/{model_card}") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_model_card( request: Request, algorithm_id: int, @@ -819,6 +834,7 @@ async def get_model_card( @router.get("/{algorithm_id}/details/system_card/download") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def download_algorithm_system_card_as_yaml( algorithm_id: int, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], request: Request ) -> FileResponse: @@ -833,6 +849,7 @@ async def download_algorithm_system_card_as_yaml( @router.get("/{algorithm_id}/file/{ulid}") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def get_file( request: Request, algorithm_id: int, @@ -854,6 +871,7 @@ async def get_file( @router.delete("/{algorithm_id}/file/{ulid}") +@add_permissions(permissions={AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]}) async def delete_file( request: Request, algorithm_id: int, diff --git a/amt/api/routes/algorithms.py b/amt/api/routes/algorithms.py index 1a80bf99..7dfe9333 100644 --- a/amt/api/routes/algorithms.py +++ b/amt/api/routes/algorithms.py @@ -1,10 +1,11 @@ import logging -from typing import Annotated, Any +from typing import Annotated, Any, cast from fastapi import APIRouter, Depends, Query, Request from fastapi.responses import HTMLResponse from amt.api.ai_act_profile import get_ai_act_profile_selector +from amt.api.decorators import add_permissions from amt.api.deps import templates from amt.api.forms.algorithm import get_algorithm_form from amt.api.group_by_category import get_localized_group_by_categories @@ -14,7 +15,7 @@ get_localized_risk_groups, ) from amt.api.routes.shared import get_filters_and_sort_by -from amt.core.authorization import get_user +from amt.core.authorization import AuthorizationResource, AuthorizationVerb, get_user from amt.core.exceptions import AMTAuthorizationError from amt.core.internationalization import get_current_translation from amt.models import Algorithm @@ -29,9 +30,11 @@ @router.get("/") +@add_permissions(permissions={AuthorizationResource.ALGORITHMS: [AuthorizationVerb.LIST]}) async def get_root( request: Request, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], + organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], skip: int = Query(0, ge=0), limit: int = Query(5000, ge=1), # todo: fix infinite scroll search: str = Query(""), @@ -39,6 +42,13 @@ async def get_root( ) -> HTMLResponse: filters, drop_filters, localized_filters, sort_by = get_filters_and_sort_by(request) + session_user = get_user(request) + user_id: str | None = session_user["sub"] if session_user else None # pyright: ignore[reportUnknownVariableType] + + filters["organization-id"] = [ + organization.id for organization in await organizations_service.get_organizations_for_user(user_id) + ] + algorithms, amount_algorithm_systems = await get_algorithms( algorithms_service, display_type, filters, limit, request, search, skip, sort_by ) @@ -76,7 +86,7 @@ async def get_root( async def get_algorithms( algorithms_service: AlgorithmsService, display_type: str, - filters: dict[str, str], + filters: dict[str, str | list[str | int]], limit: int, request: Request, search: str, @@ -84,6 +94,7 @@ async def get_algorithms( sort_by: dict[str, str], ) -> tuple[dict[str, list[Algorithm]], int | Any]: amount_algorithm_systems: int = 0 + if display_type == "LIFECYCLE": algorithms: dict[str, list[Algorithm]] = {} @@ -91,10 +102,10 @@ async def get_algorithms( if "lifecycle" in filters: for lifecycle in Lifecycles: algorithms[lifecycle.name] = [] - algorithms[filters["lifecycle"]] = await algorithms_service.paginate( + algorithms[cast(str, filters["lifecycle"])] = await algorithms_service.paginate( skip=skip, limit=limit, search=search, filters=filters, sort=sort_by ) - amount_algorithm_systems += len(algorithms[filters["lifecycle"]]) + amount_algorithm_systems += len(algorithms[cast(str, filters["lifecycle"])]) else: for lifecycle in Lifecycles: filters["lifecycle"] = lifecycle.name @@ -114,6 +125,7 @@ async def get_algorithms( @router.get("/new") +@add_permissions(permissions={AuthorizationResource.ALGORITHMS: [AuthorizationVerb.CREATE]}) async def get_new( request: Request, instrument_service: Annotated[InstrumentsService, Depends(create_instrument_service)], @@ -156,6 +168,7 @@ async def get_new( @router.post("/new", response_class=HTMLResponse) +@add_permissions(permissions={AuthorizationResource.ALGORITHMS: [AuthorizationVerb.CREATE]}) async def post_new( request: Request, algorithm_new: AlgorithmNew, diff --git a/amt/api/routes/organizations.py b/amt/api/routes/organizations.py index 7ce126ee..70c79c80 100644 --- a/amt/api/routes/organizations.py +++ b/amt/api/routes/organizations.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends, Query, Request from fastapi.responses import HTMLResponse, JSONResponse, Response +from amt.api.decorators import add_permissions from amt.api.deps import templates from amt.api.editable import ( Editables, @@ -24,12 +25,12 @@ resolve_base_navigation_items, resolve_navigation_items, ) -from amt.api.organization_filter_options import get_localized_organization_filters +from amt.api.organization_filter_options import OrganizationFilterOptions, get_localized_organization_filters from amt.api.risk_group import get_localized_risk_groups from amt.api.routes.algorithm import get_user_id_or_error from amt.api.routes.algorithms import get_algorithms from amt.api.routes.shared import UpdateFieldModel, get_filters_and_sort_by -from amt.core.authorization import get_user +from amt.core.authorization import AuthorizationResource, AuthorizationVerb, get_user from amt.core.exceptions import AMTAuthorizationError, AMTNotFound, AMTRepositoryError from amt.core.internationalization import get_current_translation from amt.models import Organization, User @@ -90,6 +91,7 @@ async def get_new( @router.get("/") +@add_permissions(permissions={AuthorizationResource.ORGANIZATIONS: [AuthorizationVerb.LIST]}) async def root( request: Request, organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)], @@ -104,6 +106,8 @@ async def root( breadcrumbs = resolve_base_navigation_items( [Navigation.ORGANIZATIONS_ROOT, Navigation.ORGANIZATIONS_OVERVIEW], request ) + # TODO: we only show organizations you are a member of (request for the pilots) + filters = {"organization-type": OrganizationFilterOptions.MY_ORGANIZATIONS.value} organizations: Sequence[Organization] = await organizations_repository.find_by( search=search, sort=sort_by, filters=filters, user_id=user["sub"] if user else None ) @@ -132,6 +136,7 @@ async def root( @router.post("/new", response_class=HTMLResponse) +@add_permissions(permissions={AuthorizationResource.ORGANIZATIONS: [AuthorizationVerb.CREATE]}) async def post_new( request: Request, organization_new: OrganizationNew, @@ -150,13 +155,14 @@ async def post_new( return response -@router.get("/{slug}") +@router.get("/{organization_slug}") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.READ]}) async def get_by_slug( request: Request, - slug: str, + organization_slug: str, organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], ) -> HTMLResponse: - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) breadcrumbs = resolve_base_navigation_items( [ Navigation.ORGANIZATIONS_ROOT, @@ -165,9 +171,9 @@ async def get_by_slug( request, ) - tab_items = get_organization_tabs(request, organization_slug=slug) + tab_items = get_organization_tabs(request, organization_slug=organization_slug) context = { - "base_href": f"/organizations/{ slug }", + "base_href": f"/organizations/{ organization_slug }", "organization": organization, "organization_id": organization.id, "tab_items": tab_items, @@ -177,24 +183,25 @@ async def get_by_slug( async def get_organization_or_error( - organizations_service: OrganizationsService, request: Request, slug: str + organizations_service: OrganizationsService, request: Request, organization_slug: str ) -> Organization: try: - organization = await organizations_service.find_by_slug(slug) + organization = await organizations_service.find_by_slug(organization_slug) request.state.path_variables = {"organization_slug": organization.slug} except AMTRepositoryError as e: raise AMTNotFound from e return organization -@router.get("/{slug}/edit") +@router.get("/{organization_slug}/edit") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.UPDATE]}) async def get_organization_edit( request: Request, organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], - slug: str, + organization_slug: str, full_resource_path: str, ) -> HTMLResponse: - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) editable: ResolvedEditable = await get_enriched_resolved_editable( context_variables={"organization_id": organization.id}, @@ -213,7 +220,7 @@ async def get_organization_edit( "relative_resource_path": editable.relative_resource_path.replace("/", ".") if editable.relative_resource_path else "", - "base_href": f"/organizations/{ slug }", + "base_href": f"/organizations/{ organization_slug }", "resource_object": editable.resource_object, "full_resource_path": full_resource_path, "editable_object": editable, @@ -222,14 +229,15 @@ async def get_organization_edit( return templates.TemplateResponse(request, "parts/edit_cell.html.j2", context) -@router.get("/{slug}/cancel") +@router.get("/{organization_slug}/cancel") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.UPDATE]}) async def get_organization_cancel( request: Request, - slug: str, + organization_slug: str, full_resource_path: str, organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], ) -> HTMLResponse: - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) editable: ResolvedEditable = await get_enriched_resolved_editable( context_variables={"organization_id": organization.id}, @@ -247,7 +255,7 @@ async def get_organization_cancel( "relative_resource_path": editable.relative_resource_path.replace("/", ".") if editable.relative_resource_path else "", - "base_href": f"/organizations/{ slug }", + "base_href": f"/organizations/{ organization_slug }", "resource_object": None, # TODO: this should become an optional parameter in the Jinja template "full_resource_path": full_resource_path, "editable_object": editable, @@ -256,15 +264,16 @@ async def get_organization_cancel( return templates.TemplateResponse(request, "parts/view_cell.html.j2", context) -@router.put("/{slug}/update") +@router.put("/{organization_slug}/update") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.UPDATE]}) async def get_organization_update( request: Request, organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], update_data: UpdateFieldModel, - slug: str, + organization_slug: str, full_resource_path: str, ) -> HTMLResponse: - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) user_id = get_user_id_or_error(request) @@ -299,7 +308,7 @@ async def get_organization_update( "relative_resource_path": editable.relative_resource_path.replace("/", ".") if editable.relative_resource_path else "", - "base_href": f"/organizations/{ slug }", + "base_href": f"/organizations/{ organization_slug }", "resource_object": None, "full_resource_path": full_resource_path, "editable_object": editable, @@ -312,18 +321,19 @@ async def get_organization_update( return templates.TemplateResponse(request, "parts/view_cell.html.j2", context) -@router.get("/{slug}/algorithms") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.LIST]}) +@router.get("/{organization_slug}/algorithms") async def show_algorithms( request: Request, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], - slug: str, + organization_slug: str, skip: int = Query(0, ge=0), limit: int = Query(5000, ge=1), # todo: fix infinite scroll search: str = Query(""), display_type: str = Query(""), ) -> HTMLResponse: - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) filters, drop_filters, localized_filters, sort_by = get_filters_and_sort_by(request) filters["organization-id"] = str(organization.id) @@ -332,7 +342,7 @@ async def show_algorithms( ) next = skip + limit - tab_items = get_organization_tabs(request, organization_slug=slug) + tab_items = get_organization_tabs(request, organization_slug=organization_slug) breadcrumbs = resolve_base_navigation_items( [ @@ -359,7 +369,7 @@ async def show_algorithms( "filters": localized_filters, "sort_by": sort_by, "display_type": display_type, - "base_href": f"/organizations/{slug}/algorithms", + "base_href": f"/organizations/{organization_slug}/algorithms", "organization_id": organization.id, } @@ -371,58 +381,62 @@ async def show_algorithms( return templates.TemplateResponse(request, "organizations/algorithms.html.j2", context) -@router.delete("/{slug}/members/{user_id}") +@router.delete("/{organization_slug}/members/{user_id}") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.UPDATE]}) async def remove_member( request: Request, - slug: str, + organization_slug: str, user_id: UUID, organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], users_repository: Annotated[UsersRepository, Depends(UsersRepository)], ) -> HTMLResponse: # TODO (Robbert): add authorization and check if user and organization exist? - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) user: User | None = await users_repository.find_by_id(user_id) if user: await organizations_service.remove_user(organization, user) - return templates.Redirect(request, f"/organizations/{slug}/members") + return templates.Redirect(request, f"/organizations/{organization_slug}/members") raise AMTAuthorizationError -@router.get("/{slug}/members/form") +@router.get("/{organization_slug}/members/form") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.UPDATE]}) async def get_members_form( request: Request, - slug: str, + organization_slug: str, ) -> HTMLResponse: form = get_organization_form(id="organization", translations=get_current_translation(request), user=None) - context: dict[str, Any] = {"form": form, "slug": slug} + context: dict[str, Any] = {"form": form, "slug": organization_slug} return templates.TemplateResponse(request, "organizations/parts/add_members_modal.html.j2", context) -@router.put("/{slug}/members", response_class=HTMLResponse) +@router.put("/{organization_slug}/members", response_class=HTMLResponse) +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.UPDATE]}) async def add_new_members( request: Request, - slug: str, + organization_slug: str, organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], organization_users: OrganizationUsers, ) -> HTMLResponse: - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) await organizations_service.add_users(organization, organization_users.user_ids) - return templates.Redirect(request, f"/organizations/{slug}/members") + return templates.Redirect(request, f"/organizations/{organization_slug}/members") -@router.get("/{slug}/members") +@router.get("/{organization_slug}/members") +@add_permissions(permissions={AuthorizationResource.ORGANIZATION_INFO_SLUG: [AuthorizationVerb.LIST]}) async def get_members( request: Request, - slug: str, + organization_slug: str, organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)], users_repository: Annotated[UsersRepository, Depends(UsersRepository)], skip: int = Query(0, ge=0), limit: int = Query(5000, ge=1), # todo: fix infinite scroll search: str = Query(""), ) -> HTMLResponse: - organization = await get_organization_or_error(organizations_service, request, slug) + organization = await get_organization_or_error(organizations_service, request, organization_slug) filters, drop_filters, localized_filters, sort_by = get_filters_and_sort_by(request) - tab_items = get_organization_tabs(request, organization_slug=slug) + tab_items = get_organization_tabs(request, organization_slug=organization_slug) breadcrumbs = resolve_base_navigation_items( [ Navigation.ORGANIZATIONS_ROOT, diff --git a/amt/api/routes/shared.py b/amt/api/routes/shared.py index 576225c8..9a87d8c0 100644 --- a/amt/api/routes/shared.py +++ b/amt/api/routes/shared.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any +from typing import Any, cast from pydantic import BaseModel from starlette.requests import Request @@ -14,7 +14,7 @@ def get_filters_and_sort_by( request: Request, -) -> tuple[dict[str, str], list[str], dict[str, LocalizedValueItem], dict[str, str]]: +) -> tuple[dict[str, str | list[str | int]], list[str], dict[str, LocalizedValueItem], dict[str, str]]: active_filters: dict[str, str] = { k.removeprefix("active-filter-"): v for k, v in request.query_params.items() @@ -29,9 +29,11 @@ def get_filters_and_sort_by( if "organization-type" in add_filters and add_filters["organization-type"] == OrganizationFilterOptions.ALL.value: del add_filters["organization-type"] drop_filters: list[str] = [v for k, v in request.query_params.items() if k.startswith("drop-filter") and v != ""] - filters: dict[str, str] = {k: v for k, v in (active_filters | add_filters).items() if k not in drop_filters} + filters: dict[str, str | list[str | int]] = { + k: v for k, v in (active_filters | add_filters).items() if k not in drop_filters + } localized_filters: dict[str, LocalizedValueItem] = { - k: get_localized_value(k, v, request) for k, v in filters.items() + k: get_localized_value(k, cast(str, v), request) for k, v in filters.items() } sort_by: dict[str, str] = { k.removeprefix("sort-by-"): v for k, v in request.query_params.items() if k.startswith("sort-by-") and v != "" diff --git a/amt/core/authorization.py b/amt/core/authorization.py index 9756361f..9a9bc283 100644 --- a/amt/core/authorization.py +++ b/amt/core/authorization.py @@ -21,12 +21,17 @@ class AuthorizationType(StrEnum): class AuthorizationResource(StrEnum): + ORGANIZATIONS = "organizations/" ORGANIZATION_INFO = "organization/{organization_id}" ORGANIZATION_ALGORITHM = "organization/{organization_id}/algorithm" ORGANIZATION_MEMBER = "organization/{organization_id}/member" - ALGORITHM = "algoritme/{algoritme_id}" - ALGORITHM_SYSTEMCARD = "algoritme/{algoritme_id}/systemcard" - ALGORITHM_MEMBER = "algoritme/{algoritme_id}/user" + ORGANIZATION_INFO_SLUG = "organization/{organization_slug}" + ORGANIZATION_ALGORITHM_SLUG = "organization/{organization_slug}/algorithm" + ORGANIZATION_MEMBER_SLUG = "organization/{organization_slug}/member" + ALGORITHMS = "algorithms/" + ALGORITHM = "algorithm/{algorithm_id}" + ALGORITHM_SYSTEMCARD = "algorithm/{algorithm_id}/systemcard" + ALGORITHM_MEMBER = "algorithm/{algorithm_id}/user" def get_user(request: Request) -> dict[str, Any] | None: diff --git a/amt/core/exceptions.py b/amt/core/exceptions.py index 1a39cae9..5bb41e9f 100644 --- a/amt/core/exceptions.py +++ b/amt/core/exceptions.py @@ -38,7 +38,9 @@ def __init__(self) -> None: class AMTNotFound(AMTHTTPException): def __init__(self) -> None: self.detail: str = _( - "The requested page or resource could not be found. Please check the URL or query and try again." + "The requested page or resource could not be found, " + "or you do not have the correct permissions to access it. Please check the " + "URL or query and try again." ) super().__init__(status.HTTP_404_NOT_FOUND, self.detail) @@ -81,8 +83,12 @@ def __init__(self) -> None: class AMTPermissionDenied(AMTHTTPException): def __init__(self) -> None: - self.detail: str = _("You do not have the correct permissions to access this resource.") - super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail) + self.detail: str = _( + "The requested page or resource could not be found, " + "or you do not have the correct permissions to access it. Please check the " + "URL or query and try again." + ) + super().__init__(status.HTTP_404_NOT_FOUND, self.detail) class AMTStorageError(AMTHTTPException): diff --git a/amt/locale/base.pot b/amt/locale/base.pot index 12482b99..bd333727 100644 --- a/amt/locale/base.pot +++ b/amt/locale/base.pot @@ -314,43 +314,40 @@ msgstr "" msgid "An error occurred while processing the instrument. Please try again later." msgstr "" -#: amt/core/exceptions.py:40 +#: amt/core/exceptions.py:40 amt/core/exceptions.py:86 msgid "" -"The requested page or resource could not be found. Please check the URL " -"or query and try again." +"The requested page or resource could not be found, or you do not have the" +" correct permissions to access it. Please check the URL or query and try " +"again." msgstr "" -#: amt/core/exceptions.py:48 +#: amt/core/exceptions.py:50 msgid "CSRF check failed." msgstr "" -#: amt/core/exceptions.py:54 +#: amt/core/exceptions.py:56 msgid "Only static files are supported." msgstr "" -#: amt/core/exceptions.py:60 +#: amt/core/exceptions.py:62 msgid "Key not correct: {field}" msgstr "" -#: amt/core/exceptions.py:66 +#: amt/core/exceptions.py:68 msgid "Value not correct: {field}" msgstr "" -#: amt/core/exceptions.py:72 +#: amt/core/exceptions.py:74 msgid "Failed to authorize, please login and try again." msgstr "" -#: amt/core/exceptions.py:78 +#: amt/core/exceptions.py:80 msgid "" "Something went wrong during the authorization flow. Please try again " "later." msgstr "" -#: amt/core/exceptions.py:84 -msgid "You do not have the correct permissions to access this resource." -msgstr "" - -#: amt/core/exceptions.py:90 +#: amt/core/exceptions.py:96 msgid "Something went wrong storing your file. PLease try again later." msgstr "" @@ -573,9 +570,42 @@ msgstr "" #: amt/site/templates/errors/Exception.html.j2:5 #: amt/site/templates/errors/RequestValidationError_400.html.j2:5 +#: amt/site/templates/errors/_404_Exception.html.j2:1 msgid "An error occurred" msgstr "" +#: amt/site/templates/errors/_404_Exception.html.j2:3 +msgid "We couldn't find what you were looking for. This might be because:" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:6 +msgid "The link isn't correct (anymore)" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:7 +msgid "The page has moved or been removed" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:8 +msgid "You don't have access to this page" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:10 +msgid "What now?" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:12 +msgid "Double-check if you typed the URL correctly" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:14 +msgid "Head back to the overview page" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:16 +msgid "Contact your admin" +msgstr "" + #: amt/site/templates/errors/_AMTCSRFProtectError_401.html.j2:11 #: amt/site/templates/errors/_Exception.html.j2:5 msgid "An error occurred. Please try again later" diff --git a/amt/locale/en_US/LC_MESSAGES/messages.mo b/amt/locale/en_US/LC_MESSAGES/messages.mo index 249ef0d7..738f59d7 100644 Binary files a/amt/locale/en_US/LC_MESSAGES/messages.mo and b/amt/locale/en_US/LC_MESSAGES/messages.mo differ diff --git a/amt/locale/en_US/LC_MESSAGES/messages.po b/amt/locale/en_US/LC_MESSAGES/messages.po index 5cfec7e3..e56c0c42 100644 --- a/amt/locale/en_US/LC_MESSAGES/messages.po +++ b/amt/locale/en_US/LC_MESSAGES/messages.po @@ -315,43 +315,40 @@ msgstr "" msgid "An error occurred while processing the instrument. Please try again later." msgstr "" -#: amt/core/exceptions.py:40 +#: amt/core/exceptions.py:40 amt/core/exceptions.py:86 msgid "" -"The requested page or resource could not be found. Please check the URL " -"or query and try again." +"The requested page or resource could not be found, or you do not have the" +" correct permissions to access it. Please check the URL or query and try " +"again." msgstr "" -#: amt/core/exceptions.py:48 +#: amt/core/exceptions.py:50 msgid "CSRF check failed." msgstr "" -#: amt/core/exceptions.py:54 +#: amt/core/exceptions.py:56 msgid "Only static files are supported." msgstr "" -#: amt/core/exceptions.py:60 +#: amt/core/exceptions.py:62 msgid "Key not correct: {field}" msgstr "" -#: amt/core/exceptions.py:66 +#: amt/core/exceptions.py:68 msgid "Value not correct: {field}" msgstr "" -#: amt/core/exceptions.py:72 +#: amt/core/exceptions.py:74 msgid "Failed to authorize, please login and try again." msgstr "" -#: amt/core/exceptions.py:78 +#: amt/core/exceptions.py:80 msgid "" "Something went wrong during the authorization flow. Please try again " "later." msgstr "" -#: amt/core/exceptions.py:84 -msgid "You do not have the correct permissions to access this resource." -msgstr "" - -#: amt/core/exceptions.py:90 +#: amt/core/exceptions.py:96 msgid "Something went wrong storing your file. PLease try again later." msgstr "" @@ -574,9 +571,42 @@ msgstr "" #: amt/site/templates/errors/Exception.html.j2:5 #: amt/site/templates/errors/RequestValidationError_400.html.j2:5 +#: amt/site/templates/errors/_404_Exception.html.j2:1 msgid "An error occurred" msgstr "" +#: amt/site/templates/errors/_404_Exception.html.j2:3 +msgid "We couldn't find what you were looking for. This might be because:" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:6 +msgid "The link isn't correct (anymore)" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:7 +msgid "The page has moved or been removed" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:8 +msgid "You don't have access to this page" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:10 +msgid "What now?" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:12 +msgid "Double-check if you typed the URL correctly" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:14 +msgid "Head back to the overview page" +msgstr "" + +#: amt/site/templates/errors/_404_Exception.html.j2:16 +msgid "Contact your admin" +msgstr "" + #: amt/site/templates/errors/_AMTCSRFProtectError_401.html.j2:11 #: amt/site/templates/errors/_Exception.html.j2:5 msgid "An error occurred. Please try again later" diff --git a/amt/locale/nl_NL/LC_MESSAGES/messages.mo b/amt/locale/nl_NL/LC_MESSAGES/messages.mo index ec2e0270..38eee241 100644 Binary files a/amt/locale/nl_NL/LC_MESSAGES/messages.mo and b/amt/locale/nl_NL/LC_MESSAGES/messages.mo differ diff --git a/amt/locale/nl_NL/LC_MESSAGES/messages.po b/amt/locale/nl_NL/LC_MESSAGES/messages.po index 0a4f0415..a7a6defa 100644 --- a/amt/locale/nl_NL/LC_MESSAGES/messages.po +++ b/amt/locale/nl_NL/LC_MESSAGES/messages.po @@ -325,35 +325,37 @@ msgstr "" "Er is een fout opgetreden tijdens het verwerken van het instrument. " "Probeer het later opnieuw." -#: amt/core/exceptions.py:40 +#: amt/core/exceptions.py:40 amt/core/exceptions.py:86 msgid "" -"The requested page or resource could not be found. Please check the URL " -"or query and try again." +"The requested page or resource could not be found, or you do not have the" +" correct permissions to access it. Please check the URL or query and try " +"again." msgstr "" -"De gevraagde pagina of bron kon niet worden gevonden. Controleer de URL " -"of query en probeer het opnieuw." +"De gevraagde pagina of bron kon niet worden gevonden of u beschikt niet " +"over de juiste machtigingen om deze bron te openen. Controleer de URL of " +"query en probeer het opnieuw." -#: amt/core/exceptions.py:48 +#: amt/core/exceptions.py:50 msgid "CSRF check failed." msgstr "CSRF-controle mislukt." -#: amt/core/exceptions.py:54 +#: amt/core/exceptions.py:56 msgid "Only static files are supported." msgstr "Alleen statische bestanden worden ondersteund." -#: amt/core/exceptions.py:60 +#: amt/core/exceptions.py:62 msgid "Key not correct: {field}" msgstr "Sleutel niet correct: {field}" -#: amt/core/exceptions.py:66 +#: amt/core/exceptions.py:68 msgid "Value not correct: {field}" msgstr "Waarde is niet correct: {field}" -#: amt/core/exceptions.py:72 +#: amt/core/exceptions.py:74 msgid "Failed to authorize, please login and try again." msgstr "Autoriseren is mislukt. Meld u aan en probeer het opnieuw." -#: amt/core/exceptions.py:78 +#: amt/core/exceptions.py:80 msgid "" "Something went wrong during the authorization flow. Please try again " "later." @@ -361,11 +363,7 @@ msgstr "" "Er is iets fout gegaan tijdens de autorisatiestroom. Probeer het later " "opnieuw" -#: amt/core/exceptions.py:84 -msgid "You do not have the correct permissions to access this resource." -msgstr "U beschikt niet over de juiste machtigingen om deze bron te openen." - -#: amt/core/exceptions.py:90 +#: amt/core/exceptions.py:96 msgid "Something went wrong storing your file. PLease try again later." msgstr "" "Er is iets fout gegaan tijdens het opslaan van uw bestand. Probeer het " @@ -597,9 +595,42 @@ msgstr "Inloggen" #: amt/site/templates/errors/Exception.html.j2:5 #: amt/site/templates/errors/RequestValidationError_400.html.j2:5 +#: amt/site/templates/errors/_404_Exception.html.j2:1 msgid "An error occurred" msgstr "Er is een fout opgetreden" +#: amt/site/templates/errors/_404_Exception.html.j2:3 +msgid "We couldn't find what you were looking for. This might be because:" +msgstr "We konden niet vinden wat u zocht. Dit kan komen doordat:" + +#: amt/site/templates/errors/_404_Exception.html.j2:6 +msgid "The link isn't correct (anymore)" +msgstr "De link niet (meer) correct is" + +#: amt/site/templates/errors/_404_Exception.html.j2:7 +msgid "The page has moved or been removed" +msgstr "De pagina is verplaatst of verwijderd" + +#: amt/site/templates/errors/_404_Exception.html.j2:8 +msgid "You don't have access to this page" +msgstr "U heeft geen toegang tot deze pagina" + +#: amt/site/templates/errors/_404_Exception.html.j2:10 +msgid "What now?" +msgstr "Wat nu?" + +#: amt/site/templates/errors/_404_Exception.html.j2:12 +msgid "Double-check if you typed the URL correctly" +msgstr "Controleer nogmaals of u de URL correct hebt getypt" + +#: amt/site/templates/errors/_404_Exception.html.j2:14 +msgid "Head back to the overview page" +msgstr "Ga terug naar de overzichtspagina" + +#: amt/site/templates/errors/_404_Exception.html.j2:16 +msgid "Contact your admin" +msgstr "Neem contact op met uw beheerder" + #: amt/site/templates/errors/_AMTCSRFProtectError_401.html.j2:11 #: amt/site/templates/errors/_Exception.html.j2:5 msgid "An error occurred. Please try again later" diff --git a/amt/middleware/authorization.py b/amt/middleware/authorization.py index a5ec48fe..f3e64400 100644 --- a/amt/middleware/authorization.py +++ b/amt/middleware/authorization.py @@ -1,11 +1,13 @@ import os import typing +from uuid import UUID from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import RedirectResponse, Response from amt.core.authorization import get_user +from amt.models import User from amt.services.authorization import AuthorizationService RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] @@ -19,14 +21,24 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - if request.url.path.startswith("/static/"): return await call_next(request) + authorization_service = AuthorizationService() + disable_auth_str = os.environ.get("DISABLE_AUTH") auth_disable = False if disable_auth_str is None else disable_auth_str.lower() == "true" if auth_disable: auto_login_uuid: str | None = os.environ.get("AUTO_LOGIN_UUID", None) if auto_login_uuid: - request.session["user"] = {"sub": auto_login_uuid} - - authorization_service = AuthorizationService() + user_object: User | None = await authorization_service.get_user(UUID(auto_login_uuid)) + if user_object: + request.session["user"] = { + "sub": str(user_object.id), + "email": user_object.email, + "name": user_object.name, + "email_hash": user_object.email_hash, + "name_encoded": user_object.name_encoded, + } + else: + request.session["user"] = {"sub": auto_login_uuid} user = get_user(request) diff --git a/amt/repositories/algorithms.py b/amt/repositories/algorithms.py index e4edc874..e92e7d11 100644 --- a/amt/repositories/algorithms.py +++ b/amt/repositories/algorithms.py @@ -1,6 +1,7 @@ import logging from collections.abc import Sequence -from typing import Annotated +from typing import Annotated, cast +from uuid import UUID from fastapi import Depends from sqlalchemy import func, select @@ -10,7 +11,7 @@ from amt.api.risk_group import RiskGroup from amt.core.exceptions import AMTRepositoryError -from amt.models import Algorithm +from amt.models import Algorithm, Organization, User from amt.repositories.deps import get_session logger = logging.getLogger(__name__) @@ -74,7 +75,7 @@ async def find_by_id(self, algorithm_id: int) -> Algorithm: raise AMTRepositoryError from e async def paginate( # noqa - self, skip: int, limit: int, search: str, filters: dict[str, str], sort: dict[str, str] + self, skip: int, limit: int, search: str, filters: dict[str, str | list[str | int]], sort: dict[str, str] ) -> list[Algorithm]: try: statement = select(Algorithm) @@ -83,15 +84,18 @@ async def paginate( # noqa if filters: for key, value in filters.items(): match key: + case "id": + statement = statement.where(Algorithm.id == int(cast(str, value))) case "lifecycle": statement = statement.filter(Algorithm.lifecycle == value) case "risk-group": statement = statement.filter( Algorithm.system_card_json["ai_act_profile"]["risk_group"].as_string() - == RiskGroup[value].value + == RiskGroup[cast(str, value)].value ) case "organization-id": - statement = statement.filter(Algorithm.organization_id == int(value)) + value = [int(value)] if not isinstance(value, list) else [int(v) for v in value] + statement = statement.filter(Algorithm.organization_id.in_(value)) case _: raise TypeError(f"Unknown filter type with key: {key}") # noqa if sort: @@ -120,3 +124,12 @@ async def paginate( # noqa except Exception as e: logger.exception("Error paginating algorithms") raise AMTRepositoryError from e + + async def get_by_user(self, user_id: UUID) -> Sequence[Algorithm]: + statement = ( + select(Algorithm) + .join(Organization, Organization.id == Algorithm.organization_id) + .where(Organization.users.any(User.id == user_id)) # pyright: ignore[reportUnknownMemberType] + ) + + return (await self.session.execute(statement)).scalars().all() diff --git a/amt/repositories/authorizations.py b/amt/repositories/authorizations.py index e39711ab..a55e1cba 100644 --- a/amt/repositories/authorizations.py +++ b/amt/repositories/authorizations.py @@ -4,13 +4,16 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from amt.core.authorization import AuthorizationVerb -from amt.models import Authorization, Role, Rule +from amt.core.authorization import AuthorizationResource, AuthorizationType, AuthorizationVerb +from amt.models import Authorization, Role, Rule, User +from amt.repositories.algorithms import AlgorithmsRepository from amt.repositories.deps import get_session_non_generator +from amt.repositories.organizations import OrganizationsRepository +from amt.repositories.users import UsersRepository logger = logging.getLogger(__name__) -PermissionTuple = tuple[str, list[AuthorizationVerb], str, int] +PermissionTuple = tuple[AuthorizationResource, list[AuthorizationVerb], AuthorizationType, str | int] PermissionsList = list[PermissionTuple] @@ -26,7 +29,53 @@ async def init_session(self) -> None: if self.session is None: self.session = await get_session_non_generator() + async def get_user(self, user_id: UUID) -> User | None: + try: + await self.init_session() + return await UsersRepository(session=self.session).find_by_id(user_id) # pyright: ignore[reportArgumentType] + finally: + if self.session is not None: + await self.session.close() + async def find_by_user(self, user: UUID) -> PermissionsList | None: + """ + Returns all authorization for a user. + :return: all authorization for the user + """ + try: + await self.init_session() + authorization_verbs: list[AuthorizationVerb] = [ + AuthorizationVerb.READ, + AuthorizationVerb.UPDATE, + AuthorizationVerb.CREATE, + AuthorizationVerb.LIST, + AuthorizationVerb.DELETE, + ] + my_algorithms: PermissionsList = [ + (AuthorizationResource.ALGORITHMS, authorization_verbs, AuthorizationType.ALGORITHM, "*"), + ] + my_algorithms += [ + (AuthorizationResource.ALGORITHM, authorization_verbs, AuthorizationType.ALGORITHM, algorithm.id) + for algorithm in await AlgorithmsRepository(session=self.session).get_by_user(user) # pyright: ignore[reportArgumentType] + ] + my_organizations: PermissionsList = [ + (AuthorizationResource.ORGANIZATIONS, authorization_verbs, AuthorizationType.ORGANIZATION, "*"), + ] + my_organizations += [ + ( + AuthorizationResource.ORGANIZATION_INFO_SLUG, + authorization_verbs, + AuthorizationType.ORGANIZATION, + organization.slug, + ) + for organization in await OrganizationsRepository(session=self.session).get_by_user(user) # pyright: ignore[reportArgumentType] + ] + return my_algorithms + my_organizations + finally: + if self.session is not None: + await self.session.close() + + async def find_by_user_original(self, user: UUID) -> PermissionsList | None: """ Returns all authorization for a user. :return: all authorization for the user diff --git a/amt/repositories/organizations.py b/amt/repositories/organizations.py index afe9b8c1..6941908a 100644 --- a/amt/repositories/organizations.py +++ b/amt/repositories/organizations.py @@ -141,3 +141,7 @@ async def remove_user(self, organization: Organization, user: User) -> Organizat await self.session.commit() await self.session.refresh(organization) return organization + + async def get_by_user(self, user_id: UUID) -> Sequence[Organization]: + statement = select(Organization).where(Organization.users.any(User.id == user_id)) # pyright: ignore[reportUnknownMemberType] + return (await self.session.execute(statement)).scalars().all() diff --git a/amt/repositories/users.py b/amt/repositories/users.py index 09c38a51..8ef78a63 100644 --- a/amt/repositories/users.py +++ b/amt/repositories/users.py @@ -1,6 +1,6 @@ import logging from collections.abc import Sequence -from typing import Annotated +from typing import Annotated, cast from uuid import UUID from fastapi import Depends @@ -29,7 +29,7 @@ async def find_all( self, search: str | None = None, sort: dict[str, str] | None = None, - filters: dict[str, str] | None = None, + filters: dict[str, str | list[str | int]] | None = None, skip: int | None = None, limit: int | None = None, ) -> Sequence[User]: @@ -37,7 +37,9 @@ async def find_all( if search: statement = statement.filter(User.name.ilike(f"%{escape_like(search)}%")) if filters and "organization-id" in filters: - statement = statement.where(User.organizations.any(Organization.id == int(filters["organization-id"]))) + statement = statement.where( + User.organizations.any(Organization.id == int(cast(str, filters["organization-id"]))) + ) if sort: if "name" in sort and sort["name"] == "ascending": statement = statement.order_by(func.lower(User.name).asc()) diff --git a/amt/services/algorithms.py b/amt/services/algorithms.py index 88cd2b71..965004eb 100644 --- a/amt/services/algorithms.py +++ b/amt/services/algorithms.py @@ -117,15 +117,16 @@ async def create(self, algorithm_new: AlgorithmNew, user_id: UUID | str) -> Algo return algorithm async def paginate( - self, skip: int, limit: int, search: str, filters: dict[str, str], sort: dict[str, str] + self, skip: int, limit: int, search: str, filters: dict[str, str | list[str | int]], sort: dict[str, str] ) -> list[Algorithm]: algorithms = await self.repository.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort) return algorithms async def update(self, algorithm: Algorithm) -> Algorithm: # TODO: Is this the right place to sync system cards: system_card and system_card_json? + algorithm.sync_system_card() # TODO: when system card is missing things break, so we call it here to make sure it exists?? - dummy = algorithm.system_card # noqa: F841 # pyright: ignore[reportUnusedVariable] + dummy = algorithm.system_card # noqa: F841 pyright: ignore[reportUnusedVariable] algorithm = await self.repository.save(algorithm) return algorithm diff --git a/amt/services/authorization.py b/amt/services/authorization.py index 7a835cbb..1aa7a9f3 100644 --- a/amt/services/authorization.py +++ b/amt/services/authorization.py @@ -1,8 +1,9 @@ -import contextlib from typing import Any from uuid import UUID +from amt.api.editable import SafeDict from amt.core.authorization import AuthorizationType, AuthorizationVerb +from amt.models import User from amt.repositories.authorizations import AuthorizationRepository from amt.schema.permission import Permission @@ -14,6 +15,9 @@ class AuthorizationService: def __init__(self) -> None: self.repository = AuthorizationRepository() + async def get_user(self, user_id: UUID) -> User | None: + return await self.repository.get_user(user_id) + async def find_by_user(self, user: dict[str, Any] | None) -> dict[str, list[AuthorizationVerb]]: if not user: return {} @@ -23,18 +27,20 @@ async def find_by_user(self, user: dict[str, Any] | None) -> dict[str, list[Auth uuid = UUID(user["sub"]) authorizations: PermissionsList = await self.repository.find_by_user(uuid) # type: ignore for auth in authorizations: - auth_dict: dict[str, int] = {"organization_id": -1, "algoritme_id": -1} + auth_dict: dict[str, int | str] = {} if auth[2] == AuthorizationType.ORGANIZATION: + # TODO: check the path if we need the slug or the id? auth_dict["organization_id"] = auth[3] + auth_dict["organization_slug"] = auth[3] if auth[2] == AuthorizationType.ALGORITHM: - auth_dict["algoritme_id"] = auth[3] + auth_dict["algorithm_id"] = auth[3] resource: str = auth[0] verbs: list[AuthorizationVerb] = auth[1] - with contextlib.suppress(Exception): - resource = resource.format(**auth_dict) + + resource = resource.format_map(SafeDict(auth_dict)) permission: Permission = Permission(resource=resource, verb=verbs) diff --git a/amt/site/templates/errors/AMTNotFound_404.html.j2 b/amt/site/templates/errors/AMTNotFound_404.html.j2 new file mode 100644 index 00000000..d84d7538 --- /dev/null +++ b/amt/site/templates/errors/AMTNotFound_404.html.j2 @@ -0,0 +1,6 @@ +{% extends 'layouts/base.html.j2' %} +{% block content %} +
+ {% trans %}We couldn't find what you were looking for. This might be because:{% endtrans %} +
+ +diff --git a/tests/api/routes/test_algorithm.py b/tests/api/routes/test_algorithm.py index 0d8d23c9..de436eee 100644 --- a/tests/api/routes/test_algorithm.py +++ b/tests/api/routes/test_algorithm.py @@ -28,6 +28,7 @@ from tests.constants import ( default_algorithm, default_algorithm_with_system_card, + default_not_found_no_permission_msg, default_task, default_user, ) @@ -42,7 +43,7 @@ async def test_get_unknown_algorithm(client: AsyncClient) -> None: # then assert response.status_code == 404 assert response.headers["content-type"] == "text/html; charset=utf-8" - assert b"The requested page or resource could not be found." in response.content + assert default_not_found_no_permission_msg() in response.content @pytest.mark.asyncio @@ -117,7 +118,7 @@ async def test_get_algorithm_non_existing_algorithm(client: AsyncClient, db: Dat # then assert response.status_code == 404 - assert b"The requested page or resource could not be found." in response.content + assert default_not_found_no_permission_msg() in response.content @pytest.mark.asyncio @@ -169,7 +170,7 @@ async def test_get_system_card_unknown_algorithm(client: AsyncClient) -> None: # then assert response.status_code == 404 assert response.headers["content-type"] == "text/html; charset=utf-8" - assert b"The requested page or resource could not be found." in response.content + assert default_not_found_no_permission_msg() in response.content @pytest.mark.asyncio @@ -198,7 +199,7 @@ async def test_get_assessment_card_unknown_algorithm(client: AsyncClient, db: Da # then assert response.status_code == 404 assert response.headers["content-type"] == "text/html; charset=utf-8" - assert b"The requested page or resource could not be found." in response.content + assert default_not_found_no_permission_msg() in response.content @pytest.mark.asyncio @@ -212,7 +213,7 @@ async def test_get_assessment_card_unknown_assessment(client: AsyncClient, db: D # then assert response.status_code == 404 assert response.headers["content-type"] == "text/html; charset=utf-8" - assert b"The requested page or resource could not be found." in response.content + assert default_not_found_no_permission_msg() in response.content @pytest.mark.asyncio @@ -237,7 +238,7 @@ async def test_get_model_card_unknown_algorithm(client: AsyncClient) -> None: # then assert response.status_code == 404 assert response.headers["content-type"] == "text/html; charset=utf-8" - assert b"The requested page or resource could not be found." in response.content + assert default_not_found_no_permission_msg() in response.content @pytest.mark.asyncio @@ -251,7 +252,7 @@ async def test_get_assessment_card_unknown_model_card(client: AsyncClient, db: D # then assert response.status_code == 404 assert response.headers["content-type"] == "text/html; charset=utf-8" - assert b"The requested page or resource could not be found." in response.content + assert default_not_found_no_permission_msg() in response.content @pytest.mark.asyncio diff --git a/tests/api/test_decorator.py b/tests/api/test_decorator.py index a5032566..84f794fd 100644 --- a/tests/api/test_decorator.py +++ b/tests/api/test_decorator.py @@ -52,7 +52,7 @@ def test_permission_decorator_norequest(): def test_permission_decorator_unauthorized(): client = TestClient(app, base_url="https://testserver") response = client.get("/unauthorized") - assert response.status_code == 401 + assert response.status_code == 404 def test_permission_decorator_authorized(): @@ -74,7 +74,7 @@ def test_permission_decorator_authorized_permission_missing(): client = TestClient(app, base_url="https://testserver") response = client.get("/unauthorized", headers={"X-Permissions": '{"algoritme/1": ["Read"]}'}) - assert response.status_code == 401 + assert response.status_code == 404 def test_permission_decorator_authorized_permission_variable(): @@ -88,4 +88,4 @@ def test_permission_decorator_unauthorized_permission_variable(): client = TestClient(app, base_url="https://testserver") response = client.get("/authorizedparameters/4453546", headers={"X-Permissions": '{"organization/1": ["Create"]}'}) - assert response.status_code == 401 + assert response.status_code == 404 diff --git a/tests/conftest.py b/tests/conftest.py index 83a495e7..cd12a7fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ from amt.models.base import Base from amt.server import create_app from httpx import ASGITransport, AsyncClient -from playwright.sync_api import Browser, Page +from playwright.sync_api import Browser, BrowserContext, Page from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio.session import async_sessionmaker @@ -146,6 +146,18 @@ def browser_context_args(browser_context_args: dict[str, Any]) -> dict[str, Any] return {**browser_context_args, "base_url": "http://127.0.0.1:3462"} +@pytest.fixture +def page(context: BrowserContext) -> Page: + page = context.new_page() + do_e2e_login(page) + return page + + +@pytest.fixture +def page_no_login(context: BrowserContext) -> Page: + return context.new_page() + + @pytest.fixture(scope="session") def browser( launch_browser: Callable[[], Browser], diff --git a/tests/constants.py b/tests/constants.py index b423daff..3ba745ab 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -339,3 +339,7 @@ def default_systemcard_dic() -> dict[str, str | list[Any] | None]: "references": [], "models": [], } + + +def default_not_found_no_permission_msg() -> bytes: + return b"We couldn't find what you were looking for. This might be because:" diff --git a/tests/core/test_exceptions.py b/tests/core/test_exceptions.py index 8a101960..795ab03f 100644 --- a/tests/core/test_exceptions.py +++ b/tests/core/test_exceptions.py @@ -34,8 +34,9 @@ def test_RepositoryNoResultFound(): raise AMTNotFound() assert ( - exc_info.value.detail - == "The requested page or resource could not be found. Please check the URL or query and try again." + exc_info.value.detail == "The requested page or resource could not be found, " + "or you do not have the correct permissions to access it. " + "Please check the URL or query and try again." ) @@ -58,7 +59,11 @@ def test_AMTPermissionDenied(): with pytest.raises(AMTPermissionDenied) as exc_info: raise AMTPermissionDenied() - assert exc_info.value.detail == "You do not have the correct permissions to access this resource." + assert ( + exc_info.value.detail == "The requested page or resource could not be found, " + "or you do not have the correct permissions to access it. " + "Please check the URL or query and try again." + ) def test_AMTAuthorizationFlowError(): diff --git a/tests/database_e2e_setup.py b/tests/database_e2e_setup.py index 42619556..9942078e 100644 --- a/tests/database_e2e_setup.py +++ b/tests/database_e2e_setup.py @@ -1,6 +1,6 @@ from amt.api.lifecycles import Lifecycles from amt.enums.status import Status -from amt.models import Algorithm +from amt.models import Algorithm, Organization from sqlalchemy.ext.asyncio.session import AsyncSession from tests.constants import default_algorithm_with_lifecycle, default_task, default_user @@ -11,7 +11,10 @@ async def setup_database_e2e(session: AsyncSession) -> None: db_e2e = DatabaseTestUtils(session) await db_e2e.given([default_user()]) - await db_e2e.given([default_user(id="4738b1e151dc46219556a5662b26517c", name="Test User", organizations=[])]) + default_organization_db = (await db_e2e.get(Organization, "id", 1))[0] + await db_e2e.given( + [default_user(id="4738b1e151dc46219556a5662b26517c", name="Test User", organizations=[default_organization_db])] + ) algorithms: list[Algorithm] = [] for idx in range(120): diff --git a/tests/e2e/test_change_lang.py b/tests/e2e/test_change_lang.py index fcca6111..f63d62b2 100644 --- a/tests/e2e/test_change_lang.py +++ b/tests/e2e/test_change_lang.py @@ -3,7 +3,9 @@ @pytest.mark.slow -def test_e2e_change_language(page: Page): +def test_e2e_change_language(page_no_login: Page): + page = page_no_login + def get_lang_cookie(page: Page) -> Cookie | None: for cookie in page.context.cookies(): if "name" in cookie and cookie["name"] == "lang": diff --git a/tests/e2e/test_create_algorithm.py b/tests/e2e/test_create_algorithm.py index 0dd6dc28..a2b0602a 100644 --- a/tests/e2e/test_create_algorithm.py +++ b/tests/e2e/test_create_algorithm.py @@ -1,13 +1,9 @@ import pytest from playwright.sync_api import Page, expect -from tests.conftest import do_e2e_login - @pytest.mark.slow def test_e2e_create_algorithm(page: Page) -> None: - do_e2e_login(page) - page.goto("/algorithms/new") page.fill("#name", "My new algorithm") @@ -40,8 +36,6 @@ def test_e2e_create_algorithm(page: Page) -> None: @pytest.mark.slow def test_e2e_create_algorithm_invalid(page: Page): - do_e2e_login(page) - page.goto("/algorithms/new") page.locator("#transparency_obligations").select_option("geen transparantieverplichting") diff --git a/tests/e2e/test_create_organization.py b/tests/e2e/test_create_organization.py index 76d4a421..8f1b7222 100644 --- a/tests/e2e/test_create_organization.py +++ b/tests/e2e/test_create_organization.py @@ -1,13 +1,9 @@ import pytest from playwright.sync_api import Page, expect -from tests.conftest import do_e2e_login - @pytest.mark.slow def test_e2e_create_organization(page: Page) -> None: - do_e2e_login(page) - page.goto("/organizations/new") page.get_by_placeholder("Name of the organization").click() @@ -23,8 +19,6 @@ def test_e2e_create_organization(page: Page) -> None: @pytest.mark.slow def test_e2e_create_organization_error(page: Page) -> None: - do_e2e_login(page) - page.goto("/organizations/new") page.get_by_placeholder("Name of the organization").click() diff --git a/tests/repositories/test_authorizations.py b/tests/repositories/test_authorizations.py index 7488fef7..9a394713 100644 --- a/tests/repositories/test_authorizations.py +++ b/tests/repositories/test_authorizations.py @@ -28,12 +28,24 @@ async def test_authorization_basic(db: DatabaseTestUtils): authorization_repository = AuthorizationRepository(session=db.session) results = await authorization_repository.find_by_user(UUID(default_auth_user()["sub"])) + all_authorization_verbs: list[AuthorizationVerb] = [ + AuthorizationVerb.READ, + AuthorizationVerb.UPDATE, + AuthorizationVerb.CREATE, + AuthorizationVerb.LIST, + AuthorizationVerb.DELETE, + ] + # REMINDER: authorizations are generated, not taken from the database, + # see amt.repositories.authorizations.AuthorizationRepository.find_by_user assert results == [ + (AuthorizationResource.ALGORITHMS, all_authorization_verbs, AuthorizationType.ALGORITHM, "*"), + (AuthorizationResource.ALGORITHM, all_authorization_verbs, AuthorizationType.ALGORITHM, 1), + (AuthorizationResource.ORGANIZATIONS, all_authorization_verbs, AuthorizationType.ORGANIZATION, "*"), ( - AuthorizationResource.ORGANIZATION_INFO, - [AuthorizationVerb.CREATE, AuthorizationVerb.READ], + AuthorizationResource.ORGANIZATION_INFO_SLUG, + all_authorization_verbs, AuthorizationType.ORGANIZATION, - 1, - ) + "default-organization", + ), ] |