Skip to content

Commit

Permalink
Merge pull request #84 from guardrails-ai/forward-compatibility
Browse files Browse the repository at this point in the history
Forward compatibility
  • Loading branch information
CalebCourier authored Nov 12, 2024
2 parents 420fa37 + 4e36257 commit 8498640
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 42 deletions.
2 changes: 1 addition & 1 deletion guardrails_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.4"
__version__ = "0.0.5"
21 changes: 14 additions & 7 deletions guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from guardrails_api.clients.postgres_client import postgres_is_enabled
from guardrails_api.otel import otel_is_disabled, initialize
from guardrails_api.utils.trace_server_start_if_enabled import trace_server_start_if_enabled
from guardrails_api.utils.trace_server_start_if_enabled import (
trace_server_start_if_enabled,
)
from guardrails_api.clients.cache_client import CacheClient
from rich.console import Console
from rich.rule import Rule
Expand Down Expand Up @@ -84,7 +86,7 @@ def create_app(

@app.before_request
def basic_cors():
if request.method.lower() == 'options':
if request.method.lower() == "options":
return Response()

app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1)
Expand Down Expand Up @@ -112,20 +114,25 @@ def basic_cors():
app.register_blueprint(root_bp)
app.register_blueprint(guards_bp)

console.print(f"\n:rocket: Guardrails API is available at {self_endpoint}")
console.print(
f"\n:rocket: Guardrails API is available at {self_endpoint}"
f":book: Visit {self_endpoint}/docs to see available API endpoints.\n"
)
console.print(f":book: Visit {self_endpoint}/docs to see available API endpoints.\n")

console.print(":green_circle: Active guards and OpenAI compatible endpoints:")

with app.app_context():
from guardrails_api.blueprints.guards import guard_client

for g in guard_client.get_guards():
g = g.to_dict()
console.print(f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1")
console.print(
f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1"
)

console.print("")
console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white"))
console.print(
Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")
)

return app
return app
23 changes: 16 additions & 7 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from guardrails_api.clients.postgres_client import postgres_is_enabled
from guardrails_api.utils.handle_error import handle_error
from guardrails_api.utils.get_llm_callable import get_llm_callable
from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response
from guardrails_api.utils.openai import (
outcome_to_chat_completion,
outcome_to_stream_response,
)

guards_bp = Blueprint("guards", __name__, url_prefix="/guards")

Expand Down Expand Up @@ -272,7 +275,6 @@ def validate(guard_name: str):
# ) as validate_span:
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)


# validate_span.set_attribute("guardName", decoded_guard_name)
if llm_api is not None:
llm_api = get_llm_callable(llm_api)
Expand All @@ -295,7 +297,7 @@ def validate(guard_name: str):
else:
guard: Guard = Guard.from_dict(guard_struct.to_dict())
elif is_async:
guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict())
guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict())

