Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
dtam committed Sep 13, 2024
1 parent 77e9bce commit 78fde5c
Show file tree
Hide file tree
Showing 15 changed files with 222 additions and 98 deletions.
94 changes: 72 additions & 22 deletions guardrails_api/api/guards.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import json
import os
import inspect
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Request, Response, APIRouter
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request, APIRouter
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from urllib.parse import unquote_plus
from guardrails import AsyncGuard, Guard
from guardrails.classes import ValidationOutcome
Expand All @@ -16,7 +14,10 @@
from guardrails_api.clients.pg_guard_client import PGGuardClient
from guardrails_api.clients.postgres_client import postgres_is_enabled
from guardrails_api.utils.get_llm_callable import get_llm_callable
from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response
from guardrails_api.utils.openai import (
outcome_to_chat_completion,
outcome_to_stream_response,
)
from guardrails_api.utils.handle_error import handle_error
from string import Template

Expand All @@ -39,59 +40,87 @@

router = APIRouter()


@router.get("/guards")
@handle_error
async def get_guards():
guards = guard_client.get_guards()
return [g.to_dict() for g in guards]


@router.post("/guards")
@handle_error
async def create_guard(guard: GuardStruct):
if not postgres_is_enabled():
raise HTTPException(status_code=501, detail="Not Implemented POST /guards is not implemented for in-memory guards.")
raise HTTPException(
status_code=501,
detail="Not Implemented POST /guards is not implemented for in-memory guards.",
)
new_guard = guard_client.create_guard(guard)
return new_guard.to_dict()


@router.get("/guards/{guard_name}")
@handle_error
async def get_guard(guard_name: str, asOf: Optional[str] = None):
decoded_guard_name = unquote_plus(guard_name)
guard = guard_client.get_guard(decoded_guard_name, asOf)
if guard is None:
raise HTTPException(status_code=404, detail=f"A Guard with the name {decoded_guard_name} does not exist!")
raise HTTPException(
status_code=404,
detail=f"A Guard with the name {decoded_guard_name} does not exist!",
)
return guard.to_dict()


@router.put("/guards/{guard_name}")
@handle_error
async def update_guard(guard_name: str, guard: GuardStruct):
if not postgres_is_enabled():
raise HTTPException(status_code=501, detail="PUT /<guard_name> is not implemented for in-memory guards.")
raise HTTPException(
status_code=501,
detail="PUT /<guard_name> is not implemented for in-memory guards.",
)
decoded_guard_name = unquote_plus(guard_name)
updated_guard = guard_client.upsert_guard(decoded_guard_name, guard)
return updated_guard.to_dict()


@router.delete("/guards/{guard_name}")
@handle_error
async def delete_guard(guard_name: str):
if not postgres_is_enabled():
raise HTTPException(status_code=501, detail="DELETE /<guard_name> is not implemented for in-memory guards.")
raise HTTPException(
status_code=501,
detail="DELETE /<guard_name> is not implemented for in-memory guards.",
)
decoded_guard_name = unquote_plus(guard_name)
guard = guard_client.delete_guard(decoded_guard_name)
return guard.to_dict()


@router.post("/guards/{guard_name}/openai/v1/chat/completions")
@handle_error
async def openai_v1_chat_completions(guard_name: str, request: Request):
payload = await request.json()
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)
if guard_struct is None:
raise HTTPException(status_code=404, detail=f"A Guard with the name {decoded_guard_name} does not exist!")
raise HTTPException(
status_code=404,
detail=f"A Guard with the name {decoded_guard_name} does not exist!",
)

guard = Guard.from_dict(guard_struct.to_dict()) if not isinstance(guard_struct, Guard) else guard_struct
guard = (
Guard.from_dict(guard_struct.to_dict())
if not isinstance(guard_struct, Guard)
else guard_struct
)
stream = payload.get("stream", False)
has_tool_gd_tool_call = any(tool.get("function", {}).get("name") == "gd_response_tool" for tool in payload.get("tools", []))
has_tool_gd_tool_call = any(
tool.get("function", {}).get("name") == "gd_response_tool"
for tool in payload.get("tools", [])
)

