Skip to content

Commit

Permalink
Moved webhook handling in a separate router. Added optional routers t…
Browse files Browse the repository at this point in the history
…o Aana SDK.
  • Loading branch information
Aleksandr Movchan committed Feb 3, 2025
1 parent 7186bc3 commit 4030d70
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 51 deletions.
61 changes: 16 additions & 45 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import orjson
import ray
from fastapi import Depends
from fastapi import APIRouter, Depends
from fastapi.openapi.utils import get_openapi
from fastapi.responses import StreamingResponse
from ray import serve
Expand All @@ -18,30 +18,18 @@
from aana.api.responses import AanaJSONResponse
from aana.api.webhook import (
WebhookEventType,
WebhookRegistrationRequest,
WebhookRegistrationResponse,
trigger_task_webhooks,
)
from aana.api.webhook import router as webhook_router
from aana.configs.settings import settings as aana_settings
from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse
from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.task import TaskId, TaskInfo
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.storage.models.task import Status as TaskStatus
from aana.storage.models.webhook import WebhookEntity
from aana.storage.repository.task import TaskRepository
from aana.storage.repository.webhook import WebhookRepository
from aana.storage.session import get_session


def get_db():
"""Get a database session."""
db = get_session()
try:
yield db
finally:
db.close()
from aana.storage.session import get_db, get_session


@serve.deployment(ray_actor_options={"num_cpus": 0.1})
Expand All @@ -52,19 +40,31 @@ class RequestHandler:
ready = False

def __init__(
self, app_name: str, endpoints: list[Endpoint], deployments: list[str]
self,
app_name: str,
endpoints: list[Endpoint],
deployments: list[str],
routers: list[APIRouter] | None = None,
):
"""Constructor.
Args:
app_name (str): The name of the application.
endpoints (dict): List of endpoints for the request.
deployments (list[str]): List of deployment names for the app.
routers (list[APIRouter]): List of FastAPI routers to include in the app.
"""
self.app_name = app_name
self.endpoints = endpoints
self.deployments = deployments

# Include the webhook router
app.include_router(webhook_router)
# Include the custom routers
if routers is not None:
for router in routers:
app.include_router(router)

self.event_manager = EventManager()
self.custom_schemas: dict[str, dict] = {}
for endpoint in self.endpoints:
Expand Down Expand Up @@ -344,32 +344,3 @@ async def status(self) -> SDKStatusResponse:
message=app_status_message,
deployments=deployment_statuses,
)

@app.post("/webhooks", status_code=201)
async def register_webhook(
self,
request: WebhookRegistrationRequest,
db: Annotated[Session, Depends(get_db)],
) -> WebhookRegistrationResponse:
"""Register a new webhook.
Args:
request (WebhookRegistrationRequest): The webhook registration request.
db (Session): The database session.
Returns:
WebhookRegistrationResponse: The response message.
"""
webhook_repo = WebhookRepository(db)
try:
webhook = WebhookEntity(
user_id=request.user_id,
url=request.url,
events=request.events,
)
webhook_repo.save(webhook)
except Exception:
return WebhookRegistrationResponse(message="Failed to register webhook")
return WebhookRegistrationResponse(
id=str(webhook.id), message="Webhook registered successfully"
)
40 changes: 37 additions & 3 deletions aana/api/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
import hmac
import json
import logging
from typing import Annotated, Any

import httpx
from fastapi import APIRouter, Depends
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.orm import Session
from tenacity import retry, stop_after_attempt, wait_exponential

from aana.configs.settings import settings as aana_settings
from aana.storage.models.api_key import ApiKeyEntity
from aana.storage.models.task import TaskEntity
from aana.storage.models.webhook import WebhookEventType
from aana.storage.models.webhook import WebhookEntity, WebhookEventType
from aana.storage.repository.webhook import WebhookRepository
from aana.storage.session import get_session
from aana.storage.session import get_db, get_session

logger = logging.getLogger(__name__)

