Skip to content

Commit

Permalink
Rail XML: Rename "format" attr to "validators" (#439)
Browse files Browse the repository at this point in the history
* Rail XML: Rename "format" attr to "validators"

* added note that `format` will be removed in 0.4.x

* test_rail: Check format deprecation resolves properly
irgolic authored Nov 21, 2023
1 parent 58d2316 commit d677d51
Showing 16 changed files with 100 additions and 66 deletions.
32 changes: 16 additions & 16 deletions guardrails/datatypes.py
Original file line number Diff line number Diff line change
@@ -11,10 +11,10 @@
from lxml import etree as ET
from typing_extensions import Self

from guardrails.formatattr import FormatAttr
from guardrails.utils.casting_utils import to_float, to_int, to_string
from guardrails.utils.xml_utils import cast_xml_to_string
from guardrails.validator_base import Validator, ValidatorSpec
from guardrails.validatorsattr import ValidatorsAttr

logger = logging.getLogger(__name__)

@@ -62,20 +62,20 @@ class DataType:
def __init__(
self,
children: Dict[str, Any],
format_attr: FormatAttr,
validators_attr: ValidatorsAttr,
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
self._children = children
self.format_attr = format_attr
self.validators_attr = validators_attr
self.name = name
self.description = description
self.optional = optional

@property
def validators(self) -> TypedList:
return self.format_attr.validators
return self.validators_attr.validators

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._children})"
@@ -119,9 +119,9 @@ def set_children_from_xml(self, element: ET._Element):
@classmethod
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
# TODO: don't want to pass strict through to DataType,
# but need to pass it to FormatAttr.from_xml
# but need to pass it to ValidatorsAttr.from_element
# how to handle this?
format_attr = FormatAttr.from_xml(element, cls.tag, strict)
validators_attr = ValidatorsAttr.from_xml(element, cls.tag, strict)

is_optional = element.attrib.get("required", "true") == "false"

@@ -133,7 +133,7 @@ def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
if description is not None:
description = cast_xml_to_string(description)

data_type = cls({}, format_attr, is_optional, name, description, **kwargs)
data_type = cls({}, validators_attr, is_optional, name, description, **kwargs)
data_type.set_children_from_xml(element)
return data_type

@@ -203,7 +203,7 @@ def from_string_rail(
) -> Self:
return cls(
children={},
format_attr=FormatAttr.from_validators(validators, cls.tag, strict),
validators_attr=ValidatorsAttr.from_validators(validators, cls.tag, strict),
optional=False,
name=None,
description=description,
@@ -267,12 +267,12 @@ class Date(ScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)
self.date_format = None

def from_str(self, s: str) -> Optional[datetime.date]:
@@ -306,13 +306,13 @@ class Time(ScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
self.time_format = "%H:%M:%S"
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)

def from_str(self, s: str) -> Optional[datetime.time]:
"""Create a Time from a string."""
@@ -486,13 +486,13 @@ class Choice(NonScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
discriminator_key: str,
) -> None:
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)
self.discriminator_key = discriminator_key

@classmethod
@@ -548,12 +548,12 @@ class Case(NonScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)

def collect_validation(
self,
8 changes: 4 additions & 4 deletions guardrails/schema.py
Original file line number Diff line number Diff line change
@@ -778,11 +778,11 @@ def transpile(self, method: str = "default") -> str:
"Here's a description of what I want you to generate: "
f"{obj.description}"
)
if not obj.format_attr.empty:
if not obj.validators_attr.empty:
schema += (
"\n\nYour generated response should satisfy the following properties:"
)
for validator in obj.format_attr.validators:
for validator in obj.validators_attr.validators:
schema += f"\n- {validator.to_prompt()}"

