Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into movchan74/webhook-sup…
Browse files Browse the repository at this point in the history
…port
  • Loading branch information
Aleksandr Movchan committed Feb 3, 2025
2 parents b9f6323 + ae08b1d commit cb2fbe7
Show file tree
Hide file tree
Showing 13 changed files with 1,735 additions and 1,953 deletions.
32 changes: 31 additions & 1 deletion aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from inspect import isasyncgenfunction
from typing import Annotated, Any, get_origin

Expand All @@ -18,6 +19,7 @@
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import check_admin_permissions
from aana.configs.settings import settings as aana_settings
from aana.core.models.api_service import ApiKey
from aana.core.models.exception import ExceptionResponseModel
Expand All @@ -37,6 +39,21 @@ def get_default_values(func):
}


class DeferOption(str, Enum):
"""Enum for defer option.
Attributes:
ALWAYS (str): Always defer. Endpoints with this option will always be defer execution to the task queue.
NEVER (str): Never defer. Endpoints with this option will never be defer execution to the task queue.
OPTIONAL (str): Optionally defer. Endpoints with this option can be defer execution to the task queue if
the defer query parameter is set to True.
"""

ALWAYS = "always"
NEVER = "never"
OPTIONAL = "optional"


@dataclass
class Endpoint:
"""Class used to represent an endpoint.
Expand All @@ -45,12 +62,16 @@ class Endpoint:
name (str): Name of the endpoint.
path (str): Path of the endpoint (e.g. "/video/transcribe").
summary (str): Description of the endpoint that will be shown in the API documentation.
admin_required (bool): Flag indicating if the endpoint requires admin access.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler] | None): The list of event handlers to register for the endpoint.
"""

name: str
path: str
summary: str
admin_required: bool = False
defer_option: DeferOption = DeferOption.OPTIONAL
initialized: bool = False
event_handlers: list[EventHandler] | None = None

Expand Down Expand Up @@ -323,9 +344,18 @@ async def route_func(
defer: bool = Query(
description="Defer execution of the endpoint to the task queue.",
default=False,
include_in_schema=aana_settings.task_queue.enabled,
include_in_schema=aana_settings.task_queue.enabled
and self.defer_option == DeferOption.OPTIONAL,
),
):
if aana_settings.api_service.enabled and self.admin_required:
check_admin_permissions(request)

if self.defer_option == DeferOption.ALWAYS:
defer = True
elif self.defer_option == DeferOption.NEVER:
defer = False

form_data = await request.form()

# Parse files from the form data
Expand Down
10 changes: 7 additions & 3 deletions aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aana.exceptions.api_service import (
ApiKeyNotFound,
ApiKeyNotProvided,
ApiKeyValidationFailed,
InactiveSubscription,
)
from aana.storage.models.api_key import ApiKeyEntity
Expand All @@ -34,9 +35,12 @@ async def api_key_check(request: Request, call_next):
raise ApiKeyNotProvided()

with get_session() as session:
api_key_info = (
session.query(ApiKeyEntity).filter_by(api_key=api_key).first()
)
try:
api_key_info = (
session.query(ApiKeyEntity).filter_by(api_key=api_key).first()
)
except Exception as e:
raise ApiKeyValidationFailed() from e

