Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input validation #483

Closed
wants to merge 14 commits into from
Prev Previous commit
Next Next commit
add async tests
irgolic committed Dec 4, 2023
commit a0a693e63b3726aa47bc193776679dc64bd1cd55
1 change: 0 additions & 1 deletion guardrails/guard.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,6 @@
from guardrails.rail import Rail
from guardrails.run import AsyncRunner, Runner
from guardrails.schema import Schema, StringSchema
from guardrails.utils.reask_utils import sub_reasks_with_fixed_values
from guardrails.validators import Validator

logger = logging.getLogger(__name__)
18 changes: 9 additions & 9 deletions guardrails/run.py
Original file line number Diff line number Diff line change
@@ -11,13 +11,7 @@
from guardrails.prompt import Instructions, Prompt
from guardrails.schema import Schema, StringSchema
from guardrails.utils.llm_response import LLMResponse
from guardrails.utils.reask_utils import (
FieldReAsk,
NonParseableReAsk,
ReAsk,
reasks_to_dict,
sub_reasks_with_fixed_values,
)
from guardrails.utils.reask_utils import NonParseableReAsk, ReAsk, reasks_to_dict
from guardrails.validator_base import ValidatorError

logger = logging.getLogger(__name__)
@@ -361,7 +355,8 @@ def prepare(
iteration.outputs.validation_output = validated_msg_history
if isinstance(validated_msg_history, ReAsk):
raise ValidatorError(
f"Message history validation failed: {validated_msg_history}"
f"Message history validation failed: "
f"{validated_msg_history}"
)
if validated_msg_history != msg_str:
raise ValidatorError("Message history validation failed")
@@ -698,6 +693,8 @@ async def async_run(
output_schema,
prompt_params=prompt_params,
)
except (ValidatorError, ValueError) as e:
raise e
except Exception as e:
error_message = str(e)

@@ -934,7 +931,8 @@ async def async_prepare(
)
if isinstance(validated_msg_history, ReAsk):
raise ValidatorError(
f"Message history validation failed: {validated_msg_history}"
f"Message history validation failed: "
f"{validated_msg_history}"
)
if validated_msg_history != msg_str:
raise ValidatorError("Message history validation failed")
@@ -963,6 +961,7 @@ async def async_prepare(
validated_prompt = await prompt_schema.async_validate(
iteration, prompt.source, self.metadata
)
iteration.outputs.validation_output = validated_prompt
if validated_prompt is None:
raise ValidatorError("Prompt validation failed")
if isinstance(validated_prompt, ReAsk):
@@ -981,6 +980,7 @@ async def async_prepare(
validated_instructions = await instructions_schema.async_validate(
iteration, instructions.source, self.metadata
)
iteration.outputs.validation_output = validated_instructions
if validated_instructions is None:
raise ValidatorError("Instructions validation failed")
if isinstance(validated_instructions, ReAsk):
203 changes: 181 additions & 22 deletions tests/unit_tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,11 @@
from guardrails import Guard
from guardrails.datatypes import DataType
from guardrails.schema import StringSchema
from guardrails.utils.openai_utils import OPENAI_VERSION, get_static_openai_create_func
from guardrails.utils.openai_utils import (
OPENAI_VERSION,
get_static_openai_acreate_func,
get_static_openai_create_func,
)
from guardrails.utils.reask_utils import FieldReAsk
from guardrails.validator_base import (
FailResult,
@@ -627,15 +631,7 @@ class Pet(BaseModel):
name: str = Field(description="a unique pet name")


def test_input_validation_fix(mocker):
if OPENAI_VERSION.startswith("0"):
mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion)
else:
mocker.patch(
"openai.resources.chat.completions.Completions.create",
new=mock_chat_completion,
)

def test_input_validation_fix():
# fix returns an amended value for prompt/instructions validation,
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
validators=[TwoWords(on_fail="fix")]
@@ -674,7 +670,7 @@ def test_input_validation_fix(mocker):

# rail prompt validation
guard = Guard.from_rail_string(
f"""
"""
<rail version="0.1">
<prompt
validators="two-words"
@@ -694,7 +690,7 @@ def test_input_validation_fix(mocker):

# rail instructions validation
guard = Guard.from_rail_string(
f"""
"""
<rail version="0.1">
<prompt>
This is not two words
@@ -716,6 +712,89 @@ def test_input_validation_fix(mocker):
assert guard.history.first.iterations.first.outputs.validation_output == "This also"


@pytest.mark.asyncio
@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Not supported in v1")
async def test_async_input_validation_fix():
# fix returns an amended value for prompt/instructions validation,
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
validators=[TwoWords(on_fail="fix")]
)
await guard(
get_static_openai_acreate_func(),
prompt="What kind of pet should I get?",
)
assert guard.history.first.iterations.first.outputs.validation_output == "What kind"
guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation(
validators=[TwoWords(on_fail="fix")]
)
await guard(
get_static_openai_acreate_func(),
prompt="What kind of pet should I get and what should I name it?",
instructions="But really, what kind of pet should I get?",
)
assert (
guard.history.first.iterations.first.outputs.validation_output == "But really,"
)

# but raises for msg_history validation
with pytest.raises(ValidatorError):
guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation(
validators=[TwoWords(on_fail="fix")]
)
await guard(
get_static_openai_acreate_func(),
msg_history=[
{
"role": "user",
"content": "What kind of pet should I get?",
}
],
)

# rail prompt validation
guard = Guard.from_rail_string(
"""
<rail version="0.1">
<prompt
validators="two-words"
on-fail-two-words="fix"
>
This is not two words
</prompt>
<output type="string">
</output>
</rail>
"""
)
await guard(
get_static_openai_acreate_func(),
)
assert guard.history.first.iterations.first.outputs.validation_output == "This is"

# rail instructions validation
guard = Guard.from_rail_string(
"""
<rail version="0.1">
<prompt>
This is not two words
</prompt>
<instructions
validators="two-words"
on-fail-two-words="fix"
>
This also is not two words
</instructions>
<output type="string">
</output>
</rail>
"""
)
await guard(
get_static_openai_acreate_func(),
)
assert guard.history.first.iterations.first.outputs.validation_output == "This also"


@pytest.mark.parametrize(
"on_fail",
[
@@ -725,15 +804,7 @@ def test_input_validation_fix(mocker):
"exception",
],
)
def test_input_validation_fail(mocker, on_fail):
if OPENAI_VERSION.startswith("0"):
mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion)
else:
mocker.patch(
"openai.resources.chat.completions.Completions.create",
new=mock_chat_completion,
)

def test_input_validation_fail(on_fail):
# with_prompt_validation
with pytest.raises(ValidatorError):
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
@@ -771,7 +842,7 @@ def test_input_validation_fail(mocker, on_fail):
guard = Guard.from_rail_string(
f"""
<rail version="0.1">
<prompt
<prompt
validators="two-words"
on-fail-two-words="{on_fail}"
>
@@ -810,6 +881,94 @@ def test_input_validation_fail(mocker, on_fail):
)


@pytest.mark.parametrize(
"on_fail",
[
"reask",
"filter",
"refrain",
"exception",
],
)
@pytest.mark.asyncio
@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Not supported in v1")
async def test_input_validation_fail_async(mocker, on_fail):
# with_prompt_validation
with pytest.raises(ValidatorError):
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
validators=[TwoWords(on_fail=on_fail)]
)
await guard(
get_static_openai_acreate_func(),
prompt="What kind of pet should I get?",
)
# with_instructions_validation
with pytest.raises(ValidatorError):
guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation(
validators=[TwoWords(on_fail=on_fail)]
)
await guard(
get_static_openai_acreate_func(),
prompt="What kind of pet should I get and what should I name it?",
instructions="What kind of pet should I get?",
)
# with_msg_history_validation
with pytest.raises(ValidatorError):
guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation(
validators=[TwoWords(on_fail=on_fail)]
)
await guard(
get_static_openai_acreate_func(),
msg_history=[
{
"role": "user",
"content": "What kind of pet should I get?",
}
],
)
# rail prompt validation
guard = Guard.from_rail_string(
f"""
<rail version="0.1">
<prompt
validators="two-words"
on-fail-two-words="{on_fail}"
>
This is not two words
</prompt>
<output type="string">
</output>
</rail>
"""
)
with pytest.raises(ValidatorError):
await guard(
get_static_openai_acreate_func(),
)
# rail instructions validation
guard = Guard.from_rail_string(
f"""
<rail version="0.1">
<prompt>
This is not two words
</prompt>
<instructions
validators="two-words"
on-fail-two-words="{on_fail}"
>
This also is not two words
</instructions>
<output type="string">
</output>
</rail>
"""
)
with pytest.raises(ValidatorError):
await guard(
get_static_openai_acreate_func(),
)


def test_input_validation_mismatch_raise():
# prompt validation, msg_history argument
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(