Skip to content

Commit

Permalink
Implement CRUD operations for webhooks and added corresponding tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Feb 3, 2025
1 parent 4030d70 commit b9f6323
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 19 deletions.
15 changes: 15 additions & 0 deletions aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
@app.middleware("http")
async def api_key_check(request: Request, call_next):
"""Middleware to check the API key and subscription status."""
excluded_paths = ["/openapi.json", "/docs", "/redoc"]
if request.url.path in excluded_paths:
return await call_next(request)

if aana_settings.api_service.enabled:
api_key = request.headers.get("x-api-key")

Expand All @@ -44,3 +48,14 @@ async def api_key_check(request: Request, call_next):

response = await call_next(request)
return response


def get_api_key_info(request: Request) -> dict | None:
"""Get the API key info dependency."""
return getattr(request.state, "api_key_info", None)


def get_user_id(request: Request) -> str | None:
"""Get the user ID dependency."""
api_key_info = get_api_key_info(request)
return api_key_info.get("user_id") if api_key_info else None
142 changes: 123 additions & 19 deletions aana/api/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from typing import Annotated, Any

import httpx
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.orm import Session
from tenacity import retry, stop_after_attempt, wait_exponential

from aana.api.app import get_user_id
from aana.configs.settings import settings as aana_settings
from aana.storage.models.api_key import ApiKeyEntity
from aana.storage.models.task import TaskEntity
Expand All @@ -23,12 +24,12 @@
router = APIRouter(tags=["webhooks"])


# Request models


class WebhookRegistrationRequest(BaseModel):
"""Request to register a webhook."""

user_id: str | None = Field(
None, description="The user ID. If None, the webhook is a system-wide webhook."
)
url: str = Field(
..., description="The URL to which the webhook will send requests."
)
Expand All @@ -40,13 +41,57 @@ class WebhookRegistrationRequest(BaseModel):
model_config = ConfigDict(extra="forbid")


class WebhookUpdateRequest(BaseModel):
"""Request to update a webhook."""

url: str | None = Field(None, description="New URL for the webhook.")
events: list[WebhookEventType] | None = Field(
None, description="New list of events to subscribe to."
)

model_config = ConfigDict(extra="forbid")


# Response models


class WebhookRegistrationResponse(BaseModel):
"""Response for a webhook registration."""

id: str | None
message: str


class WebhookResponse(BaseModel):
"""Response for a webhook registration."""

id: str = Field(..., description="The webhook ID.")
url: str = Field(
..., description="The URL to which the webhook will send requests."
)
events: list[WebhookEventType] = Field(
..., description="The events that the webhook is subscribed to."
)

@classmethod
def from_entity(cls, webhook: WebhookEntity) -> "WebhookResponse":
"""Create a WebhookResponse from a WebhookEntity."""
return WebhookResponse(
id=str(webhook.id),
url=webhook.url,
events=webhook.events,
)


class WebhookListResponse(BaseModel):
"""Response for a list of webhooks."""

webhooks: list[WebhookResponse] = Field(..., description="The list of webhooks.")


# Webhook Models


class TaskStatusChangeWebhookPayload(BaseModel):
"""Payload for a task status change webhook."""

Expand All @@ -63,6 +108,9 @@ class WebhookBody(BaseModel):
payload: TaskStatusChangeWebhookPayload


# Webhook functions