if not api_key_info:
raise ApiKeyNotFound(key=api_key)
Expand Down
18 changes: 16 additions & 2 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import AdminRequired
from aana.api.webhook import (
WebhookEventType,
trigger_task_webhooks,
Expand Down Expand Up @@ -212,9 +213,22 @@ async def delete_task_endpoint(
task = task_repo.delete(task_id)
return TaskId(task_id=str(task.id))

@app.post("/chat/completions", response_model=ChatCompletion)
@app.post(
"/chat/completions",
response_model=ChatCompletion,
include_in_schema=aana_settings.openai_endpoint_enabled,
)
async def chat_completions(self, request: ChatCompletionRequest):
"""Handle chat completions requests for OpenAI compatible API."""
if not aana_settings.openai_endpoint_enabled:
return AanaJSONResponse(
content={
"error": {
"message": "The OpenAI-compatible endpoint is not enabled."
}
},
status_code=404,
)

async def _async_chat_completions(
handle: AanaDeploymentHandle,
Expand Down Expand Up @@ -285,7 +299,7 @@ async def _async_chat_completions(
}

@app.get("/api/status", response_model=SDKStatusResponse)
async def status(self) -> SDKStatusResponse:
async def status(self, is_admin: AdminRequired) -> SDKStatusResponse:
"""The endpoint for checking the status of the application."""
app_names = [
self.app_name,
Expand Down
35 changes: 35 additions & 0 deletions aana/api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Annotated

from fastapi import Depends, Request

from aana.configs.settings import settings as aana_settings
from aana.exceptions.api_service import AdminOnlyAccess


def check_admin_permissions(request: Request):
"""Check if the user is an admin.
Args:
request (Request): The request object
Raises:
AdminOnlyAccess: If the user is not an admin
"""
if aana_settings.api_service.enabled:
api_key_info = request.state.api_key_info
is_admin = api_key_info.get("is_admin", False)
if not is_admin:
raise AdminOnlyAccess()


class AdminCheck:
"""Dependency to check if the user is an admin."""

async def __call__(self, request: Request) -> bool:
"""Check if the user is an admin."""
check_admin_permissions(request)
return True


AdminRequired = Annotated[bool, Depends(AdminCheck())]
""" Annotation to check if the user is an admin. If not, it will raise an exception. """
3 changes: 3 additions & 0 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Settings(BaseSettings):
audio_dir (Path): The temporary audio directory.
model_dir (Path): The temporary model directory.
num_workers (int): The number of web workers.
openai_endpoint_enabled (bool): Flag indicating if the OpenAI-compatible endpoint is enabled. Enabled by default.
task_queue (TaskQueueSettings): The task queue settings.
db_config (DbSettings): The database configuration.
test (TestSettings): The test settings.
Expand All @@ -91,6 +92,8 @@ class Settings(BaseSettings):

num_workers: int = 2

openai_endpoint_enabled: bool = True

task_queue: TaskQueueSettings = TaskQueueSettings()

db_config: DbSettings = DbSettings()
Expand Down
13 changes: 11 additions & 2 deletions aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from aana.utils.lazy_import import LazyImport

with LazyImport(
"Run 'pip install mobius-faster-whisper' or 'pip install aana[asr]'"
"Run 'pip install faster-whisper' or 'pip install aana[asr]'"
) as faster_whisper_import:
from faster_whisper import BatchedInferencePipeline, WhisperModel

Expand Down Expand Up @@ -318,7 +318,16 @@ async def transcribe_in_chunks(
params = BatchedWhisperParams()
audio_array = audio.get_numpy()
if vad_segments:
sampling_rate = self.model.feature_extractor.sampling_rate
vad_input = [seg.to_whisper_dict() for seg in vad_segments]
vad_input = [
{
"start": int(seg["start"] * sampling_rate),
"end": int(seg["end"] * sampling_rate),
}
for seg in vad_input
]

if not audio_array.any():
# For silent audios/no audio tracks, return empty output with language as silence
yield WhisperOutput(
Expand All @@ -332,7 +341,7 @@ async def transcribe_in_chunks(
try:
segments, info = self.batched_model.transcribe(
audio_array,
vad_segments=vad_input if vad_segments else None,
clip_timestamps=vad_input if vad_segments else None,
batch_size=batch_size,
**params.model_dump(),
)
Expand Down
26 changes: 26 additions & 0 deletions aana/exceptions/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,29 @@ def __init__(self, key: str):
def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.key,))


class AdminOnlyAccess(BaseException):
"""Exception raised when the user does not have enough permissions."""

def __init__(self):
"""Initialize the exception."""
self.message = "Admin only access"
super().__init__(message=self.message)

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, ())


class ApiKeyValidationFailed(BaseException):
"""Exception raised when the API key validation fails."""

def __init__(self):
"""Initialize the exception."""
self.message = "API key validation failed"
super().__init__(message=self.message)

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, ())
8 changes: 7 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.serve.schema import ApplicationStatusOverview
from rich import print as rprint

from aana.api.api_generation import Endpoint
from aana.api.api_generation import DeferOption, Endpoint
from aana.api.event_handlers.event_handler import EventHandler
from aana.api.request_handler import RequestHandler
from aana.configs.settings import settings as aana_settings
Expand Down Expand Up @@ -269,6 +269,8 @@ def register_endpoint(
path: str,
summary: str,
endpoint_cls: type[Endpoint],
admin_required: bool = False,
defer_option: DeferOption = DeferOption.OPTIONAL,
event_handlers: list[EventHandler] | None = None,
):
"""Register an endpoint.
Expand All @@ -278,12 +280,16 @@ def register_endpoint(
path (str): The path of the endpoint.
summary (str): The summary of the endpoint.
endpoint_cls (Type[Endpoint]): The class of the endpoint.
admin_required (bool, optional): If True, the endpoint requires admin access. Defaults to False.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler], optional): The event handlers to register for the endpoint.
"""
endpoint = endpoint_cls(
name=name,
path=path,
summary=summary,
admin_required=admin_required,
defer_option=defer_option,
event_handlers=event_handlers,
)
self.endpoints[name] = endpoint
Expand Down
5 changes: 5 additions & 0 deletions aana/storage/models/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class ApiKeyEntity(ApiServiceBase):
user_id: Mapped[str] = mapped_column(
nullable=False, comment="ID of the user who owns this API key"
)
is_admin: Mapped[bool] = mapped_column(
nullable=False, default=False, comment="Whether the user is an admin"
)
subscription_id: Mapped[str] = mapped_column(
nullable=False, comment="ID of the associated subscription"
)
Expand All @@ -37,7 +40,9 @@ def __repr__(self) -> str:
"""String representation of the API key."""
return (
f"<APIKeyEntity(id={self.id}, "
f"api_key={self.api_key}, "
f"user_id={self.user_id}, "
f"is_admin={self.is_admin}, "
f"subscription_id={self.subscription_id}, "
f"is_subscription_active={self.is_subscription_active})>"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/model_hub/asr.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[WhisperDeployment](./../../reference/deployments.md#aana.deployments.whisper_deployment.WhisperDeployment) allows you to transcribe or translate audio with Whisper models. The deployment is based on the [faster-whisper](https://github.com/SYSTRAN/faster-whisper) library.

!!! Tip
To use Whisper deployment, install required libraries with `pip install mobius-faster-whisper` or include extra dependencies using `pip install aana[asr]`.
To use Whisper deployment, install required libraries with `pip install faster-whisper` or include extra dependencies using `pip install aana[asr]`.

[WhisperConfig](./../../reference/deployments.md#aana.deployments.whisper_deployment.WhisperConfig) is used to configure the Whisper deployment.

Expand Down
3 changes: 3 additions & 0 deletions docs/pages/openai_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to i

Chat Completions API is available at the `/chat/completions` endpoint.

!!! Tip
The endpoint is enabled by default but can be disabled by setting the environment variable: `OPENAI_ENDPOINT_ENABLED=False`.

It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API.

```python
Expand Down
Loading

0 comments on commit cb2fbe7

Please sign in to comment.