Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic authorization roles and enforce them. #472

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ jobs:

- name: Upload playwright tracing
if: failure()
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4.6.0
with:
name: playwright-${{ github.sha }}
path: test-results/
Expand All @@ -173,7 +173,7 @@ jobs:

- name: Upload code coverage report
if: matrix.python-version == '3.12'
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4.6.0
with:
name: codecoverage-${{ github.sha }}
path: htmlcov/
Expand Down Expand Up @@ -295,7 +295,7 @@ jobs:
TRIVY_PASSWORD: ${{ secrets.GITHUB_TOKEN }}

- name: Upload SBOM & License
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v4.6.0
with:
name: sbom-licence-${{ github.sha }}.json
path: |
Expand Down
8 changes: 3 additions & 5 deletions amt/api/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@

from fastapi import HTTPException, Request

from amt.api.utils import SafeDict
from amt.core.exceptions import AMTPermissionDenied


def add_permissions(permissions: dict[str, list[str]]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def permission(permissions: dict[str, list[str]]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
request = kwargs.get("request")
organization_id = kwargs.get("organization_id")
algoritme_id = kwargs.get("algoritme_id")
if not isinstance(request, Request): # todo: change exception to custom exception
raise HTTPException(status_code=400, detail="Request object is missing")

for permission, verbs in permissions.items():
permission = permission.format(organization_id=organization_id)
permission = permission.format(algoritme_id=algoritme_id)
permission = permission.format_map(SafeDict(kwargs))
request_permissions: dict[str, list[str]] = (
request.state.permissions if hasattr(request.state, "permissions") else {}
)
Expand Down
6 changes: 1 addition & 5 deletions amt/api/editable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from amt.api.editable_validators import EditableValidator, EditableValidatorMinMaxLength, EditableValidatorSlug
from amt.api.lifecycles import get_localized_lifecycles
from amt.api.routes.shared import UpdateFieldModel, nested_value
from amt.api.utils import SafeDict
from amt.core.exceptions import AMTNotFound
from amt.models import Algorithm, Organization
from amt.models.base import Base
Expand Down Expand Up @@ -198,11 +199,6 @@ def __iter__(self) -> Generator[tuple[str, Any], Any, Any]:
editables = Editables()


class SafeDict(dict[str, str | int]):
def __missing__(self, key: str) -> str:
return "{" + key + "}"


class EditModes(StrEnum):
EDIT = "EDIT"
VIEW = "VIEW"
Expand Down
34 changes: 26 additions & 8 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi.responses import FileResponse, HTMLResponse
from ulid import ULID

from amt.api.decorators import permission
from amt.api.deps import templates
from amt.api.editable import (
EditModes,
Expand All @@ -27,7 +28,7 @@
resolve_navigation_items,
)
from amt.api.routes.shared import UpdateFieldModel, get_filters_and_sort_by, replace_none_with_empty_string_inplace
from amt.core.authorization import get_user
from amt.core.authorization import AuthorizationResource, AuthorizationVerb, get_user
from amt.core.exceptions import AMTError, AMTNotFound, AMTRepositoryError
from amt.core.internationalization import get_current_translation
from amt.enums.status import Status
Expand Down Expand Up @@ -215,6 +216,7 @@ async def get_algorithm_context(


@router.get("/{algorithm_id}/details")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_algorithm_details(
request: Request, algorithm_id: int, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)]
) -> HTMLResponse:
Expand All @@ -236,6 +238,7 @@ async def get_algorithm_details(


@router.get("/{algorithm_id}/edit")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]})
async def get_algorithm_edit(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -271,6 +274,7 @@ async def get_algorithm_edit(


@router.get("/{algorithm_id}/cancel")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]})
async def get_algorithm_cancel(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -305,6 +309,7 @@ async def get_algorithm_cancel(


@router.put("/{algorithm_id}/update")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.UPDATE]})
async def get_algorithm_update(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -356,6 +361,7 @@ async def get_algorithm_update(


@router.get("/{algorithm_id}/details/system_card")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_system_card(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -398,6 +404,7 @@ async def get_system_card(


@router.get("/{algorithm_id}/details/system_card/compliance")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_system_card_requirements(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -453,7 +460,9 @@ async def get_system_card_requirements(
extended_linked_measures.append(ext_measure_task)
requirements_and_measures.append((requirement, completed_measures_count, extended_linked_measures)) # pyright: ignore [reportUnknownMemberType]

measure_task_functions = await get_measure_task_functions(measure_tasks, users_repository, sort_by, filters)
measure_task_functions: dict[str, list[User]] = await get_measure_task_functions(
measure_tasks, users_repository, sort_by, filters
)

context = {
"instrument_state": instrument_state,
Expand All @@ -473,7 +482,7 @@ async def _fetch_members(
users_repository: UsersRepository,
search_name: str,
sort_by: dict[str, str],
filters: dict[str, str],
filters: dict[str, str | list[str | int]],
) -> User | None:
members = await users_repository.find_all(search=search_name, sort=sort_by, filters=filters)
return members[0] if members else None
Expand All @@ -483,9 +492,9 @@ async def get_measure_task_functions(
measure_tasks: list[MeasureTask],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
sort_by: dict[str, str],
filters: dict[str, str],
) -> dict[str, list[Any]]:
measure_task_functions: dict[str, list[Any]] = defaultdict(list)
filters: dict[str, str | list[str | int]],
) -> dict[str, list[User]]:
measure_task_functions: dict[str, list[User]] = defaultdict(list)

for measure_task in measure_tasks:
person_types = ["accountable_persons", "reviewer_persons", "responsible_persons"]
Expand Down Expand Up @@ -533,6 +542,7 @@ async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure


@router.delete("/{algorithm_id}")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.DELETE]})
async def delete_algorithm(
request: Request,
algorithm_id: int,
Expand All @@ -543,6 +553,7 @@ async def delete_algorithm(


@router.get("/{algorithm_id}/measure/{measure_urn}")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_measure(
request: Request,
organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)],
Expand Down Expand Up @@ -598,7 +609,7 @@ async def get_users_from_function_name(
measure_responsible: Annotated[str | None, Form()],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
sort_by: dict[str, str],
filters: dict[str, str],
filters: dict[str, str | list[str | int]],
) -> tuple[list[Person], list[Person], list[Person]]:
accountable_persons, reviewer_persons, responsible_persons = [], [], []
if measure_accountable:
Expand All @@ -614,6 +625,7 @@ async def get_users_from_function_name(


@router.post("/{algorithm_id}/measure/{measure_urn}")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def update_measure_value(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -679,6 +691,7 @@ async def update_measure_value(


@router.get("/{algorithm_id}/members")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_algorithm_members(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -712,6 +725,7 @@ async def get_algorithm_members(


@router.get("/{algorithm_id}/details/system_card/assessments/{assessment_card}")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_assessment_card(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -765,6 +779,7 @@ async def get_assessment_card(


@router.get("/{algorithm_id}/details/system_card/models/{model_card}")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_model_card(
request: Request,
algorithm_id: int,
Expand Down Expand Up @@ -819,20 +834,22 @@ async def get_model_card(


@router.get("/{algorithm_id}/details/system_card/download")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def download_algorithm_system_card_as_yaml(
algorithm_id: int, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], request: Request
) -> FileResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
filename = algorithm.name + "_" + datetime.datetime.now(datetime.UTC).isoformat() + ".yaml"
with open(filename, "w") as outfile:
yaml.dump(algorithm.system_card.model_dump(), outfile)
yaml.dump(algorithm.system_card.model_dump(), outfile, sort_keys=False)
try:
return FileResponse(filename, filename=filename)
except AMTRepositoryError as e:
raise AMTNotFound from e


@router.get("/{algorithm_id}/file/{ulid}")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def get_file(
request: Request,
algorithm_id: int,
Expand All @@ -854,6 +871,7 @@ async def get_file(


@router.delete("/{algorithm_id}/file/{ulid}")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def delete_file(
request: Request,
algorithm_id: int,
Expand Down
23 changes: 18 additions & 5 deletions amt/api/routes/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from typing import Annotated, Any
from typing import Annotated, Any, cast

from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import HTMLResponse

from amt.api.ai_act_profile import get_ai_act_profile_selector
from amt.api.decorators import permission
from amt.api.deps import templates
from amt.api.forms.algorithm import get_algorithm_form
from amt.api.group_by_category import get_localized_group_by_categories
Expand All @@ -14,7 +15,7 @@
get_localized_risk_groups,
)
from amt.api.routes.shared import get_filters_and_sort_by
from amt.core.authorization import get_user
from amt.core.authorization import AuthorizationResource, AuthorizationVerb, get_user
from amt.core.exceptions import AMTAuthorizationError
from amt.core.internationalization import get_current_translation
from amt.models import Algorithm
Expand All @@ -29,16 +30,25 @@


@router.get("/")
@permission({AuthorizationResource.ALGORITHMS: [AuthorizationVerb.LIST]})
async def get_root(
request: Request,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)],
skip: int = Query(0, ge=0),
limit: int = Query(5000, ge=1), # todo: fix infinite scroll
search: str = Query(""),
display_type: str = Query(""),
) -> HTMLResponse:
filters, drop_filters, localized_filters, sort_by = get_filters_and_sort_by(request)

session_user = get_user(request)
user_id: str | None = session_user["sub"] if session_user else None # pyright: ignore[reportUnknownVariableType]

filters["organization-id"] = [
organization.id for organization in await organizations_service.get_organizations_for_user(user_id)
]

algorithms, amount_algorithm_systems = await get_algorithms(
algorithms_service, display_type, filters, limit, request, search, skip, sort_by
)
Expand Down Expand Up @@ -76,25 +86,26 @@ async def get_root(
async def get_algorithms(
algorithms_service: AlgorithmsService,
display_type: str,
filters: dict[str, str],
filters: dict[str, str | list[str | int]],
limit: int,
request: Request,
search: str,
skip: int,
sort_by: dict[str, str],
) -> tuple[dict[str, list[Algorithm]], int | Any]:
amount_algorithm_systems: int = 0

if display_type == "LIFECYCLE":
algorithms: dict[str, list[Algorithm]] = {}

# When the lifecycle filter is active, only show these algorithms
if "lifecycle" in filters:
for lifecycle in Lifecycles:
algorithms[lifecycle.name] = []
algorithms[filters["lifecycle"]] = await algorithms_service.paginate(
algorithms[cast(str, filters["lifecycle"])] = await algorithms_service.paginate(
skip=skip, limit=limit, search=search, filters=filters, sort=sort_by
)
amount_algorithm_systems += len(algorithms[filters["lifecycle"]])
amount_algorithm_systems += len(algorithms[cast(str, filters["lifecycle"])])
else:
for lifecycle in Lifecycles:
filters["lifecycle"] = lifecycle.name
Expand All @@ -114,6 +125,7 @@ async def get_algorithms(


@router.get("/new")
@permission({AuthorizationResource.ALGORITHMS: [AuthorizationVerb.CREATE]})
async def get_new(
request: Request,
instrument_service: Annotated[InstrumentsService, Depends(create_instrument_service)],
Expand Down Expand Up @@ -156,6 +168,7 @@ async def get_new(


@router.post("/new", response_class=HTMLResponse)
@permission({AuthorizationResource.ALGORITHMS: [AuthorizationVerb.CREATE]})
async def post_new(
request: Request,
algorithm_new: AlgorithmNew,
Expand Down
Loading
Loading