diff --git a/api/api.py b/api/api.py index b876f66..247a775 100644 --- a/api/api.py +++ b/api/api.py @@ -1,11 +1,13 @@ import asyncio import os import signal +import time import traceback +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum -from functools import lru_cache from pathlib import Path from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import UUID, uuid4 @@ -14,120 +16,116 @@ from dotenv import load_dotenv from fastapi import ( BackgroundTasks, - Depends, FastAPI, - Header, HTTPException, Query, + Request, status, ) from fastapi.concurrency import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse from loguru import logger from pydantic import BaseModel, Field -from supabase import Client, create_client - from swarms.structs.agent import Agent -# Load environment variables load_dotenv() +@dataclass +class RateLimitWindow: + requests: int + start_time: float -class APIKey(BaseModel): - """Model matching Supabase api_keys table""" - - id: UUID - created_at: datetime - name: str - user_id: UUID - key: str - limit_credit_dollar: Optional[float] = None - is_deleted: bool = False - - -class User(BaseModel): - id: UUID - name: str - is_active: bool = True - is_admin: bool = False - - -@lru_cache() -def get_supabase() -> Client: - """Get cached Supabase client""" - supabase_url = os.getenv("SUPABASE_URL") - supabase_key = os.getenv("SUPABASE_SERVICE_KEY") - if not supabase_url or not supabase_key: - raise ValueError("Supabase configuration is missing") - return create_client(supabase_url, supabase_key) - - -async def get_current_user( - api_key: str = Header(..., description="API key for authentication"), -) -> User: - """Validate API key against Supabase and return current user.""" - if not api_key or not api_key.startswith("sk-"): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key format", - headers={"WWW-Authenticate": "ApiKey"}, - ) - - try: - supabase = get_supabase() - - # Query the api_keys table - response = ( - supabase.table("api_keys") - .select("id, name, user_id, key, limit_credit_dollar, is_deleted") - .eq("key", api_key) - .single() - .execute() - ) - - if not response.data: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - headers={"WWW-Authenticate": "ApiKey"}, - ) - - key_data = response.data - - # Check if key is deleted - if key_data["is_deleted"]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="API key has been deleted", - headers={"WWW-Authenticate": "ApiKey"}, - ) - - # Check credit limit if applicable - if ( - key_data["limit_credit_dollar"] is not None - and key_data["limit_credit_dollar"] <= 0 - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="API key credit limit exceeded", - ) - - # Create user object - return User( - id=key_data["user_id"], - name=key_data["name"], - is_active=not key_data["is_deleted"], - ) - - except Exception as e: - logger.error(f"Error validating API key: {str(e)}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="API key validation failed", - headers={"WWW-Authenticate": "ApiKey"}, - ) +class RateLimiter: + def __init__(self): + # Store rate limit windows per IP + self.ip_windows: Dict[str, Dict[str, RateLimitWindow]] = defaultdict(dict) + + # Configure different rate limits + self.limits = { + 'per_second': 2, # 2 requests per second + 'per_minute': 30, # 30 requests per minute + 'per_hour': 1000, # 1000 requests per hour + 'per_day': 10000, # 10000 requests per day + } + + # Window durations in seconds + self.windows = { + 'per_second': 1, + 'per_minute': 60, + 'per_hour': 3600, + 'per_day': 86400, + } + + def _clean_old_windows(self, ip: str): + """Remove expired windows for an IP""" + current_time = time.time() + for window_type in list(self.ip_windows[ip].keys()): + window = self.ip_windows[ip][window_type] + if current_time - window.start_time >= self.windows[window_type]: + del self.ip_windows[ip][window_type] + + def _get_window(self, ip: str, window_type: str) -> RateLimitWindow: + """Get or create a rate limit window""" + current_time = time.time() + window = self.ip_windows[ip].get(window_type) + + # If window doesn't exist or is expired, create new one + if not window or (current_time - window.start_time >= self.windows[window_type]): + window = RateLimitWindow(requests=0, start_time=current_time) + self.ip_windows[ip][window_type] = window + + return window + + async def check_rate_limit(self, request: Request) -> None: + """Check if request exceeds rate limits""" + # Skip rate limiting for health check endpoint + if request.url.path == "/health": + return + + ip = request.client.host + self._clean_old_windows(ip) + + # Check each window type + for window_type, limit in self.limits.items(): + window = self._get_window(ip, window_type) + + if window.requests >= limit: + retry_after = self.windows[window_type] - (time.time() - window.start_time) + raise HTTPException( + status_code=429, + detail={ + "error": "Rate limit exceeded", + "window": window_type, + "limit": limit, + "retry_after": int(retry_after) + }, + headers={ + "Retry-After": str(int(retry_after)), + "X-RateLimit-Limit": str(limit), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(window.start_time + self.windows[window_type])) + } + ) + window.requests += 1 + + def get_rate_limit_status(self, ip: str) -> Dict[str, Dict[str, int]]: + """Get current rate limit status for an IP""" + self._clean_old_windows(ip) + status = {} + + for window_type, limit in self.limits.items(): + window = self._get_window(ip, window_type) + remaining = max(0, limit - window.requests) + reset_time = int(window.start_time + self.windows[window_type]) + + status[window_type] = { + "limit": limit, + "remaining": remaining, + "reset": reset_time + } + + return status class UvicornServer(uvicorn.Server): """Customized uvicorn server with graceful shutdown support""" @@ -141,85 +139,33 @@ async def shutdown(self, sockets=None): logger.info("Shutting down server...") await super().shutdown(sockets) - class AgentStatus(str, Enum): """Enum for agent status.""" - IDLE = "idle" PROCESSING = "processing" ERROR = "error" MAINTENANCE = "maintenance" - -# Security configurations -API_KEY_LENGTH = 32 # Length of generated API keys - - -class APIKeyCreate(BaseModel): - name: str # A friendly name for the API key - - -class User(BaseModel): - id: UUID - username: str - is_active: bool = True - is_admin: bool = False - api_keys: Dict[str, APIKey] = Field(default_factory=dict) - - def ensure_active_api_key(self) -> Optional[APIKey]: - """Ensure user has at least one active API key.""" - active_keys = [key for key in self.api_keys.values() if key.is_active] - if not active_keys: - return None - return active_keys[0] - - class AgentConfig(BaseModel): """Configuration model for creating a new agent.""" - agent_name: str = Field(..., description="Name of the agent") - model_name: str = Field( - ..., - description="Name of the llm you want to use provided by litellm", - ) - description: str = Field( - default="", description="Description of the agent's purpose" - ) + model_name: str = Field(..., description="Name of the llm you want to use provided by litellm") + description: str = Field(default="", description="Description of the agent's purpose") system_prompt: str = Field(..., description="System prompt for the agent") - model_name: str = Field(default="gpt-4", description="Model name to use") - temperature: float = Field( - default=0.1, - ge=0.0, - le=2.0, - description="Temperature for the model", - ) + temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Temperature for the model") max_loops: int = Field(default=1, ge=1, description="Maximum number of loops") - dynamic_temperature_enabled: bool = Field( - default=True, description="Enable dynamic temperature" - ) + dynamic_temperature_enabled: bool = Field(default=True, description="Enable dynamic temperature") user_name: str = Field(default="default_user", description="Username for the agent") retry_attempts: int = Field(default=1, ge=1, description="Number of retry attempts") context_length: int = Field(default=200000, ge=1000, description="Context length") - output_type: str = Field( - default="string", description="Output type (string or json)" - ) + output_type: str = Field(default="string", description="Output type (string or json)") streaming_on: bool = Field(default=False, description="Enable streaming") - tags: List[str] = Field( - default_factory=list, - description="Tags for categorizing the agent", - ) - stopping_token: str = Field( - default="", description="Stopping token for the agent" - ) - auto_generate_prompt: bool = Field( - default=False, - description="Auto-generate prompt based on agent details such as name, description, etc.", - ) - + tags: List[str] = Field(default_factory=list, description="Tags for categorizing the agent") + stopping_token: str = Field(default="", description="Stopping token for the agent") + auto_generate_prompt: bool = Field(default=False, description="Auto-generate prompt based on agent details") class AgentUpdate(BaseModel): """Model for updating agent configuration.""" - description: Optional[str] = None system_prompt: Optional[str] = None temperature: Optional[float] = 0.5 @@ -227,10 +173,8 @@ class AgentUpdate(BaseModel): tags: Optional[List[str]] = None status: Optional[AgentStatus] = None - class AgentSummary(BaseModel): """Summary model for agent listing.""" - agent_id: UUID agent_name: str description: str @@ -241,10 +185,8 @@ class AgentSummary(BaseModel): tags: List[str] status: AgentStatus - class AgentMetrics(BaseModel): """Model for agent performance metrics.""" - total_completions: int average_response_time: float error_rate: float @@ -254,20 +196,16 @@ class AgentMetrics(BaseModel): success_rate: float peak_tokens_per_minute: int - class CompletionRequest(BaseModel): """Model for completion requests.""" - prompt: str = Field(..., description="The prompt to process") agent_id: UUID = Field(..., description="ID of the agent to use") max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate") temperature_override: Optional[float] = 0.5 stream: bool = Field(default=False, description="Enable streaming response") - class CompletionResponse(BaseModel): """Model for completion responses.""" - agent_id: UUID response: str metadata: Dict[str, Any] @@ -275,14 +213,11 @@ class CompletionResponse(BaseModel): processing_time: float token_usage: Dict[str, int] - class AgentStore: - """Enhanced store for managing agents.""" - + """Store for managing agents.""" def __init__(self): self.agents: Dict[UUID, Agent] = {} self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} - self.user_agents: Dict[UUID, List[UUID]] = {} # user_id -> [agent_ids] self.executor = ThreadPoolExecutor(max_workers=4) self._ensure_directories() @@ -291,19 +226,9 @@ def _ensure_directories(self): Path("logs").mkdir(exist_ok=True) Path("states").mkdir(exist_ok=True) - async def verify_agent_access(self, agent_id: UUID, user_id: UUID) -> bool: - """Verify if a user has access to an agent.""" - if agent_id not in self.agents: - return False - return ( - self.agent_metadata[agent_id]["owner_id"] == user_id - or self.users[user_id].is_admin - ) - - async def create_agent(self, config: AgentConfig, user_id: UUID) -> UUID: + async def create_agent(self, config: AgentConfig) -> UUID: """Create a new agent with the given configuration.""" try: - agent = Agent( agent_name=config.agent_name, system_prompt=config.system_prompt, @@ -337,11 +262,6 @@ async def create_agent(self, config: AgentConfig, user_id: UUID) -> UUID: "successful_completions": 0, } - # Add to user's agents list - if user_id not in self.user_agents: - self.user_agents[user_id] = [] - self.user_agents[user_id].append(agent_id) - return agent_id except Exception as e: @@ -382,39 +302,6 @@ async def update_agent(self, agent_id: UUID, update: AgentUpdate) -> None: logger.info(f"Updated agent {agent_id}") - def ensure_user_api_key(self, user_id: UUID) -> APIKey: - """Ensure user has at least one active API key.""" - if user_id not in self.users: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) - - user = self.users[user_id] - existing_key = user.ensure_active_api_key() - if existing_key: - return existing_key - - # Create new API key if none exists - return self.create_api_key(user_id, "Default Key") - - def validate_api_key(self, api_key: str) -> Optional[UUID]: - """Validate an API key and return the associated user ID.""" - if not api_key: - return None - - user_id = self.api_keys.get(api_key) - if not user_id or api_key not in self.users[user_id].api_keys: - return None - - key_object = self.users[user_id].api_keys[api_key] - if not key_object.is_active: - return None - - # Update last used timestamp - key_object.last_used = datetime.utcnow() - return user_id - async def list_agents( self, tags: Optional[List[str]] = None, @@ -425,7 +312,6 @@ async def list_agents( for agent_id, agent in self.agents.items(): metadata = self.agent_metadata[agent_id] - # Apply filters if tags and not any(tag in metadata["tags"] for tag in tags): continue if status and metadata["status"] != status: @@ -451,7 +337,6 @@ async def get_agent_metrics(self, agent_id: UUID) -> AgentMetrics: metadata = self.agent_metadata[agent_id] response_times = metadata["response_times"] - # Calculate metrics total_time = datetime.utcnow() - metadata["start_time"] uptime = total_time - metadata["downtime"] uptime_percentage = (uptime.total_seconds() / total_time.total_seconds()) * 100 @@ -506,7 +391,6 @@ async def delete_agent(self, agent_id: UUID) -> None: detail=f"Agent {agent_id} not found", ) - # Clean up any resources agent = self.agents[agent_id] if agent.autosave and os.path.exists(agent.saved_state_path): os.remove(agent.saved_state_path) @@ -521,33 +405,28 @@ async def process_completion( prompt: str, agent_id: UUID, max_tokens: Optional[int] = None, - temperature_override: Optional[float] = None, + temperature_override: Optional[float] = 0.5, ) -> CompletionResponse: """Process a completion request using the specified agent.""" start_time = datetime.utcnow() metadata = self.agent_metadata[agent_id] try: - # Update agent status metadata["status"] = AgentStatus.PROCESSING metadata["last_used"] = start_time - # Process the completion response = agent.run(prompt) - # Update metrics processing_time = (datetime.utcnow() - start_time).total_seconds() metadata["response_times"].append(processing_time) metadata["total_completions"] += 1 metadata["successful_completions"] += 1 - # Estimate token usage (this is a rough estimate) prompt_tokens = len(prompt.split()) * 1.3 completion_tokens = len(response.split()) * 1.3 total_tokens = int(prompt_tokens + completion_tokens) metadata["total_tokens"] += total_tokens - # Update tokens per minute tracking current_minute = datetime.utcnow().replace(second=0, microsecond=0) if "tokens_per_minute" not in metadata: metadata["tokens_per_minute"] = {} @@ -560,8 +439,6 @@ async def process_completion( response=response, metadata={ "agent_name": agent.agent_name, - # "model_name": agent.llm.model_name, - # "temperature": 0.5, }, timestamp=datetime.utcnow(), processing_time=processing_time, @@ -585,215 +462,116 @@ async def process_completion( finally: metadata["status"] = AgentStatus.IDLE - class StoreManager: _instance = None @classmethod - def get_instance(cls) -> "AgentStore": + def get_instance(cls) -> AgentStore: if cls._instance is None: cls._instance = AgentStore() return cls._instance - -# Modify the dependency function def get_store() -> AgentStore: """Dependency to get the AgentStore instance.""" return StoreManager.get_instance() - -# Modify the get_current_user dependency -async def get_current_user( - api_key: str = Header(..., description="API key for authentication"), - store: AgentStore = Depends(get_store), -) -> User: - """Validate API key and return current user.""" - if not api_key: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="API key is required", - headers={"WWW-Authenticate": "ApiKey"}, - ) - - user_id = store.validate_api_key(api_key) - if not user_id: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired API key", - headers={"WWW-Authenticate": "ApiKey"}, - ) - - user = store.users.get(user_id) - if not user: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) - - if not user.ensure_active_api_key(): - # Attempt to create new API key - store.ensure_user_api_key(user_id) - - return user - - class SwarmsAPI: - """Enhanced API class for Swarms agent integration.""" - def __init__(self): self.app = FastAPI( title="Swarms Agent API", - description="Production-grade API for Swarms agent interaction", + description="Free API for Swarms agent interaction", version="1.0.0", docs_url="/v1/docs", redoc_url="/v1/redoc", ) - # Initialize the store using the singleton manager self.store = StoreManager.get_instance() - # Configure CORS self.app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Configure appropriately for production + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) + self.rate_limiter = RateLimiter() self._setup_routes() def _setup_routes(self): - """Set up API routes with Supabase authentication.""" + """Set up API routes.""" + + def _setup_middleware(self): + @self.app.middleware("http") + async def rate_limit_middleware(request: Request, call_next): + # Apply rate limiting + await self.rate_limiter.check_rate_limit(request) + + # Add rate limit headers to response + response = await call_next(request) + + # Add rate limit status headers + status = self.rate_limiter.get_rate_limit_status(request.client.host) + for window_type, info in status.items(): + response.headers[f"X-RateLimit-{window_type}-Limit"] = str(info["limit"]) + response.headers[f"X-RateLimit-{window_type}-Remaining"] = str(info["remaining"]) + response.headers[f"X-RateLimit-{window_type}-Reset"] = str(info["reset"]) + + return response + + @self.app.get("/v1/rate-limit-status") + async def get_rate_limit_status(request: Request): + """Get current rate limit status""" + return self.rate_limiter.get_rate_limit_status(request.client.host) - @self.app.get("/v1/users/me/agents", response_model=List[AgentSummary]) - async def list_user_agents( - current_user: User = Depends(get_current_user), + @self.app.get("/v1/agents", response_model=List[AgentSummary]) + async def list_agents( tags: Optional[List[str]] = Query(None), status: Optional[AgentStatus] = None, ): - """List all agents owned by the current user.""" - user_agents = self.store.user_agents.get(current_user.id, []) - return [ - agent - for agent in await self.store.list_agents(tags, status) - if agent.agent_id in user_agents - ] + """List all agents, optionally filtered by tags and status.""" + return await self.store.list_agents(tags, status) @self.app.post("/v1/agent", response_model=Dict[str, UUID]) - async def create_agent( - config: AgentConfig, - current_user: User = Depends(get_current_user), - ): + async def create_agent(config: AgentConfig): """Create a new agent with the specified configuration.""" - logger.info(f"User {current_user.id} creating new agent") - agent_id = await self.store.create_agent(config, current_user.id) + agent_id = await self.store.create_agent(config) return {"agent_id": agent_id} - @self.app.get("/v1/agents", response_model=List[AgentSummary]) - async def list_agents( - current_user: User = Depends(get_current_user), - tags: Optional[List[str]] = Query(None), - status: Optional[AgentStatus] = None, - ): - """List all agents, optionally filtered by tags and status.""" - agents = await self.store.list_agents(tags, status) - # Filter agents based on user access - return [ - agent - for agent in agents - if await self.store.verify_agent_access(agent.agent_id, current_user.id) - ] - @self.app.patch("/v1/agent/{agent_id}", response_model=Dict[str, str]) async def update_agent( agent_id: UUID, update: AgentUpdate, - current_user: User = Depends(get_current_user), ): """Update an existing agent's configuration.""" - if not await self.store.verify_agent_access(agent_id, current_user.id): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to update this agent", - ) - await self.store.update_agent(agent_id, update) return {"status": "updated"} @self.app.get("/v1/agent/{agent_id}/metrics", response_model=AgentMetrics) - async def get_agent_metrics( - agent_id: UUID, current_user: User = Depends(get_current_user) - ): + async def get_agent_metrics(agent_id: UUID): """Get performance metrics for a specific agent.""" - if not await self.store.verify_agent_access(agent_id, current_user.id): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to view this agent's metrics", - ) - return await self.store.get_agent_metrics(agent_id) @self.app.post("/v1/agent/{agent_id}/clone", response_model=Dict[str, UUID]) - async def clone_agent( - agent_id: UUID, - new_name: str, - current_user: User = Depends(get_current_user), - ): + async def clone_agent(agent_id: UUID, new_name: str): """Clone an existing agent with a new name.""" - if not await self.store.verify_agent_access(agent_id, current_user.id): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to clone this agent", - ) - new_id = await self.store.clone_agent(agent_id, new_name) - # Add the cloned agent to user's agents - if current_user.id not in self.store.user_agents: - self.store.user_agents[current_user.id] = [] - self.store.user_agents[current_user.id].append(new_id) - return {"agent_id": new_id} @self.app.delete("/v1/agent/{agent_id}") - async def delete_agent( - agent_id: UUID, current_user: User = Depends(get_current_user) - ): + async def delete_agent(agent_id: UUID): """Delete an agent.""" - if not await self.store.verify_agent_access(agent_id, current_user.id): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to delete this agent", - ) - await self.store.delete_agent(agent_id) - # Remove from user's agents list - if current_user.id in self.store.user_agents: - self.store.user_agents[current_user.id] = [ - aid - for aid in self.store.user_agents[current_user.id] - if aid != agent_id - ] return {"status": "deleted"} @self.app.post("/v1/agent/completions", response_model=CompletionResponse) async def create_completion( request: CompletionRequest, background_tasks: BackgroundTasks, - current_user: User = Depends(get_current_user), ): """Process a completion request with the specified agent.""" - if not await self.store.verify_agent_access( - request.agent_id, current_user.id - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to use this agent", - ) - try: agent = await self.store.get_agent(request.agent_id) - # Process completion response = await self.store.process_completion( agent, request.prompt, @@ -802,7 +580,6 @@ async def create_completion( request.temperature_override, ) - # Schedule background cleanup background_tasks.add_task(self._cleanup_old_metrics, request.agent_id) return response @@ -813,18 +590,35 @@ async def create_completion( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error processing completion: {str(e)}", ) + + @self.app.get("/v1/agent/batch/completions/status") + async def get_batch_completion_status(requests: List[CompletionRequest]): + """Get status of batch completion requests.""" + responses = [] + for req in requests: + try: + agent = await self.store.get_agent(req.agent_id) + response = await self.store.process_completion( + agent, + req.prompt, + req.agent_id, + req.max_tokens, + req.temperature_override, + ) + responses.append(response) + except Exception as e: + logger.error(f"Error processing batch completion: {str(e)}") + responses.append( + { + "error": f"Error processing completion: {str(e)}", + "agent_id": req.agent_id, + } + ) + return responses @self.app.get("/v1/agent/{agent_id}/status") - async def get_agent_status( - agent_id: UUID, current_user: User = Depends(get_current_user) - ): + async def get_agent_status(agent_id: UUID): """Get the current status of an agent.""" - if not await self.store.verify_agent_access(agent_id, current_user.id): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to view this agent's status", - ) - metadata = self.store.agent_metadata.get(agent_id) if not metadata: raise HTTPException( @@ -842,22 +636,8 @@ async def get_agent_status( @self.app.get("/health") async def health_check(): - """Health check endpoint - no auth required.""" - try: - # Test Supabase connection - supabase = get_supabase() - supabase.table("api_keys").select("count", count="exact").execute() - return {"status": "healthy", "database": "connected"} - except Exception as e: - logger.error(f"Health check failed: {str(e)}") - return JSONResponse( - status_code=503, - content={ - "status": "unhealthy", - "database": "disconnected", - "error": str(e), - }, - ) + """Health check endpoint.""" + return {"status": "healthy"} async def _cleanup_old_metrics(self, agent_id: UUID): """Clean up old metrics data to prevent memory bloat.""" @@ -874,7 +654,9 @@ async def _cleanup_old_metrics(self, agent_id: UUID): # Clean up old tokens per minute data if "tokens_per_minute" in metadata: metadata["tokens_per_minute"] = { - k: v for k, v in metadata["tokens_per_minute"].items() if k > cutoff + k: v + for k, v in metadata["tokens_per_minute"].items() + if k > cutoff } @@ -960,4 +742,4 @@ def run_server(): if __name__ == "__main__": - run_server() + run_server() \ No newline at end of file diff --git a/api/experimental/api.py b/api/experimental/api.py new file mode 100644 index 0000000..b876f66 --- /dev/null +++ b/api/experimental/api.py @@ -0,0 +1,963 @@ +import asyncio +import os +import signal +import traceback +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +from enum import Enum +from functools import lru_cache +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional +from uuid import UUID, uuid4 + +import uvicorn +from dotenv import load_dotenv +from fastapi import ( + BackgroundTasks, + Depends, + FastAPI, + Header, + HTTPException, + Query, + status, +) +from fastapi.concurrency import asynccontextmanager +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from loguru import logger +from pydantic import BaseModel, Field +from supabase import Client, create_client + +from swarms.structs.agent import Agent + +# Load environment variables +load_dotenv() + + +class APIKey(BaseModel): + """Model matching Supabase api_keys table""" + + id: UUID + created_at: datetime + name: str + user_id: UUID + key: str + limit_credit_dollar: Optional[float] = None + is_deleted: bool = False + + +class User(BaseModel): + id: UUID + name: str + is_active: bool = True + is_admin: bool = False + + +@lru_cache() +def get_supabase() -> Client: + """Get cached Supabase client""" + supabase_url = os.getenv("SUPABASE_URL") + supabase_key = os.getenv("SUPABASE_SERVICE_KEY") + if not supabase_url or not supabase_key: + raise ValueError("Supabase configuration is missing") + return create_client(supabase_url, supabase_key) + + +async def get_current_user( + api_key: str = Header(..., description="API key for authentication"), +) -> User: + """Validate API key against Supabase and return current user.""" + if not api_key or not api_key.startswith("sk-"): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key format", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + try: + supabase = get_supabase() + + # Query the api_keys table + response = ( + supabase.table("api_keys") + .select("id, name, user_id, key, limit_credit_dollar, is_deleted") + .eq("key", api_key) + .single() + .execute() + ) + + if not response.data: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + key_data = response.data + + # Check if key is deleted + if key_data["is_deleted"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key has been deleted", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Check credit limit if applicable + if ( + key_data["limit_credit_dollar"] is not None + and key_data["limit_credit_dollar"] <= 0 + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="API key credit limit exceeded", + ) + + # Create user object + return User( + id=key_data["user_id"], + name=key_data["name"], + is_active=not key_data["is_deleted"], + ) + + except Exception as e: + logger.error(f"Error validating API key: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key validation failed", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + +class UvicornServer(uvicorn.Server): + """Customized uvicorn server with graceful shutdown support""" + + async def setup(self, sockets=None): + """Setup the server""" + await super().setup(sockets) + + async def shutdown(self, sockets=None): + """Gracefully shutdown the server""" + logger.info("Shutting down server...") + await super().shutdown(sockets) + + +class AgentStatus(str, Enum): + """Enum for agent status.""" + + IDLE = "idle" + PROCESSING = "processing" + ERROR = "error" + MAINTENANCE = "maintenance" + + +# Security configurations +API_KEY_LENGTH = 32 # Length of generated API keys + + +class APIKeyCreate(BaseModel): + name: str # A friendly name for the API key + + +class User(BaseModel): + id: UUID + username: str + is_active: bool = True + is_admin: bool = False + api_keys: Dict[str, APIKey] = Field(default_factory=dict) + + def ensure_active_api_key(self) -> Optional[APIKey]: + """Ensure user has at least one active API key.""" + active_keys = [key for key in self.api_keys.values() if key.is_active] + if not active_keys: + return None + return active_keys[0] + + +class AgentConfig(BaseModel): + """Configuration model for creating a new agent.""" + + agent_name: str = Field(..., description="Name of the agent") + model_name: str = Field( + ..., + description="Name of the llm you want to use provided by litellm", + ) + description: str = Field( + default="", description="Description of the agent's purpose" + ) + system_prompt: str = Field(..., description="System prompt for the agent") + model_name: str = Field(default="gpt-4", description="Model name to use") + temperature: float = Field( + default=0.1, + ge=0.0, + le=2.0, + description="Temperature for the model", + ) + max_loops: int = Field(default=1, ge=1, description="Maximum number of loops") + dynamic_temperature_enabled: bool = Field( + default=True, description="Enable dynamic temperature" + ) + user_name: str = Field(default="default_user", description="Username for the agent") + retry_attempts: int = Field(default=1, ge=1, description="Number of retry attempts") + context_length: int = Field(default=200000, ge=1000, description="Context length") + output_type: str = Field( + default="string", description="Output type (string or json)" + ) + streaming_on: bool = Field(default=False, description="Enable streaming") + tags: List[str] = Field( + default_factory=list, + description="Tags for categorizing the agent", + ) + stopping_token: str = Field( + default="", description="Stopping token for the agent" + ) + auto_generate_prompt: bool = Field( + default=False, + description="Auto-generate prompt based on agent details such as name, description, etc.", + ) + + +class AgentUpdate(BaseModel): + """Model for updating agent configuration.""" + + description: Optional[str] = None + system_prompt: Optional[str] = None + temperature: Optional[float] = 0.5 + max_loops: Optional[int] = 1 + tags: Optional[List[str]] = None + status: Optional[AgentStatus] = None + + +class AgentSummary(BaseModel): + """Summary model for agent listing.""" + + agent_id: UUID + agent_name: str + description: str + system_prompt: str + created_at: datetime + last_used: datetime + total_completions: int + tags: List[str] + status: AgentStatus + + +class AgentMetrics(BaseModel): + """Model for agent performance metrics.""" + + total_completions: int + average_response_time: float + error_rate: float + last_24h_completions: int + total_tokens_used: int + uptime_percentage: float + success_rate: float + peak_tokens_per_minute: int + + +class CompletionRequest(BaseModel): + """Model for completion requests.""" + + prompt: str = Field(..., description="The prompt to process") + agent_id: UUID = Field(..., description="ID of the agent to use") + max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate") + temperature_override: Optional[float] = 0.5 + stream: bool = Field(default=False, description="Enable streaming response") + + +class CompletionResponse(BaseModel): + """Model for completion responses.""" + + agent_id: UUID + response: str + metadata: Dict[str, Any] + timestamp: datetime + processing_time: float + token_usage: Dict[str, int] + + +class AgentStore: + """Enhanced store for managing agents.""" + + def __init__(self): + self.agents: Dict[UUID, Agent] = {} + self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} + self.user_agents: Dict[UUID, List[UUID]] = {} # user_id -> [agent_ids] + self.executor = ThreadPoolExecutor(max_workers=4) + self._ensure_directories() + + def _ensure_directories(self): + """Ensure required directories exist.""" + Path("logs").mkdir(exist_ok=True) + Path("states").mkdir(exist_ok=True) + + async def verify_agent_access(self, agent_id: UUID, user_id: UUID) -> bool: + """Verify if a user has access to an agent.""" + if agent_id not in self.agents: + return False + return ( + self.agent_metadata[agent_id]["owner_id"] == user_id + or self.users[user_id].is_admin + ) + + async def create_agent(self, config: AgentConfig, user_id: UUID) -> UUID: + """Create a new agent with the given configuration.""" + try: + + agent = Agent( + agent_name=config.agent_name, + system_prompt=config.system_prompt, + model_name=config.model_name, + max_loops=config.max_loops, + dynamic_temperature_enabled=True, + user_name=config.user_name, + retry_attempts=config.retry_attempts, + context_length=config.context_length, + return_step_meta=False, + output_type="str", + streaming_on=config.streaming_on, + stopping_token=config.stopping_token, + auto_generate_prompt=config.auto_generate_prompt, + ) + + agent_id = uuid4() + self.agents[agent_id] = agent + self.agent_metadata[agent_id] = { + "description": config.description, + "created_at": datetime.utcnow(), + "last_used": datetime.utcnow(), + "total_completions": 0, + "tags": config.tags, + "total_tokens": 0, + "error_count": 0, + "response_times": [], + "status": AgentStatus.IDLE, + "start_time": datetime.utcnow(), + "downtime": timedelta(), + "successful_completions": 0, + } + + # Add to user's agents list + if user_id not in self.user_agents: + self.user_agents[user_id] = [] + self.user_agents[user_id].append(agent_id) + + return agent_id + + except Exception as e: + logger.error(f"Error creating agent: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create agent: {str(e)}", + ) + + async def get_agent(self, agent_id: UUID) -> Agent: + """Retrieve an agent by ID.""" + agent = self.agents.get(agent_id) + if not agent: + logger.error(f"Agent not found: {agent_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + return agent + + async def update_agent(self, agent_id: UUID, update: AgentUpdate) -> None: + """Update agent configuration.""" + agent = await self.get_agent(agent_id) + metadata = self.agent_metadata[agent_id] + + if update.system_prompt: + agent.system_prompt = update.system_prompt + if update.max_loops is not None: + agent.max_loops = update.max_loops + if update.tags is not None: + metadata["tags"] = update.tags + if update.description is not None: + metadata["description"] = update.description + if update.status is not None: + metadata["status"] = update.status + if update.status == AgentStatus.MAINTENANCE: + metadata["downtime"] += datetime.utcnow() - metadata["last_used"] + + logger.info(f"Updated agent {agent_id}") + + def ensure_user_api_key(self, user_id: UUID) -> APIKey: + """Ensure user has at least one active API key.""" + if user_id not in self.users: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found", + ) + + user = self.users[user_id] + existing_key = user.ensure_active_api_key() + if existing_key: + return existing_key + + # Create new API key if none exists + return self.create_api_key(user_id, "Default Key") + + def validate_api_key(self, api_key: str) -> Optional[UUID]: + """Validate an API key and return the associated user ID.""" + if not api_key: + return None + + user_id = self.api_keys.get(api_key) + if not user_id or api_key not in self.users[user_id].api_keys: + return None + + key_object = self.users[user_id].api_keys[api_key] + if not key_object.is_active: + return None + + # Update last used timestamp + key_object.last_used = datetime.utcnow() + return user_id + + async def list_agents( + self, + tags: Optional[List[str]] = None, + status: Optional[AgentStatus] = None, + ) -> List[AgentSummary]: + """List all agents, optionally filtered by tags and status.""" + summaries = [] + for agent_id, agent in self.agents.items(): + metadata = self.agent_metadata[agent_id] + + # Apply filters + if tags and not any(tag in metadata["tags"] for tag in tags): + continue + if status and metadata["status"] != status: + continue + + summaries.append( + AgentSummary( + agent_id=agent_id, + agent_name=agent.agent_name, + system_prompt=agent.system_prompt, + description=metadata["description"], + created_at=metadata["created_at"], + last_used=metadata["last_used"], + total_completions=metadata["total_completions"], + tags=metadata["tags"], + status=metadata["status"], + ) + ) + return summaries + + async def get_agent_metrics(self, agent_id: UUID) -> AgentMetrics: + """Get performance metrics for an agent.""" + metadata = self.agent_metadata[agent_id] + response_times = metadata["response_times"] + + # Calculate metrics + total_time = datetime.utcnow() - metadata["start_time"] + uptime = total_time - metadata["downtime"] + uptime_percentage = (uptime.total_seconds() / total_time.total_seconds()) * 100 + + success_rate = ( + metadata["successful_completions"] / metadata["total_completions"] * 100 + if metadata["total_completions"] > 0 + else 0 + ) + + return AgentMetrics( + total_completions=metadata["total_completions"], + average_response_time=( + sum(response_times) / len(response_times) if response_times else 0 + ), + error_rate=( + metadata["error_count"] / metadata["total_completions"] + if metadata["total_completions"] > 0 + else 0 + ), + last_24h_completions=sum( + 1 for t in response_times if (datetime.utcnow() - t).days < 1 + ), + total_tokens_used=metadata["total_tokens"], + uptime_percentage=uptime_percentage, + success_rate=success_rate, + peak_tokens_per_minute=max(metadata.get("tokens_per_minute", [0])), + ) + + async def clone_agent(self, agent_id: UUID, new_name: str) -> UUID: + """Clone an existing agent with a new name.""" + original_agent = await self.get_agent(agent_id) + original_metadata = self.agent_metadata[agent_id] + + config = AgentConfig( + agent_name=new_name, + description=f"Clone of {original_agent.agent_name}", + system_prompt=original_agent.system_prompt, + model_name=original_agent.model_name, + temperature=0.5, + max_loops=original_agent.max_loops, + tags=original_metadata["tags"], + ) + + return await self.create_agent(config) + + async def delete_agent(self, agent_id: UUID) -> None: + """Delete an agent.""" + if agent_id not in self.agents: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + + # Clean up any resources + agent = self.agents[agent_id] + if agent.autosave and os.path.exists(agent.saved_state_path): + os.remove(agent.saved_state_path) + + del self.agents[agent_id] + del self.agent_metadata[agent_id] + logger.info(f"Deleted agent {agent_id}") + + async def process_completion( + self, + agent: Agent, + prompt: str, + agent_id: UUID, + max_tokens: Optional[int] = None, + temperature_override: Optional[float] = None, + ) -> CompletionResponse: + """Process a completion request using the specified agent.""" + start_time = datetime.utcnow() + metadata = self.agent_metadata[agent_id] + + try: + # Update agent status + metadata["status"] = AgentStatus.PROCESSING + metadata["last_used"] = start_time + + # Process the completion + response = agent.run(prompt) + + # Update metrics + processing_time = (datetime.utcnow() - start_time).total_seconds() + metadata["response_times"].append(processing_time) + metadata["total_completions"] += 1 + metadata["successful_completions"] += 1 + + # Estimate token usage (this is a rough estimate) + prompt_tokens = len(prompt.split()) * 1.3 + completion_tokens = len(response.split()) * 1.3 + total_tokens = int(prompt_tokens + completion_tokens) + metadata["total_tokens"] += total_tokens + + # Update tokens per minute tracking + current_minute = datetime.utcnow().replace(second=0, microsecond=0) + if "tokens_per_minute" not in metadata: + metadata["tokens_per_minute"] = {} + metadata["tokens_per_minute"][current_minute] = ( + metadata["tokens_per_minute"].get(current_minute, 0) + total_tokens + ) + + return CompletionResponse( + agent_id=agent_id, + response=response, + metadata={ + "agent_name": agent.agent_name, + # "model_name": agent.llm.model_name, + # "temperature": 0.5, + }, + timestamp=datetime.utcnow(), + processing_time=processing_time, + token_usage={ + "prompt_tokens": int(prompt_tokens), + "completion_tokens": int(completion_tokens), + "total_tokens": total_tokens, + }, + ) + + except Exception as e: + metadata["error_count"] += 1 + metadata["status"] = AgentStatus.ERROR + logger.error( + f"Error in completion processing: {str(e)}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing completion: {str(e)}", + ) + finally: + metadata["status"] = AgentStatus.IDLE + + +class StoreManager: + _instance = None + + @classmethod + def get_instance(cls) -> "AgentStore": + if cls._instance is None: + cls._instance = AgentStore() + return cls._instance + + +# Modify the dependency function +def get_store() -> AgentStore: + """Dependency to get the AgentStore instance.""" + return StoreManager.get_instance() + + +# Modify the get_current_user dependency +async def get_current_user( + api_key: str = Header(..., description="API key for authentication"), + store: AgentStore = Depends(get_store), +) -> User: + """Validate API key and return current user.""" + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key is required", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + user_id = store.validate_api_key(api_key) + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + user = store.users.get(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found", + ) + + if not user.ensure_active_api_key(): + # Attempt to create new API key + store.ensure_user_api_key(user_id) + + return user + + +class SwarmsAPI: + """Enhanced API class for Swarms agent integration.""" + + def __init__(self): + self.app = FastAPI( + title="Swarms Agent API", + description="Production-grade API for Swarms agent interaction", + version="1.0.0", + docs_url="/v1/docs", + redoc_url="/v1/redoc", + ) + # Initialize the store using the singleton manager + self.store = StoreManager.get_instance() + + # Configure CORS + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + self._setup_routes() + + def _setup_routes(self): + """Set up API routes with Supabase authentication.""" + + @self.app.get("/v1/users/me/agents", response_model=List[AgentSummary]) + async def list_user_agents( + current_user: User = Depends(get_current_user), + tags: Optional[List[str]] = Query(None), + status: Optional[AgentStatus] = None, + ): + """List all agents owned by the current user.""" + user_agents = self.store.user_agents.get(current_user.id, []) + return [ + agent + for agent in await self.store.list_agents(tags, status) + if agent.agent_id in user_agents + ] + + @self.app.post("/v1/agent", response_model=Dict[str, UUID]) + async def create_agent( + config: AgentConfig, + current_user: User = Depends(get_current_user), + ): + """Create a new agent with the specified configuration.""" + logger.info(f"User {current_user.id} creating new agent") + agent_id = await self.store.create_agent(config, current_user.id) + return {"agent_id": agent_id} + + @self.app.get("/v1/agents", response_model=List[AgentSummary]) + async def list_agents( + current_user: User = Depends(get_current_user), + tags: Optional[List[str]] = Query(None), + status: Optional[AgentStatus] = None, + ): + """List all agents, optionally filtered by tags and status.""" + agents = await self.store.list_agents(tags, status) + # Filter agents based on user access + return [ + agent + for agent in agents + if await self.store.verify_agent_access(agent.agent_id, current_user.id) + ] + + @self.app.patch("/v1/agent/{agent_id}", response_model=Dict[str, str]) + async def update_agent( + agent_id: UUID, + update: AgentUpdate, + current_user: User = Depends(get_current_user), + ): + """Update an existing agent's configuration.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to update this agent", + ) + + await self.store.update_agent(agent_id, update) + return {"status": "updated"} + + @self.app.get("/v1/agent/{agent_id}/metrics", response_model=AgentMetrics) + async def get_agent_metrics( + agent_id: UUID, current_user: User = Depends(get_current_user) + ): + """Get performance metrics for a specific agent.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to view this agent's metrics", + ) + + return await self.store.get_agent_metrics(agent_id) + + @self.app.post("/v1/agent/{agent_id}/clone", response_model=Dict[str, UUID]) + async def clone_agent( + agent_id: UUID, + new_name: str, + current_user: User = Depends(get_current_user), + ): + """Clone an existing agent with a new name.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to clone this agent", + ) + + new_id = await self.store.clone_agent(agent_id, new_name) + # Add the cloned agent to user's agents + if current_user.id not in self.store.user_agents: + self.store.user_agents[current_user.id] = [] + self.store.user_agents[current_user.id].append(new_id) + + return {"agent_id": new_id} + + @self.app.delete("/v1/agent/{agent_id}") + async def delete_agent( + agent_id: UUID, current_user: User = Depends(get_current_user) + ): + """Delete an agent.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to delete this agent", + ) + + await self.store.delete_agent(agent_id) + # Remove from user's agents list + if current_user.id in self.store.user_agents: + self.store.user_agents[current_user.id] = [ + aid + for aid in self.store.user_agents[current_user.id] + if aid != agent_id + ] + return {"status": "deleted"} + + @self.app.post("/v1/agent/completions", response_model=CompletionResponse) + async def create_completion( + request: CompletionRequest, + background_tasks: BackgroundTasks, + current_user: User = Depends(get_current_user), + ): + """Process a completion request with the specified agent.""" + if not await self.store.verify_agent_access( + request.agent_id, current_user.id + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to use this agent", + ) + + try: + agent = await self.store.get_agent(request.agent_id) + + # Process completion + response = await self.store.process_completion( + agent, + request.prompt, + request.agent_id, + request.max_tokens, + request.temperature_override, + ) + + # Schedule background cleanup + background_tasks.add_task(self._cleanup_old_metrics, request.agent_id) + + return response + + except Exception as e: + logger.error(f"Error processing completion: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing completion: {str(e)}", + ) + + @self.app.get("/v1/agent/{agent_id}/status") + async def get_agent_status( + agent_id: UUID, current_user: User = Depends(get_current_user) + ): + """Get the current status of an agent.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to view this agent's status", + ) + + metadata = self.store.agent_metadata.get(agent_id) + if not metadata: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + + return { + "agent_id": agent_id, + "status": metadata["status"], + "last_used": metadata["last_used"], + "total_completions": metadata["total_completions"], + "error_count": metadata["error_count"], + } + + @self.app.get("/health") + async def health_check(): + """Health check endpoint - no auth required.""" + try: + # Test Supabase connection + supabase = get_supabase() + supabase.table("api_keys").select("count", count="exact").execute() + return {"status": "healthy", "database": "connected"} + except Exception as e: + logger.error(f"Health check failed: {str(e)}") + return JSONResponse( + status_code=503, + content={ + "status": "unhealthy", + "database": "disconnected", + "error": str(e), + }, + ) + + async def _cleanup_old_metrics(self, agent_id: UUID): + """Clean up old metrics data to prevent memory bloat.""" + metadata = self.store.agent_metadata.get(agent_id) + if metadata: + # Keep only last 24 hours of response times + cutoff = datetime.utcnow() - timedelta(days=1) + metadata["response_times"] = [ + t + for t in metadata["response_times"] + if isinstance(t, (int, float)) and t > cutoff.timestamp() + ] + + # Clean up old tokens per minute data + if "tokens_per_minute" in metadata: + metadata["tokens_per_minute"] = { + k: v for k, v in metadata["tokens_per_minute"].items() if k > cutoff + } + + +class APIServer: + def __init__(self, app: FastAPI, host: str = "0.0.0.0", port: int = 8080): + self.app = app + self.host = host + self.port = port + self.config = uvicorn.Config( + app=app, + host=host, + port=port, + log_level="info", + access_log=True, + workers=os.cpu_count() * 2, + ) + self.server = UvicornServer(config=self.config) + + # Setup signal handlers + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + def _handle_signal(self, signum, frame): + """Handle shutdown signals""" + logger.info(f"Received signal {signum}") + asyncio.create_task(self.shutdown()) + + async def startup(self) -> None: + """Start the server""" + try: + logger.info(f"Starting API server on http://{self.host}:{self.port}") + print(f"Starting API server on http://{self.host}:{self.port}") + await self.server.serve() + except Exception as e: + logger.error(f"Failed to start server: {str(e)}") + raise + + async def shutdown(self) -> None: + """Shutdown the server""" + try: + logger.info("Initiating graceful shutdown...") + await self.server.shutdown() + except Exception as e: + logger.error(f"Error during shutdown: {str(e)}") + raise + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator: + """Lifespan context manager for the FastAPI app""" + # Startup + logger.info("Starting up API server...") + yield + # Shutdown + logger.info("Shutting down API server...") + + +def create_app() -> FastAPI: + """Create and configure the FastAPI application""" + logger.info("Creating FastAPI application") + api = SwarmsAPI() + app = api.app + + # Add lifespan handling + app.router.lifespan_context = lifespan + + logger.info("FastAPI application created successfully") + return app + + +def run_server(): + """Run the API server""" + try: + # Create the FastAPI app + app = create_app() + + # Create and run the server + server = APIServer(app) + asyncio.run(server.startup()) + except Exception as e: + logger.error(f"Failed to start API: {str(e)}") + print(f"Error starting server: {str(e)}") + + +if __name__ == "__main__": + run_server() diff --git a/api/skypilot.yaml b/api/skypilot.yaml deleted file mode 100644 index d0c558e..0000000 --- a/api/skypilot.yaml +++ /dev/null @@ -1,43 +0,0 @@ -name: agentapi - -service: - readiness_probe: - path: /docs - initial_delay_seconds: 300 - timeout_seconds: 30 - - replica_policy: - min_replicas: 1 - max_replicas: 50 - target_qps_per_replica: 5 - upscale_delay_seconds: 180 - downscale_delay_seconds: 600 - - -envs: - WORKSPACE_DIR: "agent_workspace" - OPENAI_API_KEY: "" - -resources: - ports: 8000 # FastAPI default port - cpus: 16 - memory: 64 - disk_size: 50 - use_spot: true - -workdir: . - -setup: | - git clone https://github.com/kyegomez/swarms.git - cd swarms/api - pip install -r requirements.txt - pip install swarms - -run: | - uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4 - -# env: -# PYTHONPATH: /app/swarms -# LOG_LEVEL: "INFO" -# # MAX_WORKERS: "4" - diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..1d24a8a --- /dev/null +++ b/tests.py @@ -0,0 +1,109 @@ +import requests +import json +from datetime import datetime + +# Base URL for the API +BASE_URL = "http://localhost:8080" + +# Test health check endpoint +response = requests.get(f"{BASE_URL}/health") +print("Health Check Response:", response.json()) + +# Test creating a new agent +agent_config = { + "agent_name": "test_agent", + "model_name": "gpt-3.5-turbo", + "description": "A test agent", + "system_prompt": "You are a helpful assistant", + "temperature": 0.5, + "max_loops": 1, + "dynamic_temperature_enabled": True, + "user_name": "test_user", + "retry_attempts": 1, + "context_length": 200000, + "output_type": "string", + "streaming_on": False, + "tags": ["test"], + "stopping_token": "", + "auto_generate_prompt": False +} + +response = requests.post(f"{BASE_URL}/v1/agent", json=agent_config) +print("\nCreate Agent Response:", response.json()) +agent_id = response.json()["agent_id"] + +# Test getting rate limit status +response = requests.get(f"{BASE_URL}/v1/rate-limit-status") +print("\nRate Limit Status:", response.json()) + +# Test listing all agents +response = requests.get(f"{BASE_URL}/v1/agents") +print("\nList Agents Response:", response.json()) + +# Test updating an agent +update_data = { + "description": "Updated test agent", + "system_prompt": "Updated system prompt", + "temperature": 0.7, + "max_loops": 2, + "tags": ["test", "updated"] +} + +response = requests.patch(f"{BASE_URL}/v1/agent/{agent_id}", json=update_data) +print("\nUpdate Agent Response:", response.json()) + +# Test getting agent metrics +response = requests.get(f"{BASE_URL}/v1/agent/{agent_id}/metrics") +print("\nAgent Metrics Response:", response.json()) + +# Test creating a completion +completion_request = { + "prompt": "Hello, how are you?", + "agent_id": agent_id, + "max_tokens": 100, + "temperature_override": 0.8, + "stream": False +} + +response = requests.post(f"{BASE_URL}/v1/agent/completions", json=completion_request) +print("\nCompletion Response:", response.json()) + +# Test cloning an agent +response = requests.post(f"{BASE_URL}/v1/agent/{agent_id}/clone", params={"new_name": "cloned_agent"}) +print("\nClone Agent Response:", response.json()) + +# Test getting agent status +response = requests.get(f"{BASE_URL}/v1/agent/{agent_id}/status") +print("\nAgent Status Response:", response.json()) + +# Test batch completion status +batch_requests = [ + { + "prompt": "What is 2+2?", + "agent_id": agent_id, + "max_tokens": 50, + "temperature_override": 0.5, + "stream": False + }, + { + "prompt": "Who are you?", + "agent_id": agent_id, + "max_tokens": 50, + "temperature_override": 0.5, + "stream": False + } +] + +response = requests.get(f"{BASE_URL}/v1/agent/batch/completions/status", json=batch_requests) +print("\nBatch Completion Status Response:", response.json()) + +# Test deleting an agent +response = requests.delete(f"{BASE_URL}/v1/agent/{agent_id}") +print("\nDelete Agent Response:", response.json()) + +# Test listing agents with filters +response = requests.get(f"{BASE_URL}/v1/agents", params={ + "tags": ["test"], + "status": "idle" +}) +print("\nFiltered Agents List Response:", response.json()) \ No newline at end of file