Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AI-Act profile to editable fields #504

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 21 additions & 58 deletions amt/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import logging
from collections.abc import Sequence
from enum import Enum
from os import PathLike
from pyclbr import Class
from typing import Any, AnyStr, TypeVar
from typing import TypeVar

from fastapi import Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from jinja2 import Environment, StrictUndefined, Undefined
from starlette.background import BackgroundTask
from starlette.templating import _TemplateResponse # pyright: ignore [reportPrivateUsage]
from jinja2 import StrictUndefined, Undefined
from starlette.requests import Request

from amt.api.editable import is_editable_resource, is_parent_editable
from amt.api.editable_util import replace_digits_in_brackets, resolve_resource_list_path
from amt.api.editable_util import (
is_editable_resource,
is_parent_editable,
replace_digits_in_brackets,
resolve_resource_list_path,
)
from amt.api.http_browser_caching import url_for_cache
from amt.api.localizable import LocalizableEnum
from amt.api.navigation import NavigationItem, get_main_menu
Expand All @@ -25,6 +23,7 @@
nested_enum_value,
nested_value,
)
from amt.api.template_classes import LocaleJinja2Templates
from amt.core.authorization import AuthorizationVerb, get_user
from amt.core.config import VERSION, get_settings
from amt.core.internationalization import (
Expand All @@ -33,8 +32,6 @@
get_current_translation,
get_dynamic_field_translations,
get_requested_language,
get_supported_translation,
get_translation,
supported_translations,
time_ago,
)
Expand Down Expand Up @@ -78,51 +75,6 @@ def permission(permission: str, verb: AuthorizationVerb, permissions: dict[str,
return authorized


# we use a custom override so we can add the translation per request, which is parsed in the Request object in kwargs
class LocaleJinja2Templates(Jinja2Templates):
def _create_env(
self,
directory: str | PathLike[AnyStr] | Sequence[str | PathLike[AnyStr]],
**env_options: Any, # noqa: ANN401
) -> Environment:
env: Environment = super()._create_env(directory, **env_options) # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType, reportArgumentType]
env.add_extension("jinja2.ext.i18n") # pyright: ignore [reportUnknownMemberType]
return env # pyright: ignore [reportUnknownVariableType]

def TemplateResponse( # pyright: ignore [reportIncompatibleMethodOverride]
self,
request: Request,
name: str,
context: dict[str, Any] | None = None,
status_code: int = 200,
headers: dict[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse:
content_language = get_supported_translation(get_requested_language(request))
translations = get_translation(content_language)
if headers is None:
headers = {}
headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
if "Content-Language" not in headers:
headers["Content-Language"] = ",".join(supported_translations)
self.env.install_gettext_translations(translations, newstyle=True) # pyright: ignore [reportUnknownMemberType]

if context is None:
context = {}

if hasattr(request.state, "csrftoken"):
context["csrftoken"] = request.state.csrftoken
else:
context["csrftoken"] = ""

return super().TemplateResponse(request, name, context, status_code, headers, media_type, background)

def Redirect(self, request: Request, url: str) -> HTMLResponse:
headers = {"HX-Redirect": url}
return self.TemplateResponse(request, "redirect.html.j2", headers=headers)


def instance(obj: Class, type_string: str) -> bool:
match type_string:
case "str":
Expand Down Expand Up @@ -152,6 +104,15 @@ def hasattr_jinja(obj: object, attributes: str) -> bool:
return True


def equal_or_includes(my_value: str, check_against_value: str | list[str] | tuple[str]) -> bool:
"""Test if my_value equals or exists in check_against_value"""
if isinstance(check_against_value, list | tuple):
return my_value in check_against_value
elif isinstance(check_against_value, str):
return my_value == check_against_value
return False


templates = LocaleJinja2Templates(
directory="amt/site/templates/", context_processors=[custom_context_processor], undefined=get_undefined_behaviour()
)
Expand All @@ -172,5 +133,7 @@ def hasattr_jinja(obj: object, attributes: str) -> bool:
templates.env.globals.update(is_parent_editable=is_parent_editable) # pyright: ignore [reportUnknownMemberType]
templates.env.globals.update(resolve_resource_list_path=resolve_resource_list_path) # pyright: ignore [reportUnknownMemberType]
templates.env.globals.update(get_localized_value=get_localized_value) # pyright: ignore [reportUnknownMemberType]
# env tests allows for usage in templates like: if value is test_name(other_value)
templates.env.tests["permission"] = permission # pyright: ignore [reportUnknownMemberType]
templates.env.tests["equal_or_includes"] = equal_or_includes # pyright: ignore [reportUnknownMemberType]
templates.env.add_extension("jinja2_base64_filters.Base64Filters") # pyright: ignore [reportUnknownMemberType]
106 changes: 57 additions & 49 deletions amt/api/editable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@

from starlette.requests import Request

from amt.api.editable_classes import Editable, EditableType, EditModes, ResolvedEditable
from amt.api.editable_classes import Editable, EditableType, EditModes, FormState, ResolvedEditable
from amt.api.editable_converters import EditableConverterForOrganizationInAlgorithm, StatusConverterForSystemcard
from amt.api.editable_enforcers import EditableEnforcerForOrganizationInAlgorithm
from amt.api.editable_hooks import PreConfirmAIActHook, RedirectOrganizationHook, UpdateAIActHook
from amt.api.editable_util import (
extract_number_and_string,
replace_digits_in_brackets,
replace_wildcard_with_digits_in_brackets,
)
from amt.api.editable_validators import EditableValidatorMinMaxLength, EditableValidatorSlug
from amt.api.editable_value_providers import AIActValuesProvider
from amt.api.lifecycles import get_localized_lifecycles
from amt.api.routes.shared import nested_value
from amt.api.update_utils import extract_number_and_string, set_path
from amt.api.utils import SafeDict
from amt.core.exceptions import AMTNotFound
from amt.models import Algorithm, Organization
Expand Down Expand Up @@ -268,6 +269,50 @@ class Editables:
full_resource_path="organization/{organization_id}/slug",
implementation_type=WebFormFieldImplementationType.TEXT,
validator=EditableValidatorSlug(),
hooks={FormState.POST_SAVE: RedirectOrganizationHook()},
)

ALGORITHM_EDITABLE_AIACT = Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile",
implementation_type=WebFormFieldImplementationType.PARENT,
hooks={FormState.PRE_CONFIRM: PreConfirmAIActHook(), FormState.POST_SAVE: UpdateAIActHook()},
children=[
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/role",
implementation_type=WebFormFieldImplementationType.MULTIPLE_CHECKBOX_AI_ACT,
values_provider=AIActValuesProvider(type="role"),
),
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/type",
implementation_type=WebFormFieldImplementationType.SELECT_AI_ACT,
values_provider=AIActValuesProvider(type="type"),
),
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/open_source",
implementation_type=WebFormFieldImplementationType.SELECT_AI_ACT,
values_provider=AIActValuesProvider(type="open_source"),
),
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/risk_group",
implementation_type=WebFormFieldImplementationType.SELECT_AI_ACT,
values_provider=AIActValuesProvider(type="risk_group"),
),
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/conformity_assessment_body",
implementation_type=WebFormFieldImplementationType.SELECT_AI_ACT,
values_provider=AIActValuesProvider(type="conformity_assessment_body"),
),
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/systemic_risk",
implementation_type=WebFormFieldImplementationType.SELECT_AI_ACT,
values_provider=AIActValuesProvider(type="systemic_risk"),
),
Editable(
full_resource_path="algorithm/{algorithm_id}/system_card/ai_act_profile/transparency_obligations",
implementation_type=WebFormFieldImplementationType.SELECT_AI_ACT,
values_provider=AIActValuesProvider(type="transparency_obligations"),
),
],
)

