Skip to content

Commit

Permalink
Merge pull request #413 from irgolic/remove-xml-from-pydantic-guard
Browse files Browse the repository at this point in the history
Remove xml from pydantic guard
  • Loading branch information
zsimjee authored Nov 13, 2023
2 parents 0504873 + c9ad1f7 commit a4e19bc
Show file tree
Hide file tree
Showing 14 changed files with 728 additions and 826 deletions.
73 changes: 62 additions & 11 deletions guardrails/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
import warnings
from dataclasses import dataclass
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, Dict, Iterable
from typing import Any, Dict, Iterable
from typing import List as TypedList
from typing import Optional, Type, TypeVar, Union
from typing import Optional, Sequence, Type, TypeVar, Union

from dateutil.parser import parse
from lxml import etree as ET
from pydantic.fields import ModelField
from typing_extensions import Self

from guardrails.formatattr import FormatAttr
from guardrails.utils.casting_utils import to_float, to_int, to_string
from guardrails.utils.xml_utils import cast_xml_to_string
from guardrails.validator_base import Validator

if TYPE_CHECKING:
from guardrails.schema import FormatAttr
from guardrails.validator_base import Validator, ValidatorSpec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,7 +63,7 @@ class DataType:
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
format_attr: FormatAttr,
optional: bool,
name: Optional[str],
description: Optional[str],
Expand Down Expand Up @@ -120,12 +119,10 @@ def set_children_from_xml(self, element: ET._Element):

@classmethod
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
from guardrails.schema import FormatAttr

# TODO: don't want to pass strict through to DataType,
# but need to pass it to FormatAttr.from_element
# but need to pass it to FormatAttr.from_xml
# how to handle this?
format_attr = FormatAttr.from_element(element, cls.tag, strict)
format_attr = FormatAttr.from_xml(element, cls.tag, strict)

is_optional = element.attrib.get("required", "true") == "false"

Expand All @@ -141,6 +138,28 @@ def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
data_type.set_children_from_xml(element)
return data_type

@classmethod
def from_pydantic_field(
cls,
field: ModelField,
children: Optional[Dict[str, "DataType"]] = None,
strict: bool = False,
**kwargs,
) -> Self:
if children is None:
children = {}

validators = field.field_info.extra.get("validators", [])
format_attr = FormatAttr.from_validators(validators, cls.tag, strict)

is_optional = field.required is False

name = field.name
description = field.field_info.description

data_type = cls(children, format_attr, is_optional, name, description, **kwargs)
return data_type

@property
def children(self) -> SimpleNamespace:
"""Return a SimpleNamespace of the children of this DataType."""
Expand Down Expand Up @@ -175,6 +194,7 @@ def deprecate_type(cls: type):
versions 0.3.0 and beyond. Use the pydantic 'str' primitive instead.""",
DeprecationWarning,
)
return cls


class ScalarType(DataType):
Expand All @@ -197,6 +217,21 @@ def from_str(self, s: str) -> Optional[str]:
"""Create a String from a string."""
return to_string(s)

@classmethod
def from_string_rail(
cls,
validators: Sequence[ValidatorSpec],
description: Optional[str] = None,
strict: bool = False,
) -> Self:
return cls(
children={},
format_attr=FormatAttr.from_validators(validators, cls.tag, strict),
optional=False,
name=None,
description=description,
)


@register_type("integer")
class Integer(ScalarType):
Expand Down Expand Up @@ -326,6 +361,10 @@ class Email(ScalarType):

tag = "email"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))


@deprecate_type
@register_type("url")
Expand All @@ -334,6 +373,10 @@ class URL(ScalarType):

tag = "url"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))


@deprecate_type
@register_type("pythoncode")
Expand All @@ -342,6 +385,10 @@ class PythonCode(ScalarType):

tag = "pythoncode"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))


@deprecate_type
@register_type("sql")
Expand All @@ -350,6 +397,10 @@ class SQLCode(ScalarType):

tag = "sql"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
deprecate_type(type(self))


@register_type("percentage")
class Percentage(ScalarType):
Expand Down
Loading

0 comments on commit a4e19bc

Please sign in to comment.