Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Jan 16, 2025
1 parent 679a1aa commit 90b4f39
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ENV PYTHONUNBUFFERED=1 \
PATH="/app/venv/bin:$PATH" \
PYTHONPATH=/app \
PORT=8080 \
OPENAI_API_KEY="" \
OPENAI_API_KEY="put your key or any keys, xai, deepseek, etc" \
WORKSPACE_DIR="agent_workspace"

# Set working directory
Expand Down
135 changes: 91 additions & 44 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,32 @@

load_dotenv()


@dataclass
class RateLimitWindow:
requests: int
start_time: float


class RateLimiter:
def __init__(self):
# Store rate limit windows per IP
self.ip_windows: Dict[str, Dict[str, RateLimitWindow]] = defaultdict(dict)

# Configure different rate limits
self.limits = {
'per_second': 2, # 2 requests per second
'per_minute': 30, # 30 requests per minute
'per_hour': 1000, # 1000 requests per hour
'per_day': 10000, # 10000 requests per day
"per_second": 2, # 2 requests per second
"per_minute": 30, # 30 requests per minute
"per_hour": 1000, # 1000 requests per hour
"per_day": 10000, # 10000 requests per day
}

# Window durations in seconds
self.windows = {
'per_second': 1,
'per_minute': 60,
'per_hour': 3600,
'per_day': 86400,
"per_second": 1,
"per_minute": 60,
"per_hour": 3600,
"per_day": 86400,
}

def _clean_old_windows(self, ip: str):
Expand All @@ -68,12 +70,14 @@ def _get_window(self, ip: str, window_type: str) -> RateLimitWindow:
"""Get or create a rate limit window"""
current_time = time.time()
window = self.ip_windows[ip].get(window_type)

# If window doesn't exist or is expired, create new one
if not window or (current_time - window.start_time >= self.windows[window_type]):
if not window or (
current_time - window.start_time >= self.windows[window_type]
):
window = RateLimitWindow(requests=0, start_time=current_time)
self.ip_windows[ip][window_type] = window

return window

async def check_rate_limit(self, request: Request) -> None:
Expand All @@ -88,23 +92,27 @@ async def check_rate_limit(self, request: Request) -> None:
# Check each window type
for window_type, limit in self.limits.items():
window = self._get_window(ip, window_type)

if window.requests >= limit:
retry_after = self.windows[window_type] - (time.time() - window.start_time)
retry_after = self.windows[window_type] - (
time.time() - window.start_time
)
raise HTTPException(
status_code=429,
detail={
"error": "Rate limit exceeded",
"window": window_type,
"limit": limit,
"retry_after": int(retry_after)
"retry_after": int(retry_after),
},
headers={
"Retry-After": str(int(retry_after)),
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(window.start_time + self.windows[window_type]))
}
"X-RateLimit-Reset": str(
int(window.start_time + self.windows[window_type])
),
},
)

window.requests += 1
Expand All @@ -113,20 +121,21 @@ def get_rate_limit_status(self, ip: str) -> Dict[str, Dict[str, int]]:
"""Get current rate limit status for an IP"""
self._clean_old_windows(ip)
status = {}

for window_type, limit in self.limits.items():
window = self._get_window(ip, window_type)
remaining = max(0, limit - window.requests)
reset_time = int(window.start_time + self.windows[window_type])

status[window_type] = {
"limit": limit,
"remaining": remaining,
"reset": reset_time
"reset": reset_time,
}

return status


class UvicornServer(uvicorn.Server):
"""Customized uvicorn server with graceful shutdown support"""

Expand All @@ -139,42 +148,66 @@ async def shutdown(self, sockets=None):
logger.info("Shutting down server...")
await super().shutdown(sockets)


class AgentStatus(str, Enum):
"""Enum for agent status."""

IDLE = "idle"
PROCESSING = "processing"
ERROR = "error"
MAINTENANCE = "maintenance"


class AgentConfig(BaseModel):
"""Configuration model for creating a new agent."""

agent_name: str = Field(..., description="Name of the agent")
model_name: str = Field(..., description="Name of the llm you want to use provided by litellm")
description: str = Field(default="", description="Description of the agent's purpose")
model_name: str = Field(
..., description="Name of the llm you want to use provided by litellm"
)
description: str = Field(
default="", description="Description of the agent's purpose"
)
system_prompt: str = Field(..., description="System prompt for the agent")
temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Temperature for the model")
temperature: float = Field(
default=0.1, ge=0.0, le=2.0, description="Temperature for the model"
)
max_loops: int = Field(default=1, ge=1, description="Maximum number of loops")
dynamic_temperature_enabled: bool = Field(default=True, description="Enable dynamic temperature")
dynamic_temperature_enabled: bool = Field(
default=True, description="Enable dynamic temperature"
)
user_name: str = Field(default="default_user", description="Username for the agent")
retry_attempts: int = Field(default=1, ge=1, description="Number of retry attempts")
context_length: int = Field(default=200000, ge=1000, description="Context length")
output_type: str = Field(default="string", description="Output type (string or json)")
output_type: str = Field(
default="string", description="Output type (string or json)"
)
streaming_on: bool = Field(default=False, description="Enable streaming")
tags: List[str] = Field(default_factory=list, description="Tags for categorizing the agent")
stopping_token: str = Field(default="<DONE>", description="Stopping token for the agent")
auto_generate_prompt: bool = Field(default=False, description="Auto-generate prompt based on agent details")
tags: List[str] = Field(
default_factory=list, description="Tags for categorizing the agent"
)
stopping_token: str = Field(
default="<DONE>", description="Stopping token for the agent"
)
auto_generate_prompt: bool = Field(
default=False, description="Auto-generate prompt based on agent details"
)


