Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input validation #483

Closed
wants to merge 14 commits into from
Next Next commit
Internal plumbing for input validation
irgolic committed Nov 29, 2023
commit 73a1d093cdc86a89ac4706c9c776141664ae5f4f
15 changes: 11 additions & 4 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import contextvars
import logging
import warnings
from typing import (
Any,
Awaitable,
@@ -22,7 +23,7 @@
from guardrails.prompt import Instructions, Prompt
from guardrails.rail import Rail
from guardrails.run import AsyncRunner, Runner
from guardrails.schema import Schema
from guardrails.schema import Schema, StringSchema
from guardrails.utils.logs_utils import GuardState
from guardrails.utils.reask_utils import sub_reasks_with_fixed_values
from guardrails.validators import Validator
@@ -62,9 +63,14 @@ def __init__(
self.base_model = base_model

@property
def input_schema(self) -> Optional[Schema]:
def prompt_schema(self) -> Optional[Schema]:
"""Return the input schema."""
return self.rail.input_schema
return self.rail.prompt_schema

@property
def instructions_schema(self) -> Optional[Schema]:
"""Return the input schema."""
return self.rail.instructions_schema

@property
def output_schema(self) -> Schema:
@@ -351,7 +357,8 @@ def _call_sync(
prompt=prompt_obj,
msg_history=msg_history_obj,
api=get_llm_ask(llm_api, *args, **kwargs),
input_schema=self.input_schema,
prompt_schema=self.prompt_schema,
instructions_schema=self.instructions_schema,
output_schema=self.output_schema,
num_reasks=num_reasks,
metadata=metadata,
19 changes: 11 additions & 8 deletions guardrails/rail.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,8 @@ class Rail:
4. `<instructions>`, which contains the instructions to be passed to the LLM
"""

input_schema: Optional[Schema]
prompt_schema: Optional[Schema]
instructions_schema: Optional[Schema]
output_schema: Schema
instructions: Optional[Instructions]
prompt: Optional[Prompt]
@@ -44,16 +45,15 @@ def from_pydantic(
reask_prompt: Optional[str] = None,
reask_instructions: Optional[str] = None,
):
input_schema = None

output_schema = cls.load_json_schema_from_pydantic(
output_class,
reask_prompt_template=reask_prompt,
reask_instructions_template=reask_instructions,
)

return cls(
input_schema=input_schema,
prompt_schema=None,
instructions_schema=None,
output_schema=output_schema,
instructions=cls.load_instructions(instructions, output_schema),
prompt=cls.load_prompt(prompt, output_schema),
@@ -78,12 +78,15 @@ def from_xml(cls, xml: ET._Element):
)

# Load <input /> schema
# TODO change this to `prompt_validators` and `instructions_validators`
raw_input_schema = xml.find("input")
if raw_input_schema is None:
# No input schema, so do no input checking.
input_schema = None
else:
input_schema = cls.load_input_schema_from_xml(raw_input_schema)
prompt_schema = None
instructions_schema = None

# Load <output /> schema
raw_output_schema = xml.find("output")
@@ -123,7 +126,8 @@ def from_xml(cls, xml: ET._Element):
version = cast_xml_to_string(version)

return cls(
input_schema=input_schema,
prompt_schema=prompt_schema,
instructions_schema=instructions_schema,
output_schema=output_schema,
instructions=instructions,
prompt=prompt,
@@ -140,8 +144,6 @@ def from_string_validators(
reask_prompt: Optional[str] = None,
reask_instructions: Optional[str] = None,
):
input_schema = None

output_schema = cls.load_string_schema_from_string(
validators,
description=description,
@@ -150,7 +152,8 @@ def from_string_validators(
)

return cls(
input_schema=input_schema,
prompt_schema=None,
instructions_schema=None,
output_schema=output_schema,
instructions=cls.load_instructions(instructions, output_schema),
prompt=cls.load_prompt(prompt, output_schema),
68 changes: 55 additions & 13 deletions guardrails/run.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from guardrails.datatypes import verify_metadata_requirements
from guardrails.llm_providers import AsyncPromptCallableBase, PromptCallableBase
from guardrails.prompt import Instructions, Prompt
from guardrails.schema import Schema
from guardrails.schema import Schema, StringSchema
from guardrails.utils.llm_response import LLMResponse
from guardrails.utils.logs_utils import GuardHistory, GuardLogs, GuardState
from guardrails.utils.reask_utils import (
@@ -18,6 +18,7 @@
reasks_to_dict,
sub_reasks_with_fixed_values,
)
from guardrails.validator_base import ValidatorError

logger = logging.getLogger(__name__)
actions_logger = logging.getLogger(f"{__name__}.actions")
@@ -52,7 +53,8 @@ def __init__(
instructions: Optional[Union[str, Instructions]] = None,
msg_history: Optional[List[Dict]] = None,
api: Optional[PromptCallableBase] = None,
input_schema: Optional[Schema] = None,
prompt_schema: Optional[StringSchema] = None,
instructions_schema: Optional[StringSchema] = None,
metadata: Optional[Dict[str, Any]] = None,
output: Optional[str] = None,
guard_history: Optional[GuardHistory] = None,
@@ -88,7 +90,8 @@ def __init__(
self.msg_history = None

self.api = api
self.input_schema = input_schema
self.prompt_schema = prompt_schema
self.instructions_schema = instructions_schema
self.output_schema = output_schema
self.guard_state = guard_state
self.num_reasks = num_reasks
@@ -138,16 +141,18 @@ def __call__(self, prompt_params: Optional[Dict] = None) -> GuardHistory:
instructions=self.instructions,
prompt=self.prompt,
api=self.api,
input_schema=self.input_schema,
prompt_schema=self.prompt_schema,
instructions_schema=self.instructions_schema,
output_schema=self.output_schema,
num_reasks=self.num_reasks,
metadata=self.metadata,
):
instructions, prompt, msg_history, input_schema, output_schema = (
instructions, prompt, msg_history, prompt_schema, instructions_schema, output_schema = (
self.instructions,
self.prompt,
self.msg_history,
self.input_schema,
self.prompt_schema,
self.instructions_schema,
self.output_schema,
)
for index in range(self.num_reasks + 1):
@@ -159,7 +164,8 @@ def __call__(self, prompt_params: Optional[Dict] = None) -> GuardHistory:
prompt=prompt,
msg_history=msg_history,
prompt_params=prompt_params,
input_schema=input_schema,
prompt_schema=prompt_schema,
instructions_schema=instructions_schema,
output_schema=output_schema,
output=self.output if index == 0 else None,
)
@@ -186,7 +192,8 @@ def step(
prompt: Optional[Prompt],
msg_history: Optional[List[Dict]],
prompt_params: Dict,
input_schema: Optional[Schema],
prompt_schema: Optional[StringSchema],
instructions_schema: Optional[StringSchema],
output_schema: Schema,
output: Optional[str] = None,
):
@@ -199,7 +206,8 @@ def step(
instructions=instructions,
prompt=prompt,
prompt_params=prompt_params,
input_schema=input_schema,
prompt_schema=prompt_schema,
instructions_schema=instructions_schema,
output_schema=output_schema,
):
# Prepare: run pre-processing, and input validation.
@@ -209,13 +217,15 @@ def step(
msg_history = None
else:
instructions, prompt, msg_history = self.prepare(
guard_logs, # TODO pass something else here
index,
instructions,
prompt,
msg_history,
prompt_params,
api,
input_schema,
prompt_schema,
instructions_schema,
output_schema,
)

@@ -265,13 +275,15 @@ def step(

def prepare(
self,
guard_logs: GuardLogs,
index: int,
instructions: Optional[Instructions],
prompt: Optional[Prompt],
msg_history: Optional[List[Dict]],
prompt_params: Dict,
api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]],
input_schema: Optional[Schema],
prompt_schema: Optional[StringSchema],
instructions_schema: Optional[StringSchema],
output_schema: Schema,
) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]:
"""Prepare by running pre-processing and input validation.
@@ -293,6 +305,8 @@ def prepare(
msg["content"] = msg["content"].format(**prompt_params)

prompt, instructions = None, None

# TODO figure out what to do with msg_history in terms of input validation
elif prompt is not None:
if isinstance(prompt, str):
prompt = Prompt(prompt)
@@ -307,6 +321,32 @@ def prepare(
instructions, prompt = output_schema.preprocess_prompt(
api, instructions, prompt
)

# validate prompt
if prompt_schema is not None:
validated_prompt = prompt_schema.validate(
guard_logs, prompt.source, self.metadata
)
if validated_prompt is None:
raise ValidatorError("Prompt validation failed")
if isinstance(validated_prompt, ReAsk):
raise ValidatorError(
f"Prompt validation failed: {validated_prompt}"
)
prompt = Prompt(validated_prompt)

# validate instructions
if instructions_schema is not None:
validated_instructions = instructions_schema.validate(
guard_logs, instructions.source, self.metadata
)
if validated_instructions is None:
raise ValidatorError("Instructions validation failed")
if isinstance(validated_instructions, ReAsk):
raise ValidatorError(
f"Instructions validation failed: {validated_instructions}"
)
instructions = Instructions(validated_instructions)
else:
raise ValueError("Prompt or message history must be provided.")

@@ -591,14 +631,16 @@ async def async_step(
prompt = None
msg_history = None
else:
instructions, prompt, msg_history = self.prepare(
instructions, prompt, msg_history = await self.async_prepare(
guard_logs, # TODO pass something else here
index,
instructions,
prompt,
msg_history,
prompt_params,
api,
input_schema,
prompt_schema,
instructions_schema,
output_schema,
)