diff --git a/Dockerfile b/Dockerfile index ca847f0..d01a5ee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,7 +45,7 @@ COPY . . EXPOSE 8000 # This is our start command; yours might be different. -# The guardrails-api is a standard Flask application. -# You can use whatever production server you want that support Flask. +# The guardrails-api is a standard FastAPI application. +# You can use whatever production server you want that support FastAPI. # Here we use gunicorn CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --workers=2 'guardrails_api.app:create_app(".env", "sample-config.py")' \ No newline at end of file diff --git a/guardrails_api/blueprints/__init__.py b/guardrails_api/api/__init__.py similarity index 100% rename from guardrails_api/blueprints/__init__.py rename to guardrails_api/api/__init__.py diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py new file mode 100644 index 0000000..afdd206 --- /dev/null +++ b/guardrails_api/api/guards.py @@ -0,0 +1,324 @@ +import json +import os +import inspect +from typing import Any, Dict, Optional +from fastapi import HTTPException, Request, APIRouter +from fastapi.responses import JSONResponse, StreamingResponse +from urllib.parse import unquote_plus +from guardrails import AsyncGuard, Guard +from guardrails.classes import ValidationOutcome +from opentelemetry.trace import Span +from guardrails_api_client import Guard as GuardStruct +from guardrails_api.clients.cache_client import CacheClient +from guardrails_api.clients.memory_guard_client import MemoryGuardClient +from guardrails_api.clients.pg_guard_client import PGGuardClient +from guardrails_api.clients.postgres_client import postgres_is_enabled +from guardrails_api.utils.get_llm_callable import get_llm_callable +from guardrails_api.utils.openai import ( + outcome_to_chat_completion, + outcome_to_stream_response, +) +from guardrails_api.utils.handle_error import handle_error +from string import Template + +# if no pg_host is set, use in memory guards +if postgres_is_enabled(): + guard_client = PGGuardClient() +else: + guard_client = MemoryGuardClient() + # Will be defined at runtime + import config # noqa + + exports = config.__dir__() + for export_name in exports: + export = getattr(config, export_name) + is_guard = isinstance(export, Guard) + if is_guard: + guard_client.create_guard(export) + +cache_client = CacheClient() + +cache_client.initialize() + +router = APIRouter() + + +@router.get("/guards") +@handle_error +async def get_guards(): + guards = guard_client.get_guards() + return [g.to_dict() for g in guards] + + +@router.post("/guards") +@handle_error +async def create_guard(guard: GuardStruct): + if not postgres_is_enabled(): + raise HTTPException( + status_code=501, + detail="Not Implemented POST /guards is not implemented for in-memory guards.", + ) + new_guard = guard_client.create_guard(guard) + return new_guard.to_dict() + + +@router.get("/guards/{guard_name}") +@handle_error +async def get_guard(guard_name: str, asOf: Optional[str] = None): + decoded_guard_name = unquote_plus(guard_name) + guard = guard_client.get_guard(decoded_guard_name, asOf) + if guard is None: + raise HTTPException( + status_code=404, + detail=f"A Guard with the name {decoded_guard_name} does not exist!", + ) + return guard.to_dict() + + +@router.put("/guards/{guard_name}") +@handle_error +async def update_guard(guard_name: str, guard: GuardStruct): + if not postgres_is_enabled(): + raise HTTPException( + status_code=501, + detail="PUT / is not implemented for in-memory guards.", + ) + decoded_guard_name = unquote_plus(guard_name) + updated_guard = guard_client.upsert_guard(decoded_guard_name, guard) + return updated_guard.to_dict() + + +@router.delete("/guards/{guard_name}") +@handle_error +async def delete_guard(guard_name: str): + if not postgres_is_enabled(): + raise HTTPException( + status_code=501, + detail="DELETE / is not implemented for in-memory guards.", + ) + decoded_guard_name = unquote_plus(guard_name) + guard = guard_client.delete_guard(decoded_guard_name) + return guard.to_dict() + + +@router.post("/guards/{guard_name}/openai/v1/chat/completions") +@handle_error +async def openai_v1_chat_completions(guard_name: str, request: Request): + payload = await request.json() + decoded_guard_name = unquote_plus(guard_name) + guard_struct = guard_client.get_guard(decoded_guard_name) + if guard_struct is None: + raise HTTPException( + status_code=404, + detail=f"A Guard with the name {decoded_guard_name} does not exist!", + ) + + guard = ( + Guard.from_dict(guard_struct.to_dict()) + if not isinstance(guard_struct, Guard) + else guard_struct + ) + stream = payload.get("stream", False) + has_tool_gd_tool_call = any( + tool.get("function", {}).get("name") == "gd_response_tool" + for tool in payload.get("tools", []) + ) + + if not stream: + validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload) + llm_response = guard.history.last.iterations.last.outputs.llm_response_info + result = outcome_to_chat_completion( + validation_outcome=validation_outcome, + llm_response=llm_response, + has_tool_gd_tool_call=has_tool_gd_tool_call, + ) + return JSONResponse(content=result) + else: + + async def openai_streamer(): + guard_stream = guard(num_reasks=0, **payload) + for result in guard_stream: + chunk = json.dumps( + outcome_to_stream_response(validation_outcome=result) + ) + yield f"data: {chunk}\n\n" + yield "\n" + + return StreamingResponse(openai_streamer(), media_type="text/event-stream") + + +@router.post("/guards/{guard_name}/validate") +@handle_error +async def validate(guard_name: str, request: Request): + payload = await request.json() + openai_api_key = request.headers.get( + "x-openai-api-key", os.environ.get("OPENAI_API_KEY") + ) + decoded_guard_name = unquote_plus(guard_name) + guard_struct = guard_client.get_guard(decoded_guard_name) + + llm_output = payload.pop("llmOutput", None) + num_reasks = payload.pop("numReasks", None) + prompt_params = payload.pop("promptParams", {}) + llm_api = payload.pop("llmApi", None) + args = payload.pop("args", []) + stream = payload.pop("stream", False) + + payload["api_key"] = payload.get("api_key", openai_api_key) + + if llm_api is not None: + llm_api = get_llm_callable(llm_api) + if openai_api_key is None: + raise HTTPException( + status_code=400, + detail="Cannot perform calls to OpenAI without an api key.", + ) + + guard = guard_struct + is_async = inspect.iscoroutinefunction(llm_api) + + if not isinstance(guard_struct, Guard): + if is_async: + guard = AsyncGuard.from_dict(guard_struct.to_dict()) + else: + guard: Guard = Guard.from_dict(guard_struct.to_dict()) + elif is_async: + guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict()) + + if llm_api is None and num_reasks and num_reasks > 1: + raise HTTPException( + status_code=400, + detail="Cannot perform re-asks without an LLM API. Specify llm_api when calling guard(...).", + ) + + if llm_output is not None: + if stream: + raise HTTPException( + status_code=400, detail="Streaming is not supported for parse calls!" + ) + result: ValidationOutcome = guard.parse( + llm_output=llm_output, + num_reasks=num_reasks, + prompt_params=prompt_params, + llm_api=llm_api, + **payload, + ) + else: + if stream: + + async def guard_streamer(): + guard_stream = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + stream=stream, + *args, + **payload, + ) + for result in guard_stream: + validation_output = ValidationOutcome.from_guard_history( + guard.history.last + ) + yield validation_output, result + + async def validate_streamer(guard_iter): + async for validation_output, result in guard_iter: + fragment_dict = result.to_dict() + fragment_dict["error_spans"] = [ + json.dumps({"start": x.start, "end": x.end, "reason": x.reason}) + for x in guard.error_spans_in_output() + ] + yield json.dumps(fragment_dict) + "\n" + + call = guard.history.last + final_validation_output = ValidationOutcome( + callId=call.id, + validation_passed=result.validation_passed, + validated_output=result.validated_output, + history=guard.history, + raw_llm_output=result.raw_llm_output, + ) + final_output_dict = final_validation_output.to_dict() + final_output_dict["error_spans"] = [ + json.dumps({"start": x.start, "end": x.end, "reason": x.reason}) + for x in guard.error_spans_in_output() + ] + yield json.dumps(final_output_dict) + "\n" + + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{final_validation_output.call_id}" + await cache_client.set(cache_key, serialized_history, 300) + + return StreamingResponse( + validate_streamer(guard_streamer()), media_type="application/json" + ) + else: + if inspect.iscoroutinefunction(guard): + result: ValidationOutcome = await guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + *args, + **payload, + ) + else: + result: ValidationOutcome = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + *args, + **payload, + ) + + serialized_history = [call.to_dict() for call in guard.history] + cache_key = f"{guard.name}-{result.call_id}" + await cache_client.set(cache_key, serialized_history, 300) + return result.to_dict() + + +@router.get("/guards/{guard_name}/history/{call_id}") +@handle_error +async def guard_history(guard_name: str, call_id: str): + cache_key = f"{guard_name}-{call_id}" + return await cache_client.get(cache_key) + + +def collect_telemetry( + *, + guard: Guard, + validate_span: Span, + validation_output: ValidationOutcome, + prompt_params: Dict[str, Any], + result: ValidationOutcome, +): + # Below is all telemetry collection and + # should have no impact on what is returned to the user + prompt = guard.history.last.inputs.prompt + if prompt: + prompt = Template(prompt).safe_substitute(**prompt_params) + validate_span.set_attribute("prompt", prompt) + + instructions = guard.history.last.inputs.instructions + if instructions: + instructions = Template(instructions).safe_substitute(**prompt_params) + validate_span.set_attribute("instructions", instructions) + + validate_span.set_attribute("validation_status", guard.history.last.status) + validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output) + + # Use the serialization from the class instead of re-writing it + valid_output: str = ( + json.dumps(validation_output.validated_output) + if isinstance(validation_output.validated_output, dict) + else str(validation_output.validated_output) + ) + validate_span.set_attribute("validated_output", valid_output) + + validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed) + + num_of_reasks = ( + guard.history.last.iterations.length - 1 + if guard.history.last.iterations.length > 0 + else 0 + ) + validate_span.set_attribute("num_of_reasks", num_of_reasks) diff --git a/guardrails_api/api/root.py b/guardrails_api/api/root.py new file mode 100644 index 0000000..6a13a9b --- /dev/null +++ b/guardrails_api/api/root.py @@ -0,0 +1,78 @@ +import os +from string import Template + +from fastapi import HTTPException, APIRouter +from fastapi.responses import HTMLResponse, JSONResponse +from pydantic import BaseModel + +from guardrails_api.open_api_spec import get_open_api_spec +from sqlalchemy import text +from guardrails_api.classes.health_check import HealthCheck +from guardrails_api.clients.postgres_client import PostgresClient, postgres_is_enabled +from guardrails_api.utils.logger import logger + + +class HealthCheckResponse(BaseModel): + status: int + message: str + + +router = APIRouter() + + +@router.get("/") +async def home(): + return "Hello, world!" + + +@router.get("/health-check", response_model=HealthCheckResponse) +async def health_check(): + try: + if not postgres_is_enabled(): + return HealthCheck(200, "Ok").to_dict() + + pg_client = PostgresClient() + query = text("SELECT count(datid) FROM pg_stat_activity;") + response = pg_client.db.session.execute(query).all() + + logger.info("response: %s", response) + + return HealthCheck(200, "Ok").to_dict() + except Exception as e: + logger.error(f"Health check failed: {str(e)}") + raise HTTPException(status_code=500, detail="Internal Server Error") + + +@router.get("/api-docs", response_class=JSONResponse) +async def api_docs(): + api_spec = get_open_api_spec() + return JSONResponse(content=api_spec) + + +@router.get("/docs", response_class=HTMLResponse) +async def docs(): + host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") + swagger_ui = Template(""" + + + + + + SwaggerUI + + + +
+ + + +""").safe_substitute(apiDocUrl=f"{host}/api-docs") + + return HTMLResponse(content=swagger_ui) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index a33fbbc..29e2d44 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -1,22 +1,58 @@ -import os -from typing import Optional -from flask import Flask -from flask.json.provider import DefaultJSONProvider -from flask_cors import CORS -from werkzeug.middleware.proxy_fix import ProxyFix -from urllib.parse import urlparse +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from guardrails import configure_logging -from opentelemetry.instrumentation.flask import FlaskInstrumentor +from guardrails_api.clients.cache_client import CacheClient from guardrails_api.clients.postgres_client import postgres_is_enabled from guardrails_api.otel import otel_is_disabled, initialize -from guardrails_api.utils.trace_server_start_if_enabled import trace_server_start_if_enabled -from guardrails_api.clients.cache_client import CacheClient +from guardrails_api.utils.trace_server_start_if_enabled import ( + trace_server_start_if_enabled, +) +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor +from opentelemetry import trace, context, baggage + from rich.console import Console from rich.rule import Rule +from typing import Optional +import importlib.util +import json +import os - -# TODO: Move this to a separate file -class OverrideJsonProvider(DefaultJSONProvider): +from starlette.middleware.base import BaseHTTPMiddleware + + +class RequestInfoMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + tracer = trace.get_tracer(__name__) + # Get the current context and attach it to this task + with tracer.start_as_current_span("request_info") as span: + client_ip = request.client.host + user_agent = request.headers.get("user-agent", "unknown") + referrer = request.headers.get("referrer", "unknown") + user_id = request.headers.get("x-user-id", "unknown") + organization = request.headers.get("x-organization", "unknown") + app = request.headers.get("x-app", "unknown") + + context.attach(baggage.set_baggage("client.ip", client_ip)) + context.attach(baggage.set_baggage("http.user_agent", user_agent)) + context.attach(baggage.set_baggage("http.referrer", referrer)) + context.attach(baggage.set_baggage("user.id", user_id)) + context.attach(baggage.set_baggage("organization", organization)) + context.attach(baggage.set_baggage("app", app)) + + span.set_attribute("client.ip", client_ip) + span.set_attribute("http.user_agent", user_agent) + span.set_attribute("http.referrer", referrer) + span.set_attribute("user.id", user_id) + span.set_attribute("organization", organization) + span.set_attribute("app", app) + + response = await call_next(request) + return response + + +# Custom JSON encoder +class CustomJSONEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, set): return list(o) @@ -25,28 +61,14 @@ def default(self, o): return super().default(o) -class ReverseProxied(object): - def __init__(self, app): - self.app = app - - def __call__(self, environ, start_response): - self_endpoint = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") - url = urlparse(self_endpoint) - environ["wsgi.url_scheme"] = url.scheme - return self.app(environ, start_response) - - def register_config(config: Optional[str] = None): default_config_file = os.path.join(os.getcwd(), "./config.py") config_file = config or default_config_file config_file_path = os.path.abspath(config_file) if os.path.isfile(config_file_path): - from importlib.machinery import SourceFileLoader - - # This creates a module named "validators" with the contents of the init file - # This allow statements like `from validators import StartsWith` - # But more importantly, it registers all of the validators imported in the init - SourceFileLoader("config", config_file_path).load_module() + spec = importlib.util.spec_from_file_location("config", config_file_path) + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) def create_app( @@ -74,21 +96,27 @@ def create_app( register_config(config) - app = Flask(__name__) - app.json = OverrideJsonProvider(app) + app = FastAPI(openapi_url="") + + # Add the custom middleware + app.add_middleware(RequestInfoMiddleware) - app.config["APPLICATION_ROOT"] = "/" - app.config["PREFERRED_URL_SCHEME"] = "https" - app.wsgi_app = ReverseProxied(app.wsgi_app) - CORS(app) + # Initialize FastAPIInstrumentor + FastAPIInstrumentor.instrument_app(app) - app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1) + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) guardrails_log_level = os.environ.get("GUARDRAILS_LOG_LEVEL", "INFO") configure_logging(log_level=guardrails_log_level) if not otel_is_disabled(): - FlaskInstrumentor().instrument_app(app) initialize() # if no pg_host is set, don't set up postgres @@ -99,28 +127,47 @@ def create_app( pg_client.initialize(app) cache_client = CacheClient() - cache_client.initialize(app) + cache_client.initialize() - from guardrails_api.blueprints.root import root_bp - from guardrails_api.blueprints.guards import guards_bp + from guardrails_api.api.root import router as root_router + from guardrails_api.api.guards import router as guards_router, guard_client - app.register_blueprint(root_bp) - app.register_blueprint(guards_bp) + app.include_router(root_router) + app.include_router(guards_router) + # Custom JSON encoder + @app.exception_handler(ValueError) + async def value_error_handler(request: Request, exc: ValueError): + return JSONResponse( + status_code=400, + content={"message": str(exc)}, + ) + + console.print(f"\n:rocket: Guardrails API is available at {self_endpoint}") console.print( - f"\n:rocket: Guardrails API is available at {self_endpoint}" + f":book: Visit {self_endpoint}/docs to see available API endpoints.\n" ) - console.print(f":book: Visit {self_endpoint}/docs to see available API endpoints.\n") console.print(":green_circle: Active guards and OpenAI compatible endpoints:") - with app.app_context(): - from guardrails_api.blueprints.guards import guard_client - for g in guard_client.get_guards(): - g = g.to_dict() - console.print(f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1") + guards = guard_client.get_guards() + + for g in guards: + g_dict = g.to_dict() + console.print( + f"- Guard: [bold white]{g_dict.get('name')}[/bold white] {self_endpoint}/guards/{g_dict.get('name')}/openai/v1" + ) console.print("") - console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")) + console.print( + Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white") + ) + + return app + + +if __name__ == "__main__": + import uvicorn - return app \ No newline at end of file + app = create_app() + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/guardrails_api/blueprints/guards.py b/guardrails_api/blueprints/guards.py deleted file mode 100644 index dd8db10..0000000 --- a/guardrails_api/blueprints/guards.py +++ /dev/null @@ -1,516 +0,0 @@ -import asyncio -import json -import os -import inspect -from guardrails.hub import * # noqa -from string import Template -from typing import Any, Dict, cast -from flask import Blueprint, Response, request, stream_with_context -from urllib.parse import unquote_plus -from guardrails import AsyncGuard, Guard -from guardrails.classes import ValidationOutcome -from opentelemetry.trace import Span -from guardrails_api_client import Guard as GuardStruct -from guardrails_api.classes.http_error import HttpError -from guardrails_api.clients.cache_client import CacheClient -from guardrails_api.clients.memory_guard_client import MemoryGuardClient -from guardrails_api.clients.pg_guard_client import PGGuardClient -from guardrails_api.clients.postgres_client import postgres_is_enabled -from guardrails_api.utils.handle_error import handle_error -from guardrails_api.utils.get_llm_callable import get_llm_callable -from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response - -guards_bp = Blueprint("guards", __name__, url_prefix="/guards") - - -# if no pg_host is set, use in memory guards -if postgres_is_enabled(): - guard_client = PGGuardClient() -else: - guard_client = MemoryGuardClient() - # Will be defined at runtime - import config # noqa - - exports = config.__dir__() - for export_name in exports: - export = getattr(config, export_name) - is_guard = isinstance(export, Guard) - if is_guard: - guard_client.create_guard(export) - -cache_client = CacheClient() - - -@guards_bp.route("/", methods=["GET", "POST"]) -@handle_error -def guards(): - if request.method == "GET": - guards = guard_client.get_guards() - return [g.to_dict() for g in guards] - elif request.method == "POST": - if not postgres_is_enabled(): - raise HttpError( - 501, - "NotImplemented", - "POST /guards is not implemented for in-memory guards.", - ) - payload = request.json - guard = GuardStruct.from_dict(payload) - new_guard = guard_client.create_guard(guard) - return new_guard.to_dict() - else: - raise HttpError( - 405, - "Method Not Allowed", - "/guards only supports the GET and POST methods. You specified" - " {request_method}".format(request_method=request.method), - ) - - -@guards_bp.route("/", methods=["GET", "PUT", "DELETE"]) -@handle_error -def guard(guard_name: str): - decoded_guard_name = unquote_plus(guard_name) - if request.method == "GET": - as_of_query = request.args.get("asOf") - guard = guard_client.get_guard(decoded_guard_name, as_of_query) - if guard is None: - raise HttpError( - 404, - "NotFound", - "A Guard with the name {guard_name} does not exist!".format( - guard_name=decoded_guard_name - ), - ) - return guard.to_dict() - elif request.method == "PUT": - if not postgres_is_enabled(): - raise HttpError( - 501, - "NotImplemented", - "PUT / is not implemented for in-memory guards.", - ) - payload = request.json - guard = GuardStruct.from_dict(payload) - updated_guard = guard_client.upsert_guard(decoded_guard_name, guard) - return updated_guard.to_dict() - elif request.method == "DELETE": - if not postgres_is_enabled(): - raise HttpError( - 501, - "NotImplemented", - "DELETE / is not implemented for in-memory guards.", - ) - guard = guard_client.delete_guard(decoded_guard_name) - return guard.to_dict() - else: - raise HttpError( - 405, - "Method Not Allowed", - "/guard/ only supports the GET, PUT, and DELETE methods." - " You specified {request_method}".format(request_method=request.method), - ) - - -def collect_telemetry( - *, - guard: Guard, - validate_span: Span, - validation_output: ValidationOutcome, - prompt_params: Dict[str, Any], - result: ValidationOutcome, -): - # Below is all telemetry collection and - # should have no impact on what is returned to the user - prompt = guard.history.last.inputs.prompt - if prompt: - prompt = Template(prompt).safe_substitute(**prompt_params) - validate_span.set_attribute("prompt", prompt) - - instructions = guard.history.last.inputs.instructions - if instructions: - instructions = Template(instructions).safe_substitute(**prompt_params) - validate_span.set_attribute("instructions", instructions) - - validate_span.set_attribute("validation_status", guard.history.last.status) - validate_span.set_attribute("raw_llm_ouput", result.raw_llm_output) - - # Use the serialization from the class instead of re-writing it - valid_output: str = ( - json.dumps(validation_output.validated_output) - if isinstance(validation_output.validated_output, dict) - else str(validation_output.validated_output) - ) - validate_span.set_attribute("validated_output", valid_output) - - validate_span.set_attribute("tokens_consumed", guard.history.last.tokens_consumed) - - num_of_reasks = ( - guard.history.last.iterations.length - 1 - if guard.history.last.iterations.length > 0 - else 0 - ) - validate_span.set_attribute("num_of_reasks", num_of_reasks) - - -@guards_bp.route("//openai/v1/chat/completions", methods=["POST"]) -@handle_error -def openai_v1_chat_completions(guard_name: str): - # This endpoint implements the OpenAI Chat API - # It is mean to be fully compatible - # The only difference is that it uses the Guard API under the hood - # instead of the OpenAI API and supports guardrail API error handling - # To use this with the OpenAI SDK you can use the following code: - # import openai - # openai.base_url = "http://localhost:8000/guards//openai/v1/" - # response = openai.chat.completions( - # model="gpt-3.5-turbo-0125", - # messages=[ - # {"role": "user", "content": "Hello, how are you?"}, - # ], - # stream=True, - # ) - # print(response) - # to configure guard rails error handling from the server side you can use the following code: - # - - payload = request.json - decoded_guard_name = unquote_plus(guard_name) - guard_struct = guard_client.get_guard(decoded_guard_name) - guard = guard_struct - if guard_struct is None: - raise HttpError( - 404, - "NotFound", - "A Guard with the name {guard_name} does not exist!".format( - guard_name=decoded_guard_name - ), - ) - - if not isinstance(guard_struct, Guard): - guard: Guard = Guard.from_dict(guard_struct.to_dict()) - stream = payload.get("stream", False) - has_tool_gd_tool_call = False - - try: - tools = payload.get("tools", []) - tools.filter(lambda tool: tool["funcion"]["name"] == "gd_response_tool") - has_tool_gd_tool_call = len(tools) > 0 - except (KeyError, AttributeError): - pass - - if not stream: - validation_outcome: ValidationOutcome = guard( - # todo make this come from the guard struct? - # currently we dont support .configure - num_reasks=0, - **payload, - ) - llm_response = guard.history.last.iterations.last.outputs.llm_response_info - result = outcome_to_chat_completion( - validation_outcome=validation_outcome, - llm_response=llm_response, - has_tool_gd_tool_call=has_tool_gd_tool_call, - ) - return result - - else: - # need to return validated chunks that look identical to openai's - # should look something like - # data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":None,"finish_reason":None}]} - # .... - # data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":None,"finish_reason":"stop"}]} - def openai_streamer(): - guard_stream = guard( - num_reasks=0, - **payload, - ) - for result in guard_stream: - chunk_string = f"data: {json.dumps(outcome_to_stream_response(validation_outcome=result))}\n\n" - yield chunk_string.encode("utf-8") - # close the stream - yield b"\n" - - return Response( - stream_with_context(openai_streamer()), - ) - - -@guards_bp.route("//validate", methods=["POST"]) -@handle_error -def validate(guard_name: str): - # Do we actually need a child span here? - # We could probably use the existing span from the request unless we forsee - # capturing the same attributes on non-GaaS Guard runs. - if request.method != "POST": - raise HttpError( - 405, - "Method Not Allowed", - "/guards//validate only supports the POST method. You specified" - " {request_method}".format(request_method=request.method), - ) - payload = request.json - openai_api_key = request.headers.get( - "x-openai-api-key", os.environ.get("OPENAI_API_KEY") - ) - decoded_guard_name = unquote_plus(guard_name) - guard_struct = guard_client.get_guard(decoded_guard_name) - llm_output = payload.pop("llmOutput", None) - num_reasks = payload.pop("numReasks", None) - prompt_params = payload.pop("promptParams", {}) - llm_api = payload.pop("llmApi", None) - args = payload.pop("args", []) - stream = payload.pop("stream", False) - - # service_name = os.environ.get("OTEL_SERVICE_NAME", "guardrails-api") - # otel_tracer = get_tracer(service_name) - - payload["api_key"] = payload.get("api_key", openai_api_key) - - # with otel_tracer.start_as_current_span( - # f"validate-{decoded_guard_name}" - # ) as validate_span: - # guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer) - - - # validate_span.set_attribute("guardName", decoded_guard_name) - if llm_api is not None: - llm_api = get_llm_callable(llm_api) - if openai_api_key is None: - raise HttpError( - status=400, - message="BadRequest", - cause=( - "Cannot perform calls to OpenAI without an api key. Pass" - " openai_api_key when initializing the Guard or set the" - " OPENAI_API_KEY environment variable." - ), - ) - - guard = guard_struct - is_async = inspect.iscoroutinefunction(llm_api) - if not isinstance(guard_struct, Guard): - if is_async: - guard = AsyncGuard.from_dict(guard_struct.to_dict()) - else: - guard: Guard = Guard.from_dict(guard_struct.to_dict()) - elif is_async: - guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict()) - - if llm_api is None and num_reasks and num_reasks > 1: - raise HttpError( - status=400, - message="BadRequest", - cause=( - "Cannot perform re-asks without an LLM API. Specify llm_api when" - " calling guard(...)." - ), - ) - if llm_output is not None: - if stream: - raise HttpError( - status=400, - message="BadRequest", - cause="Streaming is not supported for parse calls!", - ) - result: ValidationOutcome = guard.parse( - llm_output=llm_output, - num_reasks=num_reasks, - prompt_params=prompt_params, - llm_api=llm_api, - **payload, - ) - else: - if stream: - def guard_streamer(): - guard_stream = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, - ) - for result in guard_stream: - # TODO: Just make this a ValidationOutcome with history - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) - ) - yield validation_output, cast(ValidationOutcome, result) - - async def async_guard_streamer(): - guard_stream = await guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - stream=stream, - *args, - **payload, - ) - async for result in guard_stream: - validation_output: ValidationOutcome = ( - ValidationOutcome.from_guard_history(guard.history.last) - ) - yield validation_output, cast(ValidationOutcome, result) - - def validate_streamer(guard_iter): - next_result = None - for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" - call = guard.history.last - final_validation_output: ValidationOutcome = ValidationOutcome( - callId=call.id, - validation_passed=next_result.validation_passed, - validated_output=next_result.validated_output, - history=guard.history, - raw_llm_output=next_result.raw_llm_output, - ) - # I don't know if these are actually making it to OpenSearch - # because the span may be ended already - # collect_telemetry( - # guard=guard, - # validate_span=validate_span, - # validation_output=next_validation_output, - # prompt_params=prompt_params, - # result=next_result - # ) - final_output_dict = final_validation_output.to_dict() - final_output_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - final_output_json = json.dumps(final_output_dict) - - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{final_validation_output.call_id}" - cache_client.set(cache_key, serialized_history, 300) - yield f"{final_output_json}\n" - - async def async_validate_streamer(guard_iter): - next_result = None - # next_validation_output = None - async for validation_output, result in guard_iter: - next_result = result - # next_validation_output = validation_output - fragment_dict = result.to_dict() - fragment_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - fragment = json.dumps(fragment_dict) - yield f"{fragment}\n" - - call = guard.history.last - final_validation_output: ValidationOutcome = ValidationOutcome( - callId=call.id, - validation_passed=next_result.validation_passed, - validated_output=next_result.validated_output, - history=guard.history, - raw_llm_output=next_result.raw_llm_output, - ) - # I don't know if these are actually making it to OpenSearch - # because the span may be ended already - # collect_telemetry( - # guard=guard, - # validate_span=validate_span, - # validation_output=next_validation_output, - # prompt_params=prompt_params, - # result=next_result - # ) - final_output_dict = final_validation_output.to_dict() - final_output_dict["error_spans"] = list( - map( - lambda x: json.dumps( - {"start": x.start, "end": x.end, "reason": x.reason} - ), - guard.error_spans_in_output(), - ) - ) - final_output_json = json.dumps(final_output_dict) - - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{final_validation_output.call_id}" - cache_client.set(cache_key, serialized_history, 300) - yield f"{final_output_json}\n" - # apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async - def iter_over_async(ait, loop): - ait = ait.__aiter__() - async def get_next(): - try: - obj = await ait.__anext__() - return False, obj - except StopAsyncIteration: - return True, None - while True: - done, obj = loop.run_until_complete(get_next()) - if done: - break - yield obj - if is_async: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop) - else: - iter = validate_streamer(guard_streamer()) - return Response( - stream_with_context(iter), - content_type="application/json", - # content_type="text/event-stream" - ) - - result: ValidationOutcome = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - # api_key=openai_api_key, - *args, - **payload, - ) - - # TODO: Just make this a ValidationOutcome with history - # validation_output = ValidationOutcome( - # validation_passed = result.validation_passed, - # validated_output=result.validated_output, - # history=guard.history, - # raw_llm_output=result.raw_llm_output, - # ) - - # collect_telemetry( - # guard=guard, - # validate_span=validate_span, - # validation_output=validation_output, - # prompt_params=prompt_params, - # result=result - # ) - serialized_history = [call.to_dict() for call in guard.history] - cache_key = f"{guard.name}-{result.call_id}" - cache_client.set(cache_key, serialized_history, 300) - return result.to_dict() - - -@guards_bp.route("//history/", methods=["GET"]) -@handle_error -def guard_history(guard_name: str, call_id: str): - if request.method == "GET": - cache_key = f"{guard_name}-{call_id}" - return cache_client.get(cache_key) diff --git a/guardrails_api/blueprints/root.py b/guardrails_api/blueprints/root.py deleted file mode 100644 index ed388be..0000000 --- a/guardrails_api/blueprints/root.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -import json -import flask -from string import Template -from flask import Blueprint -from guardrails_api.open_api_spec import get_open_api_spec -from sqlalchemy import text -from guardrails_api.classes.health_check import HealthCheck -from guardrails_api.clients.postgres_client import PostgresClient, postgres_is_enabled -from guardrails_api.utils.handle_error import handle_error -from guardrails_api.utils.logger import logger - - -root_bp = Blueprint("root", __name__, url_prefix="/") - - -@root_bp.route("/") -@handle_error -def home(): - return "Hello, Flask!" - - -@root_bp.route("/health-check") -@handle_error -def health_check(): - # If we're not using postgres, just return Ok - if not postgres_is_enabled(): - return HealthCheck(200, "Ok").to_dict() - # Make sure we're connected to the database and can run queries - pg_client = PostgresClient() - query = text("SELECT count(datid) FROM pg_stat_activity;") - response = pg_client.db.session.execute(query).all() - # # This works with otel logging - # logger.info(f"response: {response}") - # As does this - logger.info("response: %s", response) - # # This throws an error - # print("response: ", response) - return HealthCheck(200, "Ok").to_dict() - - -@root_bp.route("/api-docs") -@handle_error -def api_docs(): - api_spec = get_open_api_spec() - return json.dumps(api_spec) - - -@root_bp.route("/docs") -@handle_error -def docs(): - host = os.environ.get("SELF_ENDPOINT", "http://localhost:8000") - swagger_ui = Template(""" - - - - - - SwaggerUI - - - -
- - - -""").safe_substitute(apiDocUrl=f"{host}/api-docs") # noqa - - return flask.render_template_string(swagger_ui) diff --git a/guardrails_api/cli/start.py b/guardrails_api/cli/start.py index eb8f027..95d93d2 100644 --- a/guardrails_api/cli/start.py +++ b/guardrails_api/cli/start.py @@ -3,6 +3,8 @@ from guardrails_api.cli.cli import cli from guardrails_api.app import create_app from guardrails_api.utils.configuration import valid_configuration +import uvicorn + @cli.command("start") def start( @@ -24,4 +26,5 @@ def start( env = env or None config = config or None valid_configuration(config) - create_app(env, config, port).run(port=port) + app = create_app(env, config, port) + uvicorn.run(app, port=port) diff --git a/guardrails_api/clients/cache_client.py b/guardrails_api/clients/cache_client.py index 550bc4b..4ddace1 100644 --- a/guardrails_api/clients/cache_client.py +++ b/guardrails_api/clients/cache_client.py @@ -1,8 +1,7 @@ import threading -from flask_caching import Cache +from aiocache import caches -# TODO: Add option to connect to Redis or MemCached backend with environment variables class CacheClient: _instance = None _lock = threading.Lock() @@ -10,27 +9,30 @@ class CacheClient: def __new__(cls): if cls._instance is None: with cls._lock: - cls._instance = super(CacheClient, cls).__new__(cls) + if cls._instance is None: # Double-checked locking + cls._instance = super().__new__(cls) return cls._instance - def initialize(self, app): - self.cache = Cache( - app, - config={ - "CACHE_TYPE": "SimpleCache", - "CACHE_DEFAULT_TIMEOUT": 300, - "CACHE_THRESHOLD": 50, - }, + def initialize(self): + caches.set_config( + { + "default": { + "cache": "aiocache.SimpleMemoryCache", + "serializer": {"class": "aiocache.serializers.JsonSerializer"}, + "ttl": 300, + } + } ) + self.cache = caches.get("default") - def get(self, key): - return self.cache.get(key) + async def get(self, key: str): + return await self.cache.get(key) - def set(self, key, value, ttl): - self.cache.set(key, value, timeout=ttl) + async def set(self, key: str, value: str, ttl: int): + await self.cache.set(key, value, ttl=ttl) - def delete(self, key): - self.cache.delete(key) + async def delete(self, key: str): + await self.cache.delete(key) - def clear(self): - self.cache.clear() + async def clear(self): + await self.cache.clear() diff --git a/guardrails_api/clients/pg_guard_client.py b/guardrails_api/clients/pg_guard_client.py index c7a1f48..226232a 100644 --- a/guardrails_api/clients/pg_guard_client.py +++ b/guardrails_api/clients/pg_guard_client.py @@ -18,14 +18,20 @@ def __init__(self): self.initialized = True self.pgClient = PostgresClient() + def get_db(self): # generator for local sessions + db = self.pgClient.SessionLocal() + try: + yield db + finally: + db.close() + def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: - latest_guard_item = ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) + db = next(self.get_db()) + latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first() audit_item = None if as_of_date is not None: audit_item = ( - self.pgClient.db.session.query(GuardItemAudit) + db.query(GuardItemAudit) .filter_by(name=guard_name) .filter(GuardItemAudit.replaced_on > as_of_date) .order_by(GuardItemAudit.replaced_on.asc()) @@ -43,27 +49,29 @@ def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: return from_guard_item(guard_item) def get_guard_item(self, guard_name: str) -> GuardItem: - return ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) + db = next(self.get_db()) + return db.query(GuardItem).filter_by(name=guard_name).first() def get_guards(self) -> List[GuardStruct]: - guard_items = self.pgClient.db.session.query(GuardItem).all() + db = next(self.get_db()) + guard_items = db.query(GuardItem).all() return [from_guard_item(gi) for gi in guard_items] def create_guard(self, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = GuardItem( name=guard.name, railspec=guard.to_dict(), num_reasks=None, description=guard.description, ) - self.pgClient.db.session.add(guard_item) - self.pgClient.db.session.commit() + db.add(guard_item) + db.commit() return from_guard_item(guard_item) def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is None: raise HttpError( @@ -76,21 +84,23 @@ def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: # guard_item.num_reasks = guard.num_reasks guard_item.railspec = guard.to_dict() guard_item.description = guard.description - self.pgClient.db.session.commit() + db.commit() return from_guard_item(guard_item) def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is not None: guard_item.railspec = guard.to_dict() guard_item.description = guard.description # guard_item.num_reasks = guard.num_reasks - self.pgClient.db.session.commit() + db.commit() return from_guard_item(guard_item) else: return self.create_guard(guard) def delete_guard(self, guard_name: str) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is None: raise HttpError( @@ -100,7 +110,7 @@ def delete_guard(self, guard_name: str) -> GuardStruct: guard_name=guard_name ), ) - self.pgClient.db.session.delete(guard_item) - self.pgClient.db.session.commit() + db.delete(guard_item) + db.commit() guard = from_guard_item(guard_item) return guard diff --git a/guardrails_api/clients/postgres_client.py b/guardrails_api/clients/postgres_client.py index 951a4f4..56e1501 100644 --- a/guardrails_api/clients/postgres_client.py +++ b/guardrails_api/clients/postgres_client.py @@ -2,16 +2,24 @@ import json import os import threading -from flask import Flask -from sqlalchemy import text +from fastapi import FastAPI from typing import Tuple -from guardrails_api.models.base import db, INIT_EXTENSIONS +from sqlalchemy import create_engine, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() def postgres_is_enabled() -> bool: return os.environ.get("PGHOST", None) is not None +# Global variables for database session +postgres_client = None +SessionLocal = None + + class PostgresClient: _instance = None _lock = threading.Lock() @@ -45,7 +53,17 @@ def get_pg_creds(self) -> Tuple[str, str]: pg_password = pg_password or os.environ.get("PGPASSWORD") return pg_user, pg_password - def initialize(self, app: Flask): + def get_db(self): + if postgres_is_enabled(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + else: + yield None + + def initialize(self, app: FastAPI): pg_user, pg_password = self.get_pg_creds() pg_host = os.environ.get("PGHOST", "localhost") pg_port = os.environ.get("PGPORT", "5432") @@ -64,23 +82,64 @@ def initialize(self, app: Flask): if os.environ.get("NODE_ENV") == "production": conf = f"{conf}?sslmode=verify-ca&sslrootcert=global-bundle.pem" - app.config["SQLALCHEMY_DATABASE_URI"] = conf + engine = create_engine(conf) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - app.secret_key = "secret" self.app = app - self.db = db - db.init_app(app) - from guardrails_api.models.guard_item import GuardItem # NOQA - from guardrails_api.models.guard_item_audit import ( # NOQA - GuardItemAudit, - AUDIT_FUNCTION, - AUDIT_TRIGGER, - ) + self.engine = engine + self.SessionLocal = SessionLocal + # Create tables + from guardrails_api.models import GuardItem, GuardItemAudit # noqa + + Base.metadata.create_all(bind=engine) + + # Execute custom SQL + with engine.connect() as connection: + connection.execute(text(INIT_EXTENSIONS)) + connection.execute(text(AUDIT_FUNCTION)) + connection.execute(text(AUDIT_TRIGGER)) + connection.commit() + + +# Define INIT_EXTENSIONS, AUDIT_FUNCTION, and AUDIT_TRIGGER here as they were in your original code +INIT_EXTENSIONS = """ +-- Your SQL for initializing extensions +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp') THEN + CREATE EXTENSION "uuid-ossp"; + END IF; +END $$; + +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN + CREATE EXTENSION "vector"; + END IF; +END $$; +""" + +AUDIT_FUNCTION = """ +CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$ +BEGIN + IF (TG_OP = 'DELETE') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D'; + ELSIF (TG_OP = 'UPDATE') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U'; + ELSIF (TG_OP = 'INSERT') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I'; + END IF; + RETURN null; +END; +$guard_audit$ +LANGUAGE plpgsql; +""" - with self.app.app_context(): - self.db.session.execute(text(INIT_EXTENSIONS)) - self.db.create_all() - self.db.session.execute(text(AUDIT_FUNCTION)) - self.db.session.execute(text(AUDIT_TRIGGER)) - self.db.session.commit() +AUDIT_TRIGGER = """ +DROP TRIGGER IF EXISTS guard_audit_trigger + ON guards; +CREATE TRIGGER guard_audit_trigger + AFTER INSERT OR UPDATE OR DELETE ON guards + FOR EACH ROW + EXECUTE PROCEDURE guard_audit_function(); +""" diff --git a/guardrails_api/models/__init__.py b/guardrails_api/models/__init__.py index e69de29..391a299 100644 --- a/guardrails_api/models/__init__.py +++ b/guardrails_api/models/__init__.py @@ -0,0 +1,5 @@ +# __init__.py +from .guard_item_audit import GuardItemAudit +from .guard_item import GuardItem + +__all__ = ["GuardItemAudit", "GuardItem"] diff --git a/guardrails_api/models/base.py b/guardrails_api/models/base.py deleted file mode 100644 index 29f0169..0000000 --- a/guardrails_api/models/base.py +++ /dev/null @@ -1,8 +0,0 @@ -from flask_sqlalchemy import SQLAlchemy - -db = SQLAlchemy() - -INIT_EXTENSIONS = """ -CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; -CREATE EXTENSION IF NOT EXISTS "vector"; -""" diff --git a/guardrails_api/models/guard_item.py b/guardrails_api/models/guard_item.py index a2dbf32..6bcacfa 100644 --- a/guardrails_api/models/guard_item.py +++ b/guardrails_api/models/guard_item.py @@ -1,9 +1,9 @@ from sqlalchemy import Column, String, Integer from sqlalchemy.dialects.postgresql import JSONB -from guardrails_api.models.base import db +from guardrails_api.clients.postgres_client import Base -class GuardItem(db.Model): +class GuardItem(Base): __tablename__ = "guards" # TODO: Make primary key a composite between guard.name and the guard owner's userId name = Column(String, primary_key=True) diff --git a/guardrails_api/models/guard_item_audit.py b/guardrails_api/models/guard_item_audit.py index 183626e..13ee3c2 100644 --- a/guardrails_api/models/guard_item_audit.py +++ b/guardrails_api/models/guard_item_audit.py @@ -1,9 +1,9 @@ from sqlalchemy import Column, String, Integer from sqlalchemy.dialects.postgresql import JSONB, TIMESTAMP, CHAR -from guardrails_api.models.base import db +from guardrails_api.clients.postgres_client import Base -class GuardItemAudit(db.Model): +class GuardItemAudit(Base): __tablename__ = "guards_audit" id = Column(String, primary_key=True) name = Column(String, nullable=False, index=True) @@ -35,29 +35,3 @@ def __init__( self.replaced_on = replaced_on self.operation = operation # self.owner = owner - - -AUDIT_FUNCTION = """ -CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$ -BEGIN - IF (TG_OP = 'DELETE') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D'; - ELSIF (TG_OP = 'UPDATE') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U'; - ELSIF (TG_OP = 'INSERT') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I'; - END IF; - RETURN null; -END; -$guard_audit$ -LANGUAGE plpgsql; -""" - -AUDIT_TRIGGER = """ -DROP TRIGGER IF EXISTS guard_audit_trigger - ON guards; -CREATE TRIGGER guard_audit_trigger - AFTER INSERT OR UPDATE OR DELETE ON guards - FOR EACH ROW - EXECUTE PROCEDURE guard_audit_function(); -""" diff --git a/guardrails_api/start-dev.sh b/guardrails_api/start-dev.sh index a27f2d3..83a2d70 100755 --- a/guardrails_api/start-dev.sh +++ b/guardrails_api/start-dev.sh @@ -1 +1,12 @@ -gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "guardrails_api.app:create_app()" --reload --capture-output --enable-stdio-inheritance +gunicorn --bind 0.0.0.0:8000 \ + --timeout 120 \ + --workers 2 \ + --threads 2 \ + --worker-class=uvicorn.workers.UvicornWorker \ + "guardrails_api.app:create_app()" \ + --reload \ + --capture-output \ + --enable-stdio-inheritance \ + --access-logfile - \ + --error-logfile - \ + --access-logformat '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" pid=%(p)s' \ No newline at end of file diff --git a/guardrails_api/start.sh b/guardrails_api/start.sh index 2696b88..4a4a353 100755 --- a/guardrails_api/start.sh +++ b/guardrails_api/start.sh @@ -1 +1 @@ -gunicorn --bind 0.0.0.0:8000 --timeout=5 --threads=10 "guardrails_api.app:create_app()" +gunicorn --bind 0.0.0.0:8000 --timeout=5 --workers=3 --worker-class=uvicorn.workers.UvicornWorker "guardrails_api.app:create_app()" \ No newline at end of file diff --git a/guardrails_api/utils/configuration.py b/guardrails_api/utils/configuration.py index 1fdb965..f793a68 100644 --- a/guardrails_api/utils/configuration.py +++ b/guardrails_api/utils/configuration.py @@ -2,21 +2,31 @@ from typing import Optional import os -def valid_configuration(config: Optional[str]=""): + +def valid_configuration(config: Optional[str] = ""): default_config_file = os.path.join(os.getcwd(), "./config.py") default_config_file_path = os.path.abspath(default_config_file) - # If config.py is not present and + # If config.py is not present and # if a config filepath is not passed and - # if postgres is not there (i.e. we’re using in-mem db) + # if postgres is not there (i.e. we’re using in-mem db) # then raise ConfigurationError has_default_config_file = os.path.isfile(default_config_file_path) - has_config_file = (config != "" and config is not None) and os.path.isfile(os.path.abspath(config)) - if not has_default_config_file and not has_config_file and not postgres_is_enabled(): - raise ConfigurationError("Can not start. Configuration not provided and default" - " configuration not found and postgres is not enabled.") + has_config_file = (config != "" and config is not None) and os.path.isfile( + os.path.abspath(config) + ) + if ( + not has_default_config_file + and not has_config_file + and not postgres_is_enabled() + ): + raise ConfigurationError( + "Can not start. Configuration not provided and default" + " configuration not found and postgres is not enabled." + ) return True + class ConfigurationError(Exception): - pass \ No newline at end of file + pass diff --git a/guardrails_api/utils/handle_error.py b/guardrails_api/utils/handle_error.py index 4fcf231..1458d9d 100644 --- a/guardrails_api/utils/handle_error.py +++ b/guardrails_api/utils/handle_error.py @@ -1,32 +1,48 @@ from functools import wraps import traceback -from werkzeug.exceptions import HTTPException from guardrails_api.classes.http_error import HttpError from guardrails_api.utils.logger import logger from guardrails.errors import ValidationError -def handle_error(fn): - @wraps(fn) - def decorator(*args, **kwargs): - try: - return fn(*args, **kwargs) - except ValidationError as validation_error: - logger.error(validation_error) - traceback.print_exception(type(validation_error), validation_error, validation_error.__traceback__) - return str(validation_error), 400 - except HttpError as http_error: - logger.error(http_error) - traceback.print_exception(type(http_error), http_error, http_error.__traceback__) - return http_error.to_dict(), http_error.status - except HTTPException as http_exception: - logger.error(http_exception) - traceback.print_exception(http_exception) - http_error = HttpError(http_exception.code, http_exception.description) - return http_error.to_dict(), http_error.status - except Exception as e: - logger.error(e) - traceback.print_exception(e) - return HttpError(500, "Internal Server Error").to_dict(), 500 +from fastapi import HTTPException + +def handle_error(func=None): + def decorator(fn): + @wraps(fn) + async def wrapper(*args, **kwargs): + try: + return await fn(*args, **kwargs) + except ValidationError as validation_error: + logger.error(validation_error) + traceback.print_exception( + type(validation_error), + validation_error, + validation_error.__traceback__, + ) + raise HTTPException(status_code=400, detail=str(validation_error)) + except HttpError as http_error: + logger.error(http_error) + traceback.print_exception( + type(http_error), http_error, http_error.__traceback__ + ) + raise HTTPException( + status_code=http_error.status_code, detail=http_error.detail + ) + except HTTPException as http_exception: + logger.error(http_exception) + traceback.print_exception( + type(http_exception), http_exception, http_exception.__traceback__ + ) + raise + except Exception as e: + logger.error(e) + traceback.print_exception(type(e), e, e.__traceback__) + raise HTTPException(status_code=500, detail="Internal Server Error") + + return wrapper + + if func: + return decorator(func) return decorator diff --git a/guardrails_api/utils/has_internet_connection.py b/guardrails_api/utils/has_internet_connection.py index 8a7099c..1a92721 100644 --- a/guardrails_api/utils/has_internet_connection.py +++ b/guardrails_api/utils/has_internet_connection.py @@ -7,4 +7,4 @@ def has_internet_connection() -> bool: res.raise_for_status() return True except requests.ConnectionError: - return False \ No newline at end of file + return False diff --git a/guardrails_api/utils/openai.py b/guardrails_api/utils/openai.py index 10ecfa2..79cc7b5 100644 --- a/guardrails_api/utils/openai.py +++ b/guardrails_api/utils/openai.py @@ -1,5 +1,6 @@ from guardrails.classes import ValidationOutcome + def outcome_to_stream_response(validation_outcome: ValidationOutcome): stream_chunk_template = { "choices": [ diff --git a/guardrails_api/utils/trace_server_start_if_enabled.py b/guardrails_api/utils/trace_server_start_if_enabled.py index 467abd6..91fbbcf 100644 --- a/guardrails_api/utils/trace_server_start_if_enabled.py +++ b/guardrails_api/utils/trace_server_start_if_enabled.py @@ -8,6 +8,7 @@ def trace_server_start_if_enabled(): config = Credentials.from_rc_file() if config.enable_metrics is True and has_internet_connection(): from guardrails.utils.hub_telemetry_utils import HubTelemetry + HubTelemetry().create_new_span( "guardrails-api/start", [ @@ -21,4 +22,4 @@ def trace_server_start_if_enabled(): ], True, False, - ) \ No newline at end of file + ) diff --git a/pyproject.toml b/pyproject.toml index ef12ba0..2520656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,14 +10,9 @@ readme = "README.md" keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"] requires-python = ">= 3.8.1" dependencies = [ - "guardrails-ai>=0.5.6", - "flask>=3.0.3,<4", - "Flask-SQLAlchemy>=3.1.1,<4", - "Flask-Caching>=2.3.0,<3", - "Werkzeug>=3.0.3,<4", + "guardrails-ai>=0.5.10", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", - "Flask-Cors>=4.0.1,<6", "boto3>=1.34.115,<2", "psycopg2-binary>=2.9.9,<3", "litellm>=1.39.3,<2", @@ -26,8 +21,12 @@ dependencies = [ "opentelemetry-sdk>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2", "opentelemetry-exporter-otlp-proto-http>=1.0.0,<2", - "opentelemetry-instrumentation-flask>=0.12b0,<1", - "requests>=2.32.3" + "opentelemetry-instrumentation-fastapi>=0.48b0", + "requests>=2.32.3", + "aiocache>=0.11.1", + "fastapi>=0.114.1", + "SQLAlchemy>=2.0.34", + "uvicorn>=0.30.6", ] [tool.setuptools.dynamic] diff --git a/requirements-lock.txt b/requirements-lock.txt index 0c18e6e..1462843 100644 --- a/requirements-lock.txt +++ b/requirements-lock.txt @@ -19,17 +19,13 @@ diff-match-patch==20230430 distro==1.9.0 Faker==25.9.2 filelock==3.15.4 -Flask==3.0.3 -Flask-Caching==2.3.0 -Flask-Cors==5.0.0 -Flask-SQLAlchemy==3.1.1 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.6.1 googleapis-common-protos==1.63.2 griffe==0.36.9 grpcio==1.65.1 -guardrails-ai==0.5.7 +guardrails-ai==0.5.9 guardrails-api-client==0.3.12 guardrails_hub_types==0.0.4 gunicorn==22.0.0 diff --git a/setup.py b/setup.py index 78528e2..806bde1 100644 --- a/setup.py +++ b/setup.py @@ -16,21 +16,23 @@ packages=find_packages(), python_requires=">=3.8, <4", install_requires=[ - "guardrails-ai>=0.4.5", - "flask>=3.0.3,<4", - "Flask-SQLAlchemy>=3.1.1,<4", - "Werkzeug>=3.0.3,<4", + "guardrails-ai>=0.5.10", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", - "Flask-Cors>=4.0.1,<6", "boto3>=1.34.115,<2", "psycopg2-binary>=2.9.9,<3", "litellm>=1.39.3,<2", "typer>=0.9.4,<1", - "opentelemetry-api>1,<2", - "opentelemetry-exporter-otlp-proto-grpc>1,<2", - "opentelemetry-exporter-otlp-proto-http>1,<2", - "opentelemetry-instrumentation-flask>=0.12b0,<1" + "opentelemetry-api>=1.0.0,<2", + "opentelemetry-sdk>=1.0.0,<2", + "opentelemetry-exporter-otlp-proto-grpc>=1.0.0,<2", + "opentelemetry-exporter-otlp-proto-http>=1.0.0,<2", + "opentelemetry-instrumentation-fastapi>=0.48b0", + "requests>=2.32.3", + "aiocache>=0.11.1", + "fastapi>=0.114.1", + "SQLAlchemy>=2.0.34", + "uvicorn>=0.30.6", ], package_data={"guardrails_api": ["py.typed", "open-api-spec.json"]}, ) diff --git a/tests/blueprints/__init__.py b/tests/api/__init__.py similarity index 100% rename from tests/blueprints/__init__.py rename to tests/api/__init__.py diff --git a/tests/api/test_guards.py b/tests/api/test_guards.py new file mode 100644 index 0000000..453a976 --- /dev/null +++ b/tests/api/test_guards.py @@ -0,0 +1,396 @@ +import os +from unittest.mock import PropertyMock + +import pytest +from fastapi.testclient import TestClient +from fastapi import FastAPI + +from guardrails.classes import ValidationOutcome +from guardrails.classes.generic import Stack +from guardrails.classes.history import Call, Iteration +from guardrails.errors import ValidationError + +from guardrails_api.app import register_config +from tests.mocks.mock_guard_client import MockGuardStruct +from guardrails_api.api.guards import router as guards_router + +# TODO: Should we mock this somehow? +# Right now it's just empty, but it technically does a file read +register_config() + +app = FastAPI() + +app.include_router(guards_router) +client = TestClient(app) + +MOCK_GUARD_STRING = { + "id": "mock-guard-id", + "name": "mock-guard", + "description": "mock guard description", + "history": Stack(), +} + + +@pytest.fixture(autouse=True) +def around_each(): + # Code that will run before the test + openai_api_key_bak = os.environ.get("OPENAI_API_KEY") + if openai_api_key_bak: + del os.environ["OPENAI_API_KEY"] + yield + # Code that will run after the test + if openai_api_key_bak: + os.environ["OPENAI_API_KEY"] = openai_api_key_bak + + +def test_guards__get(mocker): + mock_guard = MockGuardStruct() + mock_get_guards = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guards", + return_value=[mock_guard], + ) + mocker.patch("guardrails_api.api.guards.collect_telemetry") + + response = client.get("/guards") + + assert mock_get_guards.call_count == 1 + assert response.status_code == 200 + assert response.json() == [MOCK_GUARD_STRING] + + +def test_guards__post_pg(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + mocker.patch( + "guardrails_api.api.guards.GuardStruct.from_dict", + return_value=mock_guard, + ) + mocker.patch( + "guardrails_api.api.guards.guard_client.create_guard", + return_value=mock_guard, + ) + + response = client.post("/guards", json=mock_guard.to_dict()) + + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + + del os.environ["PGHOST"] + + +def test_guards__post_mem(mocker): + old = None + if "PGHOST" in os.environ: + old = os.environ.get("PGHOST") + del os.environ["PGHOST"] + mock_guard = MockGuardStruct() + + response = client.post("/guards", json=mock_guard.to_dict()) + + assert response.status_code == 501 + assert "Not Implemented" in response.json()["detail"] + if old: + os.environ["PGHOST"] = old + + +def test_guard__get_mem(mocker): + mock_guard = MockGuardStruct() + timestamp = "2024-03-04T14:11:42-06:00" + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + response = client.get(f"/guards/My%20Guard's%20Name?asOf={timestamp}") + + mock_get_guard.assert_called_once_with("My Guard's Name", timestamp) + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + + +def test_guard__put_pg(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + json_guard = { + "name": "mock-guard", + "id": "mock-guard-id", + "description": "mock guard description", + "history": Stack(), + } + mocker.patch( + "guardrails_api.api.guards.GuardStruct.from_dict", + return_value=mock_guard, + ) + mocker.patch( + "guardrails_api.api.guards.guard_client.upsert_guard", + return_value=mock_guard, + ) + + response = client.put("/guards/My%20Guard's%20Name", json=json_guard) + + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + del os.environ["PGHOST"] + + +def test_guard__delete_pg(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + mock_delete_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.delete_guard", + return_value=mock_guard, + ) + + response = client.delete("/guards/my-guard-name") + + mock_delete_guard.assert_called_once_with("my-guard-name") + assert response.status_code == 200 + assert response.json() == MOCK_GUARD_STRING + del os.environ["PGHOST"] + + +def test_validate__parse(mocker): + os.environ["PGHOST"] = "localhost" + mock_outcome = ValidationOutcome( + call_id="mock-call-id", + raw_llm_output="Hello world!", + validated_output="Hello world!", + validation_passed=True, + ) + + mock_parse = mocker.patch.object(MockGuardStruct, "parse") + mock_parse.return_value = mock_outcome + + mock_guard = MockGuardStruct() + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "pass" + mock_guard.history = Stack(Call()) + + response = client.post( + "/guards/My%20Guard's%20Name/validate", + json={"llmOutput": "Hello world!", "args": [1, 2, 3], "some_kwarg": "foo"}, + ) + + mock_get_guard.assert_called_once_with("My Guard's Name") + assert mock_parse.call_count == 1 + mock_parse.assert_called_once_with( + llm_output="Hello world!", + num_reasks=None, + prompt_params={}, + llm_api=None, + some_kwarg="foo", + api_key=None, + ) + + assert response.status_code == 200 + assert response.json() == { + "callId": "mock-call-id", + "validatedOutput": "Hello world!", + "validationPassed": True, + "rawLlmOutput": "Hello world!", + } + + del os.environ["PGHOST"] + + +def test_validate__call(mocker): + os.environ["PGHOST"] = "localhost" + mock_outcome = ValidationOutcome( + call_id="mock-call-id", + raw_llm_output="Hello world!", + validated_output=None, + validation_passed=False, + ) + + mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") + mock___call__.return_value = mock_outcome + + mock_guard = MockGuardStruct() + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "fail" + mock_guard.history = Stack(Call()) + + response = client.post( + "/guards/My%20Guard's%20Name/validate", + json={ + "promptParams": {"p1": "bar"}, + "args": [1, 2, 3], + "some_kwarg": "foo", + "prompt": "Hello world!", + }, + headers={"x-openai-api-key": "mock-key"}, + ) + + mock_get_guard.assert_called_once_with("My Guard's Name") + assert mock___call__.call_count == 1 + mock___call__.assert_called_once_with( + 1, + 2, + 3, + llm_api=None, + prompt_params={"p1": "bar"}, + num_reasks=None, + some_kwarg="foo", + api_key="mock-key", + prompt="Hello world!", + ) + + assert response.status_code == 200 + assert response.json() == { + "callId": "mock-call-id", + "validationPassed": False, + "validatedOutput": None, + "rawLlmOutput": "Hello world!", + } + + del os.environ["PGHOST"] + + +def test_validate__call_throws_validation_error(mocker): + os.environ["PGHOST"] = "localhost" + error = ValidationError("Test guard validation error") + mock_parse = mocker.patch.object(MockGuardStruct, "__call__") + mock_parse.side_effect = error + + mock_guard = MockGuardStruct() + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "fail" + mock_guard.history = Stack(Call()) + + response = client.post( + "/guards/My%20Guard's%20Name/validate", + json={ + "promptParams": {"p1": "bar"}, + "args": [1, 2, 3], + "some_kwarg": "foo", + "prompt": "Hello world!", + }, + ) + + mock_get_guard.assert_called_once_with("My Guard's Name") + + assert response.status_code == 400 + assert response.json() == {"detail": "Test guard validation error"} + + del os.environ["PGHOST"] + + +def test_openai_v1_chat_completions__raises_404(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = None + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + response = client.post( + "/guards/My%20Guard's%20Name/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello world!"}], + }, + headers={"x-openai-api-key": "mock-key"}, + ) + + assert response.status_code == 404 + assert ( + response.json()["detail"] + == "A Guard with the name My Guard's Name does not exist!" + ) + + mock_get_guard.assert_called_once_with("My Guard's Name") + + del os.environ["PGHOST"] + + +def test_openai_v1_chat_completions__call(mocker): + os.environ["PGHOST"] = "localhost" + mock_guard = MockGuardStruct() + mock_outcome = ValidationOutcome( + call_id="mock-call-id", + raw_llm_output="Hello world!", + validated_output="Hello world!", + validation_passed=False, + ) + + mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") + mock___call__.return_value = mock_outcome + + mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") + mock_from_dict.return_value = mock_guard + + mock_get_guard = mocker.patch( + "guardrails_api.api.guards.guard_client.get_guard", + return_value=mock_guard, + ) + + mock_status = mocker.patch( + "guardrails.classes.history.call.Call.status", new_callable=PropertyMock + ) + mock_status.return_value = "fail" + mock_call = Call() + mock_call.iterations = Stack(Iteration("some-id", 1)) + mock_guard.history = Stack(mock_call) + + response = client.post( + "/guards/My%20Guard's%20Name/openai/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello world!"}], + }, + headers={"x-openai-api-key": "mock-key"}, + ) + + mock_get_guard.assert_called_once_with("My Guard's Name") + assert mock___call__.call_count == 1 + mock___call__.assert_called_once_with( + num_reasks=0, + messages=[{"role": "user", "content": "Hello world!"}], + ) + + assert response.status_code == 200 + assert response.json() == { + "choices": [ + { + "message": { + "content": "Hello world!", + }, + } + ], + "guardrails": { + "reask": None, + "validation_passed": False, + "error": None, + }, + } + + del os.environ["PGHOST"] diff --git a/tests/api/test_root.py b/tests/api/test_root.py new file mode 100644 index 0000000..753dee0 --- /dev/null +++ b/tests/api/test_root.py @@ -0,0 +1,63 @@ +import os +from fastapi.testclient import TestClient +from fastapi import FastAPI +import pytest + +from guardrails_api.utils.logger import logger +from tests.mocks.mock_postgres_client import MockPostgresClient + +# Assuming you have a similar structure in your FastAPI app +from guardrails_api.api import root + + +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(root.router) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +def test_home(client): + response = client.get("/") + assert response.status_code == 200 + assert response.json() == "Hello, world!" + + # Check if all expected routes are registered + routes = [route.path for route in client.app.routes] + assert "/" in routes + assert "/health-check" in routes + assert "/openapi.json" in routes # This is FastAPI's equivalent to /api-docs + assert "/docs" in routes + + +def test_health_check(client, mocker): + os.environ["PGHOST"] = "localhost" + + mock_pg = MockPostgresClient() + mock_pg.db.session._set_rows([(1,)]) + mocker.patch("guardrails_api.api.root.PostgresClient", return_value=mock_pg) + + def text_side_effect(query: str): + return query + + mock_text = mocker.patch( + "guardrails_api.api.root.text", side_effect=text_side_effect + ) + + info_spy = mocker.spy(logger, "info") + + response = client.get("/health-check") + + mock_text.assert_called_once_with("SELECT count(datid) FROM pg_stat_activity;") + assert mock_pg.db.session.queries == ["SELECT count(datid) FROM pg_stat_activity;"] + + info_spy.assert_called_once_with("response: %s", [(1,)]) + + assert response.json() == {"status": 200, "message": "Ok"} + + del os.environ["PGHOST"] diff --git a/tests/blueprints/test_guards.py b/tests/blueprints/test_guards.py deleted file mode 100644 index 3cce59d..0000000 --- a/tests/blueprints/test_guards.py +++ /dev/null @@ -1,719 +0,0 @@ -import os -from unittest.mock import PropertyMock -from typing import Dict, Tuple - -import pytest - -from tests.mocks.mock_blueprint import MockBlueprint -from tests.mocks.mock_guard_client import MockGuardStruct -from tests.mocks.mock_request import MockRequest -from guardrails.classes import ValidationOutcome -from guardrails.classes.generic import Stack -from guardrails.classes.history import Call, Iteration -from guardrails_api.app import register_config -from guardrails.errors import ValidationError - -# TODO: Should we mock this somehow? -# Right now it's just empty, but it technically does a file read -register_config() - - -MOCK_GUARD_STRING = { - "id": "mock-guard-id", - "name": "mock-guard", - "description": "mock guard description", - "history": Stack(), -} - - -# FIXME: Why doesn't this work when running a single test? -# Either a config issue or a pytest issue -@pytest.fixture(autouse=True) -def around_each(): - # Code that will run before the test - openai_api_key_bak = os.environ.get("OPENAI_API_KEY") - if openai_api_key_bak: - del os.environ["OPENAI_API_KEY"] - yield - # Code that will run after the test - if openai_api_key_bak: - os.environ["OPENAI_API_KEY"] = openai_api_key_bak - - -def test_route_setup(mocker): - mocker.patch("flask.Blueprint", new=MockBlueprint) - - from guardrails_api.blueprints.guards import guards_bp - - assert guards_bp.route_call_count == 5 - assert guards_bp.routes == [ - "/", - "/", - "//openai/v1/chat/completions", - "//validate", - "//history/", - ] - - -def test_guards__get(mocker): - mock_guard = MockGuardStruct() - mock_request = MockRequest("GET") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guards = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guards", - return_value=[mock_guard], - ) - mocker.patch("guardrails_api.blueprints.guards.collect_telemetry") - - # >>> Conflict - # mock_get_guards = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.get_guards", return_value=[mock_guard] - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guards - - response = guards() - - assert mock_get_guards.call_count == 1 - - assert response == [MOCK_GUARD_STRING] - - -def test_guards__post_pg(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_request = MockRequest("POST", mock_guard.to_dict()) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_from_request = mocker.patch( - "guardrails_api.blueprints.guards.GuardStruct.from_dict", - return_value=mock_guard, - ) - mock_create_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.create_guard", - return_value=mock_guard, - ) - - from guardrails_api.blueprints.guards import guards - - response = guards() - - mock_from_request.assert_called_once_with(mock_guard.to_dict()) - mock_create_guard.assert_called_once_with(mock_guard) - - assert response == MOCK_GUARD_STRING - - del os.environ["PGHOST"] - - -def test_guards__post_mem(mocker): - mock_guard = MockGuardStruct() - mock_request = MockRequest("POST", mock_guard.to_dict()) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - from guardrails_api.blueprints.guards import guards - - response = guards() - - error_body, status = response - - assert status == 501 - - -def test_guards__raises(mocker): - mock_request = MockRequest("PUT") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import guards - - response = guards() - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 405 - assert error.get("message") == "Method Not Allowed" - assert ( - error.get("cause") - == "/guards only supports the GET and POST methods. You specified PUT" - ) - assert status == 405 - - -def test_guard__get_mem(mocker): - mock_guard = MockGuardStruct() - timestamp = "2024-03-04T14:11:42-06:00" - mock_request = MockRequest("GET", args={"asOf": timestamp}) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - # >>> Conflict - # mock_get_guard = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.get_guard", return_value=mock_guard - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guard - - response = guard("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name", timestamp) - assert response == MOCK_GUARD_STRING - - -def test_guard__put_pg(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - json_guard = { - "name": "mock-guard", - "id": "mock-guard-id", - "description": "mock guard description", - "history": Stack(), - } - mock_request = MockRequest("PUT", json=json_guard) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - mock_from_request = mocker.patch( - "guardrails_api.blueprints.guards.GuardStruct.from_dict", - return_value=mock_guard, - ) - mock_upsert_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.upsert_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - # >>> Conflict - # mock_from_request = mocker.patch( - # "guardrails_api.blueprints.guards.GuardStruct.from_request", return_value=mock_guard - # ) - # mock_upsert_guard = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.upsert_guard", return_value=mock_guard - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guard - - response = guard("My%20Guard's%20Name") - - mock_from_request.assert_called_once_with(json_guard) - mock_upsert_guard.assert_called_once_with("My Guard's Name", mock_guard) - assert response == MOCK_GUARD_STRING - del os.environ["PGHOST"] - - -def test_guard__delete_pg(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_request = MockRequest("DELETE") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - - mock_delete_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.delete_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - # >>> Conflict - # mock_delete_guard = mocker.patch( - # "guardrails_api.blueprints.guards.guard_client.delete_guard", return_value=mock_guard - # ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - - from guardrails_api.blueprints.guards import guard - - response = guard("my-guard-name") - - mock_delete_guard.assert_called_once_with("my-guard-name") - assert response == MOCK_GUARD_STRING - del os.environ["PGHOST"] - - -def test_guard__raises(mocker): - mock_request = MockRequest("POST") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import guard - - response = guard("guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 405 - assert error.get("message") == "Method Not Allowed" - assert ( - error.get("cause") - == "/guard/ only supports the GET, PUT, and DELETE methods. You specified POST" - ) - assert status == 405 - - -def test_validate__raises_method_not_allowed(mocker): - mock_request = MockRequest("PUT") - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer") - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import validate - - response = validate("guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 405 - assert error.get("message") == "Method Not Allowed" - assert ( - error.get("cause") - == "/guards//validate only supports the POST method. You specified PUT" - ) - assert status == 405 - - -def test_validate__raises_bad_request__openai_api_key(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - # mock_tracer = MockTracer() - mock_request = MockRequest("POST", json={"llmApi": "bar"}) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import validate - - response = validate("mock-guard") - - mock_get_guard.assert_called_once_with("mock-guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 400 - assert error.get("message") == "BadRequest" - assert error.get("cause") == ( - "Cannot perform calls to OpenAI without an api key. Pass" - " openai_api_key when initializing the Guard or set the" - " OPENAI_API_KEY environment variable." - ) - assert status == 400 - del os.environ["PGHOST"] - - -def test_validate__raises_bad_request__num_reasks(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - # mock_tracer = MockTracer() - mock_request = MockRequest("POST", json={"numReasks": 3}) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - mocker.patch("guardrails_api.utils.handle_error.logger.error") - mocker.patch("guardrails_api.utils.handle_error.traceback.print_exception") - from guardrails_api.blueprints.guards import validate - - response = validate("mock-guard") - - mock_get_guard.assert_called_once_with("mock-guard") - - assert isinstance(response, Tuple) - error, status = response - assert isinstance(error, Dict) - assert error.get("status") == 400 - assert error.get("message") == "BadRequest" - assert error.get("cause") == ( - "Cannot perform re-asks without an LLM API. Specify llm_api when" - " calling guard(...)." - ) - assert status == 400 - del os.environ["PGHOST"] - - -def test_validate__parse(mocker): - os.environ["PGHOST"] = "localhost" - mock_outcome = ValidationOutcome( - call_id="mock-call-id", - raw_llm_output="Hello world!", - validated_output="Hello world!", - validation_passed=True, - ) - - mock_parse = mocker.patch.object(MockGuardStruct, "parse") - mock_parse.return_value = mock_outcome - - mock_guard = MockGuardStruct() - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - # mock_tracer = MockTracer() - mock_request = MockRequest( - "POST", - json={"llmOutput": "Hello world!", "args": [1, 2, 3], "some_kwarg": "foo"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # >>> Conflict - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "pass" - mock_guard.history = Stack(Call()) - from guardrails_api.blueprints.guards import validate - - response = validate("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock_parse.call_count == 1 - - mock_parse.assert_called_once_with( - llm_output="Hello world!", - num_reasks=None, - prompt_params={}, - llm_api=None, - some_kwarg="foo", - api_key=None, - ) - - # Temporarily Disabled - # assert set_attribute_spy.call_count == 7 - # expected_calls = [ - # call("guardName", "My Guard's Name"), - # call("prompt", "Hello world prompt!"), - # call("validation_status", "pass"), - # call("raw_llm_ouput", "Hello world!"), - # call("validated_output", "Hello world!"), - # call("tokens_consumed", None), - # call("num_of_reasks", 0), - # ] - # set_attribute_spy.assert_has_calls(expected_calls) - - assert response == { - "callId": "mock-call-id", - "validatedOutput": "Hello world!", - "validationPassed": True, - "rawLlmOutput": "Hello world!", - } - - del os.environ["PGHOST"] - - -def test_validate__call(mocker): - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_outcome = ValidationOutcome( - call_id="mock-call-id", - raw_llm_output="Hello world!", - validated_output=None, - validation_passed=False, - ) - - mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - mock___call__.return_value = mock_outcome - - mock_guard = MockGuardStruct() - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - # mock_tracer = MockTracer() - mock_request = MockRequest( - "POST", - json={ - "llmApi": "openai.Completion.create", - "promptParams": {"p1": "bar"}, - "args": [1, 2, 3], - "some_kwarg": "foo", - "prompt": "Hello world!", - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch( - "guardrails_api.blueprints.guards.get_llm_callable", - return_value="openai.Completion.create", - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # >>> Conflict - # mocker.patch("guardrails_api.blueprints.guards.get_tracer", return_value=mock_tracer) - - # set_attribute_spy = mocker.spy(mock_tracer.span, "set_attribute") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "fail" - mock_guard.history = Stack(Call()) - from guardrails_api.blueprints.guards import validate - - response = validate("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock___call__.call_count == 1 - - mock___call__.assert_called_once_with( - 1, - 2, - 3, - llm_api="openai.Completion.create", - prompt_params={"p1": "bar"}, - num_reasks=None, - some_kwarg="foo", - api_key="mock-key", - prompt="Hello world!", - ) - - # Temporarily Disabled - # assert set_attribute_spy.call_count == 8 - # expected_calls = [ - # call("guardName", "My Guard's Name"), - # call("prompt", "Hello world prompt!"), - # call("instructions", "Hello world instructions!"), - # call("validation_status", "fail"), - # call("raw_llm_ouput", "Hello world!"), - # call("validated_output", "None"), - # call("tokens_consumed", None), - # call("num_of_reasks", 0), - # ] - # set_attribute_spy.assert_has_calls(expected_calls) - - assert response == { - "callId": "mock-call-id", - "validationPassed": False, - "validatedOutput": None, - "rawLlmOutput": "Hello world!", - } - - del os.environ["PGHOST"] - -def test_validate__call_throws_validation_error(mocker): - os.environ["PGHOST"] = "localhost" - - mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - mock___call__.side_effect = ValidationError("Test guard validation error") - - mock_guard = MockGuardStruct() - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - # mock_tracer = MockTracer() - mock_request = MockRequest( - "POST", - json={ - "llmApi": "openai.Completion.create", - "promptParams": {"p1": "bar"}, - "args": [1, 2, 3], - "some_kwarg": "foo", - "prompt": "Hello world!", - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch( - "guardrails_api.blueprints.guards.get_llm_callable", - return_value="openai.Completion.create", - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "fail" - mock_guard.history = Stack(Call()) - from guardrails_api.blueprints.guards import validate - - response = validate("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock___call__.call_count == 1 - - mock___call__.assert_called_once_with( - 1, - 2, - 3, - llm_api="openai.Completion.create", - prompt_params={"p1": "bar"}, - num_reasks=None, - some_kwarg="foo", - api_key="mock-key", - prompt="Hello world!", - ) - - assert response == ('Test guard validation error', 400) - - del os.environ["PGHOST"] - -def test_openai_v1_chat_completions__raises_404(mocker): - from guardrails_api.blueprints.guards import openai_v1_chat_completions - os.environ["PGHOST"] = "localhost" - mock_guard = None - - mock_request = MockRequest( - "POST", - json={ - "messages": [{"role":"user", "content":"Hello world!"}], - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - response = openai_v1_chat_completions("My%20Guard's%20Name") - assert response[1] == 404 - assert response[0]["message"] == 'NotFound' - - - mock_get_guard.assert_called_once_with("My Guard's Name") - - del os.environ["PGHOST"] - -def test_openai_v1_chat_completions__call(mocker): - from guardrails_api.blueprints.guards import openai_v1_chat_completions - os.environ["PGHOST"] = "localhost" - mock_guard = MockGuardStruct() - mock_outcome = ValidationOutcome( - call_id="mock-call-id", - raw_llm_output="Hello world!", - validated_output="Hello world!", - validation_passed=False, - ) - - mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - mock___call__.return_value = mock_outcome - - mock_from_dict = mocker.patch("guardrails_api.blueprints.guards.Guard.from_dict") - mock_from_dict.return_value = mock_guard - - mock_request = MockRequest( - "POST", - json={ - "messages": [{"role":"user", "content":"Hello world!"}], - }, - headers={"x-openai-api-key": "mock-key"}, - ) - - mocker.patch("flask.Blueprint", new=MockBlueprint) - mocker.patch("guardrails_api.blueprints.guards.request", mock_request) - mock_get_guard = mocker.patch( - "guardrails_api.blueprints.guards.guard_client.get_guard", - return_value=mock_guard, - ) - mocker.patch( - "guardrails_api.blueprints.guards.get_llm_callable", - return_value="openai.Completion.create", - ) - - mocker.patch("guardrails_api.blueprints.guards.CacheClient.set") - - mock_status = mocker.patch( - "guardrails.classes.history.call.Call.status", new_callable=PropertyMock - ) - mock_status.return_value = "fail" - mock_call = Call() - mock_call.iterations= Stack(Iteration('some-id', 1)) - mock_guard.history = Stack(mock_call) - - response = openai_v1_chat_completions("My%20Guard's%20Name") - - mock_get_guard.assert_called_once_with("My Guard's Name") - - assert mock___call__.call_count == 1 - - mock___call__.assert_called_once_with( - num_reasks=0, - messages=[{"role":"user", "content":"Hello world!"}], - ) - - assert response == { - "choices": [ - { - "message": { - "content": "Hello world!", - }, - } - ], - "guardrails": { - "reask": None, - "validation_passed": False, - "error": None, - }, - } - - del os.environ["PGHOST"] \ No newline at end of file diff --git a/tests/blueprints/test_root.py b/tests/blueprints/test_root.py deleted file mode 100644 index 7ef611f..0000000 --- a/tests/blueprints/test_root.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from guardrails_api.utils.logger import logger -from tests.mocks.mock_blueprint import MockBlueprint -from tests.mocks.mock_postgres_client import MockPostgresClient - - -def test_home(mocker): - mocker.patch("flask.Blueprint", new=MockBlueprint) - from guardrails_api.blueprints.root import home, root_bp - - response = home() - - assert root_bp.route_call_count == 4 - assert root_bp.routes == ["/", "/health-check", "/api-docs", "/docs"] - assert response == "Hello, Flask!" - - mocker.resetall() - - -def test_health_check(mocker): - os.environ["PGHOST"] = "localhost" - mocker.patch("flask.Blueprint", new=MockBlueprint) - - mock_pg = MockPostgresClient() - mock_pg.db.session._set_rows([(1,)]) - mocker.patch("guardrails_api.blueprints.root.PostgresClient", return_value=mock_pg) - - def text_side_effect(query: str): - return query - - mock_text = mocker.patch( - "guardrails_api.blueprints.root.text", side_effect=text_side_effect - ) - - from guardrails_api.blueprints.root import health_check - - info_spy = mocker.spy(logger, "info") - - response = health_check() - - mock_text.assert_called_once_with("SELECT count(datid) FROM pg_stat_activity;") - assert mock_pg.db.session.queries == ["SELECT count(datid) FROM pg_stat_activity;"] - - info_spy.assert_called_once_with("response: %s", [(1,)]) - assert response == {"status": 200, "message": "Ok"} - - mocker.resetall() - del os.environ["PGHOST"] diff --git a/tests/cli/test_start.py b/tests/cli/test_start.py index 2fd10da..3138a66 100644 --- a/tests/cli/test_start.py +++ b/tests/cli/test_start.py @@ -1,15 +1,19 @@ from unittest.mock import MagicMock import os + def test_start(mocker): mocker.patch("guardrails_api.cli.start.cli") - mock_flask_app = MagicMock() + mock_app = MagicMock() mock_create_app = mocker.patch( - "guardrails_api.cli.start.create_app", return_value=mock_flask_app + "guardrails_api.cli.start.create_app", return_value=mock_app ) + mocker.patch("uvicorn.run") + from guardrails_api.cli.start import start + # pg enabled os.environ["PGHOST"] = "localhost" start("env", "config", 8000) diff --git a/tests/clients/test_pg_guard_client.py b/tests/clients/test_pg_guard_client.py index 0b94224..add0048 100644 --- a/tests/clients/test_pg_guard_client.py +++ b/tests/clients/test_pg_guard_client.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import ANY as AnyMatcher +from unittest.mock import ANY as AnyMatcher, MagicMock from guardrails_api.classes.http_error import HttpError from guardrails_api.models.guard_item import GuardItem @@ -28,14 +28,19 @@ def test_init(mocker): class TestGetGuard: def test_get_latest(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + query_spy = mock_session.query + query_spy.return_value = mock_session - query_spy = mocker.spy(mock_pg_client.db.session, "query") - filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") + filter_by_spy = mock_session.filter_by + filter_by_spy.return_value = mock_session + + mock_first = mock_session.first latest_guard = MockGuardStruct() mock_first.return_value = latest_guard @@ -59,16 +64,25 @@ def test_get_latest(self, mocker): def test_with_as_of_date(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + query_spy = mock_session.query + query_spy.return_value = mock_session + + filter_by_spy = mock_session.filter_by + filter_by_spy.return_value = mock_session + + filter_spy = mock_session.filter + filter_spy.return_value = mock_session - query_spy = mocker.spy(mock_pg_client.db.session, "query") - filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") - filter_spy = mocker.spy(mock_pg_client.db.session, "filter") - order_by_spy = mocker.spy(mock_pg_client.db.session, "order_by") - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") + order_by_spy = mock_session.order_by + order_by_spy.return_value = mock_session + + mock_first = mock_session.first latest_guard = MockGuardStruct(name="latest") previous_guard = MockGuardStruct(name="previous") mock_first.side_effect = [latest_guard, previous_guard] @@ -107,13 +121,20 @@ def test_with_as_of_date(self, mocker): def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") - mock_first.return_value = None + # Mock the query method to return a mock query object + mock_query = mock_session.query.return_value + + # Mock the filter_by method to return a mock filter object + mock_filter_by = mock_query.filter_by.return_value + mock_first = mock_filter_by.first + # Mock the first method on the mock filter object to return None + mock_filter_by.first.return_value = None mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -136,14 +157,20 @@ def test_raises_not_found(self, mocker): def test_get_guard_item(mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - query_spy = mocker.spy(mock_pg_client.db.session, "query") - filter_by_spy = mocker.spy(mock_pg_client.db.session, "filter_by") - mock_first = mocker.patch.object(mock_pg_client.db.session, "first") + query_spy = mock_session.query + query_spy.return_value = mock_session + + filter_by_spy = mock_session.filter_by + filter_by_spy.return_value = mock_session + + mock_first = mock_session.first latest_guard = MockGuardStruct(name="latest") mock_first.return_value = latest_guard @@ -162,17 +189,23 @@ def test_get_guard_item(mocker): def test_get_guards(mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - query_spy = mocker.spy(mock_pg_client.db.session, "query") - mock_all = mocker.patch.object(mock_pg_client.db.session, "all") + # Ensure that query returns the mock session itself + mock_session.query.return_value = mock_session + query_spy = mock_session.query + guard_one = MockGuardStruct(name="guard one") guard_two = MockGuardStruct(name="guard two") guards = [guard_one, guard_two] - mock_all.return_value = guards + # Mock the all method on the mock session + mock_session.all.return_value = guards + mock_all = mock_session.all mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" @@ -198,7 +231,9 @@ def test_get_guards(mocker): def test_create_guard(mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mock_guard_struct_init_spy = mocker.spy(MockGuardStruct, "__init__") + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -207,8 +242,8 @@ def test_create_guard(mocker): "guardrails_api.clients.pg_guard_client.GuardItem", new=MockGuardStruct ) - add_spy = mocker.spy(mock_pg_client.db.session, "add") - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + add_spy = mocker.spy(mock_session, "add") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" @@ -244,6 +279,8 @@ class TestUpdateGuard: def test_raises_not_found(self, mocker): mock_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -253,7 +290,7 @@ def test_raises_not_found(self, mocker): ) mock_get_guard_item.return_value = None - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -287,6 +324,8 @@ def test_updates_guard_item(self, mocker): updated_guard = MockGuardStruct(description="updated description") mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -296,7 +335,7 @@ def test_updates_guard_item(self, mocker): ) mock_get_guard_item.return_value = old_guard_item - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -324,6 +363,7 @@ def test_guard_doesnt_exist_yet(self, mocker): input_guard = MockGuardStruct() new_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, @@ -367,16 +407,20 @@ def test_guard_already_exists(self, mocker): updated_guard = MockGuardStruct(description="updated description") mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + + mock_session = mock_pg_client.SessionLocal() + mock_get_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = old_guard_item - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -404,16 +448,21 @@ def test_guard_already_exists(self, mocker): class TestDeleteGuard: def test_raises_not_found(self, mocker): mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + + mock_session = mock_pg_client.SessionLocal() + mock_get_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = None - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + commit_spy = mocker.spy(mock_session, "commit") + mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -438,17 +487,23 @@ def test_raises_not_found(self, mocker): def test_deletes_guard_item(self, mocker): old_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() + mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) + + mock_session = mock_pg_client.SessionLocal() + mock_get_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" ) mock_get_guard_item.return_value = old_guard - delete_spy = mocker.spy(mock_pg_client.db.session, "delete") - commit_spy = mocker.spy(mock_pg_client.db.session, "commit") + # Mock the query and delete operations + mock_query = mock_session.query.return_value + mock_filter = mock_query.filter_by.return_value + mock_filter.first.return_value = old_guard mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) @@ -461,8 +516,8 @@ def test_deletes_guard_item(self, mocker): result = guard_client.delete_guard("mock-guard") mock_get_guard_item.assert_called_once_with("mock-guard") - assert delete_spy.call_count == 1 - assert commit_spy.call_count == 1 + assert mock_session.delete.call_count == 1 + assert mock_session.commit.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard) assert result == old_guard diff --git a/tests/mocks/mock_guard_client.py b/tests/mocks/mock_guard_client.py index 04bb77f..beca0a7 100644 --- a/tests/mocks/mock_guard_client.py +++ b/tests/mocks/mock_guard_client.py @@ -3,6 +3,7 @@ from pydantic import ConfigDict from guardrails.classes.generic import Stack + class MockGuardStruct(GuardStruct): # Pydantic Config model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/tests/mocks/mock_postgres_client.py b/tests/mocks/mock_postgres_client.py index 4197882..7ee9b2b 100644 --- a/tests/mocks/mock_postgres_client.py +++ b/tests/mocks/mock_postgres_client.py @@ -49,7 +49,14 @@ class MockDb: def __init__(self) -> None: self.session = MockSession() + def SessionLocal(self): + return self.session + class MockPostgresClient: def __init__(self): self.db = MockDb() + self.pgClient = self.db + + def get_db(self): + return MockSession() diff --git a/tests/utils/test_configuration.py b/tests/utils/test_configuration.py index 635893a..0ad5098 100644 --- a/tests/utils/test_configuration.py +++ b/tests/utils/test_configuration.py @@ -2,15 +2,16 @@ import pytest from guardrails_api.utils.configuration import valid_configuration, ConfigurationError + def test_valid_configuration(mocker): with pytest.raises(ConfigurationError): valid_configuration() - + # pg enabled os.environ["PGHOST"] = "localhost" valid_configuration("config.py") os.environ.pop("PGHOST") - + # custom config mock_isfile = mocker.patch("os.path.isfile") mock_isfile.side_effect = [False, True] @@ -20,7 +21,7 @@ def test_valid_configuration(mocker): mock_isfile.side_effect = [False, False] with pytest.raises(ConfigurationError): valid_configuration("") - + # default config mock_isfile = mocker.patch("os.path.isfile") mock_isfile.side_effect = [True, False]