Skip to content

Commit

Permalink
Merge pull request #616 from guardrails-ai/fix-v2-func-call
Browse files Browse the repository at this point in the history
fix function calling schema for pydantic v2
CalebCourier authored Mar 14, 2024
2 parents a8c46a5 + 82500da commit 2aeb9dd
Showing 7 changed files with 363 additions and 46 deletions.
49 changes: 34 additions & 15 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import asyncio
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Union, cast
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
Union,
cast,
)

from guardrails_api_client.models.validate_payload_llm_api import ValidatePayloadLlmApi
from pydantic import BaseModel
@@ -154,7 +165,9 @@ def _invoke_llm(
model: str = "gpt-3.5-turbo",
instructions: Optional[str] = None,
msg_history: Optional[List[Dict]] = None,
base_model: Optional[BaseModel] = None,
base_model: Optional[
Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
] = None,
function_call: Optional[Any] = None,
*args,
**kwargs,
@@ -184,13 +197,15 @@ def _invoke_llm(
)

# Configure function calling if applicable (only for non-streaming)
fn_kwargs = {}
if base_model and not kwargs.get("stream", False):
function_params = [convert_pydantic_model_to_openai_fn(base_model)]
if function_call is None:
function_call = {"name": function_params[0]["name"]}
fn_kwargs = {"functions": function_params, "function_call": function_call}
else:
fn_kwargs = {}
function_params = convert_pydantic_model_to_openai_fn(base_model)
if function_call is None and function_params:
function_call = {"name": function_params["name"]}
fn_kwargs = {
"functions": [function_params],
"function_call": function_call,
}

# Call OpenAI
if "api_key" in kwargs:
@@ -688,7 +703,9 @@ async def invoke_llm(
model: str = "gpt-3.5-turbo",
instructions: Optional[str] = None,
msg_history: Optional[List[Dict]] = None,
base_model: Optional[BaseModel] = None,
base_model: Optional[
Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
] = None,
function_call: Optional[Any] = None,
*args,
**kwargs,
@@ -718,13 +735,15 @@ async def invoke_llm(
)

# Configure function calling if applicable
fn_kwargs = {}
if base_model:
function_params = [convert_pydantic_model_to_openai_fn(base_model)]
if function_call is None:
function_call = {"name": function_params[0]["name"]}
fn_kwargs = {"functions": function_params, "function_call": function_call}
else:
fn_kwargs = {}
function_params = convert_pydantic_model_to_openai_fn(base_model)
if function_call is None and function_params:
function_call = {"name": function_params["name"]}
fn_kwargs = {
"functions": [function_params],
"function_call": function_call,
}

# Call OpenAI
if "api_key" in kwargs:
11 changes: 11 additions & 0 deletions guardrails/utils/dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pydantic.version

PYDANTIC_VERSION = pydantic.version.VERSION

if PYDANTIC_VERSION.startswith("1"):

def dataclass(cls): # type: ignore
return cls

else:
from dataclasses import dataclass # type: ignore # noqa
138 changes: 119 additions & 19 deletions guardrails/utils/pydantic_utils/v1.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,17 @@
from copy import deepcopy
from datetime import date, time
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel, validator
from pydantic.fields import ModelField
@@ -22,6 +32,7 @@
from guardrails.datatypes import Object as ObjectDataType
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.utils.safe_get import safe_get
from guardrails.validator_base import Validator
from guardrails.validatorsattr import ValidatorsAttr

@@ -228,7 +239,78 @@ def process_validators(vals, fld):
return model_fields


def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
def create_bare_model(model: Type[BaseModel]):
class BareModel(BaseModel):
__annotations__ = getattr(model, "__annotations__", {})

return BareModel


def reduce_to_annotations(type_annotation: Any) -> Type[Any]:
if (
type_annotation
and isinstance(type_annotation, type)
and issubclass(type_annotation, BaseModel)
):
return create_bare_model(type_annotation)
return type_annotation


def find_models_in_type(type_annotation: Any) -> Type[Any]:
type_origin = get_origin(type_annotation)
inner_types = get_args(type_annotation)
if type_origin == Union:
data_types = tuple([find_models_in_type(t) for t in inner_types])
return Type[Union[data_types]] # type: ignore
elif type_origin == list:
if len(inner_types) > 1:
raise ValueError("List data type must have exactly one child.")
# No List[List] support; we've already declared that in our types
item_type = safe_get(inner_types, 0)
return Type[List[find_models_in_type(item_type)]]
elif type_origin == dict:
# First arg is key which must be primitive
# Second arg is potentially a model
key_value_type = safe_get(inner_types, 1)
value_value_type = safe_get(inner_types, 1)
return Type[Dict[key_value_type, find_models_in_type(value_value_type)]]
else:
return reduce_to_annotations(type_annotation)


def schema_to_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
copy = deepcopy(model)
for field_key in copy.__fields__:
field = copy.__fields__.get(field_key)
if field:
extras = field.field_info.extra
if "validators" in extras:
extras["format"] = list(
v.to_prompt()
for v in extras.pop("validators", [])
if hasattr(v, "to_prompt")
)

field.field_info.extra = extras

value_type = find_models_in_type(field.annotation)
field.annotation = value_type
copy.__fields__[field_key] = field

# root_model = reduce_to_annotations(model)

# for key in root_model.__annotations__:
# value = root_model.__annotations__.get(key)
# print("value.field_info: ", value.field_info)
# value_type = find_models_in_type(value)
# root_model.__annotations__[key] = value_type

return copy


def convert_pydantic_model_to_openai_fn(
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
) -> Dict:
"""Convert a Pydantic BaseModel to an OpenAI function.
Args:
@@ -237,23 +319,41 @@ def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
Returns:
OpenAI function paramters.
"""

# Create a bare model with no extra fields
class BareModel(BaseModel):
__annotations__ = model.__annotations__

# Convert Pydantic model to JSON schema
json_schema = BareModel.schema()

# Create OpenAI function parameters
fn_params = {
"name": json_schema["title"],
"parameters": json_schema,
}
if "description" in json_schema and json_schema["description"] is not None:
fn_params["description"] = json_schema["description"]

return fn_params
return {}

# schema_model = model

# type_origin = get_origin(model)
# if type_origin == list:
# item_types = get_args(model)
# if len(item_types) > 1:
# raise ValueError("List data type must have exactly one child.")
# # No List[List] support; we've already declared that in our types
# schema_model = safe_get(item_types, 0)

# # Create a bare model with no extra fields
# bare_model = schema_to_bare_model(schema_model)

# # Convert Pydantic model to JSON schema
# json_schema = bare_model.schema()
# json_schema["title"] = schema_model.__name__

# if type_origin == list:
# json_schema = {
# "title": f"Array<{json_schema.get('title')}>",
# "type": "array",
# "items": json_schema,
# }

# # Create OpenAI function parameters
# fn_params = {
# "name": json_schema["title"],
# "parameters": json_schema,
# }
# if "description" in json_schema and json_schema["description"] is not None:
# fn_params["description"] = json_schema["description"]

# return fn_params


def field_to_datatype(field: Union[ModelField, Type]) -> Type[DataType]:
48 changes: 37 additions & 11 deletions guardrails/utils/pydantic_utils/v2.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,19 @@
from copy import deepcopy
from datetime import date, time
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_args
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)

from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.fields import FieldInfo
@@ -21,6 +33,7 @@
from guardrails.datatypes import Object as ObjectDataType
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.utils.safe_get import safe_get
from guardrails.validator_base import Validator
from guardrails.validatorsattr import ValidatorsAttr

@@ -114,14 +127,9 @@ def is_enum(type_annotation: Any) -> bool:
return False


def _create_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
class BareModel(BaseModel):
__annotations__ = getattr(model, "__annotations__", {})

return BareModel


def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
def convert_pydantic_model_to_openai_fn(
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
) -> Dict:
"""Convert a Pydantic BaseModel to an OpenAI function.
Args:
@@ -131,10 +139,28 @@ def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
OpenAI function paramters.
"""

bare_model = _create_bare_model(type(model))
schema_model = model

type_origin = get_origin(model)
if type_origin == list:
item_types = get_args(model)
if len(item_types) > 1:
raise ValueError("List data type must have exactly one child.")
# No List[List] support; we've already declared that in our types
schema_model = safe_get(item_types, 0)

schema_model = cast(Type[BaseModel], schema_model)

# Convert Pydantic model to JSON schema
json_schema = bare_model.model_json_schema()
json_schema = schema_model.model_json_schema()
json_schema["title"] = schema_model.__name__

if type_origin == list:
json_schema = {
"title": f"Array<{json_schema.get('title')}>",
"type": "array",
"items": json_schema,
}

# Create OpenAI function parameters
fn_params = {
4 changes: 3 additions & 1 deletion guardrails/validator_base.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
from guardrails.classes import InputType
from guardrails.constants import hub
from guardrails.errors import ValidationError
from guardrails.utils.dataclass import dataclass

VALIDATOR_IMPORT_WARNING = """Accessing `{validator_name}` using
`from guardrails.validators import {validator_name}` is deprecated and
@@ -372,10 +373,11 @@ class FailResult(ValidationResult):
fix_value: Optional[Any] = None


@dataclass # type: ignore
class Validator(Runnable):
"""Base class for validators."""

rail_alias: str
rail_alias: str = ""

run_in_separate_process = False
override_value_on_pass = False
84 changes: 84 additions & 0 deletions tests/unit_tests/utils/pydantic_utils/v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from copy import deepcopy
from typing import List
from warnings import warn

import pydantic.version
import pytest
from pydantic import BaseModel, Field

from guardrails.utils.pydantic_utils.v1 import convert_pydantic_model_to_openai_fn

PYDANTIC_VERSION = pydantic.version.VERSION


class Foo(BaseModel):
bar: str = Field(description="some string value")


# fmt: off
foo_schema = {
"title": "Foo",
"type": "object",
"properties": {
"bar": {
"title": "Bar",
"description": "some string value",
"type": "string"
}
},
"required": [
"bar"
]
}
# fmt: on


# This test is descriptive, not prescriptive.
@pytest.mark.skipif(
not PYDANTIC_VERSION.startswith("1"),
reason="Tests function calling syntax for Pydantic v1",
)
class TestConvertPydanticModelToOpenaiFn:
def test_object_schema(self):
expected_schema = deepcopy(foo_schema)
# When pushed through BareModel it loses the description on any properties.
del expected_schema["properties"]["bar"]["description"]

# fmt: off
expected_fn_params = { # noqa
"name": "Foo",
"parameters": expected_schema
}
# fmt: on

actual_fn_params = convert_pydantic_model_to_openai_fn(Foo)

# assert actual_fn_params == expected_fn_params
warn("Function calling is disabled for pydantic 1.x")
assert actual_fn_params == {}

def test_list_schema(self):
expected_schema = deepcopy(foo_schema)
# When pushed through BareModel it loses the description on any properties.
del expected_schema["properties"]["bar"]["description"]

# fmt: off
expected_schema = {
"title": f"Array<{expected_schema.get('title')}>",
"type": "array",
"items": expected_schema
}
# fmt: on

# fmt: off
expected_fn_params = { # noqa
"name": "Array<Foo>",
"parameters": expected_schema
}
# fmt: on

actual_fn_params = convert_pydantic_model_to_openai_fn(List[Foo])

# assert actual_fn_params == expected_fn_params
warn("Function calling is disabled for pydantic 1.x")
assert actual_fn_params == {}
75 changes: 75 additions & 0 deletions tests/unit_tests/utils/pydantic_utils/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from copy import deepcopy
from typing import List

import pydantic.version
import pytest
from pydantic import BaseModel, Field

from guardrails.utils.pydantic_utils.v2 import convert_pydantic_model_to_openai_fn

PYDANTIC_VERSION = pydantic.version.VERSION


class Foo(BaseModel):
bar: str = Field(description="some string value")


# fmt: off
foo_schema = {
"title": "Foo",
"type": "object",
"properties": {
"bar": {
"title": "Bar",
"description": "some string value",
"type": "string"
}
},
"required": [
"bar"
]
}
# fmt: on


# This test is descriptive, not prescriptive.
@pytest.mark.skipif(
not PYDANTIC_VERSION.startswith("2"),
reason="Tests function calling syntax for Pydantic v2",
)
class TestConvertPydanticModelToOpenaiFn:
def test_object_schema(self):
expected_schema = deepcopy(foo_schema)

# fmt: off
expected_fn_params = {
"name": "Foo",
"parameters": expected_schema
}
# fmt: on

actual_fn_params = convert_pydantic_model_to_openai_fn(Foo)

assert actual_fn_params == expected_fn_params

def test_list_schema(self):
expected_schema = deepcopy(foo_schema)

# fmt: off
expected_schema = {
"title": f"Array<{expected_schema.get('title')}>",
"type": "array",
"items": expected_schema
}
# fmt: on

# fmt: off
expected_fn_params = {
"name": "Array<Foo>",
"parameters": expected_schema
}
# fmt: on

actual_fn_params = convert_pydantic_model_to_openai_fn(List[Foo])

assert actual_fn_params == expected_fn_params

0 comments on commit 2aeb9dd

Please sign in to comment.