schema += "\n\nDon't talk; just go."
@@ -816,8 +816,8 @@ def datatypes_to_xml(
if dt.description:
el.attrib["description"] = dt.description

if dt.format_attr:
format_prompt = dt.format_attr.to_prompt()
if dt.validators_attr:
format_prompt = dt.validators_attr.to_prompt()
if format_prompt:
el.attrib["format"] = format_prompt

6 changes: 3 additions & 3 deletions guardrails/utils/pydantic_utils/v1.py
Original file line number Diff line number Diff line change
@@ -20,8 +20,8 @@
from guardrails.datatypes import Object as ObjectDataType
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.formatattr import FormatAttr
from guardrails.validator_base import Validator
from guardrails.validatorsattr import ValidatorsAttr


class ArbitraryModel(BaseModel):
@@ -426,5 +426,5 @@ def construct_datatype(
if validators is None:
validators = []

format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
return datatype(children, format_attr, optional, name, description, **kwargs)
validators_attr = ValidatorsAttr.from_validators(validators, datatype.tag, strict)
return datatype(children, validators_attr, optional, name, description, **kwargs)
6 changes: 3 additions & 3 deletions guardrails/utils/pydantic_utils/v2.py
Original file line number Diff line number Diff line change
@@ -21,8 +21,8 @@
from guardrails.datatypes import PythonCode as PythonCodeDataType
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.formatattr import FormatAttr
from guardrails.validator_base import Validator
from guardrails.validatorsattr import ValidatorsAttr

DataTypeT = TypeVar("DataTypeT", bound=DataType)

@@ -462,5 +462,5 @@ def construct_datatype(
if validators is None:
validators = []

format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
return datatype(children, format_attr, optional, name, description, **kwargs)
validators_attr = ValidatorsAttr.from_validators(validators, datatype.tag, strict)
return datatype(children, validators_attr, optional, name, description, **kwargs)
31 changes: 20 additions & 11 deletions guardrails/formatattr.py → guardrails/validatorsattr.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from guardrails.validator_base import Validator, ValidatorSpec


class FormatAttr(pydantic.BaseModel):
class ValidatorsAttr(pydantic.BaseModel):
"""Class for parsing and manipulating the `format` attribute of an element.
The format attribute is a string that contains semi-colon separated
@@ -29,7 +29,7 @@ class Config:
arbitrary_types_allowed = True

# The format attribute string.
format: Optional[str]
validators_spec: Optional[str]

# The on-fail handlers.
on_fail_handlers: Mapping[str, Union[str, Callable]]
@@ -115,7 +115,7 @@ def from_validators(
)

return cls(
format=None,
validators_spec=None,
on_fail_handlers=on_fails,
validator_args=validators_with_args,
validators=registered_validators,
@@ -125,26 +125,35 @@ def from_validators(
@classmethod
def from_xml(
cls, element: ET._Element, tag: str, strict: bool = False
) -> "FormatAttr":
"""Create a FormatAttr object from an XML element.
) -> "ValidatorsAttr":
"""Create a ValidatorsAttr object from an XML element.
Args:
element (ET._Element): The XML element.
Returns:
A FormatAttr object.
A ValidatorsAttr object.
"""
validators_str = element.get("validators")
format_str = element.get("format")
if format_str is None:
if format_str is not None:
warnings.warn(
"Attribute `format` is deprecated and will be removed in 0.4.x. "
"Use `validators` instead.",
DeprecationWarning,
)
validators_str = format_str

if validators_str is None:
return cls(
format=None,
validators_spec=None,
on_fail_handlers={},
validator_args={},
validators=[],
unregistered_validators=[],
)

validator_args = cls.parse(format_str)
validator_args = cls.parse(validators_str)

on_fail_handlers = {}
for key, value in element.attrib.items():
@@ -162,7 +171,7 @@ def from_xml(
)

return cls(
format=format_str,
validators_spec=validators_str,
on_fail_handlers=on_fail_handlers,
validator_args=validator_args,
validators=validators,
@@ -230,7 +239,7 @@ def parse(format_string: str) -> Dict[str, List[Any]]:
validators = {}
for token in tokens:
# Parse the token into a validator name and a list of parameters.
validator_name, args = FormatAttr.parse_token(token)
validator_name, args = ValidatorsAttr.parse_token(token)
validators[validator_name] = args

return validators
2 changes: 1 addition & 1 deletion tests/integration_tests/test_async.py
Original file line number Diff line number Diff line change
@@ -239,7 +239,7 @@ def string_rail_spec():
<rail version="0.1">
<output
type="string"
format="two-words"
validators="two-words"
on-fail-two-words="fix"
/>
<prompt>
2 changes: 1 addition & 1 deletion tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
<output>
<object name="patient_info">
<string name="gender" description="Patient's gender" />
<integer name="age" format="valid-range: 0 100" />
<integer name="age" validators="valid-range: 0 100" />
<string name="symptoms" description="Symptoms that the patient is currently experiencing" />
</object>
</output>
6 changes: 3 additions & 3 deletions tests/integration_tests/test_data_validation.py
Original file line number Diff line number Diff line change
@@ -33,11 +33,11 @@ def test_choice_validation(llm_output, raises):
<output>
<choice name="choice" on-fail-choice="exception" discriminator="action">
<case name="fight">
<string name="fight_move" format="valid-choices: {['punch','kick','headbutt']}" on-fail-valid-choices="exception" />
<string name="fight_move" validators="valid-choices: {['punch','kick','headbutt']}" on-fail-valid-choices="exception" />
</case>
<case name="flight">
<string name="flight_direction" format="valid-choices: {['north','south','east','west']}" on-fail-valid-choices="exception" />
<integer name="flight_speed" format="valid-choices: {[1,2,3,4]}" on-fail-valid-choices="exception" />
<string name="flight_direction" validators="valid-choices: {['north','south','east','west']}" on-fail-valid-choices="exception" />
<integer name="flight_speed" validators="valid-choices: {[1,2,3,4]}" on-fail-valid-choices="exception" />
</case>
</choice>
</output>
6 changes: 3 additions & 3 deletions tests/integration_tests/test_schema_to_prompt.py
Original file line number Diff line number Diff line change
@@ -13,20 +13,20 @@ def test_choice_schema():
<case name="fight">
<string
name="fight_move"
format="valid-choices: {['punch','kick','headbutt']}"
validators="valid-choices: {['punch','kick','headbutt']}"
on-fail-valid-choices="exception"
/>
</case>
<case name="flight">
<object name="flight">
<string
name="flight_direction"
format="valid-choices: {['north','south','east','west']}"
validators="valid-choices: {['north','south','east','west']}"
on-fail-valid-choices="exception"
/>
<integer
name="flight_speed"
format="valid-choices: {[1,2,3,4]}"
validators="valid-choices: {[1,2,3,4]}"
on-fail-valid-choices="exception"
/>
</object>
2 changes: 1 addition & 1 deletion tests/integration_tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ def embed_function(text: str):
# Check types remain intact
output_schema: StringSchema = guard.rail.output_schema
data_type: DataType = output_schema.root_datatype
validators = data_type.format_attr.validators
validators = data_type.validators_attr.validators
validator: SimilarToList = validators[0]

assert isinstance(validator._standard_deviations, int)
4 changes: 2 additions & 2 deletions tests/unit_tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
from lxml.builder import E

import guardrails.datatypes as datatypes
from guardrails.formatattr import FormatAttr
from guardrails.validatorsattr import ValidatorsAttr


@pytest.mark.parametrize(
@@ -26,7 +26,7 @@
],
)
def test_get_args(input_string, expected):
_, args = FormatAttr.parse_token(input_string)
_, args = ValidatorsAttr.parse_token(input_string)
assert args == expected


8 changes: 4 additions & 4 deletions tests/unit_tests/test_guard.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ def validate(self, value, metadata):
"""
<rail version="0.1">
<output>
<string name="string_name" format="myrequiringvalidator" />
<string name="string_name" validators="myrequiringvalidator" />
</output>
</rail>
""",
@@ -43,10 +43,10 @@ def validate(self, value, metadata):
<rail version="0.1">
<output>
<object name="temp_name">
<string name="string_name" format="myrequiringvalidator" />
<string name="string_name" validators="myrequiringvalidator" />
</object>
<list name="list_name">
<string name="string_name" format="myrequiringvalidator2" />
<string name="string_name" validators="myrequiringvalidator2" />
</list>
</output>
</rail>
@@ -64,7 +64,7 @@ def validate(self, value, metadata):
<string name="string_name" />
</case>
<case name="hiya">
<string name="string_name" format="myrequiringvalidator" />
<string name="string_name" validators="myrequiringvalidator" />
</case>
</choice>
</list>
25 changes: 25 additions & 0 deletions tests/unit_tests/test_rail.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from guardrails.rail import Rail


@@ -147,3 +149,26 @@ def test_rail_list_with_object():
</rail>
"""
Rail.from_string(rail_spec)


def test_format_deprecated():
rail_spec = """
<rail version="0.1">
<output>
<string name="string_name" format="two-words"/>
</output>
<instructions>
Hello world
</instructions>
<prompt>
Hello world
</prompt>
</rail>
"""
with pytest.warns(DeprecationWarning):
rail = Rail.from_string(rail_spec)
validator = rail.output_schema.root_datatype.children.string_name.validators[0]
assert validator.rail_alias == "two-words"
8 changes: 4 additions & 4 deletions tests/unit_tests/test_reask_utils.py
Original file line number Diff line number Diff line change
@@ -20,14 +20,14 @@
empty_root = Element("root")
non_empty_root = Element("root")
property = SubElement(non_empty_root, "list", name="dummy")
property.attrib["format"] = "length: 2"
property.attrib["validators"] = "length: 2"
child = SubElement(property, "string")
child.attrib["format"] = "two-words"
child.attrib["validators"] = "two-words"
non_empty_output = Element("root")
output_property = SubElement(non_empty_output, "list", name="dummy")
output_property.attrib["format"] = "length: 2"
output_property.attrib["validators"] = "length: 2"
output_child = SubElement(output_property, "string")
output_child.attrib["format"] = "two-words"
output_child.attrib["validators"] = "two-words"


@pytest.mark.parametrize(
16 changes: 8 additions & 8 deletions tests/unit_tests/test_skeleton.py
Original file line number Diff line number Diff line change
@@ -87,19 +87,19 @@
<case name="fight">
<string
name="fight_move"
format="valid-choices: {['punch','kick','headbutt']}"
validators="valid-choices: {['punch','kick','headbutt']}"
on-fail-valid-choices="exception"
/>
</case>
<case name="flight">
<string
name="flight_direction"
format="valid-choices: {['north','south','east','west']}"
validators="valid-choices: {['north','south','east','west']}"
on-fail-valid-choices="exception"
/>
<integer
name="flight_speed"
format="valid-choices: {[1,2,3,4]}"
validators="valid-choices: {[1,2,3,4]}"
on-fail-valid-choices="exception"
/>
</case>
@@ -123,20 +123,20 @@
<case name="fight">
<list name="fight">
<string
format="valid-choices: {['punch','kick','headbutt']}"
validators="valid-choices: {['punch','kick','headbutt']}"
on-fail-valid-choices="exception"
/>
</list>
</case>
<case name="flight">
<string
name="flight_direction"
format="valid-choices: {['north','south','east','west']}"
validators="valid-choices: {['north','south','east','west']}"
on-fail-valid-choices="exception"
/>
<integer
name="flight_speed"
format="valid-choices: {[1,2,3,4]}"
validators="valid-choices: {[1,2,3,4]}"
on-fail-valid-choices="exception"
/>
</case>
@@ -169,15 +169,15 @@
<case name="fight">
<string
name="fight_move"
format="valid-choices: {['punch','kick','headbutt']}"
validators="valid-choices: {['punch','kick','headbutt']}"
on-fail-valid-choices="exception"
/>
</case>
<case name="flight">
<object name="flight">
<string
name="flight_direction"
format="valid-choices: {['north','south','east','west']}"
validators="valid-choices: {['north','south','east','west']}"
on-fail-valid-choices="exception"
/>
<integer
4 changes: 2 additions & 2 deletions tests/unit_tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -328,7 +328,7 @@ def test_custom_func_validator():
<rail version="0.1">
<output>
<string name="greeting"
format="mycustomhellovalidator"
validators="mycustomhellovalidator"
on-fail-mycustomhellovalidator="fix"/>
</output>
</rail>
@@ -391,7 +391,7 @@ def test_provenance_v1(mocker):

output_schema: StringSchema = string_guard.rail.output_schema
data_type: DataType = output_schema.root_datatype
validators = data_type.format_attr.validators
validators = data_type.validators_attr.validators
prov_validator: ProvenanceV1 = validators[0]

# Check types remain intact

0 comments on commit d677d51

Please sign in to comment.