Skip to content

Commit

Permalink
Fix requirements widget and speed up dealing with requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
uittenbroekrobbert committed Jan 20, 2025
1 parent f09e230 commit c739f74
Show file tree
Hide file tree
Showing 25 changed files with 246 additions and 233 deletions.
19 changes: 14 additions & 5 deletions amt/api/forms/measure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from collections.abc import Sequence
from enum import StrEnum
from gettext import NullTranslations

from amt.models import User
from amt.schema.webform import WebForm, WebFormField, WebFormFieldType, WebFormOption, WebFormTextCloneableField


class MeasureStatusOptions(StrEnum):
TODO = "to do"
IN_PROGRESS = "in progress"
IN_REVIEW = "in review"
DONE = "done"
NOT_IMPLEMENTED = "not implemented"


async def get_measure_form(
id: str,
current_values: dict[str, str | list[str] | list[tuple[str, str]]],
Expand Down Expand Up @@ -47,11 +56,11 @@ async def get_measure_form(
name="measure_state",
label=_("Status"),
options=[
WebFormOption(value="to do", display_value="to do"),
WebFormOption(value="in progress", display_value="in progress"),
WebFormOption(value="in review", display_value="in review"),
WebFormOption(value="done", display_value="done"),
WebFormOption(value="not implemented", display_value="not implemented"),
WebFormOption(value=MeasureStatusOptions.TODO, display_value="to do"),
WebFormOption(value=MeasureStatusOptions.IN_PROGRESS, display_value="in progress"),
WebFormOption(value=MeasureStatusOptions.IN_REVIEW, display_value="in review"),
WebFormOption(value=MeasureStatusOptions.DONE, display_value="done"),
WebFormOption(value=MeasureStatusOptions.NOT_IMPLEMENTED, display_value="not implemented"),
],
default_value=current_values.get("measure_state"),
group="1",
Expand Down
75 changes: 50 additions & 25 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import datetime
import logging
import urllib.parse
from collections import defaultdict
from collections.abc import Sequence
from typing import Annotated, Any

import yaml
from fastapi import APIRouter, Depends, File, Form, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from ulid import ULID

from amt.api.decorators import permission
Expand All @@ -19,7 +20,7 @@
get_resolved_editables,
save_editable,
)
from amt.api.forms.measure import get_measure_form
from amt.api.forms.measure import MeasureStatusOptions, get_measure_form
from amt.api.navigation import (
BaseNavigationItem,
Navigation,
Expand All @@ -42,10 +43,10 @@
from amt.schema.task import MovedTask
from amt.services.algorithms import AlgorithmsService
from amt.services.instruments_and_requirements_state import InstrumentStateService, RequirementsStateService
from amt.services.measures import MeasuresService, create_measures_service
from amt.services.measures import measures_service
from amt.services.object_storage import ObjectStorageService, create_object_storage_service
from amt.services.organizations import OrganizationsService
from amt.services.requirements import RequirementsService, create_requirements_service
from amt.services.requirements import requirements_service
from amt.services.tasks import TasksService

router = APIRouter()
Expand All @@ -63,7 +64,6 @@ async def get_instrument_state(system_card: SystemCard) -> dict[str, Any]:


async def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
requirements_service = create_requirements_service()
requirements = await requirements_service.fetch_requirements(
[requirement.urn for requirement in system_card.requirements]
)
Expand Down Expand Up @@ -411,8 +411,6 @@ async def get_system_card_requirements(
organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
measures_service: Annotated[MeasuresService, Depends(create_measures_service)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = await get_instrument_state(algorithm.system_card)
Expand Down Expand Up @@ -497,7 +495,7 @@ async def get_measure_task_functions(
measure_task_functions: dict[str, list[User]] = defaultdict(list)

for measure_task in measure_tasks:
person_types = ["accountable_persons", "reviewer_persons", "responsible_persons"]
person_types = ["responsible_persons", "reviewer_persons", "accountable_persons"]
for person_type in person_types:
person_list = getattr(measure_task, person_type)
if person_list:
Expand Down Expand Up @@ -527,8 +525,6 @@ async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure
requirement_mapper[requirement_task.urn] = requirement_task

requirement_tasks: list[RequirementTask] = []
measures_service = create_measures_service()
requirements_service = create_requirements_service()
measure = await measures_service.fetch_measures(measure_urn)
for requirement_urn in measure[0].links:
# TODO: This is because measure are linked to too many requirement not applicable in our use case
Expand Down Expand Up @@ -561,9 +557,9 @@ async def get_measure(
algorithm_id: int,
measure_urn: str,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
measures_service: Annotated[MeasuresService, Depends(create_measures_service)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
search: str = Query(""),
requirement_urn: str = "",
) -> HTMLResponse:
filters, _, _, sort_by = get_filters_and_sort_by(request)
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
Expand Down Expand Up @@ -598,7 +594,12 @@ async def get_measure(
translations=get_current_translation(request),
)

context = {"measure": measure[0], "algorithm_id": algorithm_id, "form": measure_form}
context = {
"measure": measure[0],
"algorithm_id": algorithm_id,
"form": measure_form,
"requirement_urn": requirement_urn,
}

return templates.TemplateResponse(request, "algorithms/details_measure_modal.html.j2", context)

Expand Down Expand Up @@ -633,7 +634,6 @@ async def update_measure_value(
organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
measure_state: Annotated[str, Form()],
measure_responsible: Annotated[str | None, Form()] = None,
Expand All @@ -642,6 +642,7 @@ async def update_measure_value(
measure_value: Annotated[str | None, Form()] = None,
measure_links: Annotated[list[str] | None, Form()] = None,
measure_files: Annotated[list[UploadFile] | None, File()] = None,
requirement_urn: str = "",
) -> HTMLResponse:
filters, _, _, sort_by = get_filters_and_sort_by(request)
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
Expand Down Expand Up @@ -670,24 +671,48 @@ async def update_measure_value(
requirement_urns = [requirement_task.urn for requirement_task in requirement_tasks]
requirements = await requirements_service.fetch_requirements(requirement_urns)

state_order_list = set(MeasureStatusOptions)
for requirement in requirements:
count_completed = 0
state_count: dict[str, int] = {}
for link_measure_urn in requirement.links:
link_measure_task = find_measure_task(algorithm.system_card, link_measure_urn)
if link_measure_task: # noqa: SIM102
if link_measure_task.state == "done":
count_completed += 1
if link_measure_task:
state_count[link_measure_task.state] = state_count.get(link_measure_task.state, 0) + 1
requirement_task = find_requirement_task(algorithm.system_card, requirement.urn)
if count_completed == len(requirement.links):
requirement_task.state = "done" # pyright: ignore [reportOptionalMemberAccess]
elif count_completed == 0 and len(requirement.links) > 0:
requirement_task.state = "to do" # pyright: ignore [reportOptionalMemberAccess]
else:
requirement_task.state = "in progress" # pyright: ignore [reportOptionalMemberAccess]
full_match = False
for state in state_order_list:
# if all measures are in the same state, the requirement is set to that state
if requirement_task and state_count.get(state, 0) == len(requirement.links):
requirement_task.state = state
full_match = True
break
# a requirement is considered 'in progress' if any measure is of any state other than todo
if requirement_task and not full_match and len([key for key in state_count if key != MeasureStatusOptions.TODO]) > 0:
requirement_task.state = MeasureStatusOptions.IN_PROGRESS

await algorithms_service.update(algorithm)
# TODO: FIX THIS!! The page now reloads at the top, which is annoying
return templates.Redirect(request, f"/algorithm/{algorithm_id}/details/system_card/compliance")

# the redirect 'to same page' does not trigger a javascript reload, so we let us redirect by a different server URL
encoded_url = urllib.parse.quote_plus(
f"/algorithm/{algorithm_id}/details/system_card/compliance#{requirement_urn.replace(":","_")}"
)
return templates.Redirect(
request,
f"/algorithm/{algorithm_id}/redirect?to={encoded_url}",
)


@router.get("/{algorithm_id}/redirect")
@permission({AuthorizationResource.ALGORITHM: [AuthorizationVerb.READ]})
async def redirect_to(request: Request, algorithm_id: str, to: str) -> RedirectResponse:
"""
Redirects to the requested URL. We only have and use this because HTMX and javascript redirects do
not work when redirecting to the same URL, even if query params are changed.
"""
return RedirectResponse(
status_code=302,
url=to,
)


@router.get("/{algorithm_id}/members")
Expand Down
6 changes: 1 addition & 5 deletions amt/api/routes/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from amt.schema.algorithm import AlgorithmNew
from amt.schema.webform import WebForm
from amt.services.algorithms import AlgorithmsService, get_template_files
from amt.services.instruments import InstrumentsService, create_instrument_service
from amt.services.organizations import OrganizationsService

router = APIRouter()
Expand Down Expand Up @@ -128,7 +127,6 @@ async def get_algorithms(
@permission({AuthorizationResource.ALGORITHMS: [AuthorizationVerb.CREATE]})
async def get_new(
request: Request,
instrument_service: Annotated[InstrumentsService, Depends(create_instrument_service)],
organizations_service: Annotated[OrganizationsService, Depends(OrganizationsService)],
organization_id: int = Query(None),
) -> HTMLResponse:
Expand All @@ -151,10 +149,8 @@ async def get_new(

template_files = get_template_files()

instruments = await instrument_service.fetch_instruments()

context: dict[str, Any] = {
"instruments": instruments,
"instruments": [],
"ai_act_profile": ai_act_profile,
"breadcrumbs": breadcrumbs,
"sub_menu_items": {}, # sub_menu_items disabled for now,
Expand Down
7 changes: 6 additions & 1 deletion amt/api/routes/organizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ async def root(
search=search, sort=sort_by, filters=filters, user_id=user["sub"] if user else None
)

# we only can show organization you belong to, so the all organizations option is disabled
organization_filters = [
f for f in get_localized_organization_filters(request) if f and f.value != OrganizationFilterOptions.ALL.value
]

context: dict[str, Any] = {
"breadcrumbs": breadcrumbs,
"organizations": organizations,
Expand All @@ -123,7 +128,7 @@ async def root(
"organizations_length": len(organizations),
"filters": localized_filters,
"include_filters": False,
"organization_filters": get_localized_organization_filters(request),
"organization_filters": organization_filters,
}

if request.state.htmx:
Expand Down
3 changes: 1 addition & 2 deletions amt/cli/check_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from amt.schema.instrument import Instrument
from amt.schema.system_card import SystemCard
from amt.services.instruments import create_instrument_service
from amt.services.instruments import instruments_service
from amt.services.instruments_and_requirements_state import all_lifecycles, get_all_next_tasks
from amt.services.storage import StorageFactory

Expand All @@ -30,7 +30,6 @@ def get_requested_instruments(all_instruments: list[Instrument], urns: list[str]
def get_tasks_by_priority(urns: list[str], system_card_path: Path) -> None:
try:
system_card = get_system_card(system_card_path)
instruments_service = create_instrument_service()
all_instruments = asyncio.run(instruments_service.fetch_instruments())
instruments = get_requested_instruments(all_instruments, urns)
next_tasks = get_all_next_tasks(instruments, system_card)
Expand Down
Loading

0 comments on commit c739f74

Please sign in to comment.