From 4030d70c59af9bd2f97063e66ff801ba93003d66 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Mon, 3 Feb 2025 12:22:51 +0000 Subject: [PATCH] Moved webhook handling in a separate router. Added optional routers to Aana SDK. --- aana/api/request_handler.py | 61 ++++++++--------------------- aana/api/webhook.py | 40 +++++++++++++++++-- aana/sdk.py | 21 ++++++++++ aana/storage/session.py | 11 +++++- aana/tests/units/test_task_queue.py | 4 +- 5 files changed, 86 insertions(+), 51 deletions(-) diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index 2de7d3ec..3faa5583 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -5,7 +5,7 @@ import orjson import ray -from fastapi import Depends +from fastapi import APIRouter, Depends from fastapi.openapi.utils import get_openapi from fastapi.responses import StreamingResponse from ray import serve @@ -18,10 +18,9 @@ from aana.api.responses import AanaJSONResponse from aana.api.webhook import ( WebhookEventType, - WebhookRegistrationRequest, - WebhookRegistrationResponse, trigger_task_webhooks, ) +from aana.api.webhook import router as webhook_router from aana.configs.settings import settings as aana_settings from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog @@ -29,19 +28,8 @@ from aana.core.models.task import TaskId, TaskInfo from aana.deployments.aana_deployment_handle import AanaDeploymentHandle from aana.storage.models.task import Status as TaskStatus -from aana.storage.models.webhook import WebhookEntity from aana.storage.repository.task import TaskRepository -from aana.storage.repository.webhook import WebhookRepository -from aana.storage.session import get_session - - -def get_db(): - """Get a database session.""" - db = get_session() - try: - yield db - finally: - db.close() +from aana.storage.session import get_db, get_session @serve.deployment(ray_actor_options={"num_cpus": 0.1}) @@ -52,7 +40,11 @@ class RequestHandler: ready = False def __init__( - self, app_name: str, endpoints: list[Endpoint], deployments: list[str] + self, + app_name: str, + endpoints: list[Endpoint], + deployments: list[str], + routers: list[APIRouter] | None = None, ): """Constructor. @@ -60,11 +52,19 @@ def __init__( app_name (str): The name of the application. endpoints (dict): List of endpoints for the request. deployments (list[str]): List of deployment names for the app. + routers (list[APIRouter]): List of FastAPI routers to include in the app. """ self.app_name = app_name self.endpoints = endpoints self.deployments = deployments + # Include the webhook router + app.include_router(webhook_router) + # Include the custom routers + if routers is not None: + for router in routers: + app.include_router(router) + self.event_manager = EventManager() self.custom_schemas: dict[str, dict] = {} for endpoint in self.endpoints: @@ -344,32 +344,3 @@ async def status(self) -> SDKStatusResponse: message=app_status_message, deployments=deployment_statuses, ) - - @app.post("/webhooks", status_code=201) - async def register_webhook( - self, - request: WebhookRegistrationRequest, - db: Annotated[Session, Depends(get_db)], - ) -> WebhookRegistrationResponse: - """Register a new webhook. - - Args: - request (WebhookRegistrationRequest): The webhook registration request. - db (Session): The database session. - - Returns: - WebhookRegistrationResponse: The response message. - """ - webhook_repo = WebhookRepository(db) - try: - webhook = WebhookEntity( - user_id=request.user_id, - url=request.url, - events=request.events, - ) - webhook_repo.save(webhook) - except Exception: - return WebhookRegistrationResponse(message="Failed to register webhook") - return WebhookRegistrationResponse( - id=str(webhook.id), message="Webhook registered successfully" - ) diff --git a/aana/api/webhook.py b/aana/api/webhook.py index bab30184..acdf988e 100644 --- a/aana/api/webhook.py +++ b/aana/api/webhook.py @@ -3,20 +3,25 @@ import hmac import json import logging +from typing import Annotated, Any import httpx +from fastapi import APIRouter, Depends from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy.orm import Session from tenacity import retry, stop_after_attempt, wait_exponential from aana.configs.settings import settings as aana_settings from aana.storage.models.api_key import ApiKeyEntity from aana.storage.models.task import TaskEntity -from aana.storage.models.webhook import WebhookEventType +from aana.storage.models.webhook import WebhookEntity, WebhookEventType from aana.storage.repository.webhook import WebhookRepository -from aana.storage.session import get_session +from aana.storage.session import get_db, get_session logger = logging.getLogger(__name__) +router = APIRouter(tags=["webhooks"]) + class WebhookRegistrationRequest(BaseModel): """Request to register a webhook.""" @@ -47,7 +52,7 @@ class TaskStatusChangeWebhookPayload(BaseModel): task_id: str status: str - result: dict | None + result: Any | None num_retries: int @@ -142,3 +147,32 @@ async def trigger_task_webhooks(event: WebhookEventType, task: TaskEntity): ) body = WebhookBody(event=event, payload=payload) await trigger_webhooks(event, body, task.user_id) + + +@router.post("/webhooks", status_code=201) +async def register_webhook( + request: WebhookRegistrationRequest, + db: Annotated[Session, Depends(get_db)], +) -> WebhookRegistrationResponse: + """Register a new webhook. + + Args: + request (WebhookRegistrationRequest): The webhook registration request. + db (Session): The database session. + + Returns: + WebhookRegistrationResponse: The response message. + """ + webhook_repo = WebhookRepository(db) + try: + webhook = WebhookEntity( + user_id=request.user_id, + url=request.url, + events=request.events, + ) + webhook_repo.save(webhook) + except Exception: + return WebhookRegistrationResponse(message="Failed to register webhook") + return WebhookRegistrationResponse( + id=str(webhook.id), message="Webhook registered successfully" + ) diff --git a/aana/sdk.py b/aana/sdk.py index 47005780..be574c9d 100644 --- a/aana/sdk.py +++ b/aana/sdk.py @@ -7,6 +7,7 @@ import ray import yaml +from fastapi import APIRouter from ray import serve from ray.autoscaler.v2.schema import ResourceDemand from ray.autoscaler.v2.sdk import get_cluster_status @@ -50,6 +51,7 @@ def __init__( self.migration_func = migration_func self.endpoints: dict[str, Endpoint] = {} self.deployments: dict[str, Deployment] = {} + self.routers: dict[str, APIRouter] = {} if retryable_exceptions is None: self.retryable_exceptions = [ @@ -258,6 +260,7 @@ def get_main_app(self) -> Application: app_name=self.name, endpoints=self.endpoints.values(), deployments=list(self.deployments.keys()), + routers=list(self.routers.values()), ) def register_endpoint( @@ -294,6 +297,24 @@ def unregister_endpoint(self, name: str): if name in self.endpoints: del self.endpoints[name] + def register_router(self, name: str, router: APIRouter): + """Register a FastAPI router. + + Args: + name (str): The name of the router. + router (APIRouter): The instance of the APIRouter to be registered. + """ + self.routers[name] = router + + def unregister_router(self, name: str): + """Unregister a FastAPI router. + + Args: + name (str): The name of the router to be unregistered. + """ + if name in self.routers: + del self.routers[name] + def wait_for_deployment(self): # noqa: C901 """Wait for the deployment to complete.""" consecutive_resource_unavailable = 0 diff --git a/aana/storage/session.py b/aana/storage/session.py index 573eca32..141a750e 100644 --- a/aana/storage/session.py +++ b/aana/storage/session.py @@ -4,7 +4,7 @@ from aana.storage.models.api_key import ApiServiceBase from aana.storage.models.base import BaseEntity -__all__ = ["get_session"] +__all__ = ["get_session", "get_db"] engine = settings.db_config.get_engine() @@ -28,3 +28,12 @@ def get_session() -> Session: Session: SQLAlchemy Session object. """ return SessionLocal() + + +def get_db(): + """Get a database session.""" + db = get_session() + try: + yield db + finally: + db.close() diff --git a/aana/tests/units/test_task_queue.py b/aana/tests/units/test_task_queue.py index 9bdf3036..3900fb93 100644 --- a/aana/tests/units/test_task_queue.py +++ b/aana/tests/units/test_task_queue.py @@ -149,7 +149,7 @@ def test_task_queue(create_app): # noqa: C901 break time.sleep(0.1) - assert task_status == "completed" + assert task_status == "completed", response.text assert result == {"text": ["hello world!", "this is a test."]} # Delete the task @@ -201,7 +201,7 @@ def test_task_queue(create_app): # noqa: C901 break time.sleep(0.1) - assert task_status == "completed" + assert task_status == "completed", response.text assert [chunk["text"] for chunk in result] == lowercase_text # Send 30 tasks to the task queue