Skip to content

Commit

Permalink
Merge pull request #225 from mobiusml/movchan74/webhook-support
Browse files Browse the repository at this point in the history
Webhook Support for Task Queue
  • Loading branch information
movchan74 authored Feb 4, 2025
2 parents 79e88d3 + 6d16a22 commit 533ab95
Show file tree
Hide file tree
Showing 21 changed files with 965 additions and 86 deletions.
43 changes: 43 additions & 0 deletions aana/alembic/versions/acb40dabc2c0_added_webhooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Added webhooks.
Revision ID: acb40dabc2c0
Revises: d40eba8ebc4c
Create Date: 2025-01-30 14:32:16.596842
"""
from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "acb40dabc2c0"
down_revision: str | None = "d40eba8ebc4c"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Upgrade database to this revision from previous."""
# fmt: off
op.create_table('webhooks',
sa.Column('id', sa.UUID(), nullable=False, comment='Webhook ID'),
sa.Column('user_id', sa.String(), nullable=True, comment='The user ID associated with the webhook'),
sa.Column('url', sa.String(), nullable=False, comment='The URL to which the webhook will send requests'),
sa.Column('events', sa.JSON().with_variant(postgresql.JSONB(astext_type=sa.Text()), 'postgresql'), nullable=False, comment='List of events the webhook is subscribed to. If None, the webhook is subscribed to all events.'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is inserted'),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is updated'),
sa.PrimaryKeyConstraint('id', name=op.f('pk_webhooks'))
)
with op.batch_alter_table('webhooks', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_webhooks_user_id'), ['user_id'], unique=False)
# fmt: on


def downgrade() -> None:
"""Downgrade database from this revision to previous."""
with op.batch_alter_table("webhooks", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("ix_webhooks_user_id"))

op.drop_table("webhooks")
18 changes: 7 additions & 11 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from inspect import isasyncgenfunction
from typing import Annotated, Any, get_origin

import orjson
from fastapi import FastAPI, Form, Query, Request
from fastapi.responses import StreamingResponse
from pydantic import ConfigDict, Field, ValidationError, create_model
Expand All @@ -19,7 +18,7 @@
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import check_admin_permissions
from aana.api.security import require_admin_access
from aana.configs.settings import settings as aana_settings
from aana.core.models.api_service import ApiKey
from aana.core.models.exception import ExceptionResponseModel
Expand Down Expand Up @@ -270,11 +269,11 @@ async def route_func_body( # noqa: C901
if event_manager:
event_manager.handle(bound_path, defer=defer)

# Parse json data from the body
body = orjson.loads(body)
# parse form data as a pydantic model and validate it
data = RequestModel.model_validate_json(body)

# Add api_key_info to the body if API service is enabled
api_key_info: dict = {}
api_key_info: ApiKey | None = None
if aana_settings.api_service.enabled:
api_key_info = request.state.api_key_info
api_key_field = next(
Expand All @@ -286,10 +285,7 @@ async def route_func_body( # noqa: C901
None,
)
if api_key_field:
body[api_key_field] = api_key_info

# parse form data as a pydantic model and validate it
data = RequestModel.model_validate(body)
setattr(data, api_key_field, api_key_info)

# if the input requires file upload, add the files to the data
if files:
Expand All @@ -314,7 +310,7 @@ async def route_func_body( # noqa: C901
task = TaskRepository(session).save(
endpoint=bound_path,
data=data_dict,
user_id=api_key_info.get("user_id"),
user_id=api_key_info.user_id if api_key_info else None,
)
return AanaJSONResponse(content={"task_id": str(task.id)})

Expand Down Expand Up @@ -349,7 +345,7 @@ async def route_func(
),
):
if aana_settings.api_service.enabled and self.admin_required:
check_admin_permissions(request)
require_admin_access(request)

if self.defer_option == DeferOption.ALWAYS:
defer = True
Expand Down
4 changes: 3 additions & 1 deletion aana/api/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError

from aana.api.exception_handler import (
Expand All @@ -18,6 +19,7 @@
app = FastAPI()

app.add_exception_handler(ValidationError, validation_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(Exception, aana_exception_handler)


Expand Down Expand Up @@ -48,7 +50,7 @@ async def api_key_check(request: Request, call_next):
if not api_key_info.is_subscription_active:
raise InactiveSubscription(key=api_key)

request.state.api_key_info = api_key_info.to_dict()
request.state.api_key_info = api_key_info.to_model()

response = await call_next(request)
return response
56 changes: 31 additions & 25 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import json
import time
from typing import Annotated, Any
from typing import Any
from uuid import UUID, uuid4

import orjson
import ray
from fastapi import Depends
from fastapi import APIRouter
from fastapi.openapi.utils import get_openapi
from fastapi.responses import StreamingResponse
from ray import serve
from sqlalchemy.orm import Session

from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema
from aana.api.app import app
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import AdminRequired
from aana.api.security import AdminAccessDependency
from aana.api.webhook import (
WebhookEventType,
trigger_task_webhooks,
)
from aana.api.webhook import router as webhook_router
from aana.configs.settings import settings as aana_settings
from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse
from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog
Expand All @@ -25,16 +29,7 @@
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.storage.models.task import Status as TaskStatus
from aana.storage.repository.task import TaskRepository
from aana.storage.session import get_session


def get_db():
"""Get a database session."""
db = get_session()
try:
yield db
finally:
db.close()
from aana.storage.session import GetDbDependency, get_session


@serve.deployment(ray_actor_options={"num_cpus": 0.1})
Expand All @@ -45,19 +40,31 @@ class RequestHandler:
ready = False

def __init__(
self, app_name: str, endpoints: list[Endpoint], deployments: list[str]
self,
app_name: str,
endpoints: list[Endpoint],
deployments: list[str],
routers: list[APIRouter] | None = None,
):
"""Constructor.
Args:
app_name (str): The name of the application.
endpoints (dict): List of endpoints for the request.
deployments (list[str]): List of deployment names for the app.
routers (list[APIRouter]): List of FastAPI routers to include in the app.
"""
self.app_name = app_name
self.endpoints = endpoints
self.deployments = deployments

# Include the webhook router
app.include_router(webhook_router)
# Include the custom routers
if routers is not None:
for router in routers:
app.include_router(router)

self.event_manager = EventManager()
self.custom_schemas: dict[str, dict] = {}
for endpoint in self.endpoints:
Expand Down Expand Up @@ -121,7 +128,8 @@ async def execute_task(self, task_id: str | UUID) -> Any:
path = task.endpoint
kwargs = task.data

task_repo.update_status(task_id, TaskStatus.RUNNING, 0)
task = task_repo.update_status(task_id, TaskStatus.RUNNING, 0)
await trigger_task_webhooks(WebhookEventType.TASK_STARTED, task)

for e in self.endpoints:
if e.path == path:
Expand All @@ -139,16 +147,18 @@ async def execute_task(self, task_id: str | UUID) -> Any:
out = await endpoint.run(**kwargs)

with get_session() as session:
TaskRepository(session).update_status(
task = TaskRepository(session).update_status(
task_id, TaskStatus.COMPLETED, 100, out
)
await trigger_task_webhooks(WebhookEventType.TASK_COMPLETED, task)
except Exception as e:
error_response = custom_exception_handler(None, e)
error = orjson.loads(error_response.body)
with get_session() as session:
TaskRepository(session).update_status(
task = TaskRepository(session).update_status(
task_id, TaskStatus.FAILED, 0, error
)
await trigger_task_webhooks(WebhookEventType.TASK_FAILED, task)
else:
return out
finally:
Expand All @@ -160,9 +170,7 @@ async def execute_task(self, task_id: str | UUID) -> Any:
description="Get the task status by task ID.",
include_in_schema=aana_settings.task_queue.enabled,
)
async def get_task_endpoint(
self, task_id: str, db: Annotated[Session, Depends(get_db)]
) -> TaskInfo:
async def get_task_endpoint(self, task_id: str, db: GetDbDependency) -> TaskInfo:
"""Get the task with the given ID.
Args:
Expand All @@ -186,9 +194,7 @@ async def get_task_endpoint(
description="Delete the task by task ID.",
include_in_schema=aana_settings.task_queue.enabled,
)
async def delete_task_endpoint(
self, task_id: str, db: Annotated[Session, Depends(get_db)]
) -> TaskId:
async def delete_task_endpoint(self, task_id: str, db: GetDbDependency) -> TaskId:
"""Delete the task with the given ID.
Args:
Expand Down Expand Up @@ -288,7 +294,7 @@ async def _async_chat_completions(
}

@app.get("/api/status", response_model=SDKStatusResponse)
async def status(self, is_admin: AdminRequired) -> SDKStatusResponse:
async def status(self, is_admin: AdminAccessDependency) -> SDKStatusResponse:
"""The endpoint for checking the status of the application."""
app_names = [
self.app_name,
Expand Down
34 changes: 22 additions & 12 deletions aana/api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from fastapi import Depends, Request

from aana.configs.settings import settings as aana_settings
from aana.core.models.api_service import ApiKey
from aana.exceptions.api_service import AdminOnlyAccess


def check_admin_permissions(request: Request):
"""Check if the user is an admin.
def require_admin_access(request: Request) -> bool:
"""Check if the user is an admin. If not, raise an exception.
Args:
request (Request): The request object
Expand All @@ -16,20 +17,29 @@ def check_admin_permissions(request: Request):
AdminOnlyAccess: If the user is not an admin
"""
if aana_settings.api_service.enabled:
api_key_info = request.state.api_key_info
is_admin = api_key_info.get("is_admin", False)
api_key_info: ApiKey = request.state.api_key_info
is_admin = api_key_info.is_admin if api_key_info else False
if not is_admin:
raise AdminOnlyAccess()
return True


class AdminCheck:
"""Dependency to check if the user is an admin."""
def extract_api_key_info(request: Request) -> ApiKey | None:
"""Get the API key info dependency."""
return getattr(request.state, "api_key_info", None)

async def __call__(self, request: Request) -> bool:
"""Check if the user is an admin."""
check_admin_permissions(request)
return True

def extract_user_id(request: Request) -> str | None:
"""Get the user ID dependency."""
api_key_info = extract_api_key_info(request)
return api_key_info.user_id if api_key_info else None

AdminRequired = Annotated[bool, Depends(AdminCheck())]
""" Annotation to check if the user is an admin. If not, it will raise an exception. """

AdminAccessDependency = Annotated[bool, Depends(require_admin_access)]
""" Dependency to check if the user is an admin. If not, it will raise an exception. """

UserIdDependency = Annotated[str | None, Depends(extract_user_id)]
""" Dependency to get the user ID. """

ApiKeyInfoDependency = Annotated[ApiKey | None, Depends(extract_api_key_info)]
""" Dependency to get the API key info. """
Loading

0 comments on commit 533ab95

Please sign in to comment.