Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate from flask to fast api for uvicorn and asgi support #75

Merged
merged 13 commits into from
Sep 26, 2024
Merged
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ COPY . .
EXPOSE 8000

# This is our start command; yours might be different.
# The guardrails-api is a standard Flask application.
# You can use whatever production server you want that support Flask.
# The guardrails-api is a standard FastAPI application.
# You can use whatever production server you want that support FastAPI.
# Here we use gunicorn
CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --workers=2 'guardrails_api.app:create_app(".env", "sample-config.py")'
6 changes: 3 additions & 3 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ services:
profiles: ["all", "db", "infra"]
image: ankane/pgvector
environment:
POSTGRES_USER: ${PGUSER:-postgres}
POSTGRES_PASSWORD: ${PGPASSWORD:-changeme}
POSTGRES_USER: admin
POSTGRES_PASSWORD: admin
POSTGRES_DATA: /data/postgres
volumes:
- ./postgres:/data/postgres
Expand All @@ -21,7 +21,7 @@ services:
- "8088:80"
environment:
PGADMIN_DEFAULT_EMAIL: "${PGUSER:-postgres}@guardrails.com"
PGADMIN_DEFAULT_PASSWORD: ${PGPASSWORD:-changeme}
PGADMIN_DEFAULT_PASSWORD: admin
PGADMIN_SERVER_JSON_FILE: /var/lib/pgadmin/servers.json
# FIXME: Copy over server.json file and create passfile
volumes:
Expand Down
23 changes: 16 additions & 7 deletions guardrails_api/api/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,22 @@ async def validate_streamer(guard_iter):
validate_streamer(guard_streamer()), media_type="application/json"
)
else:
result: ValidationOutcome = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
*args,
**payload,
)
if inspect.iscoroutinefunction(guard):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super pedantic, but any reason to not call it once and inspect on the result? albeit, I think the type checker gets a little more angry with that method

result: ValidationOutcome = await guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
*args,
**payload,
)
else:
result: ValidationOutcome = guard(
llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
*args,
**payload,
)

serialized_history = [call.to_dict() for call in guard.history]
cache_key = f"{guard.name}-{result.call_id}"
Expand Down
37 changes: 36 additions & 1 deletion guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,46 @@
trace_server_start_if_enabled,
)
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry import trace, context, baggage

from rich.console import Console
from rich.rule import Rule
from typing import Optional
import importlib.util
import json
import os

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request

class RequestInfoMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
tracer = trace.get_tracer(__name__)
# Get the current context and attach it to this task
with tracer.start_as_current_span("request_info") as span:
client_ip = request.client.host
user_agent = request.headers.get("user-agent", "unknown")
referrer = request.headers.get("referrer", "unknown")
user_id = request.headers.get("x-user-id", "unknown")
organization = request.headers.get("x-organization", "unknown")
app = request.headers.get("x-app", "unknown")

context.attach(baggage.set_baggage("client.ip", client_ip))
context.attach(baggage.set_baggage("http.user_agent", user_agent))
context.attach(baggage.set_baggage("http.referrer", referrer))
context.attach(baggage.set_baggage("user.id", user_id))
context.attach(baggage.set_baggage("organization", organization))
context.attach(baggage.set_baggage("app", app))

span.set_attribute("client.ip", client_ip)
span.set_attribute("http.user_agent", user_agent)
span.set_attribute("http.referrer", referrer)
span.set_attribute("user.id", user_id)
span.set_attribute("organization", organization)
span.set_attribute("app", app)

response = await call_next(request)
return response

# Custom JSON encoder
class CustomJSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -64,10 +97,12 @@ def create_app(

app = FastAPI(openapi_url="")

# Add the custom middleware
app.add_middleware(RequestInfoMiddleware)

# Initialize FastAPIInstrumentor
FastAPIInstrumentor.instrument_app(app)

# app.add_middleware(ProfilingMiddleware)

# Add CORS middleware
app.add_middleware(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"]
requires-python = ">= 3.8.1"
dependencies = [
"guardrails-ai>=0.5.6",
"guardrails-ai>=0.5.10",
"Werkzeug>=3.0.3,<4",
"jsonschema>=4.22.0,<5",
"referencing>=0.35.1,<1",
Expand All @@ -22,7 +22,7 @@ dependencies = [
"opentelemetry-sdk>=1.0.0,<2",
"opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2",
"opentelemetry-exporter-otlp-proto-http>=1.0.0,<2",
"opentelemetry-instrumentation-fastapi>=0.47b0",
"opentelemetry-instrumentation-fastapi>=0.48b0",
"requests>=2.32.3",
"aiocache>=0.11.1",
"fastapi>=0.114.1",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"opentelemetry-sdk>=1.0.0,<2",
"opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2",
"opentelemetry-exporter-otlp-proto-http>=1.0.0,<2",
"opentelemetry-instrumentation-fastapi>=0.47b0",
"opentelemetry-instrumentation-fastapi>=0.48b0",
"requests>=2.32.3",
"aiocache>=0.11.1",
"fastapi>=0.114.1",
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def client(app):
def test_home(client):
response = client.get("/")
assert response.status_code == 200
assert response.json() == "Hello, FastAPI!"
assert response.json() == "Hello, world!"

# Check if all expected routes are registered
routes = [route.path for route in client.app.routes]
Expand Down