From 4da9e8d81d49558e7708354149558fc1195c7033 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Fri, 1 Nov 2024 15:53:21 -0700 Subject: [PATCH 01/12] fix: added pg advisory lock due to worker init concurrency issues (when starting with more than one worker using postgres client) --- guardrails_api/clients/postgres_client.py | 45 +++++++++++++---------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/guardrails_api/clients/postgres_client.py b/guardrails_api/clients/postgres_client.py index 56e1501..2c4e0c4 100644 --- a/guardrails_api/clients/postgres_client.py +++ b/guardrails_api/clients/postgres_client.py @@ -63,24 +63,21 @@ def get_db(self): else: yield None + + def generate_lock_id(self, name: str) -> int: + import hashlib + return int(hashlib.sha256(name.encode()).hexdigest(), 16) % (2**63) + def initialize(self, app: FastAPI): pg_user, pg_password = self.get_pg_creds() pg_host = os.environ.get("PGHOST", "localhost") pg_port = os.environ.get("PGPORT", "5432") pg_database = os.environ.get("PGDATABASE", "postgres") - pg_endpoint = ( - pg_host - if pg_host.endswith( - f":{pg_port}" - ) # FIXME: This is a cheap check; maybe use a regex instead? - else f"{pg_host}:{pg_port}" - ) - + pg_endpoint = f"{pg_host}:{pg_port}" conf = f"postgresql://{pg_user}:{pg_password}@{pg_endpoint}/{pg_database}" - if os.environ.get("NODE_ENV") == "production": - conf = f"{conf}?sslmode=verify-ca&sslrootcert=global-bundle.pem" + conf += "?sslmode=verify-ca&sslrootcert=global-bundle.pem" engine = create_engine(conf) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -88,17 +85,26 @@ def initialize(self, app: FastAPI): self.app = app self.engine = engine self.SessionLocal = SessionLocal - # Create tables - from guardrails_api.models import GuardItem, GuardItemAudit # noqa - Base.metadata.create_all(bind=engine) + lock_id = self.generate_lock_id("guardrails-api") + + # Use advisory lock to ensure only one worker runs initialization + with engine.begin() as connection: + lock_acquired = connection.execute(text(f"SELECT pg_try_advisory_lock({lock_id});")).scalar() + if lock_acquired: + self.run_initialization(connection) + # Release the lock after initialization is complete + connection.execute(text("SELECT pg_advisory_unlock(12345);")) - # Execute custom SQL - with engine.connect() as connection: - connection.execute(text(INIT_EXTENSIONS)) - connection.execute(text(AUDIT_FUNCTION)) - connection.execute(text(AUDIT_TRIGGER)) - connection.commit() + def run_initialization(self, connection): + # Perform the actual initialization tasks + from guardrails_api.models import GuardItem, GuardItemAudit # noqa + Base.metadata.create_all(bind=self.engine) + + # Execute custom SQL extensions and triggers + connection.execute(text(INIT_EXTENSIONS)) + connection.execute(text(AUDIT_FUNCTION)) + connection.execute(text(AUDIT_TRIGGER)) # Define INIT_EXTENSIONS, AUDIT_FUNCTION, and AUDIT_TRIGGER here as they were in your original code @@ -143,3 +149,4 @@ def initialize(self, app: FastAPI): FOR EACH ROW EXECUTE PROCEDURE guard_audit_function(); """ + From 6e2a84daf11cf0469fe6921361dea16495bb29c2 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Fri, 1 Nov 2024 15:53:31 -0700 Subject: [PATCH 02/12] fix for http_error class interface --- guardrails_api/classes/http_error.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/guardrails_api/classes/http_error.py b/guardrails_api/classes/http_error.py index fca445b..c5ff00f 100644 --- a/guardrails_api/classes/http_error.py +++ b/guardrails_api/classes/http_error.py @@ -8,8 +8,10 @@ def __init__( context: str = None, ): self.status = status + self.status_code = status self.message = message self.cause = cause + self.detail = f"{message} :: {cause}" if cause is not None else message self.fields = fields self.context = context From b31e84500e262c01582a4954a9d31937017281cd Mon Sep 17 00:00:00 2001 From: Alejandro Date: Fri, 1 Nov 2024 15:53:57 -0700 Subject: [PATCH 03/12] support for uvicorn factory method creation --- guardrails_api/app.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 29e2d44..425238f 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -20,13 +20,16 @@ from starlette.middleware.base import BaseHTTPMiddleware +GR_ENV_FILE = os.environ.get("GR_ENV_FILE", None) +GR_CONFIG_FILE_PATH = os.environ.get("GR_CONFIG_FILE_PATH", None) +PORT = int(os.environ.get("PORT", 8000)) class RequestInfoMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): tracer = trace.get_tracer(__name__) # Get the current context and attach it to this task with tracer.start_as_current_span("request_info") as span: - client_ip = request.client.host + client_ip = request.client.host if request.client else None user_agent = request.headers.get("user-agent", "unknown") referrer = request.headers.get("referrer", "unknown") user_id = request.headers.get("x-user-id", "unknown") @@ -40,13 +43,15 @@ async def dispatch(self, request: Request, call_next): context.attach(baggage.set_baggage("organization", organization)) context.attach(baggage.set_baggage("app", app)) - span.set_attribute("client.ip", client_ip) span.set_attribute("http.user_agent", user_agent) span.set_attribute("http.referrer", referrer) span.set_attribute("user.id", user_id) span.set_attribute("organization", organization) span.set_attribute("app", app) + if client_ip: + span.set_attribute("client.ip", client_ip) + response = await call_next(request) return response @@ -70,9 +75,13 @@ def register_config(config: Optional[str] = None): config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) + return config_file_path +# Support for providing env vars as uvicorn does not support supplying args to create_app +# - Usage: uvicorn --factory 'guardrails_api.app:create_app' --host 0.0.0.0 --port $PORT --workers 2 --timeout-keep-alive 90 +# - Usage: gunicorn -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:$PORT --timeout=90 --workers=2 "guardrails_api.app:create_app(None, None, $PORT)" def create_app( - env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None + env: Optional[str] = GR_ENV_FILE, config: Optional[str] = GR_CONFIG_FILE_PATH, port: Optional[int] = PORT ): trace_server_start_if_enabled() # used to print user-facing messages during server startup @@ -89,12 +98,12 @@ def create_app( env_file_path = os.path.abspath(env) load_dotenv(env_file_path, override=True) - set_port = port or os.environ.get("PORT", 8000) + set_port = port or PORT host = os.environ.get("HOST", "http://localhost") self_endpoint = os.environ.get("SELF_ENDPOINT", f"{host}:{set_port}") os.environ["SELF_ENDPOINT"] = self_endpoint - register_config(config) + resolved_config_file_path = register_config(config) app = FastAPI(openapi_url="") @@ -159,6 +168,10 @@ async def value_error_handler(request: Request, exc: ValueError): ) console.print("") + console.print("Using the following configuration:") + console.print(f"- Guardrails Log Level: {guardrails_log_level}") + console.print(f"- Self Endpoint: {self_endpoint}") + console.print(f"- Config File Path: {resolved_config_file_path} [Provided: {config}]") console.print( Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white") ) @@ -170,4 +183,4 @@ async def value_error_handler(request: Request, exc: ValueError): import uvicorn app = create_app() - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", port=PORT) \ No newline at end of file From 3877ac7c4af14a5e3b5a05d44917d6aadb649a81 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Fri, 1 Nov 2024 15:54:07 -0700 Subject: [PATCH 04/12] updated requirements lock --- requirements-lock.txt | 138 +++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/requirements-lock.txt b/requirements-lock.txt index 1462843..3192c15 100644 --- a/requirements-lock.txt +++ b/requirements-lock.txt @@ -1,114 +1,114 @@ -aiohttp==3.9.5 +aiocache==0.12.3 +aiohappyeyeballs==2.4.3 +aiohttp==3.10.10 aiosignal==1.3.1 annotated-types==0.7.0 -anyio==4.4.0 +anyio==4.6.2.post1 arrow==1.3.0 -attrs==23.2.0 -blinker==1.8.2 -boto3==1.34.149 -botocore==1.34.149 -cachelib==0.9.0 -certifi==2024.7.4 -charset-normalizer==3.3.2 +asgiref==3.8.1 +async-timeout==4.0.3 +attrs==24.2.0 +boto3==1.35.54 +botocore==1.35.54 +certifi==2024.8.30 +charset-normalizer==3.4.0 click==8.1.7 colorama==0.4.6 coloredlogs==15.0.1 -coverage==7.6.0 Deprecated==1.2.14 diff-match-patch==20230430 distro==1.9.0 +exceptiongroup==1.2.2 Faker==25.9.2 -filelock==3.15.4 +fastapi==0.115.4 +filelock==3.16.1 fqdn==1.5.1 -frozenlist==1.4.1 -fsspec==2024.6.1 -googleapis-common-protos==1.63.2 +frozenlist==1.5.0 +fsspec==2024.10.0 +googleapis-common-protos==1.65.0 griffe==0.36.9 -grpcio==1.65.1 -guardrails-ai==0.5.9 -guardrails-api-client==0.3.12 +grpcio==1.67.1 +guardrails-ai==0.5.15 +guardrails-api-client==0.3.13 guardrails_hub_types==0.0.4 -gunicorn==22.0.0 h11==0.14.0 -httpcore==1.0.5 -httpx==0.27.0 -huggingface-hub==0.24.2 +httpcore==1.0.6 +httpx==0.27.2 +huggingface-hub==0.26.2 humanfriendly==10.0 -idna==3.7 -importlib_metadata==8.0.0 -iniconfig==2.0.0 +idna==3.10 +importlib_metadata==8.4.0 isoduration==20.11.0 -itsdangerous==2.2.0 Jinja2==3.1.4 +jiter==0.7.0 jmespath==1.0.1 joblib==1.4.2 jsonpatch==1.33 jsonpointer==3.0.0 jsonref==1.1.0 jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -langchain-core==0.2.23 -langsmith==0.1.93 -litellm==1.42.3 +jsonschema-specifications==2024.10.1 +langchain-core==0.3.15 +langsmith==0.1.139 +litellm==1.51.2 lxml==4.9.4 markdown-it-py==3.0.0 -MarkupSafe==2.1.5 +MarkupSafe==3.0.2 mdurl==0.1.2 -multidict==6.0.5 +multidict==6.1.0 nltk==3.8.1 -openai==1.37.1 -opentelemetry-api==1.26.0 -opentelemetry-exporter-otlp-proto-common==1.26.0 -opentelemetry-exporter-otlp-proto-grpc==1.26.0 -opentelemetry-exporter-otlp-proto-http==1.26.0 -opentelemetry-instrumentation==0.47b0 -opentelemetry-instrumentation-flask==0.47b0 -opentelemetry-instrumentation-wsgi==0.47b0 -opentelemetry-proto==1.26.0 -opentelemetry-sdk==1.26.0 -opentelemetry-semantic-conventions==0.47b0 -opentelemetry-util-http==0.47b0 -orjson==3.10.6 +openai==1.53.0 +opentelemetry-api==1.27.0 +opentelemetry-exporter-otlp-proto-common==1.27.0 +opentelemetry-exporter-otlp-proto-grpc==1.27.0 +opentelemetry-exporter-otlp-proto-http==1.27.0 +opentelemetry-instrumentation==0.48b0 +opentelemetry-instrumentation-asgi==0.48b0 +opentelemetry-instrumentation-fastapi==0.48b0 +opentelemetry-proto==1.27.0 +opentelemetry-sdk==1.27.0 +opentelemetry-semantic-conventions==0.48b0 +opentelemetry-util-http==0.48b0 +orjson==3.10.10 packaging==24.1 -pluggy==1.5.0 -protobuf==4.25.4 -psycopg2-binary==2.9.9 -pydantic==2.8.2 -pydantic_core==2.20.1 +propcache==0.2.0 +protobuf==4.25.5 +psycopg2-binary==2.9.10 +pydantic==2.9.2 +pydantic_core==2.23.4 pydash==7.0.7 Pygments==2.18.0 -PyJWT==2.8.0 -pytest==8.3.2 -pytest-mock==3.14.0 +PyJWT==2.9.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 -PyYAML==6.0.1 +PyYAML==6.0.2 referencing==0.35.1 regex==2023.12.25 requests==2.32.3 +requests-toolbelt==1.0.0 rfc3339-validator==0.1.4 rfc3987==1.3.8 -rich==13.7.1 -rpds-py==0.19.1 +rich==13.9.4 +rpds-py==0.20.1 rstr==3.2.2 -ruff==0.5.5 -s3transfer==0.10.2 -setuptools==71.1.0 +s3transfer==0.10.3 +semver==3.0.2 shellingham==1.5.4 six==1.16.0 sniffio==1.3.1 -SQLAlchemy==2.0.31 -tenacity==8.5.0 -tiktoken==0.7.0 -tokenizers==0.19.1 -tqdm==4.66.4 -typer==0.9.4 -types-python-dateutil==2.9.0.20240316 +SQLAlchemy==2.0.36 +starlette==0.41.2 +tenacity==9.0.0 +tiktoken==0.8.0 +tokenizers==0.20.1 +tqdm==4.66.6 +typer==0.12.5 +types-python-dateutil==2.9.0.20241003 typing_extensions==4.12.2 uri-template==1.3.0 urllib3==2.0.7 -webcolors==24.6.0 -Werkzeug==3.0.3 +uvicorn==0.32.0 +webcolors==24.8.0 wrapt==1.16.0 -yarl==1.9.4 -zipp==3.19.2 +yarl==1.17.1 +zipp==3.20.2 From b21d4b5681483830335c39b894ffa15c731e5694 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Fri, 1 Nov 2024 15:54:18 -0700 Subject: [PATCH 05/12] fixes for dockerfile --- Dockerfile | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/Dockerfile b/Dockerfile index d01a5ee..57ede76 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,23 +1,32 @@ -FROM public.ecr.aws/docker/library/python:3.12-slim +FROM python:3.12-slim # Accept a build arg for the Guardrails token # We'll add this to the config using the configure command below -# ARG GUARDRAILS_TOKEN +ARG GUARDRAILS_TOKEN # Create app directory WORKDIR /app +# Enable venv +ENV PATH="/opt/venv/bin:$PATH" + +# Set the directory for nltk data +ENV NLTK_DATA=/opt/nltk_data + +# Set env vars for server +ENV GR_CONFIG_FILE_PATH="sample-config.py" +ENV GR_ENV_FILE=".env" +ENV PORT=8000 + # print the version just to verify RUN python3 --version # start the virtual environment RUN python3 -m venv /opt/venv -# Enable venv -ENV PATH="/opt/venv/bin:$PATH" - -# Install some utilities; you may not need all of these -RUN apt-get update -RUN apt-get install -y git +# Install some utilities +RUN apt-get update && \ + apt-get install -y git pkg-config curl gcc g++ && \ + rm -rf /var/lib/apt/lists/* # Copy the requirements file COPY requirements*.txt . @@ -26,26 +35,27 @@ COPY requirements*.txt . # If you use Poetry this step might be different RUN pip install -r requirements-lock.txt -# Set the directory for nltk data -ENV NLTK_DATA=/opt/nltk_data - # Download punkt data RUN python -m nltk.downloader -d /opt/nltk_data punkt # Run the Guardrails configure command to create a .guardrailsrc file -# RUN guardrails configure --enable-metrics --enable-remote-inferencing --token $GUARDRAILS_TOKEN +RUN guardrails configure --enable-metrics --enable-remote-inferencing --token $GUARDRAILS_TOKEN # Install any validators from the hub you want -RUN guardrails hub install hub://guardrails/valid_length +RUN guardrails hub install hub://guardrails/detect_pii --no-install-local-models && \ + guardrails hub install hub://guardrails/competitor_check --no-install-local-models + +# Fetch AWS RDS cert +RUN curl https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem -o ./global-bundle.pem # Copy the rest over # We use a .dockerignore to keep unwanted files exluded COPY . . -EXPOSE 8000 +EXPOSE ${PORT} # This is our start command; yours might be different. # The guardrails-api is a standard FastAPI application. # You can use whatever production server you want that support FastAPI. # Here we use gunicorn -CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --workers=2 'guardrails_api.app:create_app(".env", "sample-config.py")' \ No newline at end of file +CMD uvicorn --factory 'guardrails_api.app:create_app' --host 0.0.0.0 --port ${PORT} --timeout-keep-alive=90 --workers=4 \ No newline at end of file From a328b7c3f1a28e7768b684edfb6fab84766456e4 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Fri, 1 Nov 2024 15:54:45 -0700 Subject: [PATCH 06/12] updates to docker compose file for testing with pg locally --- compose.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/compose.yml b/compose.yml index b445876..78db072 100644 --- a/compose.yml +++ b/compose.yml @@ -9,7 +9,7 @@ services: volumes: - ./postgres:/data/postgres ports: - - "5432:5432" + - "5932:5432" restart: always pgadmin: profiles: ["all", "db", "infra"] @@ -30,23 +30,23 @@ services: - postgres guardrails-api: profiles: ["all", "api"] - image: guardrails-api:latest build: context: . dockerfile: Dockerfile args: PORT: "8000" + GUARDRAILS_TOKEN: ${GUARDRAILS_TOKEN:-changeme} ports: - "8000:8000" environment: # APP_ENVIRONMENT: local # AWS_PROFILE: dev # AWS_DEFAULT_REGION: us-east-1 - # PGPORT: 5432 - # PGDATABASE: postgres - # PGHOST: postgres - # PGUSER: ${PGUSER:-postgres} - # PGPASSWORD: ${PGPASSWORD:-changeme} + PGPORT: 5432 + PGDATABASE: postgres + PGHOST: postgres + PGUSER: ${PGUSER:-postgres} + PGPASSWORD: ${PGPASSWORD:-changeme} NLTK_DATA: /opt/nltk_data # OTEL_PYTHON_TRACER_PROVIDER: sdk_tracer_provider # OTEL_SERVICE_NAME: guardrails-api @@ -68,8 +68,8 @@ services: # OTEL_EXPORTER_OTLP_METRICS_ENDPOINT: http://otel-collector:4317 # OTEL_EXPORTER_OTLP_LOGS_ENDPOINT: http://otel-collector:4317 # OTEL_PYTHON_LOG_FORMAT: "%(msg)s [span_id=%(span_id)s]" - # depends_on: - # - postgres + depends_on: + - postgres # - otel-collector opensearch-node1: profiles: ["all", "otel", "infra"] From a7aed377f8ba9e87fa1a36f45081ebcfd5b89bfa Mon Sep 17 00:00:00 2001 From: Alejandro Date: Fri, 1 Nov 2024 15:56:49 -0700 Subject: [PATCH 07/12] fix lock reference --- guardrails_api/clients/postgres_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guardrails_api/clients/postgres_client.py b/guardrails_api/clients/postgres_client.py index 2c4e0c4..dfae03f 100644 --- a/guardrails_api/clients/postgres_client.py +++ b/guardrails_api/clients/postgres_client.py @@ -94,7 +94,7 @@ def initialize(self, app: FastAPI): if lock_acquired: self.run_initialization(connection) # Release the lock after initialization is complete - connection.execute(text("SELECT pg_advisory_unlock(12345);")) + connection.execute(text(f"SELECT pg_advisory_unlock({lock_id});")) def run_initialization(self, connection): # Perform the actual initialization tasks From 01e8e0dfb9f684c8003f1a0d55648fe4fc829005 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Mon, 4 Nov 2024 12:23:10 -0800 Subject: [PATCH 08/12] ported over pg guard client --- guardrails_api/clients/pg_guard_client.py | 151 +++++++++++----------- 1 file changed, 79 insertions(+), 72 deletions(-) diff --git a/guardrails_api/clients/pg_guard_client.py b/guardrails_api/clients/pg_guard_client.py index 226232a..ab821e1 100644 --- a/guardrails_api/clients/pg_guard_client.py +++ b/guardrails_api/clients/pg_guard_client.py @@ -1,4 +1,5 @@ -from typing import List +from contextlib import contextmanager +from typing import List, Optional from guardrails_api_client import Guard as GuardStruct from guardrails_api.classes.http_error import HttpError from guardrails_api.clients.guard_client import GuardClient @@ -18,48 +19,20 @@ def __init__(self): self.initialized = True self.pgClient = PostgresClient() - def get_db(self): # generator for local sessions + @contextmanager + def get_db_context(self): db = self.pgClient.SessionLocal() try: yield db finally: db.close() - def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct: - db = next(self.get_db()) - latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first() - audit_item = None - if as_of_date is not None: - audit_item = ( - db.query(GuardItemAudit) - .filter_by(name=guard_name) - .filter(GuardItemAudit.replaced_on > as_of_date) - .order_by(GuardItemAudit.replaced_on.asc()) - .first() - ) - guard_item = audit_item if audit_item is not None else latest_guard_item - if guard_item is None: - raise HttpError( - status=404, - message="NotFound", - cause="A Guard with the name {guard_name} does not exist!".format( - guard_name=guard_name - ), - ) - return from_guard_item(guard_item) - - def get_guard_item(self, guard_name: str) -> GuardItem: - db = next(self.get_db()) - return db.query(GuardItem).filter_by(name=guard_name).first() + def util_get_guard_item(self, guard_name: str, db) -> GuardItem: + item = db.query(GuardItem).filter_by(name=guard_name).first() + return item - def get_guards(self) -> List[GuardStruct]: - db = next(self.get_db()) - guard_items = db.query(GuardItem).all() - return [from_guard_item(gi) for gi in guard_items] - - def create_guard(self, guard: GuardStruct) -> GuardStruct: - db = next(self.get_db()) + def util_create_guard(self, guard: GuardStruct, db) -> GuardStruct: guard_item = GuardItem( name=guard.name, railspec=guard.to_dict(), @@ -69,48 +42,82 @@ def create_guard(self, guard: GuardStruct) -> GuardStruct: db.add(guard_item) db.commit() return from_guard_item(guard_item) + + # Below are used directly by Controllers and start db sessions - def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: - db = next(self.get_db()) - guard_item = self.get_guard_item(guard_name) - if guard_item is None: - raise HttpError( - status=404, - message="NotFound", - cause="A Guard with the name {guard_name} does not exist!".format( - guard_name=guard_name - ), - ) - # guard_item.num_reasks = guard.num_reasks - guard_item.railspec = guard.to_dict() - guard_item.description = guard.description - db.commit() - return from_guard_item(guard_item) + def get_guard(self, guard_name: str, as_of_date: Optional[str] = None) -> GuardStruct: + with self.get_db_context() as db: + latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first() + audit_item = None + if as_of_date is not None: + audit_item = ( + db.query(GuardItemAudit) + .filter_by(name=guard_name) + .filter(GuardItemAudit.replaced_on > as_of_date) + .order_by(GuardItemAudit.replaced_on.asc()) + .first() + ) + guard_item = audit_item if audit_item is not None else latest_guard_item + if guard_item is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + return from_guard_item(guard_item) - def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: - db = next(self.get_db()) - guard_item = self.get_guard_item(guard_name) - if guard_item is not None: + def get_guards(self) -> List[GuardStruct]: + with self.get_db_context() as db: + guard_items = db.query(GuardItem).all() + return [from_guard_item(gi) for gi in guard_items] + + def create_guard(self, guard: GuardStruct) -> GuardStruct: + with self.get_db_context() as db: + return self.util_create_guard(guard, db) + + def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + with self.get_db_context() as db: + guard_item = self.util_get_guard_item(guard_name, db) + if guard_item is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + # guard_item.num_reasks = guard.num_reasks guard_item.railspec = guard.to_dict() guard_item.description = guard.description - # guard_item.num_reasks = guard.num_reasks db.commit() return from_guard_item(guard_item) - else: - return self.create_guard(guard) + + def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct: + with self.get_db_context() as db: + guard_item = self.util_get_guard_item(guard_name, db) + if guard_item is not None: + guard_item.railspec = guard.to_dict() + guard_item.description = guard.description + # guard_item.num_reasks = guard.num_reasks + db.commit() + return from_guard_item(guard_item) + else: + return self.util_create_guard(guard, db) def delete_guard(self, guard_name: str) -> GuardStruct: - db = next(self.get_db()) - guard_item = self.get_guard_item(guard_name) - if guard_item is None: - raise HttpError( - status=404, - message="NotFound", - cause="A Guard with the name {guard_name} does not exist!".format( - guard_name=guard_name - ), - ) - db.delete(guard_item) - db.commit() - guard = from_guard_item(guard_item) - return guard + with self.get_db_context() as db: + guard_item = self.util_get_guard_item(guard_name, db) + if guard_item is None: + raise HttpError( + status=404, + message="NotFound", + cause="A Guard with the name {guard_name} does not exist!".format( + guard_name=guard_name + ), + ) + db.delete(guard_item) + db.commit() + guard = from_guard_item(guard_item) + return guard From 7e03f543646f2d2311b893610baf71787f261f41 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Tue, 12 Nov 2024 23:36:25 -0800 Subject: [PATCH 09/12] Fix tests due to update pgguard client interface --- tests/clients/test_pg_guard_client.py | 44 ++++++++++++++------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/tests/clients/test_pg_guard_client.py b/tests/clients/test_pg_guard_client.py index add0048..a40a328 100644 --- a/tests/clients/test_pg_guard_client.py +++ b/tests/clients/test_pg_guard_client.py @@ -178,7 +178,7 @@ def test_get_guard_item(mocker): guard_client = PGGuardClient() - result = guard_client.get_guard_item("guard") + result = guard_client.util_get_guard_item("guard", mock_session) query_spy.assert_called_once_with(GuardItem) filter_by_spy.assert_called_once_with(name="guard") @@ -286,7 +286,7 @@ def test_raises_not_found(self, mocker): return_value=mock_pg_client, ) mock_get_guard_item = mocker.patch( - "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" + "guardrails_api.clients.pg_guard_client.PGGuardClient.util_get_guard_item" ) mock_get_guard_item.return_value = None @@ -330,10 +330,10 @@ def test_updates_guard_item(self, mocker): "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - mock_get_guard_item = mocker.patch( - "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" + mock_util_get_guard_item = mocker.patch( + "guardrails_api.clients.pg_guard_client.PGGuardClient.util_get_guard_item" ) - mock_get_guard_item.return_value = old_guard_item + mock_util_get_guard_item.return_value = old_guard_item commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( @@ -347,7 +347,7 @@ def test_updates_guard_item(self, mocker): result = guard_client.update_guard("mock-guard", updated_guard) - mock_get_guard_item.assert_called_once_with("mock-guard") + mock_util_get_guard_item.assert_called_once_with("mock-guard", mock_session) assert commit_spy.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard_item) @@ -364,23 +364,24 @@ def test_guard_doesnt_exist_yet(self, mocker): new_guard = MockGuardStruct() mock_pg_client = MockPostgresClient() mock_pg_client.SessionLocal = MagicMock(return_value=MagicMock()) + mock_session = mock_pg_client.SessionLocal() mocker.patch( "guardrails_api.clients.pg_guard_client.PostgresClient", return_value=mock_pg_client, ) - mock_get_guard_item = mocker.patch( - "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" + mock_util_get_guard_item = mocker.patch( + "guardrails_api.clients.pg_guard_client.PGGuardClient.util_get_guard_item" ) - mock_get_guard_item.return_value = None + mock_util_get_guard_item.return_value = None commit_spy = mocker.spy(mock_pg_client.db.session, "commit") mock_from_guard_item = mocker.patch( "guardrails_api.clients.pg_guard_client.from_guard_item" ) - mock_create_guard = mocker.patch( - "guardrails_api.clients.pg_guard_client.PGGuardClient.create_guard" + mock_util_create_guard = mocker.patch( + "guardrails_api.clients.pg_guard_client.PGGuardClient.util_create_guard" ) - mock_create_guard.return_value = new_guard + mock_util_create_guard.return_value = new_guard from guardrails_api.clients.pg_guard_client import PGGuardClient @@ -388,10 +389,11 @@ def test_guard_doesnt_exist_yet(self, mocker): result = guard_client.upsert_guard("mock-guard", input_guard) - mock_get_guard_item.assert_called_once_with("mock-guard") + mock_util_get_guard_item.assert_called_once_with("mock-guard", mock_session) assert commit_spy.call_count == 0 assert mock_from_guard_item.call_count == 0 - mock_create_guard.assert_called_once_with(input_guard) + mock_util_create_guard.assert_called_once_with(input_guard, mock_session) + assert result == new_guard @@ -415,10 +417,10 @@ def test_guard_already_exists(self, mocker): mock_session = mock_pg_client.SessionLocal() - mock_get_guard_item = mocker.patch( - "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" + mock_util_get_guard_item = mocker.patch( + "guardrails_api.clients.pg_guard_client.PGGuardClient.util_get_guard_item" ) - mock_get_guard_item.return_value = old_guard_item + mock_util_get_guard_item.return_value = old_guard_item commit_spy = mocker.spy(mock_session, "commit") mock_from_guard_item = mocker.patch( @@ -432,7 +434,7 @@ def test_guard_already_exists(self, mocker): result = guard_client.upsert_guard("mock-guard", updated_guard) - mock_get_guard_item.assert_called_once_with("mock-guard") + mock_util_get_guard_item.assert_called_once_with("mock-guard", mock_session) assert commit_spy.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard_item) @@ -457,7 +459,7 @@ def test_raises_not_found(self, mocker): mock_session = mock_pg_client.SessionLocal() mock_get_guard_item = mocker.patch( - "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" + "guardrails_api.clients.pg_guard_client.PGGuardClient.util_get_guard_item" ) mock_get_guard_item.return_value = None @@ -496,7 +498,7 @@ def test_deletes_guard_item(self, mocker): mock_session = mock_pg_client.SessionLocal() mock_get_guard_item = mocker.patch( - "guardrails_api.clients.pg_guard_client.PGGuardClient.get_guard_item" + "guardrails_api.clients.pg_guard_client.PGGuardClient.util_get_guard_item" ) mock_get_guard_item.return_value = old_guard @@ -515,7 +517,7 @@ def test_deletes_guard_item(self, mocker): result = guard_client.delete_guard("mock-guard") - mock_get_guard_item.assert_called_once_with("mock-guard") + mock_get_guard_item.assert_called_once_with("mock-guard", mock_session) assert mock_session.delete.call_count == 1 assert mock_session.commit.call_count == 1 mock_from_guard_item.assert_called_once_with(old_guard) From d901ba8a50bbf554a1627724b0e93b6105fac29d Mon Sep 17 00:00:00 2001 From: Alejandro Date: Tue, 12 Nov 2024 23:37:50 -0800 Subject: [PATCH 10/12] revert changes to postgres_client init --- guardrails_api/clients/postgres_client.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/guardrails_api/clients/postgres_client.py b/guardrails_api/clients/postgres_client.py index dfae03f..5a95506 100644 --- a/guardrails_api/clients/postgres_client.py +++ b/guardrails_api/clients/postgres_client.py @@ -74,10 +74,18 @@ def initialize(self, app: FastAPI): pg_port = os.environ.get("PGPORT", "5432") pg_database = os.environ.get("PGDATABASE", "postgres") - pg_endpoint = f"{pg_host}:{pg_port}" + pg_endpoint = ( + pg_host + if pg_host.endswith( + f":{pg_port}" + ) # FIXME: This is a cheap check; maybe use a regex instead? + else f"{pg_host}:{pg_port}" + ) + conf = f"postgresql://{pg_user}:{pg_password}@{pg_endpoint}/{pg_database}" + if os.environ.get("NODE_ENV") == "production": - conf += "?sslmode=verify-ca&sslrootcert=global-bundle.pem" + conf = f"{conf}?sslmode=verify-ca&sslrootcert=global-bundle.pem" engine = create_engine(conf) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) From 19f6813e982d2cdb5b420a3fcebbaae3cee704a3 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 13 Nov 2024 12:16:45 -0800 Subject: [PATCH 11/12] Add comment --- guardrails_api/clients/pg_guard_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/guardrails_api/clients/pg_guard_client.py b/guardrails_api/clients/pg_guard_client.py index ab821e1..68fdec9 100644 --- a/guardrails_api/clients/pg_guard_client.py +++ b/guardrails_api/clients/pg_guard_client.py @@ -27,6 +27,9 @@ def get_db_context(self): finally: db.close() + + # These are only internal utilities and do not start db sessions + def util_get_guard_item(self, guard_name: str, db) -> GuardItem: item = db.query(GuardItem).filter_by(name=guard_name).first() return item From 20c664a8aab2d4940d373169a23a7eb773e11118 Mon Sep 17 00:00:00 2001 From: Alejandro Esquivel Date: Wed, 13 Nov 2024 12:34:54 -0800 Subject: [PATCH 12/12] Update guardrails_api/app.py --- guardrails_api/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guardrails_api/app.py b/guardrails_api/app.py index 425238f..f2ffbad 100644 --- a/guardrails_api/app.py +++ b/guardrails_api/app.py @@ -78,7 +78,7 @@ def register_config(config: Optional[str] = None): return config_file_path # Support for providing env vars as uvicorn does not support supplying args to create_app -# - Usage: uvicorn --factory 'guardrails_api.app:create_app' --host 0.0.0.0 --port $PORT --workers 2 --timeout-keep-alive 90 +# - Usage: GR_CONFIG_FILE_PATH=config.py GR_ENV_FILE=.env PORT=8080 uvicorn --factory 'guardrails_api.app:create_app' --host 0.0.0.0 --port $PORT --workers 2 --timeout-keep-alive 90 # - Usage: gunicorn -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:$PORT --timeout=90 --workers=2 "guardrails_api.app:create_app(None, None, $PORT)" def create_app( env: Optional[str] = GR_ENV_FILE, config: Optional[str] = GR_CONFIG_FILE_PATH, port: Optional[int] = PORT