diff --git a/guardrails/guard.py b/guardrails/guard.py index 9b45c4ae5..658758ce0 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -238,9 +238,10 @@ def __call__( msg_history: Optional[List[Dict]] = None, metadata: Optional[Dict] = None, full_schema_reask: Optional[bool] = None, + stream: Optional[bool] = False, *args, **kwargs, - ) -> ValidationOutcome[OT]: + ) -> Union[ValidationOutcome[OT], Iterable[str]]: ... @overload @@ -272,7 +273,7 @@ def __call__( *args, **kwargs, ) -> Union[ - Union[ValidationOutcome[OT], Iterable], Awaitable[ValidationOutcome[OT]] + Union[ValidationOutcome[OT], Iterable[str]], Awaitable[ValidationOutcome[OT]] ]: """Call the LLM and validate the output. Pass an async LLM API to return a coroutine.