From 940f9372e554a54718d693599774e01578622fab Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 15 Aug 2024 09:08:35 -0500 Subject: [PATCH] run guard client in app context; use thread lock on singletons --- guardrails_api/app.py | 12 +++++++----- guardrails_api/clients/cache_client.py | 5 ++++- guardrails_api/clients/postgres_client.py | 5 ++++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 9ea2ce5..92e72d8 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -99,7 +99,7 @@ def create_app( cache_client.initialize(app) from guardrails_api.blueprints.root import root_bp - from guardrails_api.blueprints.guards import guards_bp, guard_client + from guardrails_api.blueprints.guards import guards_bp app.register_blueprint(root_bp) app.register_blueprint(guards_bp) @@ -111,11 +111,13 @@ def create_app( console.print(":green_circle: Active guards and OpenAI compatible endpoints:") - for g in guard_client.get_guards(): - g = g.to_dict() - console.print(f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1") + with app.app_context(): + from guardrails_api.blueprints.guards import guard_client + for g in guard_client.get_guards(): + g = g.to_dict() + console.print(f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1") console.print("") console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")) - return app + return app \ No newline at end of file diff --git a/guardrails_api/clients/cache_client.py b/guardrails_api/clients/cache_client.py index dc45886..550bc4b 100644 --- a/guardrails_api/clients/cache_client.py +++ b/guardrails_api/clients/cache_client.py @@ -1,13 +1,16 @@ +import threading from flask_caching import Cache # TODO: Add option to connect to Redis or MemCached backend with environment variables class CacheClient: _instance = None + _lock = threading.Lock() def __new__(cls): if cls._instance is None: - cls._instance = super(CacheClient, cls).__new__(cls) + with cls._lock: + cls._instance = super(CacheClient, cls).__new__(cls) return cls._instance def initialize(self, app): diff --git a/guardrails_api/clients/postgres_client.py b/guardrails_api/clients/postgres_client.py index d7b2ac5..951a4f4 100644 --- a/guardrails_api/clients/postgres_client.py +++ b/guardrails_api/clients/postgres_client.py @@ -1,6 +1,7 @@ import boto3 import json import os +import threading from flask import Flask from sqlalchemy import text from typing import Tuple @@ -13,10 +14,12 @@ def postgres_is_enabled() -> bool: class PostgresClient: _instance = None + _lock = threading.Lock() def __new__(cls): if cls._instance is None: - cls._instance = super(PostgresClient, cls).__new__(cls) + with cls._lock: + cls._instance = super(PostgresClient, cls).__new__(cls) return cls._instance def fetch_pg_secret(self, secret_arn: str) -> dict: