diff --git a/guardrails_api/api/guards.py b/guardrails_api/api/guards.py index afdd206..a9f70d6 100644 --- a/guardrails_api/api/guards.py +++ b/guardrails_api/api/guards.py @@ -114,7 +114,7 @@ async def openai_v1_chat_completions(guard_name: str, request: Request): ) guard = ( - Guard.from_dict(guard_struct.to_dict()) + AsyncGuard.from_dict(guard_struct.to_dict()) if not isinstance(guard_struct, Guard) else guard_struct ) @@ -125,7 +125,7 @@ async def openai_v1_chat_completions(guard_name: str, request: Request): ) if not stream: - validation_outcome: ValidationOutcome = guard(num_reasks=0, **payload) + validation_outcome: ValidationOutcome = await guard(num_reasks=0, **payload) llm_response = guard.history.last.iterations.last.outputs.llm_response_info result = outcome_to_chat_completion( validation_outcome=validation_outcome, @@ -136,8 +136,8 @@ async def openai_v1_chat_completions(guard_name: str, request: Request): else: async def openai_streamer(): - guard_stream = guard(num_reasks=0, **payload) - for result in guard_stream: + guard_stream = await guard(num_reasks=0, **payload) + async for result in guard_stream: chunk = json.dumps( outcome_to_stream_response(validation_outcome=result) ) @@ -253,22 +253,18 @@ async def validate_streamer(guard_iter): validate_streamer(guard_streamer()), media_type="application/json" ) else: - if inspect.iscoroutinefunction(guard): - result: ValidationOutcome = await guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - *args, - **payload, - ) + execution = guard( + llm_api=llm_api, + prompt_params=prompt_params, + num_reasks=num_reasks, + *args, + **payload, + ) + + if inspect.iscoroutine(execution): + result: ValidationOutcome = await execution else: - result: ValidationOutcome = guard( - llm_api=llm_api, - prompt_params=prompt_params, - num_reasks=num_reasks, - *args, - **payload, - ) + result: ValidationOutcome = execution serialized_history = [call.to_dict() for call in guard.history] cache_key = f"{guard.name}-{result.call_id}" diff --git a/pyproject.toml b/pyproject.toml index 2520656..a46ef14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" keywords = ["Guardrails", "Guardrails AI", "Guardrails API", "Guardrails API"] requires-python = ">= 3.8.1" dependencies = [ - "guardrails-ai>=0.5.10", + "guardrails-ai>=0.5.12", "jsonschema>=4.22.0,<5", "referencing>=0.35.1,<1", "boto3>=1.34.115,<2", diff --git a/tests/api/test_guards.py b/tests/api/test_guards.py index 453a976..416baf0 100644 --- a/tests/api/test_guards.py +++ b/tests/api/test_guards.py @@ -14,6 +14,9 @@ from tests.mocks.mock_guard_client import MockGuardStruct from guardrails_api.api.guards import router as guards_router + +import asyncio + # TODO: Should we mock this somehow? # Right now it's just empty, but it technically does a file read register_config() @@ -344,7 +347,9 @@ def test_openai_v1_chat_completions__call(mocker): ) mock___call__ = mocker.patch.object(MockGuardStruct, "__call__") - mock___call__.return_value = mock_outcome + future = asyncio.Future() + future.set_result(mock_outcome) + mock___call__.return_value = future mock_from_dict = mocker.patch("guardrails_api.api.guards.Guard.from_dict") mock_from_dict.return_value = mock_guard