Skip to content

Commit

Permalink
Merge pull request #224 from mobiusml/admin_only_endpoints
Browse files Browse the repository at this point in the history
Enhance API Key Validation and Admin Access Control
  • Loading branch information
movchan74 authored Feb 3, 2025
2 parents 8d9a405 + 5cdd089 commit ae08b1d
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 8 deletions.
32 changes: 31 additions & 1 deletion aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from inspect import isasyncgenfunction
from typing import Annotated, Any, get_origin

Expand All @@ -18,6 +19,7 @@
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import check_admin_permissions
from aana.configs.settings import settings as aana_settings
from aana.core.models.api_service import ApiKey
from aana.core.models.exception import ExceptionResponseModel
Expand All @@ -37,6 +39,21 @@ def get_default_values(func):
}


class DeferOption(str, Enum):
"""Enum for defer option.
Attributes:
ALWAYS (str): Always defer. Endpoints with this option will always be defer execution to the task queue.
NEVER (str): Never defer. Endpoints with this option will never be defer execution to the task queue.
OPTIONAL (str): Optionally defer. Endpoints with this option can be defer execution to the task queue if
the defer query parameter is set to True.
"""

ALWAYS = "always"
NEVER = "never"
OPTIONAL = "optional"


@dataclass
class Endpoint:
"""Class used to represent an endpoint.
Expand All @@ -45,12 +62,16 @@ class Endpoint:
name (str): Name of the endpoint.
path (str): Path of the endpoint (e.g. "/video/transcribe").
summary (str): Description of the endpoint that will be shown in the API documentation.
admin_required (bool): Flag indicating if the endpoint requires admin access.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler] | None): The list of event handlers to register for the endpoint.
"""

name: str
path: str
summary: str
admin_required: bool = False
defer_option: DeferOption = DeferOption.OPTIONAL
initialized: bool = False
event_handlers: list[EventHandler] | None = None

Expand Down Expand Up @@ -323,9 +344,18 @@ async def route_func(
defer: bool = Query(
description="Defer execution of the endpoint to the task queue.",
default=False,
include_in_schema=aana_settings.task_queue.enabled,
include_in_schema=aana_settings.task_queue.enabled
and self.defer_option == DeferOption.OPTIONAL,
),
):
if aana_settings.api_service.enabled and self.admin_required:
check_admin_permissions(request)

if self.defer_option == DeferOption.ALWAYS:
defer = True
elif self.defer_option == DeferOption.NEVER:
defer = False

form_data = await request.form()

# Parse files from the form data
Expand Down
15 changes: 11 additions & 4 deletions aana/api/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from fastapi import FastAPI, Request
from pydantic import ValidationError

Expand All @@ -10,6 +9,7 @@
from aana.exceptions.api_service import (
ApiKeyNotFound,
ApiKeyNotProvided,
ApiKeyValidationFailed,
InactiveSubscription,
)
from aana.storage.models.api_key import ApiKeyEntity
Expand All @@ -24,16 +24,23 @@
@app.middleware("http")
async def api_key_check(request: Request, call_next):
"""Middleware to check the API key and subscription status."""
excluded_paths = ["/openapi.json", "/docs", "/redoc"]
if request.url.path in excluded_paths:
return await call_next(request)

if aana_settings.api_service.enabled:
api_key = request.headers.get("x-api-key")

if not api_key:
raise ApiKeyNotProvided()

with get_session() as session:
api_key_info = (
session.query(ApiKeyEntity).filter_by(api_key=api_key).first()
)
try:
api_key_info = (
session.query(ApiKeyEntity).filter_by(api_key=api_key).first()
)
except Exception as e:
raise ApiKeyValidationFailed() from e