if not stream:
validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload)
Expand All @@ -103,23 +132,29 @@ async def openai_v1_chat_completions(guard_name: str, request: Request):
)
return JSONResponse(content=result)
else:

async def openai_streamer():
guard_stream = guard(num_reasks=0, **payload)
for result in guard_stream:
chunk = json.dumps(outcome_to_stream_response(validation_outcome=result))
chunk = json.dumps(
outcome_to_stream_response(validation_outcome=result)
)
yield f"data: {chunk}\n\n"
yield "\n"

return StreamingResponse(openai_streamer(), media_type="text/event-stream")


@router.post("/guards/{guard_name}/validate")
@handle_error
async def validate(guard_name: str, request: Request):
payload = await request.json()
openai_api_key = request.headers.get("x-openai-api-key", os.environ.get("OPENAI_API_KEY"))
openai_api_key = request.headers.get(
"x-openai-api-key", os.environ.get("OPENAI_API_KEY")
)
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)

llm_output = payload.pop("llmOutput", None)
num_reasks = payload.pop("numReasks", None)
prompt_params = payload.pop("promptParams", {})
Expand All @@ -132,25 +167,33 @@ async def validate(guard_name: str, request: Request):
if llm_api is not None:
llm_api = get_llm_callable(llm_api)
if openai_api_key is None:
raise HTTPException(status_code=400, detail="Cannot perform calls to OpenAI without an api key.")
raise HTTPException(
status_code=400,
detail="Cannot perform calls to OpenAI without an api key.",
)

guard = guard_struct
is_async = inspect.iscoroutinefunction(llm_api)

if not isinstance(guard_struct, Guard):
if is_async:
guard = AsyncGuard.from_dict(guard_struct.to_dict())
else:
guard: Guard = Guard.from_dict(guard_struct.to_dict())
elif is_async:
guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict())
guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict())

if llm_api is None and num_reasks and num_reasks > 1:
raise HTTPException(status_code=400, detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).")
raise HTTPException(
status_code=400,
detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).",
)

if llm_output is not None:
if stream:
raise HTTPException(status_code=400, detail="Streaming is not supported for parse calls!")
raise HTTPException(
status_code=400, detail="Streaming is not supported for parse calls!"
)
result: ValidationOutcome = guard.parse(
llm_output=llm_output,
num_reasks=num_reasks,
Expand All @@ -160,6 +203,7 @@ async def validate(guard_name: str, request: Request):
)
else:
if stream:

async def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
Expand All @@ -170,7 +214,9 @@ async def guard_streamer():
**payload,
)
for result in guard_stream:
validation_output = ValidationOutcome.from_guard_history(guard.history.last)
validation_output = ValidationOutcome.from_guard_history(
guard.history.last
)
yield validation_output, result

async def validate_streamer(guard_iter):
Expand Down Expand Up @@ -201,7 +247,9 @@ async def validate_streamer(guard_iter):
cache_key = f"{guard.name}-{final_validation_output.call_id}"
await cache_client.set(cache_key, serialized_history, 300)

