diff --git a/guardrails/constants.xml b/guardrails/constants.xml index 187e497b2..eb189dcd3 100644 --- a/guardrails/constants.xml +++ b/guardrails/constants.xml @@ -52,6 +52,12 @@ ${output_schema} ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`. + +${gr.json_suffix_without_examples} +Here's an example of the structure: +${json_example} + + Given below is XML that describes the information to extract from this document and the tags to extract it into. diff --git a/guardrails/datatypes.py b/guardrails/datatypes.py index 6ca697ee9..2cdc18756 100644 --- a/guardrails/datatypes.py +++ b/guardrails/datatypes.py @@ -71,6 +71,9 @@ def __init__( self.description = description self.optional = optional + def get_example(self): + raise NotImplementedError + @property def validators(self) -> TypedList: return self.validators_attr.validators @@ -188,6 +191,9 @@ class String(ScalarType): tag = "string" + def get_example(self): + return "string" + def from_str(self, s: str) -> Optional[str]: """Create a String from a string.""" return to_string(s) @@ -214,6 +220,9 @@ class Integer(ScalarType): tag = "integer" + def get_example(self): + return 1 + def from_str(self, s: str) -> Optional[int]: """Create an Integer from a string.""" return to_int(s) @@ -225,6 +234,9 @@ class Float(ScalarType): tag = "float" + def get_example(self): + return 1.5 + def from_str(self, s: str) -> Optional[float]: """Create a Float from a string.""" return to_float(s) @@ -236,6 +248,9 @@ class Boolean(ScalarType): tag = "bool" + def get_example(self): + return True + def from_str(self, s: Union[str, bool]) -> Optional[bool]: """Create a Boolean from a string.""" if s is None: @@ -273,6 +288,9 @@ def __init__( super().__init__(children, validators_attr, optional, name, description) self.date_format = None + def get_example(self): + return datetime.date.today() + def from_str(self, s: str) -> Optional[datetime.date]: """Create a Date from a string.""" if s is None: @@ -312,6 +330,9 @@ def __init__( self.time_format = "%H:%M:%S" super().__init__(children, validators_attr, optional, name, description) + def get_example(self): + return datetime.time() + def from_str(self, s: str) -> Optional[datetime.time]: """Create a Time from a string.""" if s is None: @@ -340,6 +361,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) deprecate_type(type(self)) + def get_example(self): + return "hello@example.com" + @deprecate_type @register_type("url") @@ -352,6 +376,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) deprecate_type(type(self)) + def get_example(self): + return "https://example.com" + @deprecate_type @register_type("pythoncode") @@ -364,6 +391,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) deprecate_type(type(self)) + def get_example(self): + return "print('hello world')" + @deprecate_type @register_type("sql") @@ -376,6 +406,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) deprecate_type(type(self)) + def get_example(self): + return "SELECT * FROM table" + @register_type("percentage") class Percentage(ScalarType): @@ -383,6 +416,9 @@ class Percentage(ScalarType): tag = "percentage" + def get_example(self): + return "20%" + @register_type("enum") class Enum(ScalarType): @@ -402,6 +438,9 @@ def __init__( super().__init__(children, validators_attr, optional, name, description) self.enum_values = enum_values + def get_example(self): + return self.enum_values[0] + def from_str(self, s: str) -> Optional[str]: """Create an Enum from a string.""" if s is None: @@ -434,6 +473,9 @@ class List(NonScalarType): tag = "list" + def get_example(self): + return [e.get_example() for e in self._children.values()] + def collect_validation( self, key: str, @@ -476,6 +518,9 @@ class Object(NonScalarType): tag = "object" + def get_example(self): + return {k: v.get_example() for k, v in self._children.items()} + def collect_validation( self, key: str, @@ -546,6 +591,14 @@ def __init__( super().__init__(children, validators_attr, optional, name, description) self.discriminator_key = discriminator_key + def get_example(self): + first_discriminator = list(self._children.keys())[0] + first_child = list(self._children.values())[0] + return { + self.discriminator_key: first_discriminator, + **first_child.get_example(), + } + @classmethod def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self: # grab `discriminator` attribute @@ -606,6 +659,9 @@ def __init__( ) -> None: super().__init__(children, validators_attr, optional, name, description) + def get_example(self): + return {k: v.get_example() for k, v in self._children.items()} + def collect_validation( self, key: str, diff --git a/guardrails/schema.py b/guardrails/schema.py index 348dc7dd3..814d526cb 100644 --- a/guardrails/schema.py +++ b/guardrails/schema.py @@ -219,7 +219,7 @@ def check_valid_reask_prompt(self, reask_prompt: Optional[str]) -> None: class JsonSchema(Schema): - reask_prompt_vars = {"previous_response", "output_schema"} + reask_prompt_vars = {"previous_response", "output_schema", "json_example"} def __init__( self, @@ -269,7 +269,7 @@ def get_reask_setup( if reask_prompt_template is None: reask_prompt_template = Prompt( constants["high_level_skeleton_reask_prompt"] - + constants["json_suffix_without_examples"] + + constants["json_suffix_with_structure_example"] ) # This is incorrect @@ -300,6 +300,10 @@ def get_reask_setup( ) pruned_tree_string = pruned_tree_schema.transpile() + json_example = json.dumps( + pruned_tree_schema.root_datatype.get_example(), + indent=2, + ) def reask_decoder(obj): decoded = {} @@ -317,6 +321,7 @@ def reask_decoder(obj): reask_value, indent=2, default=reask_decoder, ensure_ascii=False ), output_schema=pruned_tree_string, + json_example=json_example, **(prompt_params or {}), ) diff --git a/tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_2.txt b/tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_2.txt index e505b6699..3e73e6ca5 100644 --- a/tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_2.txt +++ b/tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_2.txt @@ -69,6 +69,7 @@ I was given the following JSON response, which had problems due to incorrect val Help me correct the incorrect values based on the given error messages. + Given below is XML that describes the information to extract from this document and the tags to extract it into. @@ -85,6 +86,18 @@ Given below is XML that describes the information to extract from this document ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`. +Here's an example of the structure: +{ + "fees": [ + { + "name": "string", + "explanation": "string", + "value": 1.5 + } + ], + "interest_rates": {} +} + Json Output: diff --git a/tests/integration_tests/test_assets/pydantic/msg_compiled_prompt_reask.txt b/tests/integration_tests/test_assets/pydantic/msg_compiled_prompt_reask.txt index 225e652ea..bdbbaaad6 100644 --- a/tests/integration_tests/test_assets/pydantic/msg_compiled_prompt_reask.txt +++ b/tests/integration_tests/test_assets/pydantic/msg_compiled_prompt_reask.txt @@ -13,6 +13,7 @@ I was given the following JSON response, which had problems due to incorrect val Help me correct the incorrect values based on the given error messages. + Given below is XML that describes the information to extract from this document and the tags to extract it into. @@ -23,3 +24,10 @@ Given below is XML that describes the information to extract from this document ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`. + +Here's an example of the structure: +{ + "name": "string", + "director": "string", + "release_year": 1 +} diff --git a/tests/unit_tests/utils/test_reask_utils.py b/tests/unit_tests/utils/test_reask_utils.py index f2d80aef9..989aa1d50 100644 --- a/tests/unit_tests/utils/test_reask_utils.py +++ b/tests/unit_tests/utils/test_reask_utils.py @@ -3,7 +3,6 @@ import pytest from lxml import etree as ET -from guardrails import Instructions, Prompt from guardrails.classes.history.iteration import Iteration from guardrails.datatypes import Object from guardrails.schema import JsonSchema @@ -443,10 +442,14 @@ def test_get_reask_prompt( Help me correct the incorrect values based on the given error messages. + Given below is XML that describes the information to extract from this document and the tags to extract it into. %s ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`. + +Here's an example of the structure: +%s """ # noqa: E501 expected_instructions = """ You are a helpful assistant only capable of communicating with valid JSON, and no other text. @@ -467,13 +470,15 @@ def test_get_reask_prompt( result_prompt, instructions, ) = output_schema.get_reask_setup(reasks, reask_json, False) + json_example = output_schema.root_datatype.get_example() - assert result_prompt == Prompt( + assert result_prompt.source == ( expected_result_template % ( json.dumps(reask_json, indent=2), expected_rail, + json.dumps(json_example, indent=2), ) ) - assert instructions == Instructions(expected_instructions) + assert instructions.source == expected_instructions