# TODO: rethink if this is a wise solution.. we do this to keep all elements in 1 class and still
Expand Down Expand Up @@ -375,8 +420,13 @@ async def enrich_editable( # noqa: C901
)

# TODO: can we move this to the editable object instead of here?
# TODO: consider if values_providers could solve & replace the specific conditions below
if edit_mode == EditModes.EDIT:
if editable.implementation_type == WebFormFieldImplementationType.SELECT_MY_ORGANIZATIONS:
if editable.values_provider:
if request is None:
raise ValueError("Request is required when resolving a 'editable values provider'")
editable.form_options = await editable.values_provider.get_values(request)
elif editable.implementation_type == WebFormFieldImplementationType.SELECT_MY_ORGANIZATIONS:
if organizations_service is None:
raise ValueError("Organization service is required when resolving an organization")
my_organizations = await organizations_service.get_organizations_for_user(user_id=user_id)
Expand Down Expand Up @@ -427,11 +477,13 @@ def resolve_editable_path(
full_resource_path=full_resource_path,
relative_resource_path=relative_resource_path,
implementation_type=editable.implementation_type,
values_provider=editable.values_provider,
couples=couples,
children=children,
converter=editable.converter,
enforcer=editable.enforcer,
validator=editable.validator,
hooks=editable.hooks,
)