class AgentUpdate(BaseModel):
"""Model for updating agent configuration."""

description: Optional[str] = None
system_prompt: Optional[str] = None
temperature: Optional[float] = 0.5
max_loops: Optional[int] = 1
tags: Optional[List[str]] = None
status: Optional[AgentStatus] = None


class AgentSummary(BaseModel):
"""Summary model for agent listing."""

agent_id: UUID
agent_name: str
description: str
Expand All @@ -185,8 +218,10 @@ class AgentSummary(BaseModel):
tags: List[str]
status: AgentStatus


class AgentMetrics(BaseModel):
"""Model for agent performance metrics."""

total_completions: int
average_response_time: float
error_rate: float
Expand All @@ -196,25 +231,31 @@ class AgentMetrics(BaseModel):
success_rate: float
peak_tokens_per_minute: int


class CompletionRequest(BaseModel):
"""Model for completion requests."""

prompt: str = Field(..., description="The prompt to process")
agent_id: UUID = Field(..., description="ID of the agent to use")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
temperature_override: Optional[float] = 0.5
stream: bool = Field(default=False, description="Enable streaming response")


class CompletionResponse(BaseModel):
"""Model for completion responses."""

agent_id: UUID
response: str
metadata: Dict[str, Any]
timestamp: datetime
processing_time: float
token_usage: Dict[str, int]


class AgentStore:
"""Store for managing agents."""

def __init__(self):
self.agents: Dict[UUID, Agent] = {}
self.agent_metadata: Dict[UUID, Dict[str, Any]] = {}
Expand Down Expand Up @@ -462,6 +503,7 @@ async def process_completion(
finally:
metadata["status"] = AgentStatus.IDLE


class StoreManager:
_instance = None

Expand All @@ -471,10 +513,12 @@ def get_instance(cls) -> AgentStore:
cls._instance = AgentStore()
return cls._instance


def get_store() -> AgentStore:
"""Dependency to get the AgentStore instance."""
return StoreManager.get_instance()


class SwarmsAPI:
def __init__(self):
self.app = FastAPI(
Expand All @@ -499,25 +543,31 @@ def __init__(self):

def _setup_routes(self):
"""Set up API routes."""

def _setup_middleware(self):
@self.app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
# Apply rate limiting
await self.rate_limiter.check_rate_limit(request)

# Add rate limit headers to response
response = await call_next(request)

# Add rate limit status headers
status = self.rate_limiter.get_rate_limit_status(request.client.host)
for window_type, info in status.items():
response.headers[f"X-RateLimit-{window_type}-Limit"] = str(info["limit"])
response.headers[f"X-RateLimit-{window_type}-Remaining"] = str(info["remaining"])
response.headers[f"X-RateLimit-{window_type}-Reset"] = str(info["reset"])

response.headers[f"X-RateLimit-{window_type}-Limit"] = str(
info["limit"]
)
response.headers[f"X-RateLimit-{window_type}-Remaining"] = str(
info["remaining"]
)
response.headers[f"X-RateLimit-{window_type}-Reset"] = str(
info["reset"]
)

return response

@self.app.get("/v1/rate-limit-status")
async def get_rate_limit_status(request: Request):
"""Get current rate limit status"""
Expand Down Expand Up @@ -590,7 +640,7 @@ async def create_completion(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing completion: {str(e)}",
)

@self.app.get("/v1/agent/batch/completions/status")
async def get_batch_completion_status(requests: List[CompletionRequest]):
"""Get status of batch completion requests."""
Expand Down Expand Up @@ -654,9 +704,7 @@ async def _cleanup_old_metrics(self, agent_id: UUID):
# Clean up old tokens per minute data
if "tokens_per_minute" in metadata:
metadata["tokens_per_minute"] = {
k: v
for k, v in metadata["tokens_per_minute"].items()
if k > cutoff
k: v for k, v in metadata["tokens_per_minute"].items() if k > cutoff
}


Expand All @@ -671,7 +719,6 @@ def __init__(self, app: FastAPI, host: str = "0.0.0.0", port: int = 8080):
port=port,
log_level="info",
access_log=True,
workers=os.cpu_count() * 2,
)
self.server = UvicornServer(config=self.config)

Expand Down Expand Up @@ -742,4 +789,4 @@ def run_server():


if __name__ == "__main__":
run_server()
run_server()
Loading

0 comments on commit 90b4f39

Please sign in to comment.