diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 64e5180..27512f7 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -18,7 +18,6 @@ import json import os - # from pyinstrument import Profiler # from pyinstrument.renderers.html import HTMLRenderer # from pyinstrument.renderers.speedscope import SpeedscopeRenderer diff --git a/guardrails_api/clients/pg_guard_client.py b/guardrails_api/clients/pg_guard_client.py index c7a1f48..226232a 100644 --- a/guardrails_api/clients/pg_guard_client.py +++ b/guardrails_api/clients/pg_guard_client.py @@ -18,14 +18,20 @@ def __init__(self): self.initialized = True self.pgClient = PostgresClient() + def get_db(self): # generator for local sessions + db = self.pgClient.SessionLocal() + try: + yield db + finally: + db.close() + def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: - latest_guard_item = ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) + db = next(self.get_db()) + latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first() audit_item = None if as_of_date is not None: audit_item = ( - self.pgClient.db.session.query(GuardItemAudit) + db.query(GuardItemAudit) .filter_by(name=guard_name) .filter(GuardItemAudit.replaced_on > as_of_date) .order_by(GuardItemAudit.replaced_on.asc()) @@ -43,27 +49,29 @@ def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: return from_guard_item(guard_item) def get_guard_item(self, guard_name: str) -> GuardItem: - return ( - self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first() - ) + db = next(self.get_db()) + return db.query(GuardItem).filter_by(name=guard_name).first() def get_guards(self) -> List[GuardStruct]: - guard_items = self.pgClient.db.session.query(GuardItem).all() + db = next(self.get_db()) + guard_items = db.query(GuardItem).all() return [from_guard_item(gi) for gi in guard_items] def create_guard(self, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = GuardItem( name=guard.name, railspec=guard.to_dict(), num_reasks=None, description=guard.description, ) - self.pgClient.db.session.add(guard_item) - self.pgClient.db.session.commit() + db.add(guard_item) + db.commit() return from_guard_item(guard_item) def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is None: raise HttpError( @@ -76,21 +84,23 @@ def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: # guard_item.num_reasks = guard.num_reasks guard_item.railspec = guard.to_dict() guard_item.description = guard.description - self.pgClient.db.session.commit() + db.commit() return from_guard_item(guard_item) def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is not None: guard_item.railspec = guard.to_dict() guard_item.description = guard.description # guard_item.num_reasks = guard.num_reasks - self.pgClient.db.session.commit() + db.commit() return from_guard_item(guard_item) else: return self.create_guard(guard) def delete_guard(self, guard_name: str) -> GuardStruct: + db = next(self.get_db()) guard_item = self.get_guard_item(guard_name) if guard_item is None: raise HttpError( @@ -100,7 +110,7 @@ def delete_guard(self, guard_name: str) -> GuardStruct: guard_name=guard_name ), ) - self.pgClient.db.session.delete(guard_item) - self.pgClient.db.session.commit() + db.delete(guard_item) + db.commit() guard = from_guard_item(guard_item) return guard diff --git a/guardrails_api/clients/postgres_client.py b/guardrails_api/clients/postgres_client.py index 951a4f4..56e1501 100644 --- a/guardrails_api/clients/postgres_client.py +++ b/guardrails_api/clients/postgres_client.py @@ -2,16 +2,24 @@ import json import os import threading -from flask import Flask -from sqlalchemy import text +from fastapi import FastAPI from typing import Tuple -from guardrails_api.models.base import db, INIT_EXTENSIONS +from sqlalchemy import create_engine, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() def postgres_is_enabled() -> bool: return os.environ.get("PGHOST", None) is not None +# Global variables for database session +postgres_client = None +SessionLocal = None + + class PostgresClient: _instance = None _lock = threading.Lock() @@ -45,7 +53,17 @@ def get_pg_creds(self) -> Tuple[str, str]: pg_password = pg_password or os.environ.get("PGPASSWORD") return pg_user, pg_password - def initialize(self, app: Flask): + def get_db(self): + if postgres_is_enabled(): + db = self.SessionLocal() + try: + yield db + finally: + db.close() + else: + yield None + + def initialize(self, app: FastAPI): pg_user, pg_password = self.get_pg_creds() pg_host = os.environ.get("PGHOST", "localhost") pg_port = os.environ.get("PGPORT", "5432") @@ -64,23 +82,64 @@ def initialize(self, app: Flask): if os.environ.get("NODE_ENV") == "production": conf = f"{conf}?sslmode=verify-ca&sslrootcert=global-bundle.pem" - app.config["SQLALCHEMY_DATABASE_URI"] = conf + engine = create_engine(conf) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - app.secret_key = "secret" self.app = app - self.db = db - db.init_app(app) - from guardrails_api.models.guard_item import GuardItem # NOQA - from guardrails_api.models.guard_item_audit import ( # NOQA - GuardItemAudit, - AUDIT_FUNCTION, - AUDIT_TRIGGER, - ) + self.engine = engine + self.SessionLocal = SessionLocal + # Create tables + from guardrails_api.models import GuardItem, GuardItemAudit # noqa + + Base.metadata.create_all(bind=engine) + + # Execute custom SQL + with engine.connect() as connection: + connection.execute(text(INIT_EXTENSIONS)) + connection.execute(text(AUDIT_FUNCTION)) + connection.execute(text(AUDIT_TRIGGER)) + connection.commit() + + +# Define INIT_EXTENSIONS, AUDIT_FUNCTION, and AUDIT_TRIGGER here as they were in your original code +INIT_EXTENSIONS = """ +-- Your SQL for initializing extensions +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp') THEN + CREATE EXTENSION "uuid-ossp"; + END IF; +END $$; + +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN + CREATE EXTENSION "vector"; + END IF; +END $$; +""" + +AUDIT_FUNCTION = """ +CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$ +BEGIN + IF (TG_OP = 'DELETE') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D'; + ELSIF (TG_OP = 'UPDATE') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U'; + ELSIF (TG_OP = 'INSERT') THEN + INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I'; + END IF; + RETURN null; +END; +$guard_audit$ +LANGUAGE plpgsql; +""" - with self.app.app_context(): - self.db.session.execute(text(INIT_EXTENSIONS)) - self.db.create_all() - self.db.session.execute(text(AUDIT_FUNCTION)) - self.db.session.execute(text(AUDIT_TRIGGER)) - self.db.session.commit() +AUDIT_TRIGGER = """ +DROP TRIGGER IF EXISTS guard_audit_trigger + ON guards; +CREATE TRIGGER guard_audit_trigger + AFTER INSERT OR UPDATE OR DELETE ON guards + FOR EACH ROW + EXECUTE PROCEDURE guard_audit_function(); +""" diff --git a/guardrails_api/models/__init__.py b/guardrails_api/models/__init__.py index e69de29..391a299 100644 --- a/guardrails_api/models/__init__.py +++ b/guardrails_api/models/__init__.py @@ -0,0 +1,5 @@ +# __init__.py +from .guard_item_audit import GuardItemAudit +from .guard_item import GuardItem + +__all__ = ["GuardItemAudit", "GuardItem"] diff --git a/guardrails_api/models/base.py b/guardrails_api/models/base.py deleted file mode 100644 index 29f0169..0000000 --- a/guardrails_api/models/base.py +++ /dev/null @@ -1,8 +0,0 @@ -from flask_sqlalchemy import SQLAlchemy - -db = SQLAlchemy() - -INIT_EXTENSIONS = """ -CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; -CREATE EXTENSION IF NOT EXISTS "vector"; -""" diff --git a/guardrails_api/models/guard_item.py b/guardrails_api/models/guard_item.py index a2dbf32..6bcacfa 100644 --- a/guardrails_api/models/guard_item.py +++ b/guardrails_api/models/guard_item.py @@ -1,9 +1,9 @@ from sqlalchemy import Column, String, Integer from sqlalchemy.dialects.postgresql import JSONB -from guardrails_api.models.base import db +from guardrails_api.clients.postgres_client import Base -class GuardItem(db.Model): +class GuardItem(Base): __tablename__ = "guards" # TODO: Make primary key a composite between guard.name and the guard owner's userId name = Column(String, primary_key=True) diff --git a/guardrails_api/models/guard_item_audit.py b/guardrails_api/models/guard_item_audit.py index 183626e..13ee3c2 100644 --- a/guardrails_api/models/guard_item_audit.py +++ b/guardrails_api/models/guard_item_audit.py @@ -1,9 +1,9 @@ from sqlalchemy import Column, String, Integer from sqlalchemy.dialects.postgresql import JSONB, TIMESTAMP, CHAR -from guardrails_api.models.base import db +from guardrails_api.clients.postgres_client import Base -class GuardItemAudit(db.Model): +class GuardItemAudit(Base): __tablename__ = "guards_audit" id = Column(String, primary_key=True) name = Column(String, nullable=False, index=True) @@ -35,29 +35,3 @@ def __init__( self.replaced_on = replaced_on self.operation = operation # self.owner = owner - - -AUDIT_FUNCTION = """ -CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$ -BEGIN - IF (TG_OP = 'DELETE') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D'; - ELSIF (TG_OP = 'UPDATE') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U'; - ELSIF (TG_OP = 'INSERT') THEN - INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I'; - END IF; - RETURN null; -END; -$guard_audit$ -LANGUAGE plpgsql; -""" - -AUDIT_TRIGGER = """ -DROP TRIGGER IF EXISTS guard_audit_trigger - ON guards; -CREATE TRIGGER guard_audit_trigger - AFTER INSERT OR UPDATE OR DELETE ON guards - FOR EACH ROW - EXECUTE PROCEDURE guard_audit_function(); -""" diff --git a/guardrails_api/start-dev.sh b/guardrails_api/start-dev.sh index 36f33ba..83a2d70 100755 --- a/guardrails_api/start-dev.sh +++ b/guardrails_api/start-dev.sh @@ -1,6 +1,6 @@ gunicorn --bind 0.0.0.0:8000 \ --timeout 120 \ - --workers 3 \ + --workers 2 \ --threads 2 \ --worker-class=uvicorn.workers.UvicornWorker \ "guardrails_api.app:create_app()" \ diff --git a/pyproject.toml b/pyproject.toml index d3d3d94..a2db2fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"] requires-python = ">= 3.8.1" dependencies = [ "guardrails-ai>=0.5.6", - "Flask-SQLAlchemy>=3.1.1,<4", "Werkzeug>=3.0.3,<4", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", @@ -27,6 +26,7 @@ dependencies = [ "requests>=2.32.3", "aiocache>=0.11.1", "fastapi>=0.114.1", + "SQLAlchemy>=2.0.34", ] [tool.setuptools.dynamic] @@ -42,6 +42,7 @@ dev = [ "coverage", "pytest-mock", "gunicorn>=22.0.0,<23", + "uvicorn", ] [tool.pytest.ini_options] diff --git a/requirements-lock.txt b/requirements-lock.txt index 950b069..cacc27e 100644 --- a/requirements-lock.txt +++ b/requirements-lock.txt @@ -28,16 +28,10 @@ frozenlist==1.4.1 fsspec==2024.6.1 googleapis-common-protos==1.63.2 griffe==0.36.9 -<<<<<<< Updated upstream grpcio==1.65.1 -guardrails-ai==0.5.7 +guardrails-ai==0.5.9 guardrails-api-client==0.3.12 guardrails_hub_types==0.0.4 -======= -grpcio==1.64.1 -guardrails-ai==0.5.0a2 -guardrails-api-client==0.3.8 ->>>>>>> Stashed changes gunicorn==22.0.0 h11==0.14.0 httpcore==1.0.5 diff --git a/tests/cli/test_start.py b/tests/cli/test_start.py index e6973d9..befe21a 100644 --- a/tests/cli/test_start.py +++ b/tests/cli/test_start.py @@ -5,9 +5,9 @@ def test_start(mocker): mocker.patch("guardrails_api.cli.start.cli") - mock_flask_app = MagicMock() + mock_app = MagicMock() mock_create_app = mocker.patch( - "guardrails_api.cli.start.create_app", return_value=mock_flask_app + "guardrails_api.cli.start.create_app", return_value=mock_app ) from guardrails_api.cli.start import start