editables_resolved: list[ResolvedEditable] = []
Expand All @@ -445,6 +497,7 @@ def resolve_editable_path(
return {editable.full_resource_path: editable for editable in editables_resolved}


# TODO: this probably should be a method of ResolvedEditable
async def save_editable( # noqa: C901
editable: ResolvedEditable,
editable_context: dict[str, Any],
Expand All @@ -467,7 +520,6 @@ async def save_editable( # noqa: C901
new_value = editable_context.get("new_values", {}).get(editable.last_path_item())

# we validate on 'raw' form fields, so validation is done before the converter
# TODO: validate all fields (child and couples) before saving!
if editable.validator and editable.relative_resource_path is not None:
await editable.validator.validate(new_value, editable) # pyright: ignore[reportUnknownMemberType]

Expand Down Expand Up @@ -517,47 +569,3 @@ async def save_editable( # noqa: C901
raise AMTNotFound()

return editable


def set_path(obj: dict[str, Any] | object, path: str, value: typing.Any) -> None: # noqa: ANN401, C901
if not path:
raise ValueError("Path cannot be empty")

attrs = path.lstrip("/").split("/")
for attr in attrs[:-1]:
attr, index = extract_number_and_string(attr)
if isinstance(obj, dict):
obj = cast(dict[str, Any], obj)
if attr not in obj:
obj[attr] = {}
obj = obj[attr]
else:
if not hasattr(obj, attr): # pyright: ignore[reportUnknownArgumentType]
setattr(obj, attr, {}) # pyright: ignore[reportUnknownArgumentType]
obj = getattr(obj, attr) # pyright: ignore[reportUnknownArgumentType]
if obj and index is not None:
obj = cast(list[Any], obj)[index] # pyright: ignore[reportArgumentType, reportUnknownVariableType, reportUnknownArgumentType]

if isinstance(obj, dict):
obj[attrs[-1]] = value
else:
attr, index = extract_number_and_string(attrs[-1])
if index is not None:
cast(list[Any], getattr(obj, attr))[index] = value
else:
setattr(obj, attrs[-1], value)


def is_editable_resource(full_resource_path: str, editables: dict[str, ResolvedEditable]) -> bool:
return editables.get(replace_digits_in_brackets(full_resource_path), None) is not None


def is_parent_editable(editables: dict[str, ResolvedEditable], full_resource_path: str) -> bool:
full_resource_path = replace_digits_in_brackets(full_resource_path)
editable = editables.get(full_resource_path)
if editable is None:
print(full_resource_path + " : " + "false, no match")
return False
result = editable.implementation_type == WebFormFieldImplementationType.PARENT
print(full_resource_path + " : " + str(result))
return result
Loading
Loading