if llm_api is None and num_reasks and num_reasks > 1:
raise HttpError(
Expand All @@ -322,6 +324,7 @@ def validate(guard_name: str):
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
Expand Down Expand Up @@ -452,24 +455,30 @@ async def async_validate_streamer(guard_iter):
cache_key = f"{guard.name}-{final_validation_output.call_id}"
cache_client.set(cache_key, serialized_history, 300)
yield f"{final_output_json}\n"

# apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async
def iter_over_async(ait, loop):
ait = ait.__aiter__()

async def get_next():
try:
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
except StopAsyncIteration:
return True, None

while True:
done, obj = loop.run_until_complete(get_next())
if done:
if done:
break
yield obj

if is_async:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop)
iter = iter_over_async(
async_validate_streamer(async_guard_streamer()), loop
)
else:
iter = validate_streamer(guard_streamer())
return Response(
Expand Down
1 change: 1 addition & 0 deletions guardrails_api/cli/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from guardrails_api.app import create_app
from guardrails_api.utils.configuration import valid_configuration


@cli.command("start")
def start(
env: Optional[str] = typer.Option(
Expand Down
27 changes: 19 additions & 8 deletions guardrails_api/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@
from typing import Optional
import os

def valid_configuration(config: Optional[str]=""):

def valid_configuration(config: Optional[str] = ""):
default_config_file = os.path.join(os.getcwd(), "./config.py")

default_config_file_path = os.path.abspath(default_config_file)
# If config.py is not present and
# If config.py is not present and
# if a config filepath is not passed and
# if postgres is not there (i.e. we’re using in-mem db)
# if postgres is not there (i.e. we’re using in-mem db)
# then raise ConfigurationError
has_default_config_file = os.path.isfile(default_config_file_path)

has_config_file = (config != "" and config is not None) and os.path.isfile(os.path.abspath(config))
if not has_default_config_file and not has_config_file and not postgres_is_enabled():
raise ConfigurationError("Can not start. Configuration not provided and default"
" configuration not found and postgres is not enabled.")
has_config_file = (config != "" and config is not None) and os.path.isfile(
os.path.abspath(config)
)

if (
not has_default_config_file
and not has_config_file
and not postgres_is_enabled()
):
raise ConfigurationError(
"Can not start. Configuration not provided and default"
" configuration not found and postgres is not enabled."
)
return True


class ConfigurationError(Exception):
pass
pass
27 changes: 21 additions & 6 deletions guardrails_api/utils/handle_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,35 @@ def decorator(*args, **kwargs):
return fn(*args, **kwargs)
except ValidationError as validation_error:
logger.error(validation_error)
traceback.print_exception(type(validation_error), validation_error, validation_error.__traceback__)
return str(validation_error), 400
traceback.print_exception(
type(validation_error), validation_error, validation_error.__traceback__
)
resp_body = {"status_code": 400, "detail": str(validation_error)}
return resp_body, 400
except HttpError as http_error:
logger.error(http_error)
traceback.print_exception(type(http_error), http_error, http_error.__traceback__)
return http_error.to_dict(), http_error.status
traceback.print_exception(
type(http_error), http_error, http_error.__traceback__
)
resp_body = http_error.to_dict()
resp_body["status_code"] = http_error.status
resp_body["detail"] = http_error.message
return resp_body, http_error.status
except HTTPException as http_exception:
logger.error(http_exception)
traceback.print_exception(http_exception)
http_error = HttpError(http_exception.code, http_exception.description)
return http_error.to_dict(), http_error.status
resp_body = http_error.to_dict()
resp_body["status_code"] = http_error.status
resp_body["detail"] = http_error.message

return resp_body, http_error.status
except Exception as e:
logger.error(e)
traceback.print_exception(e)
return HttpError(500, "Internal Server Error").to_dict(), 500
resp_body = HttpError(500, "Internal Server Error").to_dict()
resp_body["status_code"] = 500
resp_body["detail"] = "Internal Server Error"
return resp_body, 500

return decorator
2 changes: 1 addition & 1 deletion guardrails_api/utils/has_internet_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def has_internet_connection() -> bool:
res.raise_for_status()
return True
except requests.ConnectionError:
return False
return False
1 change: 1 addition & 0 deletions guardrails_api/utils/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from guardrails.classes import ValidationOutcome


def outcome_to_stream_response(validation_outcome: ValidationOutcome):
stream_chunk_template = {
"choices": [
Expand Down
3 changes: 2 additions & 1 deletion guardrails_api/utils/trace_server_start_if_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def trace_server_start_if_enabled():
config = Credentials.from_rc_file()
if config.enable_metrics is True and has_internet_connection():
from guardrails.utils.hub_telemetry_utils import HubTelemetry

HubTelemetry().create_new_span(
"guardrails-api/start",
[
Expand All @@ -21,4 +22,4 @@ def trace_server_start_if_enabled():
],
True,
False,
)
)
23 changes: 15 additions & 8 deletions tests/blueprints/test_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def test_validate__call(mocker):

del os.environ["PGHOST"]


def test_validate__call_throws_validation_error(mocker):
os.environ["PGHOST"] = "localhost"

Expand Down Expand Up @@ -610,19 +611,24 @@ def test_validate__call_throws_validation_error(mocker):
prompt="Hello world!",
)

assert response == ('Test guard validation error', 400)
assert response == (
{"status_code": 400, "detail": "Test guard validation error"},
400,
)

del os.environ["PGHOST"]


def test_openai_v1_chat_completions__raises_404(mocker):
from guardrails_api.blueprints.guards import openai_v1_chat_completions

os.environ["PGHOST"] = "localhost"
mock_guard = None

mock_request = MockRequest(
"POST",
json={
"messages": [{"role":"user", "content":"Hello world!"}],
"messages": [{"role": "user", "content": "Hello world!"}],
},
headers={"x-openai-api-key": "mock-key"},
)
Expand All @@ -637,15 +643,16 @@ def test_openai_v1_chat_completions__raises_404(mocker):

response = openai_v1_chat_completions("My%20Guard's%20Name")
assert response[1] == 404
assert response[0]["message"] == 'NotFound'

assert response[0]["message"] == "NotFound"

mock_get_guard.assert_called_once_with("My Guard's Name")

del os.environ["PGHOST"]


def test_openai_v1_chat_completions__call(mocker):
from guardrails_api.blueprints.guards import openai_v1_chat_completions

os.environ["PGHOST"] = "localhost"
mock_guard = MockGuardStruct()
mock_outcome = ValidationOutcome(
Expand All @@ -664,7 +671,7 @@ def test_openai_v1_chat_completions__call(mocker):
mock_request = MockRequest(
"POST",
json={
"messages": [{"role":"user", "content":"Hello world!"}],
"messages": [{"role": "user", "content": "Hello world!"}],
},
headers={"x-openai-api-key": "mock-key"},
)
Expand All @@ -687,7 +694,7 @@ def test_openai_v1_chat_completions__call(mocker):
)
mock_status.return_value = "fail"
mock_call = Call()
mock_call.iterations= Stack(Iteration('some-id', 1))
mock_call.iterations = Stack(Iteration("some-id", 1))
mock_guard.history = Stack(mock_call)

response = openai_v1_chat_completions("My%20Guard's%20Name")
Expand All @@ -698,7 +705,7 @@ def test_openai_v1_chat_completions__call(mocker):

mock___call__.assert_called_once_with(
num_reasks=0,
messages=[{"role":"user", "content":"Hello world!"}],
messages=[{"role": "user", "content": "Hello world!"}],
)

assert response == {
Expand All @@ -716,4 +723,4 @@ def test_openai_v1_chat_completions__call(mocker):
},
}

del os.environ["PGHOST"]
del os.environ["PGHOST"]
2 changes: 2 additions & 0 deletions tests/cli/test_start.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import MagicMock
import os


def test_start(mocker):
mocker.patch("guardrails_api.cli.start.cli")

Expand All @@ -10,6 +11,7 @@ def test_start(mocker):
)

from guardrails_api.cli.start import start

# pg enabled
os.environ["PGHOST"] = "localhost"
start("env", "config", 8000)
Expand Down
1 change: 1 addition & 0 deletions tests/mocks/mock_guard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import ConfigDict
from guardrails.classes.generic import Stack


class MockGuardStruct(GuardStruct):
# Pydantic Config
model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down
7 changes: 4 additions & 3 deletions tests/utils/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import pytest
from guardrails_api.utils.configuration import valid_configuration, ConfigurationError


def test_valid_configuration(mocker):
with pytest.raises(ConfigurationError):
valid_configuration()

# pg enabled
os.environ["PGHOST"] = "localhost"
valid_configuration("config.py")
os.environ.pop("PGHOST")

# custom config
mock_isfile = mocker.patch("os.path.isfile")
mock_isfile.side_effect = [False, True]
Expand All @@ -20,7 +21,7 @@ def test_valid_configuration(mocker):
mock_isfile.side_effect = [False, False]
with pytest.raises(ConfigurationError):
valid_configuration("")

# default config
mock_isfile = mocker.patch("os.path.isfile")
mock_isfile.side_effect = [True, False]
Expand Down

0 comments on commit 8498640

Please sign in to comment.