if not api_key_info:
raise ApiKeyNotFound(key=api_key)
Expand Down
18 changes: 16 additions & 2 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.api.security import AdminRequired
from aana.configs.settings import settings as aana_settings
from aana.core.models.api import DeploymentStatus, SDKStatus, SDKStatusResponse
from aana.core.models.chat import ChatCompletion, ChatCompletionRequest, ChatDialog
Expand Down Expand Up @@ -201,9 +202,22 @@ async def delete_task_endpoint(
task = task_repo.delete(task_id)
return TaskId(task_id=str(task.id))

@app.post("/chat/completions", response_model=ChatCompletion)
@app.post(
"/chat/completions",
response_model=ChatCompletion,
include_in_schema=aana_settings.openai_endpoint_enabled,
)
async def chat_completions(self, request: ChatCompletionRequest):
"""Handle chat completions requests for OpenAI compatible API."""
if not aana_settings.openai_endpoint_enabled:
return AanaJSONResponse(
content={
"error": {
"message": "The OpenAI-compatible endpoint is not enabled."
}
},
status_code=404,
)

async def _async_chat_completions(
handle: AanaDeploymentHandle,
Expand Down Expand Up @@ -274,7 +288,7 @@ async def _async_chat_completions(
}

@app.get("/api/status", response_model=SDKStatusResponse)
async def status(self) -> SDKStatusResponse:
async def status(self, is_admin: AdminRequired) -> SDKStatusResponse:
"""The endpoint for checking the status of the application."""
app_names = [
self.app_name,
Expand Down
35 changes: 35 additions & 0 deletions aana/api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Annotated

from fastapi import Depends, Request

from aana.configs.settings import settings as aana_settings
from aana.exceptions.api_service import AdminOnlyAccess


def check_admin_permissions(request: Request):
"""Check if the user is an admin.
Args:
request (Request): The request object
Raises:
AdminOnlyAccess: If the user is not an admin
"""
if aana_settings.api_service.enabled:
api_key_info = request.state.api_key_info
is_admin = api_key_info.get("is_admin", False)
if not is_admin:
raise AdminOnlyAccess()


class AdminCheck:
"""Dependency to check if the user is an admin."""

async def __call__(self, request: Request) -> bool:
"""Check if the user is an admin."""
check_admin_permissions(request)
return True


AdminRequired = Annotated[bool, Depends(AdminCheck())]
""" Annotation to check if the user is an admin. If not, it will raise an exception. """
3 changes: 3 additions & 0 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Settings(BaseSettings):
audio_dir (Path): The temporary audio directory.
model_dir (Path): The temporary model directory.
num_workers (int): The number of web workers.
openai_endpoint_enabled (bool): Flag indicating if the OpenAI-compatible endpoint is enabled. Enabled by default.
task_queue (TaskQueueSettings): The task queue settings.
db_config (DbSettings): The database configuration.
test (TestSettings): The test settings.
Expand All @@ -79,6 +80,8 @@ class Settings(BaseSettings):

num_workers: int = 2

openai_endpoint_enabled: bool = True

task_queue: TaskQueueSettings = TaskQueueSettings()

db_config: DbSettings = DbSettings()
Expand Down
26 changes: 26 additions & 0 deletions aana/exceptions/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,29 @@ def __init__(self, key: str):
def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.key,))


class AdminOnlyAccess(BaseException):
"""Exception raised when the user does not have enough permissions."""

def __init__(self):
"""Initialize the exception."""
self.message = "Admin only access"
super().__init__(message=self.message)

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, ())


class ApiKeyValidationFailed(BaseException):
"""Exception raised when the API key validation fails."""

def __init__(self):
"""Initialize the exception."""
self.message = "API key validation failed"
super().__init__(message=self.message)

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, ())
8 changes: 7 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ray.serve.schema import ApplicationStatusOverview
from rich import print as rprint

from aana.api.api_generation import Endpoint
from aana.api.api_generation import DeferOption, Endpoint
from aana.api.event_handlers.event_handler import EventHandler
from aana.api.request_handler import RequestHandler
from aana.configs.settings import settings as aana_settings
Expand Down Expand Up @@ -266,6 +266,8 @@ def register_endpoint(
path: str,
summary: str,
endpoint_cls: type[Endpoint],
admin_required: bool = False,
defer_option: DeferOption = DeferOption.OPTIONAL,
event_handlers: list[EventHandler] | None = None,
):
"""Register an endpoint.
Expand All @@ -275,12 +277,16 @@ def register_endpoint(
path (str): The path of the endpoint.
summary (str): The summary of the endpoint.
endpoint_cls (Type[Endpoint]): The class of the endpoint.
admin_required (bool, optional): If True, the endpoint requires admin access. Defaults to False.
defer_option (DeferOption): Defer option for the endpoint (always, never, optional).
event_handlers (list[EventHandler], optional): The event handlers to register for the endpoint.
"""
endpoint = endpoint_cls(
name=name,
path=path,
summary=summary,
admin_required=admin_required,
defer_option=defer_option,
event_handlers=event_handlers,
)
self.endpoints[name] = endpoint
Expand Down
5 changes: 5 additions & 0 deletions aana/storage/models/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class ApiKeyEntity(ApiServiceBase):
user_id: Mapped[str] = mapped_column(
nullable=False, comment="ID of the user who owns this API key"
)
is_admin: Mapped[bool] = mapped_column(
nullable=False, default=False, comment="Whether the user is an admin"
)
subscription_id: Mapped[str] = mapped_column(
nullable=False, comment="ID of the associated subscription"
)
Expand All @@ -34,7 +37,9 @@ def __repr__(self) -> str:
"""String representation of the API key."""
return (
f"<APIKeyEntity(id={self.id}, "
f"api_key={self.api_key}, "
f"user_id={self.user_id}, "
f"is_admin={self.is_admin}, "
f"subscription_id={self.subscription_id}, "
f"is_subscription_active={self.is_subscription_active})>"
)
Expand Down
3 changes: 3 additions & 0 deletions docs/pages/openai_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to i

Chat Completions API is available at the `/chat/completions` endpoint.

!!! Tip
The endpoint is enabled by default but can be disabled by setting the environment variable: `OPENAI_ENDPOINT_ENABLED=False`.

It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API.

```python
Expand Down

0 comments on commit ae08b1d

Please sign in to comment.