diff --git a/guardrails/datatypes.py b/guardrails/datatypes.py index 54a60d317..196362d2c 100644 --- a/guardrails/datatypes.py +++ b/guardrails/datatypes.py @@ -3,20 +3,19 @@ import warnings from dataclasses import dataclass from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, Dict, Iterable +from typing import Any, Dict, Iterable from typing import List as TypedList -from typing import Optional, Type, TypeVar, Union +from typing import Optional, Sequence, Type, TypeVar, Union from dateutil.parser import parse from lxml import etree as ET +from pydantic.fields import ModelField from typing_extensions import Self +from guardrails.formatattr import FormatAttr from guardrails.utils.casting_utils import to_float, to_int, to_string from guardrails.utils.xml_utils import cast_xml_to_string -from guardrails.validator_base import Validator - -if TYPE_CHECKING: - from guardrails.schema import FormatAttr +from guardrails.validator_base import Validator, ValidatorSpec logger = logging.getLogger(__name__) @@ -64,7 +63,7 @@ class DataType: def __init__( self, children: Dict[str, Any], - format_attr: "FormatAttr", + format_attr: FormatAttr, optional: bool, name: Optional[str], description: Optional[str], @@ -120,12 +119,10 @@ def set_children_from_xml(self, element: ET._Element): @classmethod def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self: - from guardrails.schema import FormatAttr - # TODO: don't want to pass strict through to DataType, - # but need to pass it to FormatAttr.from_element + # but need to pass it to FormatAttr.from_xml # how to handle this? - format_attr = FormatAttr.from_element(element, cls.tag, strict) + format_attr = FormatAttr.from_xml(element, cls.tag, strict) is_optional = element.attrib.get("required", "true") == "false" @@ -141,6 +138,28 @@ def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self: data_type.set_children_from_xml(element) return data_type + @classmethod + def from_pydantic_field( + cls, + field: ModelField, + children: Optional[Dict[str, "DataType"]] = None, + strict: bool = False, + **kwargs, + ) -> Self: + if children is None: + children = {} + + validators = field.field_info.extra.get("validators", []) + format_attr = FormatAttr.from_validators(validators, cls.tag, strict) + + is_optional = field.required is False + + name = field.name + description = field.field_info.description + + data_type = cls(children, format_attr, is_optional, name, description, **kwargs) + return data_type + @property def children(self) -> SimpleNamespace: """Return a SimpleNamespace of the children of this DataType.""" @@ -175,6 +194,7 @@ def deprecate_type(cls: type): versions 0.3.0 and beyond. Use the pydantic 'str' primitive instead.""", DeprecationWarning, ) + return cls class ScalarType(DataType): @@ -197,6 +217,21 @@ def from_str(self, s: str) -> Optional[str]: """Create a String from a string.""" return to_string(s) + @classmethod + def from_string_rail( + cls, + validators: Sequence[ValidatorSpec], + description: Optional[str] = None, + strict: bool = False, + ) -> Self: + return cls( + children={}, + format_attr=FormatAttr.from_validators(validators, cls.tag, strict), + optional=False, + name=None, + description=description, + ) + @register_type("integer") class Integer(ScalarType): @@ -326,6 +361,10 @@ class Email(ScalarType): tag = "email" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + deprecate_type(type(self)) + @deprecate_type @register_type("url") @@ -334,6 +373,10 @@ class URL(ScalarType): tag = "url" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + deprecate_type(type(self)) + @deprecate_type @register_type("pythoncode") @@ -342,6 +385,10 @@ class PythonCode(ScalarType): tag = "pythoncode" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + deprecate_type(type(self)) + @deprecate_type @register_type("sql") @@ -350,6 +397,10 @@ class SQLCode(ScalarType): tag = "sql" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + deprecate_type(type(self)) + @register_type("percentage") class Percentage(ScalarType): diff --git a/guardrails/formatattr.py b/guardrails/formatattr.py new file mode 100644 index 000000000..9c0de5258 --- /dev/null +++ b/guardrails/formatattr.py @@ -0,0 +1,320 @@ +import re +import warnings +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import lxml.etree as ET +import pydantic + +from guardrails.utils.xml_utils import cast_xml_to_string +from guardrails.validator_base import Validator, ValidatorSpec + + +class FormatAttr(pydantic.BaseModel): + """Class for parsing and manipulating the `format` attribute of an element. + + The format attribute is a string that contains semi-colon separated + validators e.g. "valid-url; is-reachable". Each validator is itself either: + - the name of an parameter-less validator, e.g. "valid-url" + - the name of a validator with parameters, separated by a colon with a + space-separated list of parameters, e.g. "is-in: 1 2 3" + + Parameters can either be written in plain text, or in python expressions + enclosed in curly braces. For example, the following are all valid: + - "is-in: 1 2 3" + - "is-in: {1} {2} {3}" + - "is-in: {1 + 2} {2 + 3} {3 + 4}" + """ + + class Config: + arbitrary_types_allowed = True + + # The format attribute string. + format: Optional[str] + + # The on-fail handlers. + on_fail_handlers: Dict[str, str] + + # The validator arguments. + validator_args: Mapping[str, Union[Dict[str, Any], List[Any]]] + + # The validators. + validators: List[Validator] + + # The unregistered validators. + unregistered_validators: List[str] + + @property + def empty(self) -> bool: + """Return True if the format attribute is empty, False otherwise.""" + return not self.validators and not self.unregistered_validators + + @classmethod + def from_validators( + cls, + validators: Sequence[ValidatorSpec], + tag: str, + strict: bool = False, + ): + validators_with_args = {} + on_fails = {} + for val in validators: + # must be either a tuple with two elements or a gd.Validator + if isinstance(val, Validator): + # `validator` is of type gd.Validator, use the to_xml_attrib method + validator_name = val.rail_alias + validator_args = val.get_args() + validators_with_args[validator_name] = validator_args + # Set the on-fail attribute based on the on_fail value + on_fail = val.on_fail_descriptor + on_fails[val.rail_alias] = on_fail + elif isinstance(val, tuple) and len(val) == 2: + validator, on_fail = val + if isinstance(validator, Validator): + # `validator` is of type gd.Validator, use the to_xml_attrib method + validator_name = validator.rail_alias + validator_args = validator.get_args() + validators_with_args[validator_name] = validator_args + # Set the on-fail attribute based on the on_fail value + on_fails[validator.rail_alias] = on_fail + elif isinstance(validator, str): + # `validator` is a string, use it as the validator prompt + validator_name = validator + validator_args = [] + validators_with_args[validator_name] = validator_args + on_fails[validator] = on_fail + elif isinstance(validator, Callable): + # `validator` is a callable, use it as the validator prompt + if not hasattr(validator, "rail_alias"): + raise ValueError( + f"Validator {validator.__name__} must be registered with " + f"the gd.register_validator decorator" + ) + validator_name = validator.rail_alias + validator_args = [] + validators_with_args[validator_name] = validator_args + on_fails[validator.rail_alias] = on_fail + else: + raise ValueError( + f"Validator tuple {val} must be a (validator, on_fail) tuple, " + f"where the validator is a string or a callable" + ) + else: + raise ValueError( + f"Validator {val} must be a (validator, on_fail) tuple or " + f"Validator class instance" + ) + + registered_validators, unregistered_validators = cls.get_validators( + validator_args=validators_with_args, + tag=tag, + on_fail_handlers=on_fails, + strict=strict, + ) + + return cls( + format=None, + on_fail_handlers=on_fails, + validator_args=validators_with_args, + validators=registered_validators, + unregistered_validators=unregistered_validators, + ) + + @classmethod + def from_xml( + cls, element: ET._Element, tag: str, strict: bool = False + ) -> "FormatAttr": + """Create a FormatAttr object from an XML element. + + Args: + element (ET._Element): The XML element. + + Returns: + A FormatAttr object. + """ + format_str = element.get("format") + if format_str is None: + return cls( + format=None, + on_fail_handlers={}, + validator_args={}, + validators=[], + unregistered_validators=[], + ) + + validator_args = cls.parse(format_str) + + on_fail_handlers = {} + for key, value in element.attrib.items(): + key = cast_xml_to_string(key) + if key.startswith("on-fail-"): + on_fail_handler_name = key[len("on-fail-") :] + on_fail_handler = value + on_fail_handlers[on_fail_handler_name] = on_fail_handler + + validators, unregistered_validators = cls.get_validators( + validator_args=validator_args, + tag=tag, + on_fail_handlers=on_fail_handlers, + strict=strict, + ) + + return cls( + format=format_str, + on_fail_handlers=on_fail_handlers, + validator_args=validator_args, + validators=validators, + unregistered_validators=unregistered_validators, + ) + + @staticmethod + def parse_token(token: str) -> Tuple[str, List[Any]]: + """Parse a single token in the format attribute, and return the + validator name and the list of arguments. + + Args: + token (str): The token to parse, one of the tokens returned by + `self.tokens`. + + Returns: + A tuple of the validator name and the list of arguments. + """ + validator_with_args = token.strip().split(":", 1) + if len(validator_with_args) == 1: + return validator_with_args[0].strip(), [] + + validator, args_token = validator_with_args + + # Split using whitespace as a delimiter, but not if it is inside curly braces or + # single quotes. + pattern = re.compile(r"\s(?![^{}]*})|(? Dict[str, List[Any]]: + """Parse the format attribute into a dictionary of validators. + + Returns: + A dictionary of validators, where the key is the validator name, and + the value is a list of arguments. + """ + # Split the format attribute into tokens: each is a validator. + # Then, parse each token into a validator name and a list of parameters. + pattern = re.compile(r";(?![^{}]*})") + tokens = re.split(pattern, format_string) + tokens = list(filter(None, tokens)) + + validators = {} + for token in tokens: + # Parse the token into a validator name and a list of parameters. + validator_name, args = FormatAttr.parse_token(token) + validators[validator_name] = args + + return validators + + @staticmethod + def get_validators( + validator_args: Dict[str, List[Any]], + tag: str, + on_fail_handlers: Dict[str, str], + strict: bool = False, + ) -> Tuple[List[Validator], List[str]]: + """Get the list of validators from the format attribute. Only the + validators that are registered for this element will be returned. + + For example, if the format attribute is "valid-url; is-reachable", and + "is-reachable" is not registered for this element, then only the ValidUrl + validator will be returned, after instantiating it with the arguments + specified in the format attribute (if any). + + Args: + strict: If True, raise an error if a validator is not registered for + this element. If False, ignore the validator and print a warning. + + Returns: + A list of validators. + """ + from guardrails.validator_base import types_to_validators, validators_registry + + _validators = [] + _unregistered_validators = [] + for validator_name, args in validator_args.items(): + # Check if the validator is registered for this element. + # The validators in `format` that are not registered for this element + # will be ignored (with an error or warning, depending on the value of + # `strict`), and the registered validators will be returned. + if validator_name not in types_to_validators[tag]: + if strict: + raise ValueError( + f"Validator {validator_name} is not valid for" + f" element {tag}." + ) + else: + warnings.warn( + f"Validator {validator_name} is not valid for" + f" element {tag}." + ) + _unregistered_validators.append(validator_name) + continue + + validator = validators_registry[validator_name] + + # See if the formatter has an associated on_fail method. + on_fail = on_fail_handlers.get(validator_name, None) + # TODO(shreya): Load the on_fail method. + # This method should be loaded from an optional script given at the + # beginning of a rail file. + + # Create the validator. + if isinstance(args, list): + v = validator(*args, on_fail=on_fail) + elif isinstance(args, dict): + v = validator(**args, on_fail=on_fail) + else: + raise ValueError( + f"Validator {validator_name} has invalid arguments: {args}." + ) + _validators.append(v) + + return _validators, _unregistered_validators + + def to_prompt(self, with_keywords: bool = True) -> str: + """Convert the format string to another string representation for use + in prompting. Uses the validators' to_prompt method in order to + construct the string to use in prompting. + + For example, the format string "valid-url; other-validator: 1.0 + {1 + 2}" will be converted to "valid-url other-validator: + arg1=1.0 arg2=3". + """ + if self.empty: + return "" + # Use the validators' to_prompt method to convert the format string to + # another string representation. + prompt = "; ".join([v.to_prompt(with_keywords) for v in self.validators]) + unreg_prompt = "; ".join(self.unregistered_validators) + if prompt and unreg_prompt: + prompt += f"; {unreg_prompt}" + elif unreg_prompt: + prompt += unreg_prompt + return prompt diff --git a/guardrails/guard.py b/guardrails/guard.py index 99f6a72f5..f0156a650 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -8,6 +8,7 @@ Dict, List, Optional, + Sequence, Tuple, Type, Union, @@ -191,7 +192,7 @@ def from_pydantic( @classmethod def from_string( cls, - validators: List[Validator], + validators: Sequence[Validator], description: Optional[str] = None, prompt: Optional[str] = None, instructions: Optional[str] = None, diff --git a/guardrails/rail.py b/guardrails/rail.py index fcded5170..ca97c6fa4 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -1,20 +1,15 @@ """Rail class.""" import warnings from dataclasses import dataclass -from typing import List, Optional, Type +from typing import Optional, Sequence, Type from lxml import etree as ET -from lxml.etree import Element, SubElement from pydantic import BaseModel from guardrails.prompt import Instructions, Prompt from guardrails.schema import JsonSchema, Schema, StringSchema -from guardrails.utils.pydantic_utils import ( - attach_validators_to_element, - create_xml_element_for_base_model, -) from guardrails.utils.xml_utils import cast_xml_to_string -from guardrails.validators import Validator +from guardrails.validator_base import ValidatorSpec # TODO: Logging XMLPARSER = ET.XMLParser(encoding="utf-8") @@ -49,14 +44,20 @@ def from_pydantic( reask_prompt: Optional[str] = None, reask_instructions: Optional[str] = None, ): - xml = generate_xml_code( - output_class=output_class, - prompt=prompt, - instructions=instructions, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, + input_schema = None + + output_schema = cls.load_json_schema_from_pydantic( + output_class, + reask_prompt_template=reask_prompt, + reask_instructions_template=reask_instructions, + ) + + return cls( + input_schema=input_schema, + output_schema=output_schema, + instructions=cls.load_instructions(instructions, output_schema), + prompt=cls.load_prompt(prompt, output_schema), ) - return cls.from_xml(xml) @classmethod def from_file(cls, file_path: str) -> "Rail": @@ -82,7 +83,7 @@ def from_xml(cls, xml: ET._Element): # No input schema, so do no input checking. input_schema = None else: - input_schema = cls.load_input_schema(raw_input_schema) + input_schema = cls.load_input_schema_from_xml(raw_input_schema) # Load schema raw_output_schema = xml.find("output") @@ -97,7 +98,7 @@ def from_xml(cls, xml: ET._Element): reask_instructions = xml.find("reask_instructions") if reask_instructions is not None: reask_instructions = reask_instructions.text - output_schema = cls.load_output_schema( + output_schema = cls.load_output_schema_from_xml( raw_output_schema, reask_prompt=reask_prompt, reask_instructions=reask_instructions, @@ -108,14 +109,14 @@ def from_xml(cls, xml: ET._Element): # prepended to the prompt. instructions = xml.find("instructions") if instructions is not None: - instructions = cls.load_instructions(instructions, output_schema) + instructions = cls.load_instructions(instructions.text, output_schema) # Load prompt = xml.find("prompt") if prompt is None: warnings.warn("Prompt must be provided during __call__.") else: - prompt = cls.load_prompt(prompt, output_schema) + prompt = cls.load_prompt(prompt.text, output_schema) # Get version version = xml.attrib["version"] @@ -132,37 +133,37 @@ def from_xml(cls, xml: ET._Element): @classmethod def from_string_validators( cls, - validators: List[Validator], + validators: Sequence[ValidatorSpec], description: Optional[str] = None, prompt: Optional[str] = None, instructions: Optional[str] = None, reask_prompt: Optional[str] = None, reask_instructions: Optional[str] = None, ): - xml = generate_xml_code( - prompt=prompt, - instructions=instructions, - reask_prompt=reask_prompt, - reask_instructions=reask_instructions, - validators=validators, + input_schema = None + + output_schema = cls.load_string_schema_from_string( + validators, description=description, + reask_prompt_template=reask_prompt, + reask_instructions_template=reask_instructions, ) - return cls.from_xml(xml) - @staticmethod - def load_schema(root: ET._Element) -> Schema: - """Given the RAIL or element, create a Schema - object.""" - return Schema.from_element(root) + return cls( + input_schema=input_schema, + output_schema=output_schema, + instructions=cls.load_instructions(instructions, output_schema), + prompt=cls.load_prompt(prompt, output_schema), + ) @staticmethod - def load_input_schema(root: ET._Element) -> Schema: + def load_input_schema_from_xml(root: ET._Element) -> Schema: """Given the RAIL element, create a Schema object.""" # Recast the schema as an InputSchema. - return Schema.from_element(root) + return Schema.from_xml(root) @staticmethod - def load_output_schema( + def load_output_schema_from_xml( root: ET._Element, reask_prompt: Optional[str] = None, reask_instructions: Optional[str] = None, @@ -179,100 +180,61 @@ def load_output_schema( """ # If root contains a `type="string"` attribute, then it's a StringSchema if "type" in root.attrib and root.attrib["type"] == "string": - return StringSchema.from_element( + return StringSchema.from_xml( root, reask_prompt_template=reask_prompt, reask_instructions_template=reask_instructions, ) - return JsonSchema.from_element( + return JsonSchema.from_xml( root, reask_prompt_template=reask_prompt, reask_instructions_template=reask_instructions, ) @staticmethod - def load_instructions(root: ET._Element, output_schema: Schema) -> Instructions: + def load_string_schema_from_string( + validators: Sequence[ValidatorSpec], + description: Optional[str] = None, + reask_prompt_template: Optional[str] = None, + reask_instructions_template: Optional[str] = None, + ): + return StringSchema.from_string( + validators, + description=description, + reask_prompt_template=reask_prompt_template, + reask_instructions_template=reask_instructions_template, + ) + + @staticmethod + def load_json_schema_from_pydantic( + output_class: Type[BaseModel], + reask_prompt_template: Optional[str] = None, + reask_instructions_template: Optional[str] = None, + ): + return JsonSchema.from_pydantic( + output_class, + reask_prompt_template=reask_prompt_template, + reask_instructions_template=reask_instructions_template, + ) + + @staticmethod + def load_instructions( + text: Optional[str], output_schema: Schema + ) -> Optional[Instructions]: """Given the RAIL element, create Instructions.""" + if text is None: + return None return Instructions( - source=root.text or "", + source=text or "", output_schema=output_schema.transpile(), ) @staticmethod - def load_prompt(root: ET._Element, output_schema: Schema) -> Prompt: + def load_prompt(text: Optional[str], output_schema: Schema) -> Optional[Prompt]: """Given the RAIL element, create a Prompt object.""" + if text is None: + return None return Prompt( - source=root.text or "", + source=text or "", output_schema=output_schema.transpile(), ) - - -def generate_xml_code( - prompt: Optional[str] = None, - output_class: Optional[Type[BaseModel]] = None, - instructions: Optional[str] = None, - reask_prompt: Optional[str] = None, - reask_instructions: Optional[str] = None, - validators: Optional[List[Validator]] = None, - description: Optional[str] = None, -) -> ET._Element: - """Generate XML RAIL Spec from a pydantic model and a prompt. - - Parameters: Arguments: - prompt (str, optional): The prompt for this RAIL spec. - output_class (BaseModel, optional): The Pydantic model that represents the desired output schema. Do not specify if using a string schema. Defaults to None. - instructions (str, optional): Instructions for chat models. Defaults to None. - reask_prompt (str, optional): An alternative prompt to use during reasks. Defaults to None. - reask_instructions (str, optional): Alternative instructions to use during reasks. Defaults to None. - validators (List[Validator], optional): The list of validators to apply to the string schema. Do not specify if using a Pydantic model. Defaults to None. - description (str, optional): The description for a string schema. Do not specify if using a Pydantic model. Defaults to None. - """ # noqa - - # Create the root element - root = Element("rail") - root.set("version", "0.1") - - # Create the output element - output_element = SubElement(root, "output") - - if output_class and validators: - warnings.warn( - "Do not specify root level validators on a Pydantic model." - " These validators will be ignored." - ) - - if output_class is not None: - # Create XML elements for the output_class - create_xml_element_for_base_model(output_class, output_element) - else: - if validators is not None: - attach_validators_to_element(output_element, validators) - if description is not None: - output_element.set("description", description) - output_element.set("type", "string") - - if prompt is not None: - # Create the prompt element - prompt_element = SubElement(root, "prompt") - prompt_text = f"{prompt}" - prompt_element.text = prompt_text - - if instructions is not None: - # Create the instructions element - instructions_element = SubElement(root, "instructions") - instructions_text = f"{instructions}" - instructions_element.text = instructions_text - - if reask_prompt is not None: - # Create the reask_prompt element - reask_prompt_element = SubElement(root, "reask_prompt") - reask_prompt_text = f"{reask_prompt}" - reask_prompt_element.text = reask_prompt_text - - if reask_instructions is not None: - # Create the reask_instructions element - reask_instructions_element = SubElement(root, "reask_instructions") - reask_instructions_text = f"{reask_instructions}" - reask_instructions_element.text = reask_instructions_text - - return root diff --git a/guardrails/schema.py b/guardrails/schema.py index 9d493308d..690cbe1e4 100644 --- a/guardrails/schema.py +++ b/guardrails/schema.py @@ -1,13 +1,22 @@ import json import logging import pprint -import re -import warnings from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) -import pydantic from lxml import etree as ET +from pydantic import BaseModel from typing_extensions import Self from guardrails import validator_service @@ -27,6 +36,7 @@ ) from guardrails.utils.logs_utils import FieldValidationLogs, GuardLogs from guardrails.utils.parsing_utils import get_template_variables +from guardrails.utils.pydantic_utils import convert_pydantic_model_to_datatype from guardrails.utils.reask_utils import ( FieldReAsk, NonParseableReAsk, @@ -35,10 +45,9 @@ get_pruned_tree, prune_obj_for_reasking, ) -from guardrails.utils.xml_utils import cast_xml_to_string from guardrails.validator_base import ( FailResult, - Validator, + ValidatorSpec, check_refrain_in_dict, filter_in_dict, ) @@ -49,238 +58,6 @@ logger = logging.getLogger(__name__) -class FormatAttr(pydantic.BaseModel): - """Class for parsing and manipulating the `format` attribute of an element. - - The format attribute is a string that contains semi-colon separated - validators e.g. "valid-url; is-reachable". Each validator is itself either: - - the name of an parameter-less validator, e.g. "valid-url" - - the name of a validator with parameters, separated by a colon with a - space-separated list of parameters, e.g. "is-in: 1 2 3" - - Parameters can either be written in plain text, or in python expressions - enclosed in curly braces. For example, the following are all valid: - - "is-in: 1 2 3" - - "is-in: {1} {2} {3}" - - "is-in: {1 + 2} {2 + 3} {3 + 4}" - """ - - class Config: - arbitrary_types_allowed = True - - # The format attribute string. - format: Optional[str] - - # The on-fail handlers. - on_fail_handlers: Dict[str, str] - - # The validator arguments. - validator_args: Dict[str, List[Any]] - - # The validators. - validators: List[Validator] - - # The unregistered validators. - unregistered_validators: List[str] - - @property - def empty(self) -> bool: - """Return True if the format attribute is empty, False otherwise.""" - return self.format is None - - @classmethod - def from_element( - cls, element: ET._Element, tag: str, strict: bool = False - ) -> "FormatAttr": - """Create a FormatAttr object from an XML element. - - Args: - element (ET._Element): The XML element. - - Returns: - A FormatAttr object. - """ - format_str = element.get("format") - if format_str is None: - return cls( - format=None, - on_fail_handlers={}, - validator_args={}, - validators=[], - unregistered_validators=[], - ) - - validator_args = cls.parse(format_str) - - on_fail_handlers = {} - for key, value in element.attrib.items(): - key = cast_xml_to_string(key) - if key.startswith("on-fail-"): - on_fail_handler_name = key[len("on-fail-") :] - on_fail_handler = value - on_fail_handlers[on_fail_handler_name] = on_fail_handler - - validators, unregistered_validators = cls.get_validators( - validator_args=validator_args, - tag=tag, - on_fail_handlers=on_fail_handlers, - strict=strict, - ) - - return cls( - format=format_str, - on_fail_handlers=on_fail_handlers, - validator_args=validator_args, - validators=validators, - unregistered_validators=unregistered_validators, - ) - - @staticmethod - def parse_token(token: str) -> Tuple[str, List[Any]]: - """Parse a single token in the format attribute, and return the - validator name and the list of arguments. - - Args: - token (str): The token to parse, one of the tokens returned by - `self.tokens`. - - Returns: - A tuple of the validator name and the list of arguments. - """ - validator_with_args = token.strip().split(":", 1) - if len(validator_with_args) == 1: - return validator_with_args[0].strip(), [] - - validator, args_token = validator_with_args - - # Split using whitespace as a delimiter, but not if it is inside curly braces or - # single quotes. - pattern = re.compile(r"\s(?![^{}]*})|(? Dict[str, List[Any]]: - """Parse the format attribute into a dictionary of validators. - - Returns: - A dictionary of validators, where the key is the validator name, and - the value is a list of arguments. - """ - # Split the format attribute into tokens: each is a validator. - # Then, parse each token into a validator name and a list of parameters. - pattern = re.compile(r";(?![^{}]*})") - tokens = re.split(pattern, format_string) - tokens = list(filter(None, tokens)) - - validators = {} - for token in tokens: - # Parse the token into a validator name and a list of parameters. - validator_name, args = FormatAttr.parse_token(token) - validators[validator_name] = args - - return validators - - @staticmethod - def get_validators( - validator_args: Dict[str, List[Any]], - tag: str, - on_fail_handlers: Dict[str, str], - strict: bool = False, - ) -> Tuple[List[Validator], List[str]]: - """Get the list of validators from the format attribute. Only the - validators that are registered for this element will be returned. - - For example, if the format attribute is "valid-url; is-reachable", and - "is-reachable" is not registered for this element, then only the ValidUrl - validator will be returned, after instantiating it with the arguments - specified in the format attribute (if any). - - Args: - strict: If True, raise an error if a validator is not registered for - this element. If False, ignore the validator and print a warning. - - Returns: - A list of validators. - """ - from guardrails.validator_base import types_to_validators, validators_registry - - _validators = [] - _unregistered_validators = [] - for validator_name, args in validator_args.items(): - # Check if the validator is registered for this element. - # The validators in `format` that are not registered for this element - # will be ignored (with an error or warning, depending on the value of - # `strict`), and the registered validators will be returned. - if validator_name not in types_to_validators[tag]: - if strict: - raise ValueError( - f"Validator {validator_name} is not valid for" - f" element {tag}." - ) - else: - warnings.warn( - f"Validator {validator_name} is not valid for" - f" element {tag}." - ) - _unregistered_validators.append(validator_name) - continue - - validator = validators_registry[validator_name] - - # See if the formatter has an associated on_fail method. - on_fail = on_fail_handlers.get(validator_name, None) - # TODO(shreya): Load the on_fail method. - # This method should be loaded from an optional script given at the - # beginning of a rail file. - - # Create the validator. - _validators.append(validator(*args, on_fail=on_fail)) - - return _validators, _unregistered_validators - - def to_prompt(self, with_keywords: bool = True) -> str: - """Convert the format string to another string representation for use - in prompting. Uses the validators' to_prompt method in order to - construct the string to use in prompting. - - For example, the format string "valid-url; other-validator: 1.0 - {1 + 2}" will be converted to "valid-url other-validator: - arg1=1.0 arg2=3". - """ - if self.format is None: - return "" - # Use the validators' to_prompt method to convert the format string to - # another string representation. - prompt = "; ".join([v.to_prompt(with_keywords) for v in self.validators]) - unreg_prompt = "; ".join(self.unregistered_validators) - if prompt and unreg_prompt: - prompt += f"; {unreg_prompt}" - elif unreg_prompt: - prompt += unreg_prompt - return prompt - - class Schema: """Schema class that holds a _schema attribute.""" @@ -308,7 +85,7 @@ def __init__( self._reask_instructions_template = None @classmethod - def from_element( + def from_xml( cls, root: ET._Element, reask_prompt_template: Optional[str] = None, @@ -528,7 +305,7 @@ def reask_decoder(obj): return pruned_tree_schema, prompt, instructions @classmethod - def from_element( + def from_xml( cls, root: ET._Element, reask_prompt_template: Optional[str] = None, @@ -546,6 +323,23 @@ def from_element( reask_instructions_template=reask_instructions_template, ) + @classmethod + def from_pydantic( + cls, + model: Type[BaseModel], + reask_prompt_template: Optional[str] = None, + reask_instructions_template: Optional[str] = None, + ) -> Self: + strict = False + + schema = convert_pydantic_model_to_datatype(model, strict=strict) + + return cls( + schema, + reask_prompt_template=reask_prompt_template, + reask_instructions_template=reask_instructions_template, + ) + def parse( self, output: str ) -> Tuple[Union[Optional[Dict], NonParseableReAsk], Optional[Exception]]: @@ -748,7 +542,7 @@ def __init__( self.root_datatype = schema @classmethod - def from_element( + def from_xml( cls, root: ET._Element, reask_prompt_template: Optional[str] = None, @@ -769,6 +563,26 @@ def from_element( reask_instructions_template=reask_instructions_template, ) + @classmethod + def from_string( + cls, + validators: Sequence[ValidatorSpec], + description: Optional[str] = None, + reask_prompt_template: Optional[str] = None, + reask_instructions_template: Optional[str] = None, + ): + strict = False + + schema = String.from_string_rail( + validators, description=description, strict=strict + ) + + return cls( + schema=schema, + reask_prompt_template=reask_prompt_template, + reask_instructions_template=reask_instructions_template, + ) + def get_reask_setup( self, reasks: List[FieldReAsk], diff --git a/guardrails/utils/pydantic_utils.py b/guardrails/utils/pydantic_utils.py index c328ab935..f04acde57 100644 --- a/guardrails/utils/pydantic_utils.py +++ b/guardrails/utils/pydantic_utils.py @@ -4,26 +4,27 @@ import warnings from copy import deepcopy from datetime import date, time -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Type, - Union, - get_args, - get_origin, -) - -import lxml.etree as ET +from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin + from griffe.dataclasses import Docstring from griffe.docstrings.parsers import Parser, parse -from lxml.etree import Element as E -from pydantic import BaseModel, HttpUrl, validator +from pydantic import BaseModel, validator from pydantic.fields import ModelField -from guardrails.validators import Validator +from guardrails.datatypes import Boolean as BooleanDataType +from guardrails.datatypes import Case as CaseDataType +from guardrails.datatypes import Choice +from guardrails.datatypes import Choice as ChoiceDataType +from guardrails.datatypes import DataType +from guardrails.datatypes import Date as DateDataType +from guardrails.datatypes import Float as FloatDataType +from guardrails.datatypes import Integer as IntegerDataType +from guardrails.datatypes import List as ListDataType +from guardrails.datatypes import Object as ObjectDataType +from guardrails.datatypes import String as StringDataType +from guardrails.datatypes import Time as TimeDataType +from guardrails.formatattr import FormatAttr +from guardrails.validator_base import Validator griffe_docstrings_google_logger = logging.getLogger("griffe.docstrings.google") griffe_agents_nodes_logger = logging.getLogger("griffe.agents.nodes") @@ -100,7 +101,7 @@ def is_dict(type_annotation: Any) -> bool: return False -def prepare_type_annotation(type_annotation: Any) -> Type: +def prepare_type_annotation(type_annotation: Union[ModelField, Type]) -> Type: """Get the raw type annotation that can be used for downstream processing. This function does the following: @@ -129,268 +130,6 @@ def prepare_type_annotation(type_annotation: Any) -> Type: return type_annotation -def type_annotation_to_string(type_annotation: Any) -> str: - """Map a type_annotation to the name of the corresponding field type. - - This function checks if the type_annotation is a list, dict, or a - primitive type, and returns the corresponding type name, e.g. - "list", "object", "bool", "date", etc. - """ - - # Get the type annotation from the type_annotation - type_annotation = prepare_type_annotation(type_annotation) - - # Use inline import to avoid circular dependency - from guardrails.datatypes import PythonCode - - # Map the type annotation to the corresponding field type - if is_list(type_annotation): - return "list" - elif is_dict(type_annotation): - return "object" - elif type_annotation == bool: - return "bool" - elif type_annotation == date: - return "date" - elif type_annotation == float: - return "float" - elif type_annotation == int: - return "integer" - elif type_annotation == str or typing.get_origin(type_annotation) == typing.Literal: - return "string" - elif type_annotation == time: - return "time" - elif type_annotation == HttpUrl: - return "url" - elif typing.get_origin(type_annotation) == Union: - return "choice" - elif type_annotation == PythonCode: - return "string" - else: - raise ValueError(f"Unsupported type: {type_annotation}") - - -def add_validators_to_xml_element( - field_info: ModelField, element: ET._Element -) -> ET._Element: - """Extract validators from a pydantic ModelField and add to XML element. - - Args: - field_info: The field info to extract validators from - element: The XML element to add the validators to - - Returns: - The XML element with the validators added - """ - - if not isinstance(field_info, ModelField): - return element - if "validators" in field_info.field_info.extra: - validators = field_info.field_info.extra["validators"] - if not isinstance(validators, list): - validators = [validators] - - attach_validators_to_element(element, validators) - - # construct a valid-choices validator for Literal types - if typing.get_origin(field_info.annotation) is typing.Literal: - valid_choices = typing.get_args(field_info.annotation) - element.set("format", "valid-choices") - element.set("valid-choices", ",".join(valid_choices)) - - return element - - -def attach_validators_to_element( - element: ET._Element, - validators: Union[List[Validator], List[str]], -): - format_prompt = [] - on_fails = {} - for val in validators: - # must be either a tuple with two elements or a gd.Validator - if isinstance(val, Validator): - # `validator` is of type gd.Validator, use the to_xml_attrib method - validator_prompt = val.to_xml_attrib() - # Set the on-fail attribute based on the on_fail value - on_fail = val.on_fail_descriptor - on_fails[val.rail_alias] = on_fail - elif isinstance(val, tuple) and len(val) == 2: - validator, on_fail = val - if isinstance(validator, Validator): - # `validator` is of type gd.Validator, use the to_xml_attrib method - validator_prompt = validator.to_xml_attrib() - # Set the on-fail attribute based on the on_fail value - on_fails[validator.rail_alias] = on_fail - elif isinstance(validator, str): - # `validator` is a string, use it as the validator prompt - validator_prompt = validator - on_fails[validator] = on_fail - elif isinstance(validator, Callable): - # `validator` is a callable, use it as the validator prompt - if not hasattr(validator, "rail_alias"): - raise ValueError( - f"Validator {validator.__name__} must be registered with " - f"the gd.register_validator decorator" - ) - validator_prompt = validator.rail_alias - on_fails[validator.rail_alias] = on_fail - else: - raise ValueError( - f"Validator tuple {val} must be a (validator, on_fail) tuple, " - f"where the validator is a string or a callable" - ) - else: - raise ValueError( - f"Validator {val} must be a (validator, on_fail) tuple or " - f"Validator class instance" - ) - format_prompt.append(validator_prompt) - - if len(format_prompt) > 0: - format_prompt = "; ".join(format_prompt) - element.set("format", format_prompt) - for rail_alias, on_fail in on_fails.items(): - element.set("on-fail-" + rail_alias, on_fail) - - return element - - -def create_xml_element_for_field( - field: Union[ModelField, Type], - field_name: Optional[str] = None, - exclude_subfields: Optional[typing.List[str]] = None, -) -> ET._Element: - """Create an XML element corresponding to a field. - - Args: - field_info: Field's type. This could be a Pydantic ModelField or a type. - field_name: Field's name. For some fields (e.g. list), this is not required. - exclude_fields: List of fields to exclude from the XML element. - - Returns: - The XML element corresponding to the field. - """ - if exclude_subfields is None: - exclude_subfields = [] - - # Create the element based on the field type - field_type = type_annotation_to_string(field) - element = E(field_type) - - # Add name attribute - if field_name: - element.set("name", field_name) - - # Add validators - element = add_validators_to_xml_element(field, element) - - # Add description attribute - if isinstance(field, ModelField): - if field.field_info.description is not None: - element.set("description", field.field_info.description) - - if field.field_info.discriminator is not None: - assert field_type == "choice" - assert typing.get_origin(field.annotation) is Union - discriminator = field.field_info.discriminator - element.set("discriminator", discriminator) - for case in typing.get_args(field.annotation): - case_discriminator_type = case.__fields__[discriminator].type_ - assert typing.get_origin(case_discriminator_type) is typing.Literal - assert len(typing.get_args(case_discriminator_type)) == 1 - discriminator_value = typing.get_args(case_discriminator_type)[0] - case_element = E("case", name=discriminator_value) - nested_element = create_xml_element_for_field( - case, exclude_subfields=[discriminator] - ) - for child in nested_element: - case_element.append(child) - element.append(case_element) - - # Add other attributes from the field_info - for key, value in field.field_info.extra.items(): - if key not in ["validators", "description"]: - element.set(key, value) - - # Create XML elements for the field's children - if field_type in ["list", "object"]: - type_annotation = prepare_type_annotation(field) - - if is_list(type_annotation): - inner_type = get_args(type_annotation) - if len(inner_type) == 0: - # If the list is empty, we cannot infer the type of the elements - return element - - inner_type = inner_type[0] - if is_pydantic_base_model(inner_type): - object_element = create_xml_element_for_base_model(inner_type) - element.append(object_element) - else: - inner_element = create_xml_element_for_field(inner_type) - element.append(inner_element) - - elif is_dict(type_annotation): - if is_pydantic_base_model(type_annotation): - element = create_xml_element_for_base_model( - type_annotation, - element, - exclude_subfields=exclude_subfields, - ) - else: - dict_args = get_args(type_annotation) - if len(dict_args) == 2: - key_type, val_type = dict_args - assert key_type == str, "Only string keys are supported for dicts" - inner_element = create_xml_element_for_field(val_type) - element.append(inner_element) - else: - raise ValueError(f"Unsupported type: {type_annotation}") - - return element - - -def create_xml_element_for_base_model( - model: Type[BaseModel], - element: Optional[ET._Element] = None, - exclude_subfields: Optional[typing.List[str]] = None, -) -> ET._Element: - """Create an XML element for a Pydantic BaseModel. - - This function does the following: - 1. Iterates through fields of the model and creates XML elements for each field - 2. If a field is a Pydantic BaseModel, it creates a nested XML element - - Args: - model: The Pydantic BaseModel to create an XML element for - element: The XML element to add the fields to. If None, a new XML element - exclude_subfields: List of fields to exclude from the XML element. - - Returns: - The XML element with the fields added - """ - if exclude_subfields is None: - exclude_subfields = [] - - if element is None: - element_ = E("object") - else: - element_ = element - - # Extract pydantic validators from the model and add them as guardrails validators - model_fields = add_pydantic_validators_as_guardrails_validators(model) - - # Add fields to the XML element, except for fields with `when` attribute - for field_name, field in model_fields.items(): - if field_name in exclude_subfields: - continue - field_element = create_xml_element_for_field(field, field_name) - element_.append(field_element) - - return element_ - - def add_validator( *fields: str, pre: bool = False, @@ -492,6 +231,14 @@ def process_validators(vals, fld): model_fields = {} for field_name, field in model.__fields__.items(): field_copy = deepcopy(field) + + if "validators" in field.field_info.extra and not isinstance( + field.field_info.extra["validators"], list + ): + field_copy.field_info.extra["validators"] = [ + field_copy.field_info.extra["validators"] + ] + process_validators(field.pre_validators, field_copy) process_validators(field.post_validators, field_copy) model_fields[field_name] = field_copy @@ -526,3 +273,136 @@ class BareModel(BaseModel): fn_params["description"] = json_schema["description"] return fn_params + + +def field_to_datatype(field: Union[ModelField, Type]) -> Type[DataType]: + """Map a type_annotation to the name of the corresponding field type. + + This function checks if the type_annotation is a list, dict, or a + primitive type, and returns the corresponding type name, e.g. + "list", "object", "bool", "date", etc. + """ + + # FIXME: inaccessible datatypes: + # - Email + # - SQLCode + # - Percentage + + # Get the type annotation from the type_annotation + type_annotation = prepare_type_annotation(field) + + # Map the type annotation to the corresponding field type + if is_list(type_annotation): + return ListDataType + elif is_dict(type_annotation): + return ObjectDataType + elif type_annotation == bool: + return BooleanDataType + elif type_annotation == date: + return DateDataType + elif type_annotation == float: + return FloatDataType + elif type_annotation == int: + return IntegerDataType + elif type_annotation == str or typing.get_origin(type_annotation) == typing.Literal: + return StringDataType + elif type_annotation == time: + return TimeDataType + elif typing.get_origin(type_annotation) == Union: + return ChoiceDataType + else: + raise ValueError(f"Unsupported type: {type_annotation}") + + +T = typing.TypeVar("T", bound=DataType) + + +def convert_pydantic_model_to_datatype( + model_field: Union[ModelField, Type[BaseModel]], + datatype: Type[T] = ObjectDataType, + excluded_fields: Optional[typing.List[str]] = None, + name: Optional[str] = None, + strict: bool = False, +) -> T: + """Create an Object from a Pydantic model.""" + if excluded_fields is None: + excluded_fields = [] + + if isinstance(model_field, ModelField): + model = model_field.type_ + else: + model = model_field + + model_fields = add_pydantic_validators_as_guardrails_validators(model) + + children = {} + for field_name, field in model_fields.items(): + if field_name in excluded_fields: + continue + type_annotation = prepare_type_annotation(field) + target_datatype = field_to_datatype(field) + if target_datatype == ListDataType: + inner_type = get_args(type_annotation) + if len(inner_type) == 0: + # If the list is empty, we cannot infer the type of the elements + children[field_name] = ListDataType.from_pydantic_field( + field, strict=strict + ) + inner_type = inner_type[0] + if is_pydantic_base_model(inner_type): + child = convert_pydantic_model_to_datatype(inner_type) + else: + inner_target_datatype = field_to_datatype(inner_type) + child = inner_target_datatype.from_pydantic_field( + inner_type, strict=strict + ) + children[field_name] = ListDataType.from_pydantic_field( + field, children={"item": child}, strict=strict + ) + elif target_datatype == ChoiceDataType: + discriminator = field.discriminator_key or "discriminator" + choice_children = {} + for case in typing.get_args(field.type_): + case_discriminator_type = case.__fields__[discriminator].type_ + assert typing.get_origin(case_discriminator_type) is typing.Literal + assert len(typing.get_args(case_discriminator_type)) == 1 + discriminator_value = typing.get_args(case_discriminator_type)[0] + choice_children[ + discriminator_value + ] = convert_pydantic_model_to_datatype( + case, + datatype=CaseDataType, + name=discriminator_value, + strict=strict, + excluded_fields=[discriminator], + ) + children[field_name] = Choice.from_pydantic_field( + field, + children=choice_children, + strict=strict, + discriminator_key=discriminator, + ) + elif isinstance(field.type_, type) and issubclass(field.type_, BaseModel): + children[field_name] = convert_pydantic_model_to_datatype( + field, datatype=target_datatype, strict=strict + ) + else: + children[field_name] = target_datatype.from_pydantic_field( + field, strict=strict + ) + + if isinstance(model_field, ModelField): + return datatype.from_pydantic_field( + model_field, + children=children, + strict=strict, + ) + else: + format_attr = FormatAttr.from_validators([], ObjectDataType.tag, strict) + return datatype( + children=children, + format_attr=format_attr, + optional=False, + name=name, + description=None, + ) diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index abba4295f..574b47451 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -1,6 +1,6 @@ import inspect from collections import defaultdict -from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union from pydantic import BaseModel, Field @@ -258,6 +258,10 @@ def to_xml_attrib(self): params = " ".join(validator_args) return f"{self.rail_alias}: {params}" + def get_args(self): + """Get the arguments for the validator.""" + return self._kwargs + def __call__(self, value): result = self.validate(value, {}) if isinstance(result, FailResult): @@ -273,3 +277,6 @@ def __eq__(self, other): if not isinstance(other, Validator): return False return self.to_prompt() == other.to_prompt() + + +ValidatorSpec = Union[Validator, Tuple[Union[Validator, str, Callable], str]] diff --git a/tests/integration_tests/test_assets/python_rail/compiled_prompt_1.txt b/tests/integration_tests/test_assets/python_rail/compiled_prompt_1.txt index 83d7de3a5..7d9f17df2 100644 --- a/tests/integration_tests/test_assets/python_rail/compiled_prompt_1.txt +++ b/tests/integration_tests/test_assets/python_rail/compiled_prompt_1.txt @@ -12,7 +12,7 @@ Given below is XML that describes the information to extract from this document