router = APIRouter(tags=["webhooks"])


class WebhookRegistrationRequest(BaseModel):
"""Request to register a webhook."""
Expand Down Expand Up @@ -47,7 +52,7 @@ class TaskStatusChangeWebhookPayload(BaseModel):

task_id: str
status: str
result: dict | None
result: Any | None
num_retries: int


Expand Down Expand Up @@ -142,3 +147,32 @@ async def trigger_task_webhooks(event: WebhookEventType, task: TaskEntity):
)
body = WebhookBody(event=event, payload=payload)
await trigger_webhooks(event, body, task.user_id)


@router.post("/webhooks", status_code=201)
async def register_webhook(
request: WebhookRegistrationRequest,
db: Annotated[Session, Depends(get_db)],
) -> WebhookRegistrationResponse:
"""Register a new webhook.
Args:
request (WebhookRegistrationRequest): The webhook registration request.
db (Session): The database session.
Returns:
WebhookRegistrationResponse: The response message.
"""
webhook_repo = WebhookRepository(db)
try:
webhook = WebhookEntity(
user_id=request.user_id,
url=request.url,
events=request.events,
)
webhook_repo.save(webhook)
except Exception:
return WebhookRegistrationResponse(message="Failed to register webhook")
return WebhookRegistrationResponse(
id=str(webhook.id), message="Webhook registered successfully"
)
21 changes: 21 additions & 0 deletions aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import ray
import yaml
from fastapi import APIRouter
from ray import serve
from ray.autoscaler.v2.schema import ResourceDemand
from ray.autoscaler.v2.sdk import get_cluster_status
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
self.migration_func = migration_func
self.endpoints: dict[str, Endpoint] = {}
self.deployments: dict[str, Deployment] = {}
self.routers: dict[str, APIRouter] = {}

if retryable_exceptions is None:
self.retryable_exceptions = [
Expand Down Expand Up @@ -258,6 +260,7 @@ def get_main_app(self) -> Application:
app_name=self.name,
endpoints=self.endpoints.values(),
deployments=list(self.deployments.keys()),
routers=list(self.routers.values()),
)

def register_endpoint(
Expand Down Expand Up @@ -294,6 +297,24 @@ def unregister_endpoint(self, name: str):
if name in self.endpoints:
del self.endpoints[name]

def register_router(self, name: str, router: APIRouter):
"""Register a FastAPI router.
Args:
name (str): The name of the router.
router (APIRouter): The instance of the APIRouter to be registered.
"""
self.routers[name] = router

def unregister_router(self, name: str):
"""Unregister a FastAPI router.
Args:
name (str): The name of the router to be unregistered.
"""
if name in self.routers:
del self.routers[name]

def wait_for_deployment(self): # noqa: C901
"""Wait for the deployment to complete."""
consecutive_resource_unavailable = 0
Expand Down
11 changes: 10 additions & 1 deletion aana/storage/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aana.storage.models.api_key import ApiServiceBase
from aana.storage.models.base import BaseEntity

__all__ = ["get_session"]
__all__ = ["get_session", "get_db"]

engine = settings.db_config.get_engine()

Expand All @@ -28,3 +28,12 @@ def get_session() -> Session:
Session: SQLAlchemy Session object.
"""
return SessionLocal()


def get_db():
"""Get a database session."""
db = get_session()
try:
yield db
finally:
db.close()
4 changes: 2 additions & 2 deletions aana/tests/units/test_task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_task_queue(create_app): # noqa: C901
break
time.sleep(0.1)

assert task_status == "completed"
assert task_status == "completed", response.text
assert result == {"text": ["hello world!", "this is a test."]}

# Delete the task
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_task_queue(create_app): # noqa: C901
break
time.sleep(0.1)

assert task_status == "completed"
assert task_status == "completed", response.text
assert [chunk["text"] for chunk in result] == lowercase_text

# Send 30 tasks to the task queue
Expand Down

0 comments on commit 4030d70

Please sign in to comment.