Skip to content

Commit

Permalink
fix pg support
Browse files Browse the repository at this point in the history
  • Loading branch information
dtam committed Sep 14, 2024
1 parent 2e3f050 commit 1ab9e52
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 85 deletions.
1 change: 0 additions & 1 deletion guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import json
import os


# from pyinstrument import Profiler
# from pyinstrument.renderers.html import HTMLRenderer
# from pyinstrument.renderers.speedscope import SpeedscopeRenderer
Expand Down
38 changes: 24 additions & 14 deletions guardrails_api/clients/pg_guard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ def __init__(self):
self.initialized = True
self.pgClient = PostgresClient()

def get_db(self): # generator for local sessions
db = self.pgClient.SessionLocal()
try:
yield db
finally:
db.close()

def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct:
latest_guard_item = (
self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first()
)
db = next(self.get_db())
latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first()
audit_item = None
if as_of_date is not None:
audit_item = (
self.pgClient.db.session.query(GuardItemAudit)
db.query(GuardItemAudit)
.filter_by(name=guard_name)
.filter(GuardItemAudit.replaced_on > as_of_date)
.order_by(GuardItemAudit.replaced_on.asc())
Expand All @@ -43,27 +49,29 @@ def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct:
return from_guard_item(guard_item)

def get_guard_item(self, guard_name: str) -> GuardItem:
return (
self.pgClient.db.session.query(GuardItem).filter_by(name=guard_name).first()
)
db = next(self.get_db())
return db.query(GuardItem).filter_by(name=guard_name).first()

def get_guards(self) -> List[GuardStruct]:
guard_items = self.pgClient.db.session.query(GuardItem).all()
db = next(self.get_db())
guard_items = db.query(GuardItem).all()

return [from_guard_item(gi) for gi in guard_items]

def create_guard(self, guard: GuardStruct) -> GuardStruct:
db = next(self.get_db())
guard_item = GuardItem(
name=guard.name,
railspec=guard.to_dict(),
num_reasks=None,
description=guard.description,
)
self.pgClient.db.session.add(guard_item)
self.pgClient.db.session.commit()
db.add(guard_item)
db.commit()
return from_guard_item(guard_item)

def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
db = next(self.get_db())
guard_item = self.get_guard_item(guard_name)
if guard_item is None:
raise HttpError(
Expand All @@ -76,21 +84,23 @@ def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
# guard_item.num_reasks = guard.num_reasks
guard_item.railspec = guard.to_dict()
guard_item.description = guard.description
self.pgClient.db.session.commit()
db.commit()
return from_guard_item(guard_item)

def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
db = next(self.get_db())
guard_item = self.get_guard_item(guard_name)
if guard_item is not None:
guard_item.railspec = guard.to_dict()
guard_item.description = guard.description
# guard_item.num_reasks = guard.num_reasks
self.pgClient.db.session.commit()
db.commit()
return from_guard_item(guard_item)
else:
return self.create_guard(guard)

def delete_guard(self, guard_name: str) -> GuardStruct:
db = next(self.get_db())
guard_item = self.get_guard_item(guard_name)
if guard_item is None:
raise HttpError(
Expand All @@ -100,7 +110,7 @@ def delete_guard(self, guard_name: str) -> GuardStruct:
guard_name=guard_name
),
)
self.pgClient.db.session.delete(guard_item)
self.pgClient.db.session.commit()
db.delete(guard_item)
db.commit()
guard = from_guard_item(guard_item)
return guard
101 changes: 80 additions & 21 deletions guardrails_api/clients/postgres_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
import json
import os
import threading
from flask import Flask
from sqlalchemy import text
from fastapi import FastAPI
from typing import Tuple
from guardrails_api.models.base import db, INIT_EXTENSIONS
from sqlalchemy import create_engine, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

Base = declarative_base()


def postgres_is_enabled() -> bool:
return os.environ.get("PGHOST", None) is not None


# Global variables for database session
postgres_client = None
SessionLocal = None


class PostgresClient:
_instance = None
_lock = threading.Lock()
Expand Down Expand Up @@ -45,7 +53,17 @@ def get_pg_creds(self) -> Tuple[str, str]:
pg_password = pg_password or os.environ.get("PGPASSWORD")
return pg_user, pg_password

def initialize(self, app: Flask):
def get_db(self):
if postgres_is_enabled():
db = self.SessionLocal()
try:
yield db
finally:
db.close()
else:
yield None

def initialize(self, app: FastAPI):
pg_user, pg_password = self.get_pg_creds()
pg_host = os.environ.get("PGHOST", "localhost")
pg_port = os.environ.get("PGPORT", "5432")
Expand All @@ -64,23 +82,64 @@ def initialize(self, app: Flask):
if os.environ.get("NODE_ENV") == "production":
conf = f"{conf}?sslmode=verify-ca&sslrootcert=global-bundle.pem"

app.config["SQLALCHEMY_DATABASE_URI"] = conf
engine = create_engine(conf)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.secret_key = "secret"
self.app = app
self.db = db
db.init_app(app)
from guardrails_api.models.guard_item import GuardItem # NOQA
from guardrails_api.models.guard_item_audit import ( # NOQA
GuardItemAudit,
AUDIT_FUNCTION,
AUDIT_TRIGGER,
)
self.engine = engine
self.SessionLocal = SessionLocal
# Create tables
from guardrails_api.models import GuardItem, GuardItemAudit # noqa

Base.metadata.create_all(bind=engine)

# Execute custom SQL
with engine.connect() as connection:
connection.execute(text(INIT_EXTENSIONS))
connection.execute(text(AUDIT_FUNCTION))
connection.execute(text(AUDIT_TRIGGER))
connection.commit()


# Define INIT_EXTENSIONS, AUDIT_FUNCTION, and AUDIT_TRIGGER here as they were in your original code
INIT_EXTENSIONS = """
-- Your SQL for initializing extensions
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'uuid-ossp') THEN
CREATE EXTENSION "uuid-ossp";
END IF;
END $$;
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
CREATE EXTENSION "vector";
END IF;
END $$;
"""

AUDIT_FUNCTION = """
CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$
BEGIN
IF (TG_OP = 'DELETE') THEN
INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D';
ELSIF (TG_OP = 'UPDATE') THEN
INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U';
ELSIF (TG_OP = 'INSERT') THEN
INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I';
END IF;
RETURN null;
END;
$guard_audit$
LANGUAGE plpgsql;
"""

with self.app.app_context():
self.db.session.execute(text(INIT_EXTENSIONS))
self.db.create_all()
self.db.session.execute(text(AUDIT_FUNCTION))
self.db.session.execute(text(AUDIT_TRIGGER))
self.db.session.commit()
AUDIT_TRIGGER = """
DROP TRIGGER IF EXISTS guard_audit_trigger
ON guards;
CREATE TRIGGER guard_audit_trigger
AFTER INSERT OR UPDATE OR DELETE ON guards
FOR EACH ROW
EXECUTE PROCEDURE guard_audit_function();
"""
5 changes: 5 additions & 0 deletions guardrails_api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# __init__.py
from .guard_item_audit import GuardItemAudit
from .guard_item import GuardItem

__all__ = ["GuardItemAudit", "GuardItem"]
8 changes: 0 additions & 8 deletions guardrails_api/models/base.py

This file was deleted.

4 changes: 2 additions & 2 deletions guardrails_api/models/guard_item.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sqlalchemy import Column, String, Integer
from sqlalchemy.dialects.postgresql import JSONB
from guardrails_api.models.base import db
from guardrails_api.clients.postgres_client import Base


class GuardItem(db.Model):
class GuardItem(Base):
__tablename__ = "guards"
# TODO: Make primary key a composite between guard.name and the guard owner's userId
name = Column(String, primary_key=True)
Expand Down
30 changes: 2 additions & 28 deletions guardrails_api/models/guard_item_audit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sqlalchemy import Column, String, Integer
from sqlalchemy.dialects.postgresql import JSONB, TIMESTAMP, CHAR
from guardrails_api.models.base import db
from guardrails_api.clients.postgres_client import Base


class GuardItemAudit(db.Model):
class GuardItemAudit(Base):
__tablename__ = "guards_audit"
id = Column(String, primary_key=True)
name = Column(String, nullable=False, index=True)
Expand Down Expand Up @@ -35,29 +35,3 @@ def __init__(
self.replaced_on = replaced_on
self.operation = operation
# self.owner = owner


AUDIT_FUNCTION = """
CREATE OR REPLACE FUNCTION guard_audit_function() RETURNS TRIGGER AS $guard_audit$
BEGIN
IF (TG_OP = 'DELETE') THEN
INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'D';
ELSIF (TG_OP = 'UPDATE') THEN
INSERT INTO guards_audit SELECT uuid_generate_v4(), OLD.*, now(), 'U';
ELSIF (TG_OP = 'INSERT') THEN
INSERT INTO guards_audit SELECT uuid_generate_v4(), NEW.*, now(), 'I';
END IF;
RETURN null;
END;
$guard_audit$
LANGUAGE plpgsql;
"""

AUDIT_TRIGGER = """
DROP TRIGGER IF EXISTS guard_audit_trigger
ON guards;
CREATE TRIGGER guard_audit_trigger
AFTER INSERT OR UPDATE OR DELETE ON guards
FOR EACH ROW
EXECUTE PROCEDURE guard_audit_function();
"""
2 changes: 1 addition & 1 deletion guardrails_api/start-dev.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
gunicorn --bind 0.0.0.0:8000 \
--timeout 120 \
--workers 3 \
--workers 2 \
--threads 2 \
--worker-class=uvicorn.workers.UvicornWorker \
"guardrails_api.app:create_app()" \
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"]
requires-python = ">= 3.8.1"
dependencies = [
"guardrails-ai>=0.5.6",
"Flask-SQLAlchemy>=3.1.1,<4",
"Werkzeug>=3.0.3,<4",
"jsonschema>=4.22.0,<5",
"referencing>=0.35.1,<1",
Expand All @@ -27,6 +26,7 @@ dependencies = [
"requests>=2.32.3",
"aiocache>=0.11.1",
"fastapi>=0.114.1",
"SQLAlchemy>=2.0.34",
]

[tool.setuptools.dynamic]
Expand All @@ -42,6 +42,7 @@ dev = [
"coverage",
"pytest-mock",
"gunicorn>=22.0.0,<23",
"uvicorn",
]

[tool.pytest.ini_options]
Expand Down
8 changes: 1 addition & 7 deletions requirements-lock.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,10 @@ frozenlist==1.4.1
fsspec==2024.6.1
googleapis-common-protos==1.63.2
griffe==0.36.9
<<<<<<< Updated upstream
grpcio==1.65.1
guardrails-ai==0.5.7
guardrails-ai==0.5.9
guardrails-api-client==0.3.12
guardrails_hub_types==0.0.4
=======
grpcio==1.64.1
guardrails-ai==0.5.0a2
guardrails-api-client==0.3.8
>>>>>>> Stashed changes
gunicorn==22.0.0
h11==0.14.0
httpcore==1.0.5
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/test_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
def test_start(mocker):
mocker.patch("guardrails_api.cli.start.cli")

mock_flask_app = MagicMock()
mock_app = MagicMock()
mock_create_app = mocker.patch(
"guardrails_api.cli.start.create_app", return_value=mock_flask_app
"guardrails_api.cli.start.create_app", return_value=mock_app
)

from guardrails_api.cli.start import start
Expand Down

0 comments on commit 1ab9e52

Please sign in to comment.