diff --git a/guardrails/datatypes.py b/guardrails/datatypes.py
index 54a60d317..196362d2c 100644
--- a/guardrails/datatypes.py
+++ b/guardrails/datatypes.py
@@ -3,20 +3,19 @@
import warnings
from dataclasses import dataclass
from types import SimpleNamespace
-from typing import TYPE_CHECKING, Any, Dict, Iterable
+from typing import Any, Dict, Iterable
from typing import List as TypedList
-from typing import Optional, Type, TypeVar, Union
+from typing import Optional, Sequence, Type, TypeVar, Union
from dateutil.parser import parse
from lxml import etree as ET
+from pydantic.fields import ModelField
from typing_extensions import Self
+from guardrails.formatattr import FormatAttr
from guardrails.utils.casting_utils import to_float, to_int, to_string
from guardrails.utils.xml_utils import cast_xml_to_string
-from guardrails.validator_base import Validator
-
-if TYPE_CHECKING:
- from guardrails.schema import FormatAttr
+from guardrails.validator_base import Validator, ValidatorSpec
logger = logging.getLogger(__name__)
@@ -64,7 +63,7 @@ class DataType:
def __init__(
self,
children: Dict[str, Any],
- format_attr: "FormatAttr",
+ format_attr: FormatAttr,
optional: bool,
name: Optional[str],
description: Optional[str],
@@ -120,12 +119,10 @@ def set_children_from_xml(self, element: ET._Element):
@classmethod
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
- from guardrails.schema import FormatAttr
-
# TODO: don't want to pass strict through to DataType,
- # but need to pass it to FormatAttr.from_element
+ # but need to pass it to FormatAttr.from_xml
# how to handle this?
- format_attr = FormatAttr.from_element(element, cls.tag, strict)
+ format_attr = FormatAttr.from_xml(element, cls.tag, strict)
is_optional = element.attrib.get("required", "true") == "false"
@@ -141,6 +138,28 @@ def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
data_type.set_children_from_xml(element)
return data_type
+ @classmethod
+ def from_pydantic_field(
+ cls,
+ field: ModelField,
+ children: Optional[Dict[str, "DataType"]] = None,
+ strict: bool = False,
+ **kwargs,
+ ) -> Self:
+ if children is None:
+ children = {}
+
+ validators = field.field_info.extra.get("validators", [])
+ format_attr = FormatAttr.from_validators(validators, cls.tag, strict)
+
+ is_optional = field.required is False
+
+ name = field.name
+ description = field.field_info.description
+
+ data_type = cls(children, format_attr, is_optional, name, description, **kwargs)
+ return data_type
+
@property
def children(self) -> SimpleNamespace:
"""Return a SimpleNamespace of the children of this DataType."""
@@ -175,6 +194,7 @@ def deprecate_type(cls: type):
versions 0.3.0 and beyond. Use the pydantic 'str' primitive instead.""",
DeprecationWarning,
)
+ return cls
class ScalarType(DataType):
@@ -197,6 +217,21 @@ def from_str(self, s: str) -> Optional[str]:
"""Create a String from a string."""
return to_string(s)
+ @classmethod
+ def from_string_rail(
+ cls,
+ validators: Sequence[ValidatorSpec],
+ description: Optional[str] = None,
+ strict: bool = False,
+ ) -> Self:
+ return cls(
+ children={},
+ format_attr=FormatAttr.from_validators(validators, cls.tag, strict),
+ optional=False,
+ name=None,
+ description=description,
+ )
+
@register_type("integer")
class Integer(ScalarType):
@@ -326,6 +361,10 @@ class Email(ScalarType):
tag = "email"
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ deprecate_type(type(self))
+
@deprecate_type
@register_type("url")
@@ -334,6 +373,10 @@ class URL(ScalarType):
tag = "url"
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ deprecate_type(type(self))
+
@deprecate_type
@register_type("pythoncode")
@@ -342,6 +385,10 @@ class PythonCode(ScalarType):
tag = "pythoncode"
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ deprecate_type(type(self))
+
@deprecate_type
@register_type("sql")
@@ -350,6 +397,10 @@ class SQLCode(ScalarType):
tag = "sql"
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ deprecate_type(type(self))
+
@register_type("percentage")
class Percentage(ScalarType):
diff --git a/guardrails/formatattr.py b/guardrails/formatattr.py
new file mode 100644
index 000000000..9c0de5258
--- /dev/null
+++ b/guardrails/formatattr.py
@@ -0,0 +1,320 @@
+import re
+import warnings
+from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
+
+import lxml.etree as ET
+import pydantic
+
+from guardrails.utils.xml_utils import cast_xml_to_string
+from guardrails.validator_base import Validator, ValidatorSpec
+
+
+class FormatAttr(pydantic.BaseModel):
+ """Class for parsing and manipulating the `format` attribute of an element.
+
+ The format attribute is a string that contains semi-colon separated
+ validators e.g. "valid-url; is-reachable". Each validator is itself either:
+ - the name of an parameter-less validator, e.g. "valid-url"
+ - the name of a validator with parameters, separated by a colon with a
+ space-separated list of parameters, e.g. "is-in: 1 2 3"
+
+ Parameters can either be written in plain text, or in python expressions
+ enclosed in curly braces. For example, the following are all valid:
+ - "is-in: 1 2 3"
+ - "is-in: {1} {2} {3}"
+ - "is-in: {1 + 2} {2 + 3} {3 + 4}"
+ """
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ # The format attribute string.
+ format: Optional[str]
+
+ # The on-fail handlers.
+ on_fail_handlers: Dict[str, str]
+
+ # The validator arguments.
+ validator_args: Mapping[str, Union[Dict[str, Any], List[Any]]]
+
+ # The validators.
+ validators: List[Validator]
+
+ # The unregistered validators.
+ unregistered_validators: List[str]
+
+ @property
+ def empty(self) -> bool:
+ """Return True if the format attribute is empty, False otherwise."""
+ return not self.validators and not self.unregistered_validators
+
+ @classmethod
+ def from_validators(
+ cls,
+ validators: Sequence[ValidatorSpec],
+ tag: str,
+ strict: bool = False,
+ ):
+ validators_with_args = {}
+ on_fails = {}
+ for val in validators:
+ # must be either a tuple with two elements or a gd.Validator
+ if isinstance(val, Validator):
+ # `validator` is of type gd.Validator, use the to_xml_attrib method
+ validator_name = val.rail_alias
+ validator_args = val.get_args()
+ validators_with_args[validator_name] = validator_args
+ # Set the on-fail attribute based on the on_fail value
+ on_fail = val.on_fail_descriptor
+ on_fails[val.rail_alias] = on_fail
+ elif isinstance(val, tuple) and len(val) == 2:
+ validator, on_fail = val
+ if isinstance(validator, Validator):
+ # `validator` is of type gd.Validator, use the to_xml_attrib method
+ validator_name = validator.rail_alias
+ validator_args = validator.get_args()
+ validators_with_args[validator_name] = validator_args
+ # Set the on-fail attribute based on the on_fail value
+ on_fails[validator.rail_alias] = on_fail
+ elif isinstance(validator, str):
+ # `validator` is a string, use it as the validator prompt
+ validator_name = validator
+ validator_args = []
+ validators_with_args[validator_name] = validator_args
+ on_fails[validator] = on_fail
+ elif isinstance(validator, Callable):
+ # `validator` is a callable, use it as the validator prompt
+ if not hasattr(validator, "rail_alias"):
+ raise ValueError(
+ f"Validator {validator.__name__} must be registered with "
+ f"the gd.register_validator decorator"
+ )
+ validator_name = validator.rail_alias
+ validator_args = []
+ validators_with_args[validator_name] = validator_args
+ on_fails[validator.rail_alias] = on_fail
+ else:
+ raise ValueError(
+ f"Validator tuple {val} must be a (validator, on_fail) tuple, "
+ f"where the validator is a string or a callable"
+ )
+ else:
+ raise ValueError(
+ f"Validator {val} must be a (validator, on_fail) tuple or "
+ f"Validator class instance"
+ )
+
+ registered_validators, unregistered_validators = cls.get_validators(
+ validator_args=validators_with_args,
+ tag=tag,
+ on_fail_handlers=on_fails,
+ strict=strict,
+ )
+
+ return cls(
+ format=None,
+ on_fail_handlers=on_fails,
+ validator_args=validators_with_args,
+ validators=registered_validators,
+ unregistered_validators=unregistered_validators,
+ )
+
+ @classmethod
+ def from_xml(
+ cls, element: ET._Element, tag: str, strict: bool = False
+ ) -> "FormatAttr":
+ """Create a FormatAttr object from an XML element.
+
+ Args:
+ element (ET._Element): The XML element.
+
+ Returns:
+ A FormatAttr object.
+ """
+ format_str = element.get("format")
+ if format_str is None:
+ return cls(
+ format=None,
+ on_fail_handlers={},
+ validator_args={},
+ validators=[],
+ unregistered_validators=[],
+ )
+
+ validator_args = cls.parse(format_str)
+
+ on_fail_handlers = {}
+ for key, value in element.attrib.items():
+ key = cast_xml_to_string(key)
+ if key.startswith("on-fail-"):
+ on_fail_handler_name = key[len("on-fail-") :]
+ on_fail_handler = value
+ on_fail_handlers[on_fail_handler_name] = on_fail_handler
+
+ validators, unregistered_validators = cls.get_validators(
+ validator_args=validator_args,
+ tag=tag,
+ on_fail_handlers=on_fail_handlers,
+ strict=strict,
+ )
+
+ return cls(
+ format=format_str,
+ on_fail_handlers=on_fail_handlers,
+ validator_args=validator_args,
+ validators=validators,
+ unregistered_validators=unregistered_validators,
+ )
+
+ @staticmethod
+ def parse_token(token: str) -> Tuple[str, List[Any]]:
+ """Parse a single token in the format attribute, and return the
+ validator name and the list of arguments.
+
+ Args:
+ token (str): The token to parse, one of the tokens returned by
+ `self.tokens`.
+
+ Returns:
+ A tuple of the validator name and the list of arguments.
+ """
+ validator_with_args = token.strip().split(":", 1)
+ if len(validator_with_args) == 1:
+ return validator_with_args[0].strip(), []
+
+ validator, args_token = validator_with_args
+
+ # Split using whitespace as a delimiter, but not if it is inside curly braces or
+ # single quotes.
+ pattern = re.compile(r"\s(?![^{}]*})|(? Dict[str, List[Any]]:
+ """Parse the format attribute into a dictionary of validators.
+
+ Returns:
+ A dictionary of validators, where the key is the validator name, and
+ the value is a list of arguments.
+ """
+ # Split the format attribute into tokens: each is a validator.
+ # Then, parse each token into a validator name and a list of parameters.
+ pattern = re.compile(r";(?![^{}]*})")
+ tokens = re.split(pattern, format_string)
+ tokens = list(filter(None, tokens))
+
+ validators = {}
+ for token in tokens:
+ # Parse the token into a validator name and a list of parameters.
+ validator_name, args = FormatAttr.parse_token(token)
+ validators[validator_name] = args
+
+ return validators
+
+ @staticmethod
+ def get_validators(
+ validator_args: Dict[str, List[Any]],
+ tag: str,
+ on_fail_handlers: Dict[str, str],
+ strict: bool = False,
+ ) -> Tuple[List[Validator], List[str]]:
+ """Get the list of validators from the format attribute. Only the
+ validators that are registered for this element will be returned.
+
+ For example, if the format attribute is "valid-url; is-reachable", and
+ "is-reachable" is not registered for this element, then only the ValidUrl
+ validator will be returned, after instantiating it with the arguments
+ specified in the format attribute (if any).
+
+ Args:
+ strict: If True, raise an error if a validator is not registered for
+ this element. If False, ignore the validator and print a warning.
+
+ Returns:
+ A list of validators.
+ """
+ from guardrails.validator_base import types_to_validators, validators_registry
+
+ _validators = []
+ _unregistered_validators = []
+ for validator_name, args in validator_args.items():
+ # Check if the validator is registered for this element.
+ # The validators in `format` that are not registered for this element
+ # will be ignored (with an error or warning, depending on the value of
+ # `strict`), and the registered validators will be returned.
+ if validator_name not in types_to_validators[tag]:
+ if strict:
+ raise ValueError(
+ f"Validator {validator_name} is not valid for"
+ f" element {tag}."
+ )
+ else:
+ warnings.warn(
+ f"Validator {validator_name} is not valid for"
+ f" element {tag}."
+ )
+ _unregistered_validators.append(validator_name)
+ continue
+
+ validator = validators_registry[validator_name]
+
+ # See if the formatter has an associated on_fail method.
+ on_fail = on_fail_handlers.get(validator_name, None)
+ # TODO(shreya): Load the on_fail method.
+ # This method should be loaded from an optional script given at the
+ # beginning of a rail file.
+
+ # Create the validator.
+ if isinstance(args, list):
+ v = validator(*args, on_fail=on_fail)
+ elif isinstance(args, dict):
+ v = validator(**args, on_fail=on_fail)
+ else:
+ raise ValueError(
+ f"Validator {validator_name} has invalid arguments: {args}."
+ )
+ _validators.append(v)
+
+ return _validators, _unregistered_validators
+
+ def to_prompt(self, with_keywords: bool = True) -> str:
+ """Convert the format string to another string representation for use
+ in prompting. Uses the validators' to_prompt method in order to
+ construct the string to use in prompting.
+
+ For example, the format string "valid-url; other-validator: 1.0
+ {1 + 2}" will be converted to "valid-url other-validator:
+ arg1=1.0 arg2=3".
+ """
+ if self.empty:
+ return ""
+ # Use the validators' to_prompt method to convert the format string to
+ # another string representation.
+ prompt = "; ".join([v.to_prompt(with_keywords) for v in self.validators])
+ unreg_prompt = "; ".join(self.unregistered_validators)
+ if prompt and unreg_prompt:
+ prompt += f"; {unreg_prompt}"
+ elif unreg_prompt:
+ prompt += unreg_prompt
+ return prompt
diff --git a/guardrails/guard.py b/guardrails/guard.py
index 99f6a72f5..f0156a650 100644
--- a/guardrails/guard.py
+++ b/guardrails/guard.py
@@ -8,6 +8,7 @@
Dict,
List,
Optional,
+ Sequence,
Tuple,
Type,
Union,
@@ -191,7 +192,7 @@ def from_pydantic(
@classmethod
def from_string(
cls,
- validators: List[Validator],
+ validators: Sequence[Validator],
description: Optional[str] = None,
prompt: Optional[str] = None,
instructions: Optional[str] = None,
diff --git a/guardrails/rail.py b/guardrails/rail.py
index fcded5170..ca97c6fa4 100644
--- a/guardrails/rail.py
+++ b/guardrails/rail.py
@@ -1,20 +1,15 @@
"""Rail class."""
import warnings
from dataclasses import dataclass
-from typing import List, Optional, Type
+from typing import Optional, Sequence, Type
from lxml import etree as ET
-from lxml.etree import Element, SubElement
from pydantic import BaseModel
from guardrails.prompt import Instructions, Prompt
from guardrails.schema import JsonSchema, Schema, StringSchema
-from guardrails.utils.pydantic_utils import (
- attach_validators_to_element,
- create_xml_element_for_base_model,
-)
from guardrails.utils.xml_utils import cast_xml_to_string
-from guardrails.validators import Validator
+from guardrails.validator_base import ValidatorSpec
# TODO: Logging
XMLPARSER = ET.XMLParser(encoding="utf-8")
@@ -49,14 +44,20 @@ def from_pydantic(
reask_prompt: Optional[str] = None,
reask_instructions: Optional[str] = None,
):
- xml = generate_xml_code(
- output_class=output_class,
- prompt=prompt,
- instructions=instructions,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
+ input_schema = None
+
+ output_schema = cls.load_json_schema_from_pydantic(
+ output_class,
+ reask_prompt_template=reask_prompt,
+ reask_instructions_template=reask_instructions,
+ )
+
+ return cls(
+ input_schema=input_schema,
+ output_schema=output_schema,
+ instructions=cls.load_instructions(instructions, output_schema),
+ prompt=cls.load_prompt(prompt, output_schema),
)
- return cls.from_xml(xml)
@classmethod
def from_file(cls, file_path: str) -> "Rail":
@@ -82,7 +83,7 @@ def from_xml(cls, xml: ET._Element):
# No input schema, so do no input checking.
input_schema = None
else:
- input_schema = cls.load_input_schema(raw_input_schema)
+ input_schema = cls.load_input_schema_from_xml(raw_input_schema)
# Load schema
raw_output_schema = xml.find("output")
@@ -97,7 +98,7 @@ def from_xml(cls, xml: ET._Element):
reask_instructions = xml.find("reask_instructions")
if reask_instructions is not None:
reask_instructions = reask_instructions.text
- output_schema = cls.load_output_schema(
+ output_schema = cls.load_output_schema_from_xml(
raw_output_schema,
reask_prompt=reask_prompt,
reask_instructions=reask_instructions,
@@ -108,14 +109,14 @@ def from_xml(cls, xml: ET._Element):
# prepended to the prompt.
instructions = xml.find("instructions")
if instructions is not None:
- instructions = cls.load_instructions(instructions, output_schema)
+ instructions = cls.load_instructions(instructions.text, output_schema)
# Load
prompt = xml.find("prompt")
if prompt is None:
warnings.warn("Prompt must be provided during __call__.")
else:
- prompt = cls.load_prompt(prompt, output_schema)
+ prompt = cls.load_prompt(prompt.text, output_schema)
# Get version
version = xml.attrib["version"]
@@ -132,37 +133,37 @@ def from_xml(cls, xml: ET._Element):
@classmethod
def from_string_validators(
cls,
- validators: List[Validator],
+ validators: Sequence[ValidatorSpec],
description: Optional[str] = None,
prompt: Optional[str] = None,
instructions: Optional[str] = None,
reask_prompt: Optional[str] = None,
reask_instructions: Optional[str] = None,
):
- xml = generate_xml_code(
- prompt=prompt,
- instructions=instructions,
- reask_prompt=reask_prompt,
- reask_instructions=reask_instructions,
- validators=validators,
+ input_schema = None
+
+ output_schema = cls.load_string_schema_from_string(
+ validators,
description=description,
+ reask_prompt_template=reask_prompt,
+ reask_instructions_template=reask_instructions,
)
- return cls.from_xml(xml)
- @staticmethod
- def load_schema(root: ET._Element) -> Schema:
- """Given the RAIL or