return StreamingResponse(validate_streamer(guard_streamer()), media_type="application/json")
return StreamingResponse(
validate_streamer(guard_streamer()), media_type="application/json"
)
else:
result: ValidationOutcome = guard(
llm_api=llm_api,
Expand All @@ -216,12 +264,14 @@ async def validate_streamer(guard_iter):
# await cache_client.set(cache_key, serialized_history, 300)
return result.to_dict()


@router.get("/guards/{guard_name}/history/{call_id}")
@handle_error
async def guard_history(guard_name: str, call_id: str):
cache_key = f"{guard_name}-{call_id}"
return await cache_client.get(cache_key)


def collect_telemetry(
*,
guard: Guard,
Expand Down
12 changes: 8 additions & 4 deletions guardrails_api/api/root.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import json
from string import Template
from typing import Dict

from fastapi import HTTPException, APIRouter
from fastapi.responses import HTMLResponse, JSONResponse
Expand All @@ -13,16 +11,20 @@
from guardrails_api.clients.postgres_client import PostgresClient, postgres_is_enabled
from guardrails_api.utils.logger import logger


class HealthCheckResponse(BaseModel):
status: int
message: str


router = APIRouter()


@router.get("/")
async def home():
return "Hello, FastAPI!"


@router.get("/health-check", response_model=HealthCheckResponse)
async def health_check():
try:
Expand All @@ -32,19 +34,21 @@ async def health_check():
pg_client = PostgresClient()
query = text("SELECT count(datid) FROM pg_stat_activity;")
response = pg_client.db.session.execute(query).all()

logger.info("response: %s", response)

return HealthCheck(200, "Ok").to_dict()
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")


@router.get("/api-docs", response_class=JSONResponse)
async def api_docs():
api_spec = get_open_api_spec()
return JSONResponse(content=api_spec)


@router.get("/docs", response_class=HTMLResponse)
async def docs():
host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000")
Expand Down
25 changes: 18 additions & 7 deletions guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from fastapi.responses import JSONResponse
from guardrails import configure_logging
from guardrails_api.clients.cache_client import CacheClient
from guardrails_api.clients.cache_client import CacheClient
from guardrails_api.clients.postgres_client import postgres_is_enabled
from guardrails_api.otel import otel_is_disabled, initialize
from guardrails_api.utils.trace_server_start_if_enabled import trace_server_start_if_enabled
from guardrails_api.utils.trace_server_start_if_enabled import (
trace_server_start_if_enabled,
)
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from rich.console import Console
from rich.rule import Rule
Expand Down Expand Up @@ -68,6 +69,7 @@
# else:
# return await call_next(request)


# Custom JSON encoder
class CustomJSONEncoder(json.JSONEncoder):
def default(self, o):
Expand All @@ -77,6 +79,7 @@ def default(self, o):
return str(o)
return super().default(o)


# Custom middleware for reverse proxy
class ReverseProxyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
Expand All @@ -86,6 +89,7 @@ async def dispatch(self, request: Request, call_next):
response = await call_next(request)
return response


def register_config(config: Optional[str] = None):
default_config_file = os.path.join(os.getcwd(), "./config.py")
config_file = config or default_config_file
Expand All @@ -95,6 +99,7 @@ def register_config(config: Optional[str] = None):
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)


def create_app(
env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None
):
Expand Down Expand Up @@ -126,7 +131,7 @@ def create_app(
FastAPIInstrumentor.instrument_app(app)

# app.add_middleware(ProfilingMiddleware)

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -169,25 +174,31 @@ async def value_error_handler(request: Request, exc: ValueError):
content={"message": str(exc)},
)

console.print(f"\n:rocket: Guardrails API is available at {self_endpoint}")
console.print(
f"\n:rocket: Guardrails API is available at {self_endpoint}"
f":book: Visit {self_endpoint}/docs to see available API endpoints.\n"
)
console.print(f":book: Visit {self_endpoint}/docs to see available API endpoints.\n")

console.print(":green_circle: Active guards and OpenAI compatible endpoints:")

guards = guard_client.get_guards()

for g in guards:
g_dict = g.to_dict()
console.print(f"- Guard: [bold white]{g_dict.get('name')}[/bold white] {self_endpoint}/guards/{g_dict.get('name')}/openai/v1")
console.print(
f"- Guard: [bold white]{g_dict.get('name')}[/bold white] {self_endpoint}/guards/{g_dict.get('name')}/openai/v1"
)

console.print("")
console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white"))
console.print(
Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")
)

return app


if __name__ == "__main__":
import uvicorn

app = create_app()
uvicorn.run(app, host="0.0.0.0", port=8000)
1 change: 1 addition & 0 deletions guardrails_api/cli/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from guardrails_api.app import create_app
from guardrails_api.utils.configuration import valid_configuration


@cli.command("start")
def start(
env: Optional[str] = typer.Option(
Expand Down
Loading

0 comments on commit 78fde5c

Please sign in to comment.