diff --git a/aana/alembic/versions/acb40dabc2c0_added_webhooks.py b/aana/alembic/versions/acb40dabc2c0_added_webhooks.py new file mode 100644 index 00000000..2958f63f --- /dev/null +++ b/aana/alembic/versions/acb40dabc2c0_added_webhooks.py @@ -0,0 +1,43 @@ +"""Added webhooks. + +Revision ID: acb40dabc2c0 +Revises: d40eba8ebc4c +Create Date: 2025-01-30 14:32:16.596842 + +""" +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "acb40dabc2c0" +down_revision: str | None = "d40eba8ebc4c" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade database to this revision from previous.""" + # fmt: off + op.create_table('webhooks', + sa.Column('id', sa.UUID(), nullable=False, comment='Webhook ID'), + sa.Column('user_id', sa.String(), nullable=True, comment='The user ID associated with the webhook'), + sa.Column('url', sa.String(), nullable=False, comment='The URL to which the webhook will send requests'), + sa.Column('events', sa.JSON().with_variant(postgresql.JSONB(astext_type=sa.Text()), 'postgresql'), nullable=False, comment='List of events the webhook is subscribed to. If None, the webhook is subscribed to all events.'), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is inserted'), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is updated'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_webhooks')) + ) + with op.batch_alter_table('webhooks', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_webhooks_user_id'), ['user_id'], unique=False) + # fmt: on + + +def downgrade() -> None: + """Downgrade database from this revision to previous.""" + with op.batch_alter_table("webhooks", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_webhooks_user_id")) + + op.drop_table("webhooks") diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index 413b0be6..70c3745c 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -8,7 +8,6 @@ from inspect import isasyncgenfunction from typing import Annotated, Any, get_origin -import orjson from fastapi import FastAPI, Form, Query, Request from fastapi.responses import StreamingResponse from pydantic import ConfigDict, Field, ValidationError, create_model @@ -19,7 +18,7 @@ from aana.api.event_handlers.event_manager import EventManager from aana.api.exception_handler import custom_exception_handler from aana.api.responses import AanaJSONResponse -from aana.api.security import check_admin_permissions +from aana.api.security import require_admin_access from aana.configs.settings import settings as aana_settings from aana.core.models.api_service import ApiKey from aana.core.models.exception import ExceptionResponseModel @@ -270,11 +269,11 @@ async def route_func_body( # noqa: C901 if event_manager: event_manager.handle(bound_path, defer=defer) - # Parse json data from the body - body = orjson.loads(body) + # parse form data as a pydantic model and validate it + data = RequestModel.model_validate_json(body) # Add api_key_info to the body if API service is enabled - api_key_info: dict = {} + api_key_info: ApiKey | None = None if aana_settings.api_service.enabled: api_key_info = request.state.api_key_info api_key_field = next( @@ -286,10 +285,7 @@ async def route_func_body( # noqa: C901 None, ) if api_key_field: - body[api_key_field] = api_key_info - - # parse form data as a pydantic model and validate it - data = RequestModel.model_validate(body) + setattr(data, api_key_field, api_key_info) # if the input requires file upload, add the files to the data if files: @@ -314,7 +310,7 @@ async def route_func_body( # noqa: C901 task = TaskRepository(session).save( endpoint=bound_path, data=data_dict, - user_id=api_key_info.get("user_id"), + user_id=api_key_info.user_id if api_key_info else None, ) return AanaJSONResponse(content={"task_id": str(task.id)}) @@ -349,7 +345,7 @@ async def route_func( ), ): if aana_settings.api_service.enabled and self.admin_required: - check_admin_permissions(request) + require_admin_access(request) if self.defer_option == DeferOption.ALWAYS: defer = True diff --git a/aana/api/app.py b/aana/api/app.py index 2e300a06..ad006a08 100644 --- a/aana/api/app.py +++ b/aana/api/app.py @@ -1,4 +1,5 @@ from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from aana.api.exception_handler import ( @@ -18,6 +19,7 @@ app = FastAPI() app.add_exception_handler(ValidationError, validation_exception_handler) +app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(Exception, aana_exception_handler) @@ -48,7 +50,7 @@ async def api_key_check(request: Request, call_next): if not api_key_info.is_subscription_active: raise InactiveSubscription(key=api_key) - request.state.api_key_info = api_key_info.to_dict() + request.state.api_key_info = api_key_info.to_model() response = await call_next(request) return response diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index ba318d8c..8ca942ff 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -1,22 +1,26 @@ import json import time -from typing import Annotated, Any +from typing import Any from uuid import UUID, uuid4 import orjson import ray -from fastapi import Depends +from fastapi import APIRouter from fastapi.openapi.utils import get_openapi from fastapi.responses import StreamingResponse from ray import serve -from sqlalchemy.orm import Session from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema from aana.api.app import app from aana.api.event_handlers.event_manager import EventManager from aana.api.exception_handler import custom_exception_handler from aana.api.responses import AanaJSONResponse -from aana.api.security import AdminRequired +from aana.api.security import AdminAccessDependency +from aana.api.webhook import ( + WebhookEventType, + trigger_task_webhooks, +) +from aana.api.webhook import router as webhook_router from aana.configs.settings import settings as aana_settings from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog @@ -25,16 +29,7 @@ from aana.deployments.aana_deployment_handle import AanaDeploymentHandle from aana.storage.models.task import Status as TaskStatus from aana.storage.repository.task import TaskRepository -from aana.storage.session import get_session - - -def get_db(): - """Get a database session.""" - db = get_session() - try: - yield db - finally: - db.close() +from aana.storage.session import GetDbDependency, get_session @serve.deployment(ray_actor_options={"num_cpus": 0.1}) @@ -45,7 +40,11 @@ class RequestHandler: ready = False def __init__( - self, app_name: str, endpoints: list[Endpoint], deployments: list[str] + self, + app_name: str, + endpoints: list[Endpoint], + deployments: list[str], + routers: list[APIRouter] | None = None, ): """Constructor. @@ -53,11 +52,19 @@ def __init__( app_name (str): The name of the application. endpoints (dict): List of endpoints for the request. deployments (list[str]): List of deployment names for the app. + routers (list[APIRouter]): List of FastAPI routers to include in the app. """ self.app_name = app_name self.endpoints = endpoints self.deployments = deployments + # Include the webhook router + app.include_router(webhook_router) + # Include the custom routers + if routers is not None: + for router in routers: + app.include_router(router) + self.event_manager = EventManager() self.custom_schemas: dict[str, dict] = {} for endpoint in self.endpoints: @@ -121,7 +128,8 @@ async def execute_task(self, task_id: str | UUID) -> Any: path = task.endpoint kwargs = task.data - task_repo.update_status(task_id, TaskStatus.RUNNING, 0) + task = task_repo.update_status(task_id, TaskStatus.RUNNING, 0) + await trigger_task_webhooks(WebhookEventType.TASK_STARTED, task) for e in self.endpoints: if e.path == path: @@ -139,16 +147,18 @@ async def execute_task(self, task_id: str | UUID) -> Any: out = await endpoint.run(**kwargs) with get_session() as session: - TaskRepository(session).update_status( + task = TaskRepository(session).update_status( task_id, TaskStatus.COMPLETED, 100, out ) + await trigger_task_webhooks(WebhookEventType.TASK_COMPLETED, task) except Exception as e: error_response = custom_exception_handler(None, e) error = orjson.loads(error_response.body) with get_session() as session: - TaskRepository(session).update_status( + task = TaskRepository(session).update_status( task_id, TaskStatus.FAILED, 0, error ) + await trigger_task_webhooks(WebhookEventType.TASK_FAILED, task) else: return out finally: @@ -160,9 +170,7 @@ async def execute_task(self, task_id: str | UUID) -> Any: description="Get the task status by task ID.", include_in_schema=aana_settings.task_queue.enabled, ) - async def get_task_endpoint( - self, task_id: str, db: Annotated[Session, Depends(get_db)] - ) -> TaskInfo: + async def get_task_endpoint(self, task_id: str, db: GetDbDependency) -> TaskInfo: """Get the task with the given ID. Args: @@ -186,9 +194,7 @@ async def get_task_endpoint( description="Delete the task by task ID.", include_in_schema=aana_settings.task_queue.enabled, ) - async def delete_task_endpoint( - self, task_id: str, db: Annotated[Session, Depends(get_db)] - ) -> TaskId: + async def delete_task_endpoint(self, task_id: str, db: GetDbDependency) -> TaskId: """Delete the task with the given ID. Args: @@ -288,7 +294,7 @@ async def _async_chat_completions( } @app.get("/api/status", response_model=SDKStatusResponse) - async def status(self, is_admin: AdminRequired) -> SDKStatusResponse: + async def status(self, is_admin: AdminAccessDependency) -> SDKStatusResponse: """The endpoint for checking the status of the application.""" app_names = [ self.app_name, diff --git a/aana/api/security.py b/aana/api/security.py index 14c99d75..92649470 100644 --- a/aana/api/security.py +++ b/aana/api/security.py @@ -3,11 +3,12 @@ from fastapi import Depends, Request from aana.configs.settings import settings as aana_settings +from aana.core.models.api_service import ApiKey from aana.exceptions.api_service import AdminOnlyAccess -def check_admin_permissions(request: Request): - """Check if the user is an admin. +def require_admin_access(request: Request) -> bool: + """Check if the user is an admin. If not, raise an exception. Args: request (Request): The request object @@ -16,20 +17,29 @@ def check_admin_permissions(request: Request): AdminOnlyAccess: If the user is not an admin """ if aana_settings.api_service.enabled: - api_key_info = request.state.api_key_info - is_admin = api_key_info.get("is_admin", False) + api_key_info: ApiKey = request.state.api_key_info + is_admin = api_key_info.is_admin if api_key_info else False if not is_admin: raise AdminOnlyAccess() + return True -class AdminCheck: - """Dependency to check if the user is an admin.""" +def extract_api_key_info(request: Request) -> ApiKey | None: + """Get the API key info dependency.""" + return getattr(request.state, "api_key_info", None) - async def __call__(self, request: Request) -> bool: - """Check if the user is an admin.""" - check_admin_permissions(request) - return True +def extract_user_id(request: Request) -> str | None: + """Get the user ID dependency.""" + api_key_info = extract_api_key_info(request) + return api_key_info.user_id if api_key_info else None -AdminRequired = Annotated[bool, Depends(AdminCheck())] -""" Annotation to check if the user is an admin. If not, it will raise an exception. """ + +AdminAccessDependency = Annotated[bool, Depends(require_admin_access)] +""" Dependency to check if the user is an admin. If not, it will raise an exception. """ + +UserIdDependency = Annotated[str | None, Depends(extract_user_id)] +""" Dependency to get the user ID. """ + +ApiKeyInfoDependency = Annotated[ApiKey | None, Depends(extract_api_key_info)] +""" Dependency to get the API key info. """ diff --git a/aana/api/webhook.py b/aana/api/webhook.py new file mode 100644 index 00000000..1d8e99b9 --- /dev/null +++ b/aana/api/webhook.py @@ -0,0 +1,284 @@ +import asyncio +import hashlib +import hmac +import json +import logging +from typing import Any + +import httpx +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, ConfigDict, Field, HttpUrl +from tenacity import retry, stop_after_attempt, wait_exponential + +from aana.api.security import UserIdDependency +from aana.configs.settings import settings as aana_settings +from aana.storage.models.api_key import ApiKeyEntity +from aana.storage.models.task import TaskEntity +from aana.storage.models.webhook import WebhookEntity, WebhookEventType +from aana.storage.repository.webhook import WebhookRepository +from aana.storage.session import GetDbDependency, get_session + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["webhooks"]) + + +# Request models + + +class WebhookRegistrationRequest(BaseModel): + """Request to register a webhook.""" + + url: HttpUrl = Field( + ..., description="The URL to which the webhook will send requests." + ) + events: list[WebhookEventType] = Field( + None, + description="The events to subscribe to. If None, the webhook is subscribed to all events.", + ) + + model_config = ConfigDict(extra="forbid") + + +class WebhookUpdateRequest(BaseModel): + """Request to update a webhook.""" + + url: HttpUrl | None = Field(None, description="New URL for the webhook.") + events: list[WebhookEventType] | None = Field( + None, description="New list of events to subscribe to." + ) + + model_config = ConfigDict(extra="forbid") + + +# Response models + + +class WebhookRegistrationResponse(BaseModel): + """Response for a webhook registration.""" + + id: str | None + message: str + + +class WebhookResponse(BaseModel): + """Response for a webhook registration.""" + + id: str = Field(..., description="The webhook ID.") + url: HttpUrl = Field( + ..., description="The URL to which the webhook will send requests." + ) + events: list[WebhookEventType] = Field( + ..., description="The events that the webhook is subscribed to." + ) + + @classmethod + def from_entity(cls, webhook: WebhookEntity) -> "WebhookResponse": + """Create a WebhookResponse from a WebhookEntity.""" + return WebhookResponse( + id=str(webhook.id), + url=str(webhook.url), + events=webhook.events, + ) + + +class WebhookListResponse(BaseModel): + """Response for a list of webhooks.""" + + webhooks: list[WebhookResponse] = Field(..., description="The list of webhooks.") + + +# Webhook Models + + +class TaskStatusChangeWebhookPayload(BaseModel): + """Payload for a task status change webhook.""" + + task_id: str + status: str + result: Any | None + num_retries: int + + +class WebhookBody(BaseModel): + """Body for a task status change webhook.""" + + event: WebhookEventType + payload: TaskStatusChangeWebhookPayload + + +# Webhook functions + + +def generate_hmac_signature(body: dict, user_id: str | None) -> str: + """Generate HMAC signature for a payload for a given user. + + Args: + body (dict): The webhook body. + user_id (str | None): The user ID associated with the payload. + + Returns: + str: The generated HMAC signature. + """ + # Use the default secret if no user ID is provided + secret = aana_settings.webhook.hmac_secret + # Get the user-specific secret if a user ID is provided + if user_id: + with get_session() as session: + api_key_info = ( + session.query(ApiKeyEntity).filter_by(user_id=user_id).first() + ) + if api_key_info and api_key_info.hmac_secret: + secret = api_key_info.hmac_secret + + payload_str = json.dumps(body, separators=(",", ":")) + return hmac.new(secret.encode(), payload_str.encode(), hashlib.sha256).hexdigest() + + +@retry( + stop=stop_after_attempt(aana_settings.webhook.retry_attempts), + wait=wait_exponential(), + reraise=True, +) +async def send_webhook_request(url: str, body: dict, headers: dict): + """Send a webhook request with retries. + + Args: + url (str): The webhook URL. + body (dict): The body of the request. + headers (dict): The headers to include in the request. + """ + async with httpx.AsyncClient(timeout=1.0) as client: + response = await client.post(url, json=body, headers=headers) + response.raise_for_status() + + +async def send_webhook_request_with_retry(url: str, body: dict, headers: dict): + """Send a webhook request with retries. + + Args: + url (str): The webhook URL. + body (dict): The body of the request. + headers (dict): The headers to include in the request. + """ + try: + await send_webhook_request(url, body, headers) + except Exception: + logger.exception(f"Failed to send webhook request to {url}.") + + +async def trigger_webhooks( + event: WebhookEventType, body: WebhookBody, user_id: str | None +): + """Trigger webhooks for an event. + + Args: + event (WebhookEventType): The event type. + body (WebhookBody): The body of the webhook request. + user_id (str | None): The user ID associated with the event. + """ + with get_session() as session: + webhook_repo = WebhookRepository(session) + webhooks = webhook_repo.get_webhooks(user_id, event) + body_dict = body.model_dump() + + for webhook in webhooks: + signature = generate_hmac_signature(body_dict, user_id) + headers = {"X-Signature": signature} + asyncio.create_task( # noqa: RUF006 + send_webhook_request_with_retry(webhook.url, body_dict, headers) + ) + + +async def trigger_task_webhooks(event: WebhookEventType, task: TaskEntity): + """Trigger webhooks for a task event. + + Args: + event (WebhookEventType): The event type. + task (TaskEntity): The task entity. + """ + payload = TaskStatusChangeWebhookPayload( + task_id=str(task.id), + status=task.status, + result=task.result, + num_retries=task.num_retries, + ) + body = WebhookBody(event=event, payload=payload) + await trigger_webhooks(event, body, task.user_id) + + +# Webhook endpoints + + +@router.post("/webhooks", status_code=201) +async def create_webhook( + request: WebhookRegistrationRequest, + db: GetDbDependency, + user_id: UserIdDependency, +) -> WebhookResponse: + """This endpoint is used to register a webhook.""" + webhook_repo = WebhookRepository(db) + webhook = WebhookEntity( + user_id=user_id, + url=str(request.url), + events=request.events, + ) + webhook = webhook_repo.save(webhook) + return WebhookResponse.from_entity(webhook) + + +@router.get("/webhooks") +async def list_webhooks( + db: GetDbDependency, user_id: UserIdDependency +) -> WebhookListResponse: + """This endpoint is used to list all registered webhooks.""" + webhook_repo = WebhookRepository(db) + webhooks = webhook_repo.get_webhooks(user_id, None) + return WebhookListResponse( + webhooks=[WebhookResponse.from_entity(webhook) for webhook in webhooks] + ) + + +@router.get("/webhooks/{webhook_id}") +async def get_webhook( + webhook_id: str, db: GetDbDependency, user_id: UserIdDependency +) -> WebhookResponse: + """This endpoint is used to fetch a webhook by ID.""" + webhook_repo = WebhookRepository(db) + webhook = webhook_repo.read(webhook_id, check=False) + if not webhook or webhook.user_id != user_id: + raise HTTPException(status_code=404, detail="Webhook not found") + return WebhookResponse.from_entity(webhook) + + +@router.put("/webhooks/{webhook_id}") +async def update_webhook( + webhook_id: str, + request: WebhookUpdateRequest, + db: GetDbDependency, + user_id: UserIdDependency, +) -> WebhookResponse: + """This endpoint is used to update a webhook.""" + webhook_repo = WebhookRepository(db) + webhook = webhook_repo.read(webhook_id, check=False) + if not webhook or webhook.user_id != user_id: + raise HTTPException(status_code=404, detail="Webhook not found") + if request.url is not None: + webhook.url = str(request.url) + if request.events is not None: + webhook.events = request.events + webhook = webhook_repo.save(webhook) + return WebhookResponse.from_entity(webhook) + + +@router.delete("/webhooks/{webhook_id}") +async def delete_webhook( + webhook_id: str, db: GetDbDependency, user_id: UserIdDependency +) -> WebhookResponse: + """This endpoint is used to delete a webhook.""" + webhook_repo = WebhookRepository(db) + webhook = webhook_repo.read(webhook_id, check=False) + if not webhook or webhook.user_id != user_id: + raise HTTPException(status_code=404, detail="Webhook not found") + webhook = webhook_repo.delete(webhook.id) + return WebhookResponse.from_entity(webhook) diff --git a/aana/configs/settings.py b/aana/configs/settings.py index e960ea27..d3e90d9c 100644 --- a/aana/configs/settings.py +++ b/aana/configs/settings.py @@ -56,6 +56,18 @@ class ApiServiceSettings(BaseModel): lago_api_key: str | None = None +class WebhookSettings(BaseModel): + """A pydantic model for webhook settings. + + Attributes: + retry_attempts (int): The number of retry attempts for webhook delivery. + hmac_secret (str): The secret key for HMAC signature generation. + """ + + retry_attempts: int = 5 + hmac_secret: str = "webhook_secret" + + class Settings(BaseSettings): """A pydantic model for SDK settings. @@ -95,6 +107,8 @@ class Settings(BaseSettings): datastore_config=SQLiteConfig(path="/var/lib/aana_api_service_data"), ) + webhook: WebhookSettings = WebhookSettings() + @model_validator(mode="after") def setup_resource_directories(self): """Create the resource directories if they do not exist.""" diff --git a/aana/core/models/api_service.py b/aana/core/models/api_service.py index ef2097ec..3768467e 100644 --- a/aana/core/models/api_service.py +++ b/aana/core/models/api_service.py @@ -5,12 +5,23 @@ class ApiKey(BaseModel): - """Pydantic model for API key entity.""" + """Pydantic model for API key entity. + + Attributes: + api_key (str): The API key. + user_id (str): ID of the user who owns this API key. + subscription_id (str): ID of the associated subscription. + is_subscription_active (bool): Whether the subscription is active (credits are available). + is_admin (bool): Whether the user is an admin. + hmac_secret (str | None): The secret key for HMAC signature generation. + """ api_key: str user_id: str subscription_id: str is_subscription_active: bool + is_admin: bool + hmac_secret: str | None ApiKeyType = SkipJsonSchema[Annotated[ApiKey, Field(default=None)]] diff --git a/aana/exceptions/runtime.py b/aana/exceptions/runtime.py index 72d3d2a2..825b8409 100644 --- a/aana/exceptions/runtime.py +++ b/aana/exceptions/runtime.py @@ -182,3 +182,20 @@ def __init__(self, filename: str): def __reduce__(self): """Used for pickling.""" return (self.__class__, (self.filename,)) + + +class InvalidWebhookEventType(BaseException): + """Exception raised when an invalid webhook event type is provided.""" + + def __init__(self, event_type: str): + """Initialize the exception. + + Args: + event_type (str): the invalid event type + """ + super().__init__(event_type=event_type) + self.event_type = event_type + + def __reduce__(self): + """Used for pickling.""" + return (self.__class__, (self.event_type,)) diff --git a/aana/sdk.py b/aana/sdk.py index 1c56518e..252458e8 100644 --- a/aana/sdk.py +++ b/aana/sdk.py @@ -7,6 +7,7 @@ import ray import yaml +from fastapi import APIRouter from ray import serve from ray.autoscaler.v2.schema import ResourceDemand from ray.autoscaler.v2.sdk import get_cluster_status @@ -50,6 +51,7 @@ def __init__( self.migration_func = migration_func self.endpoints: dict[str, Endpoint] = {} self.deployments: dict[str, Deployment] = {} + self.routers: dict[str, APIRouter] = {} if retryable_exceptions is None: self.retryable_exceptions = [ @@ -258,6 +260,7 @@ def get_main_app(self) -> Application: app_name=self.name, endpoints=self.endpoints.values(), deployments=list(self.deployments.keys()), + routers=list(self.routers.values()), ) def register_endpoint( @@ -300,6 +303,24 @@ def unregister_endpoint(self, name: str): if name in self.endpoints: del self.endpoints[name] + def register_router(self, name: str, router: APIRouter): + """Register a FastAPI router. + + Args: + name (str): The name of the router. + router (APIRouter): The instance of the APIRouter to be registered. + """ + self.routers[name] = router + + def unregister_router(self, name: str): + """Unregister a FastAPI router. + + Args: + name (str): The name of the router to be unregistered. + """ + if name in self.routers: + del self.routers[name] + def wait_for_deployment(self): # noqa: C901 """Wait for the deployment to complete.""" consecutive_resource_unavailable = 0 diff --git a/aana/storage/models/__init__.py b/aana/storage/models/__init__.py index 8f47d600..0ccaecfd 100644 --- a/aana/storage/models/__init__.py +++ b/aana/storage/models/__init__.py @@ -15,6 +15,7 @@ from aana.storage.models.task import TaskEntity from aana.storage.models.transcript import TranscriptEntity from aana.storage.models.video import VideoEntity +from aana.storage.models.webhook import WebhookEntity __all__ = [ "BaseEntity", diff --git a/aana/storage/models/api_key.py b/aana/storage/models/api_key.py index db7d38b8..e1b974d3 100644 --- a/aana/storage/models/api_key.py +++ b/aana/storage/models/api_key.py @@ -1,6 +1,8 @@ from sqlalchemy import Boolean from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from aana.core.models.api_service import ApiKey + class ApiServiceBase(DeclarativeBase): """Base class.""" @@ -32,6 +34,9 @@ class ApiKeyEntity(ApiServiceBase): default=True, comment="Whether the subscription is active (credits are available)", ) + hmac_secret: Mapped[str] = mapped_column( + nullable=True, comment="The secret key for HMAC signature generation" + ) def __repr__(self) -> str: """String representation of the API key.""" @@ -44,6 +49,13 @@ def __repr__(self) -> str: f"is_subscription_active={self.is_subscription_active})>" ) - def to_dict(self) -> dict: + def to_model(self) -> ApiKey: """Convert the object to a dictionary.""" - return {c.name: getattr(self, c.name) for c in self.__table__.columns} + return ApiKey( + api_key=self.api_key, + user_id=self.user_id, + is_admin=self.is_admin, + subscription_id=self.subscription_id, + is_subscription_active=self.is_subscription_active, + hmac_secret=self.hmac_secret, + ) diff --git a/aana/storage/models/webhook.py b/aana/storage/models/webhook.py new file mode 100644 index 00000000..dbc910d0 --- /dev/null +++ b/aana/storage/models/webhook.py @@ -0,0 +1,45 @@ +import uuid +from enum import Enum + +from sqlalchemy import JSON, UUID +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from aana.storage.models.base import BaseEntity, TimeStampEntity + + +class WebhookEventType(str, Enum): + """Enum for webhook event types.""" + + TASK_COMPLETED = "task.completed" + TASK_FAILED = "task.failed" + TASK_STARTED = "task.started" + + +class WebhookEntity(BaseEntity, TimeStampEntity): + """Table for webhook items.""" + + __tablename__ = "webhooks" + + id: Mapped[uuid.UUID] = mapped_column( + UUID, primary_key=True, default=uuid.uuid4, comment="Webhook ID" + ) + user_id: Mapped[str | None] = mapped_column( + nullable=True, index=True, comment="The user ID associated with the webhook" + ) + url: Mapped[str] = mapped_column( + nullable=False, comment="The URL to which the webhook will send requests" + ) + events: Mapped[list[str]] = mapped_column( + JSON().with_variant(JSONB, "postgresql"), + nullable=False, + comment="List of events the webhook is subscribed to. If the list is empty, the webhook is subscribed to all events.", + ) + + def __repr__(self) -> str: + """String representation of the webhook.""" + return ( + f"" + ) diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index d4c38796..be01d467 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -241,7 +241,7 @@ def update_status( progress: int | None = None, result: Any = None, commit: bool = True, - ): + ) -> TaskEntity: """Update the status of a task. Args: @@ -250,6 +250,9 @@ def update_status( progress (int | None): The progress. If None, the progress will not be updated. result (Any): The result. commit (bool): Whether to commit the transaction. + + Returns: + TaskEntity: The updated task. """ task = self.read(task_id) if status == TaskStatus.COMPLETED or status == TaskStatus.FAILED: @@ -263,6 +266,7 @@ def update_status( task.result = result if commit: self.session.commit() + return task def get_active_tasks(self) -> list[TaskEntity]: """Fetches all active tasks. diff --git a/aana/storage/repository/webhook.py b/aana/storage/repository/webhook.py new file mode 100644 index 00000000..727e1a60 --- /dev/null +++ b/aana/storage/repository/webhook.py @@ -0,0 +1,113 @@ +from uuid import UUID + +from sqlalchemy import func, or_ +from sqlalchemy.orm import Session + +from aana.exceptions.runtime import InvalidWebhookEventType +from aana.storage.models.webhook import WebhookEntity, WebhookEventType +from aana.storage.repository.base import BaseRepository + + +class WebhookRepository(BaseRepository[WebhookEntity]): + """Repository for webhooks.""" + + def __init__(self, session: Session): + """Constructor.""" + super().__init__(session, WebhookEntity) + + def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity: + """Reads a single webhook from the database. + + Args: + item_id (str | UUID): ID of the webhook to retrieve. + check (bool): whether to raise if the entity is not found (defaults to True). + + Returns: + The corresponding entity from the database if found. + + Raises: + NotFoundException if the entity is not found and `check` is True. + """ + if isinstance(item_id, str): + item_id = UUID(item_id) + return super().read(item_id, check) + + def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity: + """Delete a webhook from the database. + + Args: + item_id (str | UUID): The ID of the webhook to delete. + check (bool): whether to raise if the entity is not found (defaults to True). + + Returns: + WebhookEntity: The deleted webhook. + + Raises: + NotFoundException: The id does not correspond to a record in the database. + """ + if isinstance(item_id, str): + item_id = UUID(item_id) + return super().delete(item_id, check) + + def save(self, webhook: WebhookEntity) -> WebhookEntity: + """Save a webhook to the database. + + Args: + webhook (WebhookEntity): The webhook to save. + + Returns: + WebhookEntity: The saved webhook. + """ + if webhook.events is None: + webhook.events = [] + + # Check if events are in WebhookEventType enum + if webhook.events: + try: + webhook.events = [WebhookEventType(event) for event in webhook.events] + except ValueError as e: + raise InvalidWebhookEventType(event_type=e.args[0]) from e + + self.session.add(webhook) + self.session.commit() + return webhook + + def get_webhooks( + self, user_id: str | None, event_type: str | None = None + ) -> list[WebhookEntity]: + """Get webhooks for a user. + + Args: + user_id (str | None): The user ID. If None, get system-wide webhooks. + event_type (str | None): Filter webhooks by event type. If None, return all webhooks. + + Returns: + List[WebhookEntity]: The list of webhooks. + """ + query = self.session.query(WebhookEntity).filter_by(user_id=user_id) + if event_type: + if self.session.bind.dialect.name == "postgresql": + query = query.filter( + or_( + WebhookEntity.events.op("@>")([event_type]), + WebhookEntity.events == [], + ) + ) + elif self.session.bind.dialect.name == "sqlite": + events_func = func.json_each(WebhookEntity.events).table_valued( + "value", joins_implicitly=True + ) + query = query.filter( + or_( + self.session.query(events_func) + .filter(events_func.c.value == event_type) + .exists(), + WebhookEntity.events == "[]", + ) + ) + else: + raise NotImplementedError( + f"Filtering by event type is not supported for {self.session.bind.dialect.name}" + ) + + return query.all() diff --git a/aana/storage/session.py b/aana/storage/session.py index 573eca32..19e6fcce 100644 --- a/aana/storage/session.py +++ b/aana/storage/session.py @@ -1,10 +1,13 @@ +from typing import Annotated + +from fastapi import Depends from sqlalchemy.orm import Session, sessionmaker from aana.configs.settings import settings from aana.storage.models.api_key import ApiServiceBase from aana.storage.models.base import BaseEntity -__all__ = ["get_session"] +__all__ = ["get_session", "get_db"] engine = settings.db_config.get_engine() @@ -28,3 +31,16 @@ def get_session() -> Session: Session: SQLAlchemy Session object. """ return SessionLocal() + + +def get_db(): + """Get a database session.""" + db = get_session() + try: + yield db + finally: + db.close() + + +GetDbDependency = Annotated[Session, Depends(get_db)] +""" Dependency to get a database session. """ diff --git a/aana/tests/db/datastore/test_webhook_repo.py b/aana/tests/db/datastore/test_webhook_repo.py new file mode 100644 index 00000000..fb4ba44e --- /dev/null +++ b/aana/tests/db/datastore/test_webhook_repo.py @@ -0,0 +1,99 @@ +# ruff: noqa: S101 + +import pytest + +from aana.exceptions.runtime import InvalidWebhookEventType +from aana.storage.models.webhook import WebhookEntity +from aana.storage.repository.webhook import WebhookRepository + + +@pytest.fixture +def webhook_entities(): + """Create webhook entities for testing.""" + # fmt: off + webhooks = [ + WebhookEntity(user_id="user1", url="https://example1.com", events=["task.completed", "task.failed"]), + WebhookEntity(user_id="user1", url="https://example2.com", events=["task.failed"]), + WebhookEntity(user_id="user2", url="https://example3.com", events=["task.completed", "task.failed"]), + WebhookEntity(user_id=None, url="https://example4.com", events=["task.failed"]), # System webhook for task.failed + WebhookEntity(user_id=None, url="https://example5.com", events=None), # System webhook for all events + ] + return webhooks + # fmt: on + + +def test_save_webhook(db_session, webhook_entities): + """Test saving a webhook.""" + webhook_repo = WebhookRepository(db_session) + for webhook in webhook_entities: + saved_webhook = webhook_repo.save(webhook) + assert saved_webhook.id is not None + assert saved_webhook.user_id == webhook.user_id + assert saved_webhook.url == webhook.url + assert saved_webhook.events == webhook.events + + retrieved_webhook = webhook_repo.read(saved_webhook.id) + assert retrieved_webhook.id == saved_webhook.id + assert retrieved_webhook.user_id == saved_webhook.user_id + assert retrieved_webhook.url == saved_webhook.url + + # Test saving a webhook with invalid event type + with pytest.raises(InvalidWebhookEventType): + webhook = WebhookEntity( + user_id="user1", url="https://example7.com", events=["invalid.event"] + ) + webhook_repo.save(webhook) + + +def test_get_webhooks(db_session, webhook_entities): + """Test fetching webhooks.""" + webhook_repo = WebhookRepository(db_session) + + # Save webhooks + for webhook in webhook_entities: + webhook_repo.save(webhook) + + # Test webhooks with user ID set to None + webhooks = webhook_repo.get_webhooks(user_id=None) + assert len(webhooks) == 2 + assert {webhook.url for webhook in webhooks} == { + "https://example4.com", + "https://example5.com", + } + + # Test webhooks with user ID set to "user1" + webhooks = webhook_repo.get_webhooks(user_id="user1") + assert len(webhooks) == 2 + assert {webhook.url for webhook in webhooks} == { + "https://example1.com", + "https://example2.com", + } + + # Test webhooks with user ID set to "user2" + webhooks = webhook_repo.get_webhooks(user_id="user2") + assert len(webhooks) == 1 + + # Test webhooks with user ID set to "user3" + webhooks = webhook_repo.get_webhooks(user_id="user3") + assert len(webhooks) == 0 + + # Test webhooks with user ID set to None and event type set to "task.failed" + webhooks = webhook_repo.get_webhooks(user_id=None, event_type="task.failed") + assert len(webhooks) == 2 + assert {webhook.url for webhook in webhooks} == { + "https://example4.com", + "https://example5.com", + } + + # Test webhooks with user ID set to "user1" and event type set to "task.failed" + webhooks = webhook_repo.get_webhooks(user_id="user1", event_type="task.failed") + assert len(webhooks) == 2 + assert {webhook.url for webhook in webhooks} == { + "https://example1.com", + "https://example2.com", + } + + # Test webhooks with user ID set to "user1" and event type set to "task.completed" + webhooks = webhook_repo.get_webhooks(user_id="user1", event_type="task.completed") + assert len(webhooks) == 1 + assert {webhook.url for webhook in webhooks} == {"https://example1.com"} diff --git a/aana/tests/units/test_task_queue.py b/aana/tests/units/test_task_queue.py index 900d8744..3900fb93 100644 --- a/aana/tests/units/test_task_queue.py +++ b/aana/tests/units/test_task_queue.py @@ -116,17 +116,15 @@ def test_task_queue(create_app): # noqa: C901 aana_app = create_app(deployments, endpoints) port = aana_app.port - route_prefix = "" - # Check that the server is ready - response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready") + response = requests.get(f"http://localhost:{port}/api/ready") assert response.status_code == 200 assert response.json() == {"ready": True} # Test lowercase endpoint data = {"text": ["Hello World!", "This is a test."]} response = requests.post( - f"http://localhost:{port}{route_prefix}/lowercase", + f"http://localhost:{port}/lowercase", data={"body": json.dumps(data)}, ) assert response.status_code == 200 @@ -135,7 +133,7 @@ def test_task_queue(create_app): # noqa: C901 # Defer endpoint execution response = requests.post( - f"http://localhost:{port}{route_prefix}/lowercase?defer=True", + f"http://localhost:{port}/lowercase?defer=True", data={"body": json.dumps(data)}, ) assert response.status_code == 200 @@ -144,44 +142,36 @@ def test_task_queue(create_app): # noqa: C901 # Check the task status with timeout of 10 seconds start_time = time.time() while time.time() - start_time < 10: - response = requests.get( - f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" - ) + response = requests.get(f"http://localhost:{port}/tasks/get/{task_id}") task_status = response.json().get("status") result = response.json().get("result") if task_status == "completed": break time.sleep(0.1) - assert task_status == "completed" + assert task_status == "completed", response.text assert result == {"text": ["hello world!", "this is a test."]} # Delete the task - response = requests.get( - f"http://localhost:{port}{route_prefix}/tasks/delete/{task_id}" - ) + response = requests.get(f"http://localhost:{port}/tasks/delete/{task_id}") assert response.status_code == 200 assert response.json().get("task_id") == task_id # Check that the task is deleted - response = requests.get( - f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" - ) + response = requests.get(f"http://localhost:{port}/tasks/get/{task_id}") assert response.status_code == 404 assert response.json().get("error") == "NotFoundException" # Check non-existent task task_id = "d1b1b1b1-1b1b-1b1b-1b1b-1b1b1b1b1b1b" - response = requests.get( - f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" - ) + response = requests.get(f"http://localhost:{port}/tasks/get/{task_id}") assert response.status_code == 404 assert response.json().get("error") == "NotFoundException" # Test lowercase streaming endpoint data = {"text": ["Hello World!", "This is a test."]} response = requests.post( - f"http://localhost:{port}{route_prefix}/lowercase_stream", + f"http://localhost:{port}/lowercase_stream", data={"body": json.dumps(data)}, stream=True, ) @@ -195,7 +185,7 @@ def test_task_queue(create_app): # noqa: C901 # Test task queue with streaming endpoint data = {"text": ["Hello World!", "This is a test."]} response = requests.post( - f"http://localhost:{port}{route_prefix}/lowercase_stream?defer=True", + f"http://localhost:{port}/lowercase_stream?defer=True", data={"body": json.dumps(data)}, ) assert response.status_code == 200 @@ -204,16 +194,14 @@ def test_task_queue(create_app): # noqa: C901 # Check the task status with timeout of 10 seconds start_time = time.time() while time.time() - start_time < 10: - response = requests.get( - f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" - ) + response = requests.get(f"http://localhost:{port}/tasks/get/{task_id}") task_status = response.json().get("status") result = response.json().get("result") if task_status == "completed": break time.sleep(0.1) - assert task_status == "completed" + assert task_status == "completed", response.text assert [chunk["text"] for chunk in result] == lowercase_text # Send 30 tasks to the task queue @@ -221,7 +209,7 @@ def test_task_queue(create_app): # noqa: C901 for i in range(30): data = {"text": [f"Task {i}"]} response = requests.post( - f"http://localhost:{port}{route_prefix}/lowercase_stream?defer=True", + f"http://localhost:{port}/lowercase_stream?defer=True", data={"body": json.dumps(data)}, ) assert response.status_code == 200 @@ -234,9 +222,7 @@ def test_task_queue(create_app): # noqa: C901 for task_id in task_ids: if task_id in completed_tasks: continue - response = requests.get( - f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" - ) + response = requests.get(f"http://localhost:{port}/tasks/get/{task_id}") task_status = response.json().get("status") result = response.json().get("result") if task_status == "completed": @@ -248,9 +234,7 @@ def test_task_queue(create_app): # noqa: C901 # Check that all tasks are completed for task_id in task_ids: - response = requests.get( - f"http://localhost:{port}{route_prefix}/tasks/get/{task_id}" - ) + response = requests.get(f"http://localhost:{port}/tasks/get/{task_id}") response = response.json() task_status = response.get("status") assert task_status == "completed", response diff --git a/aana/tests/units/test_webhooks.py b/aana/tests/units/test_webhooks.py new file mode 100644 index 00000000..79adb183 --- /dev/null +++ b/aana/tests/units/test_webhooks.py @@ -0,0 +1,164 @@ +# ruff: noqa: S101, S113 +import hashlib +import hmac +import json +from typing import Annotated, TypedDict + +import requests +from pydantic import Field + +from aana.api.api_generation import Endpoint +from aana.configs.settings import settings as aana_settings + +TextList = Annotated[list[str], Field(description="List of text to lowercase.")] + + +class LowercaseEndpointOutput(TypedDict): + """The output of the lowercase endpoint.""" + + text: list[str] + + +class LowercaseEndpoint(Endpoint): + """Lowercase endpoint.""" + + async def run(self, text: TextList) -> LowercaseEndpointOutput: + """Lowercase the text. + + Args: + text (TextList): The list of text to lowercase + + Returns: + LowercaseEndpointOutput: The lowercase texts + """ + return {"text": [t.lower() for t in text]} + + +deployments = [] + +endpoints = [ + { + "name": "lowercase", + "path": "/lowercase", + "summary": "Lowercase text", + "endpoint_cls": LowercaseEndpoint, + } +] + + +def test_webhooks(create_app, httpserver): + """Test webhooks.""" + aana_app = create_app(deployments, endpoints) + + port = aana_app.port + + # Check that the server is ready + response = requests.get(f"http://localhost:{port}/api/ready") + assert response.status_code == 200, response.text + assert response.json() == {"ready": True} + + # Setup the webhook listener + def webhook_listener(request): + payload = request.json + + # Validate the HMAC signature + secret_key = aana_settings.webhook.hmac_secret + signature = request.headers.get("X-Signature") + payload_str = json.dumps(payload, separators=(",", ":")) + actual_signature = hmac.new( + secret_key.encode(), payload_str.encode(), hashlib.sha256 + ).hexdigest() + assert signature == actual_signature + assert payload["event"] == "task.completed" + + httpserver.expect_request("/webhooks").respond_with_handler(webhook_listener) + + # Register the webhook + data = { + "url": f"http://localhost:{httpserver.port}/webhooks", + "events": ["task.completed"], + } + response = requests.post( + f"http://localhost:{port}/webhooks", + json=data, + ) + assert response.status_code == 201, response.text + + # Test lowercase endpoint + data = {"text": ["Hello World!", "This is a test."]} + response = requests.post( + f"http://localhost:{port}/lowercase", + data={"body": json.dumps(data)}, + ) + assert response.status_code == 200, response.text + lowercase_text = response.json().get("text") + assert lowercase_text == ["hello world!", "this is a test."] + + # Defer endpoint execution + response = requests.post( + f"http://localhost:{port}/lowercase?defer=True", + data={"body": json.dumps(data)}, + ) + assert response.status_code == 200, response.text + + httpserver.check_assertions() + httpserver.check_handler_errors() + + +def test_webhook_crud(create_app): + """Test webhook CRUD operations.""" + aana_app = create_app(deployments, endpoints) + port = aana_app.port + base_url = f"http://localhost:{port}" + + # Clear existing webhooks + response = requests.get(f"{base_url}/webhooks") + assert response.status_code == 200 + webhooks = response.json()["webhooks"] + for webhook in webhooks: + response = requests.delete(f"{base_url}/webhooks/{webhook['id']}") + assert response.status_code == 200 + + # Test Create + webhook_data = { + "url": "http://example.com/webhook", + "events": ["task.completed", "task.failed"], + } + response = requests.post(f"{base_url}/webhooks", json=webhook_data) + assert response.status_code == 201 + webhook_id = response.json()["id"] + assert response.json()["url"] == webhook_data["url"] + assert response.json()["events"] == webhook_data["events"] + + # Test Read (List) + response = requests.get(f"{base_url}/webhooks") + assert response.status_code == 200 + webhooks = response.json()["webhooks"] + assert len(webhooks) == 1 + assert webhooks[0]["id"] == webhook_id + + # Test Read (Single) + response = requests.get(f"{base_url}/webhooks/{webhook_id}") + assert response.status_code == 200 + assert response.json()["id"] == webhook_id + assert response.json()["url"] == webhook_data["url"] + + # Test Update + update_data = {"url": "http://example.com/webhook2", "events": ["task.completed"]} + response = requests.put(f"{base_url}/webhooks/{webhook_id}", json=update_data) + assert response.status_code == 200 + assert response.json()["url"] == update_data["url"] + assert response.json()["events"] == update_data["events"] + + # Test Delete + response = requests.delete(f"{base_url}/webhooks/{webhook_id}") + assert response.status_code == 200 + + # Verify webhook is deleted + response = requests.get(f"{base_url}/webhooks/{webhook_id}") + assert response.status_code == 404 + + # Test validation for URL + response = requests.post(f"{base_url}/webhooks", json={"url": "invalid-url"}) + assert response.status_code == 422 + assert response.json()["error"] == "ValidationError" diff --git a/poetry.lock b/poetry.lock index 5f04981b..859d0f06 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5763,6 +5763,22 @@ tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} [package.extras] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] +[[package]] +name = "pytest-httpserver" +version = "1.1.1" +description = "pytest-httpserver is a httpserver for pytest" +optional = false +python-versions = ">=3.8" +groups = ["tests"] +markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" +files = [ + {file = "pytest_httpserver-1.1.1-py3-none-any.whl", hash = "sha256:aadc744bfac773a2ea93d05c2ef51fa23c087e3cc5dace3ea9d45cdd4bfe1fe8"}, + {file = "pytest_httpserver-1.1.1.tar.gz", hash = "sha256:e5c46c62c0aa65e5d4331228cb2cb7db846c36e429c3e74ca806f284806bf7c6"}, +] + +[package.dependencies] +Werkzeug = ">=2.0.0" + [[package]] name = "pytest-mock" version = "3.14.0" @@ -8509,6 +8525,25 @@ files = [ {file = "websockets-14.2.tar.gz", hash = "sha256:5059ed9c54945efb321f097084b4c7e52c246f2c869815876a69d1efc4ad6eb5"}, ] +[[package]] +name = "werkzeug" +version = "3.1.3" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.9" +groups = ["tests"] +markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" +files = [ + {file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"}, + {file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "wrapt" version = "1.17.2" @@ -8912,4 +8947,4 @@ vllm = ["outlines", "vllm"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "672d596ed660946588deac33967e77a0209b3e47364269b41ecd5bc4a87480b9" +content-hash = "0cb9a93cddd2784da8f52d3dc6cdb4834b9dee16b6cbce57d8459ffd13359ab7" diff --git a/pyproject.toml b/pyproject.toml index b01216bd..6627de33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ optional = true pytest-asyncio = "^0.23.6" pytest-dotenv = "^0.5.2" pytest-env = "^1.1.3" +pytest-httpserver = "^1.1.1" pytest-mock = "^3.12.0" pytest-postgresql = "^6.0.0" pytest-timeout = "^2.2.0" @@ -113,6 +114,7 @@ rapidfuzz = "^3.4.0" sentence-transformers = ">=2.6.1" sqlalchemy-utils = "^0.41.1" + [project.scripts] aana = "aana.cli:cli"