def generate_hmac_signature(body: dict, user_id: str | None) -> str:
"""Generate HMAC signature for a payload for a given user.
Expand Down Expand Up @@ -149,30 +197,86 @@ async def trigger_task_webhooks(event: WebhookEventType, task: TaskEntity):
await trigger_webhooks(event, body, task.user_id)


# Webhook endpoints


@router.post("/webhooks", status_code=201)
async def register_webhook(
async def create_webhook(
request: WebhookRegistrationRequest,
db: Annotated[Session, Depends(get_db)],
) -> WebhookRegistrationResponse:
"""Register a new webhook.
Args:
request (WebhookRegistrationRequest): The webhook registration request.
db (Session): The database session.
Returns:
WebhookRegistrationResponse: The response message.
"""
user_id: Annotated[str | None, Depends(get_user_id)],
) -> WebhookResponse:
"""This endpoint is used to register a webhook."""
webhook_repo = WebhookRepository(db)
try:
webhook = WebhookEntity(
user_id=request.user_id,
user_id=user_id,
url=request.url,
events=request.events,
)
webhook_repo.save(webhook)
webhook = webhook_repo.save(webhook)
except Exception:
return WebhookRegistrationResponse(message="Failed to register webhook")
return WebhookRegistrationResponse(
id=str(webhook.id), message="Webhook registered successfully"
return WebhookResponse.from_entity(webhook)


@router.get("/webhooks")
async def list_webhooks(
db: Annotated[Session, Depends(get_db)],
user_id: Annotated[str | None, Depends(get_user_id)],
) -> WebhookListResponse:
"""This endpoint is used to list all registered webhooks."""
webhook_repo = WebhookRepository(db)
webhooks = webhook_repo.get_webhooks(user_id, None)
return WebhookListResponse(
webhooks=[WebhookResponse.from_entity(webhook) for webhook in webhooks]
)


@router.get("/webhooks/{webhook_id}")
async def get_webhook(
webhook_id: str,
db: Annotated[Session, Depends(get_db)],
user_id: Annotated[str | None, Depends(get_user_id)],
) -> WebhookResponse:
"""This endpoint is used to fetch a webhook by ID."""
webhook_repo = WebhookRepository(db)
webhook = webhook_repo.read(webhook_id, check=False)
if not webhook or webhook.user_id != user_id:
raise HTTPException(status_code=404, detail="Webhook not found")
return WebhookResponse.from_entity(webhook)


@router.put("/webhooks/{webhook_id}")
async def update_webhook(
webhook_id: str,
request: WebhookUpdateRequest,
db: Annotated[Session, Depends(get_db)],
user_id: Annotated[str | None, Depends(get_user_id)],
) -> WebhookResponse:
"""This endpoint is used to update a webhook."""
webhook_repo = WebhookRepository(db)
webhook = webhook_repo.read(webhook_id, check=False)
if not webhook or webhook.user_id != user_id:
raise HTTPException(status_code=404, detail="Webhook not found")
if request.url is not None:
webhook.url = request.url
if request.events is not None:
webhook.events = request.events
webhook = webhook_repo.save(webhook)
return WebhookResponse.from_entity(webhook)


@router.delete("/webhooks/{webhook_id}")
async def delete_webhook(
webhook_id: str,
db: Annotated[Session, Depends(get_db)],
user_id: Annotated[str | None, Depends(get_user_id)],
) -> WebhookResponse:
"""This endpoint is used to delete a webhook."""
webhook_repo = WebhookRepository(db)
webhook = webhook_repo.read(webhook_id, check=False)
if not webhook or webhook.user_id != user_id:
raise HTTPException(status_code=404, detail="Webhook not found")
webhook = webhook_repo.delete(webhook.id)
return WebhookResponse.from_entity(webhook)
36 changes: 36 additions & 0 deletions aana/storage/repository/webhook.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from sqlalchemy import func, or_
from sqlalchemy.orm import Session

Expand All @@ -13,6 +15,40 @@ def __init__(self, session: Session):
"""Constructor."""
super().__init__(session, WebhookEntity)

def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity:
"""Reads a single webhook from the database.
Args:
item_id (str | UUID): ID of the webhook to retrieve.
check (bool): whether to raise if the entity is not found (defaults to True).
Returns:
The corresponding entity from the database if found.
Raises:
NotFoundException if the entity is not found and `check` is True.
"""
if isinstance(item_id, str):
item_id = UUID(item_id)
return super().read(item_id, check)

def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity:
"""Delete a webhook from the database.
Args:
item_id (str | UUID): The ID of the webhook to delete.
check (bool): whether to raise if the entity is not found (defaults to True).
Returns:
WebhookEntity: The deleted webhook.
Raises:
NotFoundException: The id does not correspond to a record in the database.
"""
if isinstance(item_id, str):
item_id = UUID(item_id)
return super().delete(item_id, check)

def save(self, webhook: WebhookEntity) -> WebhookEntity:
"""Save a webhook to the database.
Expand Down
54 changes: 54 additions & 0 deletions aana/tests/units/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,57 @@ def webhook_listener(request):

httpserver.check_assertions()
httpserver.check_handler_errors()


def test_webhook_crud(create_app):
"""Test webhook CRUD operations."""
aana_app = create_app(deployments, endpoints)
port = aana_app.port
base_url = f"http://localhost:{port}"

# Clear existing webhooks
response = requests.get(f"{base_url}/webhooks")
assert response.status_code == 200
webhooks = response.json()["webhooks"]
for webhook in webhooks:
response = requests.delete(f"{base_url}/webhooks/{webhook['id']}")
assert response.status_code == 200

# Test Create
webhook_data = {
"url": "http://example.com/webhook",
"events": ["task.completed", "task.failed"],
}
response = requests.post(f"{base_url}/webhooks", json=webhook_data)
assert response.status_code == 201
webhook_id = response.json()["id"]
assert response.json()["url"] == webhook_data["url"]
assert response.json()["events"] == webhook_data["events"]

# Test Read (List)
response = requests.get(f"{base_url}/webhooks")
assert response.status_code == 200
webhooks = response.json()["webhooks"]
assert len(webhooks) == 1
assert webhooks[0]["id"] == webhook_id

# Test Read (Single)
response = requests.get(f"{base_url}/webhooks/{webhook_id}")
assert response.status_code == 200
assert response.json()["id"] == webhook_id
assert response.json()["url"] == webhook_data["url"]

# Test Update
update_data = {"url": "http://example.com/webhook2", "events": ["task.completed"]}
response = requests.put(f"{base_url}/webhooks/{webhook_id}", json=update_data)
assert response.status_code == 200
assert response.json()["url"] == update_data["url"]
assert response.json()["events"] == update_data["events"]

# Test Delete
response = requests.delete(f"{base_url}/webhooks/{webhook_id}")
assert response.status_code == 200

# Verify webhook is deleted
response = requests.get(f"{base_url}/webhooks/{webhook_id}")
assert response.status_code == 404

0 comments on commit b9f6323

Please sign in to comment.