Skip to content

Commit

Permalink
Add 'confirmation' screen and update AI Act profile on confirmation
Browse files Browse the repository at this point in the history
  • Loading branch information
uittenbroekrobbert committed Mar 5, 2025
1 parent 2de9e3f commit df9112c
Show file tree
Hide file tree
Showing 43 changed files with 1,445 additions and 793 deletions.
68 changes: 10 additions & 58 deletions amt/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import logging
from collections.abc import Sequence
from enum import Enum
from os import PathLike
from pyclbr import Class
from typing import Any, AnyStr, TypeVar
from typing import TypeVar

from fastapi import Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from jinja2 import Environment, StrictUndefined, Undefined
from starlette.background import BackgroundTask
from starlette.templating import _TemplateResponse # pyright: ignore [reportPrivateUsage]
from jinja2 import StrictUndefined, Undefined
from starlette.requests import Request

from amt.api.editable import is_editable_resource, is_parent_editable
from amt.api.editable_util import replace_digits_in_brackets, resolve_resource_list_path
from amt.api.editable_util import (
is_editable_resource,
is_parent_editable,
replace_digits_in_brackets,
resolve_resource_list_path,
)
from amt.api.http_browser_caching import url_for_cache
from amt.api.localizable import LocalizableEnum
from amt.api.navigation import NavigationItem, get_main_menu
Expand All @@ -25,6 +23,7 @@
nested_enum_value,
nested_value,
)
from amt.api.template_classes import LocaleJinja2Templates
from amt.core.authorization import AuthorizationVerb, get_user
from amt.core.config import VERSION, get_settings
from amt.core.internationalization import (
Expand All @@ -33,8 +32,6 @@
get_current_translation,
get_dynamic_field_translations,
get_requested_language,
get_supported_translation,
get_translation,
supported_translations,
time_ago,
)
Expand Down Expand Up @@ -78,51 +75,6 @@ def permission(permission: str, verb: AuthorizationVerb, permissions: dict[str,
return authorized


# we use a custom override so we can add the translation per request, which is parsed in the Request object in kwargs
class LocaleJinja2Templates(Jinja2Templates):
def _create_env(
self,
directory: str | PathLike[AnyStr] | Sequence[str | PathLike[AnyStr]],
**env_options: Any, # noqa: ANN401
) -> Environment:
env: Environment = super()._create_env(directory, **env_options) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType, reportArgumentType]
env.add_extension("jinja2.ext.i18n") # pyright: ignore [reportUnknownMemberType]
return env # pyright: ignore [reportUnknownVariableType]

def TemplateResponse( # pyright: ignore [reportIncompatibleMethodOverride]
self,
request: Request,
name: str,
context: dict[str, Any] | None = None,
status_code: int = 200,
headers: dict[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse:
content_language = get_supported_translation(get_requested_language(request))
translations = get_translation(content_language)
if headers is None:
headers = {}
headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
if "Content-Language" not in headers:
headers["Content-Language"] = ",".join(supported_translations)
self.env.install_gettext_translations(translations, newstyle=True) # pyright: ignore [reportUnknownMemberType]

if context is None:
context = {}

if hasattr(request.state, "csrftoken"):
context["csrftoken"] = request.state.csrftoken
else:
context["csrftoken"] = ""

return super().TemplateResponse(request, name, context, status_code, headers, media_type, background)

def Redirect(self, request: Request, url: str) -> HTMLResponse:
headers = {"HX-Redirect": url}
return self.TemplateResponse(request, "redirect.html.j2", headers=headers)


def instance(obj: Class, type_string: str) -> bool:
match type_string:
case "str":
Expand Down
53 changes: 7 additions & 46 deletions amt/api/editable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@

from starlette.requests import Request

from amt.api.editable_classes import Editable, EditableType, EditModes, ResolvedEditable
from amt.api.editable_classes import Editable, EditableType, EditModes, FormState, ResolvedEditable
from amt.api.editable_converters import EditableConverterForOrganizationInAlgorithm, StatusConverterForSystemcard
from amt.api.editable_enforcers import EditableEnforcerForOrganizationInAlgorithm
from amt.api.editable_hooks import PreConfirmAIActHook, RedirectOrganizationHook, UpdateAIActHook
from amt.api.editable_util import (
extract_number_and_string,
replace_digits_in_brackets,
replace_wildcard_with_digits_in_brackets,
)
from amt.api.editable_validators import EditableValidatorMinMaxLength, EditableValidatorSlug
from amt.api.editable_value_providers import AIActValuesProvider
from amt.api.lifecycles import get_localized_lifecycles
from amt.api.routes.shared import nested_value
from amt.api.update_utils import extract_number_and_string, set_path
from amt.api.utils import SafeDict
from amt.core.exceptions import AMTNotFound
from amt.models import Algorithm, Organization
Expand Down Expand Up @@ -269,11 +269,13 @@ class Editables:
full_resource_path="organization/{organization_id}/slug",
implementation_type=WebFormFieldImplementationType.TEXT,
validator=EditableValidatorSlug(),
hooks={FormState.POST_SAVE: RedirectOrganizationHook()},
)

ALGORITHM_EDITABLE_AIACT = Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile",
implementation_type=WebFormFieldImplementationType.PARENT,
hooks={FormState.PRE_CONFIRM: PreConfirmAIActHook(), FormState.POST_SAVE: UpdateAIActHook()},
children=[
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/role",
Expand Down Expand Up @@ -481,6 +483,7 @@ def resolve_editable_path(
converter=editable.converter,
enforcer=editable.enforcer,
validator=editable.validator,
hooks=editable.hooks,
)

editables_resolved: list[ResolvedEditable] = []
Expand All @@ -494,6 +497,7 @@ def resolve_editable_path(
return {editable.full_resource_path: editable for editable in editables_resolved}


# TODO: this probably should be a method of ResolvedEditable
async def save_editable( # noqa: C901
editable: ResolvedEditable,
editable_context: dict[str, Any],
Expand All @@ -516,7 +520,6 @@ async def save_editable( # noqa: C901
new_value = editable_context.get("new_values", {}).get(editable.last_path_item())

# we validate on 'raw' form fields, so validation is done before the converter
# TODO: validate all fields (child and couples) before saving!
if editable.validator and editable.relative_resource_path is not None:
await editable.validator.validate(new_value, editable) # pyright: ignore[reportUnknownMemberType]

Expand Down Expand Up @@ -566,45 +569,3 @@ async def save_editable( # noqa: C901
raise AMTNotFound()

return editable


def set_path(obj: dict[str, Any] | object, path: str, value: typing.Any) -> None: # noqa: ANN401, C901
if not path:
raise ValueError("Path cannot be empty")

attrs = path.lstrip("/").split("/")
for attr in attrs[:-1]:
attr, index = extract_number_and_string(attr)
if isinstance(obj, dict):
obj = cast(dict[str, Any], obj)
if attr not in obj:
obj[attr] = {}
obj = obj[attr]
else:
if not hasattr(obj, attr): # pyright: ignore[reportUnknownArgumentType]
setattr(obj, attr, {}) # pyright: ignore[reportUnknownArgumentType]
obj = getattr(obj, attr) # pyright: ignore[reportUnknownArgumentType]
if obj and index is not None:
obj = cast(list[Any], obj)[index] # pyright: ignore[reportArgumentType, reportUnknownVariableType, reportUnknownArgumentType]

if isinstance(obj, dict):
obj[attrs[-1]] = value
else:
attr, index = extract_number_and_string(attrs[-1])
if index is not None:
cast(list[Any], getattr(obj, attr))[index] = value
else:
setattr(obj, attrs[-1], value)


def is_editable_resource(full_resource_path: str, editables: dict[str, ResolvedEditable]) -> bool:
return editables.get(replace_digits_in_brackets(full_resource_path), None) is not None


def is_parent_editable(editables: dict[str, ResolvedEditable], full_resource_path: str) -> bool:
full_resource_path = replace_digits_in_brackets(full_resource_path)
editable = editables.get(full_resource_path)
if editable is None:
return False
result = editable.implementation_type == WebFormFieldImplementationType.PARENT
return result
126 changes: 124 additions & 2 deletions amt/api/editable_classes.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,99 @@
import logging
import re
from enum import StrEnum
from typing import Any, Final
from abc import ABC, abstractmethod
from enum import Enum, StrEnum, auto
from typing import Any, Final, cast

from fastapi import Request
from starlette.responses import HTMLResponse

from amt.api.editable_converters import (
EditableConverter,
)
from amt.api.editable_enforcers import EditableEnforcer
from amt.api.editable_validators import EditableValidator
from amt.api.editable_value_providers import EditableValuesProvider
from amt.api.template_classes import LocaleJinja2Templates
from amt.models.base import Base
from amt.schema.webform import WebFormFieldImplementationTypeFields, WebFormOption

type EditableType = Editable
type FormStateType = FormState
type ResolvedEditableType = ResolvedEditable

logger = logging.getLogger(__name__)


class FormState(Enum):
"""
The FormState is used to streamline the form flow for
the inline editor. States can have a hook attacked to it, which is
registered in the Editable object.
"""

VALIDATE = auto()
PRE_CONFIRM = auto()
CONFIRM_SAVE = auto()
PRE_SAVE = auto()
SAVE = auto()
POST_SAVE = auto()
COMPLETED = auto()

@classmethod
def pre_save_states(cls) -> frozenset[FormStateType]:
return frozenset({cls.PRE_CONFIRM, cls.CONFIRM_SAVE, cls.PRE_SAVE})

@classmethod
def post_save_states(cls) -> frozenset[FormStateType]:
return frozenset({cls.POST_SAVE, cls.COMPLETED})

def is_before_save(self) -> bool:
return self in self.pre_save_states()

def is_validate(self) -> bool:
return self == self.VALIDATE

def is_after_save(self) -> bool:
return self in self.post_save_states()

def is_save(self) -> bool:
return self == self.SAVE

@classmethod
def get_next_state(cls, state: FormStateType) -> FormStateType:
if state.value >= cls.COMPLETED.value:
return cls.COMPLETED
next_state = cls(state.value + 1)
logger.info(f"FormState is moving to next state: {next_state}")
return next_state

@classmethod
def all_states_after(cls, state: FormStateType) -> list[FormStateType]:
return [s for s in cls if s.value > state.value]

@classmethod
def from_string(cls, state_name: str) -> FormStateType:
try:
return cast(FormState, cls[state_name])
except KeyError as e:
raise ValueError(f"Invalid state name: {state_name}") from e


class EditableHook(ABC):
"""
Hooks can be used to run a function at a specific moment in the FormState flow.
"""

@abstractmethod
async def execute(
self,
request: Request,
templates: LocaleJinja2Templates,
editable: ResolvedEditableType,
editable_context: dict[str, str | dict[str, str]],
) -> HTMLResponse | None:
pass


class Editable:
"""
Expand Down Expand Up @@ -42,6 +122,7 @@ def __init__(
converter: EditableConverter | None = None,
enforcer: EditableEnforcer | None = None,
validator: EditableValidator | None = None,
hooks: dict[FormState, EditableHook] | None = None,
# TODO: determine if relative resource path is really required for editable
relative_resource_path: str | None = None,
) -> None:
Expand All @@ -54,6 +135,7 @@ def __init__(
self.enforcer = enforcer
self.validator = validator
self.relative_resource_path = relative_resource_path
self.hooks = hooks

def add_bidirectional_couple(self, target: EditableType) -> None:
"""
Expand All @@ -74,6 +156,12 @@ def add_child(self, target: EditableType) -> None:
"""
self.children.append(target)

def register_hook(self, state: FormState, hook: EditableHook) -> None:
if self.hooks is not None:
self.hooks[state] = hook
else:
raise ValueError("Cannot register hook because hooks is None")


class ResolvedEditable:
value: Any | None
Expand All @@ -97,6 +185,7 @@ def __init__(
value: str | None = None,
resource_object: Base | None = None,
relative_resource_path: str | None = None,
hooks: dict[FormState, EditableHook] | None = None,
) -> None:
self.full_resource_path = full_resource_path
self.implementation_type = implementation_type
Expand All @@ -110,6 +199,7 @@ def __init__(
self.value = value
self.resource_object = resource_object
self.relative_resource_path = relative_resource_path
self.hooks = hooks

def last_path_item(self) -> str:
return self.full_resource_path.split("/")[-1]
Expand All @@ -122,6 +212,38 @@ def safe_html_path(self) -> str:
return re.sub(r"[\[\]/*]", "_", self.relative_resource_path) # pyright: ignore[reportUnknownVariableType, reportCallIssue]
raise ValueError("Can not convert path to save html path as it is None")

def has_hook(self, state: FormState) -> bool:
if self.hooks is None:
return False
return state in self.hooks

def get_hook(self, state: FormState) -> EditableHook | None:
if self.hooks is None:
return None
return self.hooks.get(state)

async def run_hook(
self,
state: FormState,
request: Request,
templates: LocaleJinja2Templates,
editable: ResolvedEditableType,
editable_context: dict[str, str | dict[str, str]],
) -> HTMLResponse | None:
if self.hooks is not None and self.has_hook(state):
logger.info(f"Running hook for state {state} for editable {self.full_resource_path}")
return await self.hooks[state].execute(request, templates, editable, editable_context)
return None

async def validate(self, editable_context: dict[str, Any]) -> None:
editables_to_validate = list(self.couples or set()) + (self.children or []) + [self]
for editable in editables_to_validate:
new_value = editable_context.get("new_values", {}).get(editable.last_path_item())
if editable.validator and editable.relative_resource_path is not None:
await editable.validator.validate(new_value, editable) # pyright: ignore[reportUnknownMemberType]
if editable.enforcer:
await editable.enforcer.enforce(**editable_context)


class EditModes(StrEnum):
EDIT = "EDIT"
Expand Down
Loading

0 comments on commit df9112c

Please sign in to comment.