Skip to content

Commit

Permalink
Merge branch 'main' into feature/claude
Browse files Browse the repository at this point in the history
  • Loading branch information
irgolic committed Nov 22, 2023
2 parents 7781485 + f1f511c commit 9cae16f
Show file tree
Hide file tree
Showing 48 changed files with 1,050 additions and 473 deletions.
30 changes: 22 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
pydantic-version: ['1.10.9', '2.4.2']
openai-version: ['0.28.1', '1.2.4']
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -72,16 +73,27 @@ jobs:
run: |
make full
poetry run pip install pydantic==${{ matrix.pydantic-version }}
poetry run pip install openai==${{ matrix.openai-version }}
- if: matrix.pydantic-version == '2.4.2'
name: Static analysis with pyright (ignoring pydantic v1)
- if: matrix.pydantic-version == '2.4.2' && matrix.openai-version == '0.28.1'
name: Static analysis with pyright (ignoring pydantic v1 and openai v1)
run: |
make type-pydantic-v2
make type-pydantic-v2-openai-v0
- if: matrix.pydantic-version == '1.10.9'
name: Static analysis with mypy (ignoring pydantic v2)
- if: matrix.pydantic-version == '1.10.9' && matrix.openai-version == '0.28.1'
name: Static analysis with mypy (ignoring pydantic v2 and openai v1)
run: |
make type-pydantic-v1
make type-pydantic-v1-openai-v0
- if: matrix.pydantic-version == '2.4.2' && matrix.openai-version == '1.2.4'
name: Static analysis with pyright (ignoring pydantic v1 and openai v0)
run: |
make type-pydantic-v2-openai-v1
- if: matrix.pydantic-version == '1.10.9' && matrix.openai-version == '1.2.4'
name: Static analysis with mypy (ignoring pydantic v2 and openai v0)
run: |
make type-pydantic-v1-openai-v1
Pytests:
runs-on: ubuntu-latest
Expand All @@ -92,6 +104,7 @@ jobs:
# dependencies: ['dev', 'full']
dependencies: ['full']
pydantic-version: ['1.10.9', '2.4.2']
openai-version: ['0.28.1', '1.2.4']
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -103,15 +116,16 @@ jobs:
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}-${{ matrix.openai-version }}

- name: Install Poetry
uses: snok/install-poetry@v1

- name: Install Dependencies
run: |
make ${{ matrix.dependencies }}
python -m pip install pydantic==${{ matrix.pydantic-version }}
poetry run pip install pydantic==${{ matrix.pydantic-version }}
poetry run pip install openai==${{ matrix.openai-version }}
- name: Run Pytests
run: |
Expand Down
18 changes: 14 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@ autoformat:
type:
poetry run pyright guardrails/

type-pydantic-v1:
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py"]}' > pyrightconfig.json
type-pydantic-v1-openai-v0:
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

type-pydantic-v2:
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py"]}' > pyrightconfig.json
type-pydantic-v1-openai-v1:
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

type-pydantic-v2-openai-v0:
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v1.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

type-pydantic-v2-openai-v1:
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py", "guardrails/utils/openai_utils/v0.py"]}' > pyrightconfig.json
poetry run pyright guardrails/
rm pyrightconfig.json

Expand Down
16 changes: 8 additions & 8 deletions guardrails/applications/text2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from string import Template
from typing import Callable, Dict, Optional, Type

import openai

from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore
from guardrails.embedding import EmbeddingBase, OpenAIEmbedding
from guardrails.guard import Guard
from guardrails.utils.openai_utils import get_static_openai_create_func
from guardrails.utils.sql_utils import create_sql_driver
from guardrails.vectordb import Faiss, VectorDBBase

Expand Down Expand Up @@ -71,7 +70,7 @@ def __init__(
rail_params: Optional[Dict] = None,
example_formatter: Callable = example_formatter,
reask_prompt: str = REASK_PROMPT,
llm_api: Callable = openai.Completion.create,
llm_api: Optional[Callable] = None,
llm_api_kwargs: Optional[Dict] = None,
num_relevant_examples: int = 2,
):
Expand All @@ -88,6 +87,8 @@ def __init__(
example_formatter: Fn to format examples. Defaults to example_formatter.
reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT.
"""
if llm_api is None:
llm_api = get_static_openai_create_func()

self.example_formatter = example_formatter
self.llm_api = llm_api
Expand Down Expand Up @@ -185,9 +186,10 @@ def __call__(self, text: str) -> Optional[str]:
"Async API is not supported in Text2SQL application. "
"Please use a synchronous API."
)

if self.llm_api is None:
return None
try:
output = self.guard(
return self.guard(
self.llm_api,
prompt_params={
"nl_instruction": text,
Expand All @@ -201,6 +203,4 @@ def __call__(self, text: str) -> Optional[str]:
"generated_sql"
]
except TypeError:
output = None

return output
return None
32 changes: 16 additions & 16 deletions guardrails/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from lxml import etree as ET
from typing_extensions import Self

from guardrails.formatattr import FormatAttr
from guardrails.utils.casting_utils import to_float, to_int, to_string
from guardrails.utils.xml_utils import cast_xml_to_string
from guardrails.validator_base import Validator, ValidatorSpec
from guardrails.validatorsattr import ValidatorsAttr

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,20 +62,20 @@ class DataType:
def __init__(
self,
children: Dict[str, Any],
format_attr: FormatAttr,
validators_attr: ValidatorsAttr,
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
self._children = children
self.format_attr = format_attr
self.validators_attr = validators_attr
self.name = name
self.description = description
self.optional = optional

@property
def validators(self) -> TypedList:
return self.format_attr.validators
return self.validators_attr.validators

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._children})"
Expand Down Expand Up @@ -119,9 +119,9 @@ def set_children_from_xml(self, element: ET._Element):
@classmethod
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
# TODO: don't want to pass strict through to DataType,
# but need to pass it to FormatAttr.from_xml
# but need to pass it to ValidatorsAttr.from_element
# how to handle this?
format_attr = FormatAttr.from_xml(element, cls.tag, strict)
validators_attr = ValidatorsAttr.from_xml(element, cls.tag, strict)

is_optional = element.attrib.get("required", "true") == "false"

Expand All @@ -133,7 +133,7 @@ def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
if description is not None:
description = cast_xml_to_string(description)

data_type = cls({}, format_attr, is_optional, name, description, **kwargs)
data_type = cls({}, validators_attr, is_optional, name, description, **kwargs)
data_type.set_children_from_xml(element)
return data_type

Expand Down Expand Up @@ -203,7 +203,7 @@ def from_string_rail(
) -> Self:
return cls(
children={},
format_attr=FormatAttr.from_validators(validators, cls.tag, strict),
validators_attr=ValidatorsAttr.from_validators(validators, cls.tag, strict),
optional=False,
name=None,
description=description,
Expand Down Expand Up @@ -267,12 +267,12 @@ class Date(ScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)
self.date_format = None

def from_str(self, s: str) -> Optional[datetime.date]:
Expand Down Expand Up @@ -306,13 +306,13 @@ class Time(ScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
self.time_format = "%H:%M:%S"
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)

def from_str(self, s: str) -> Optional[datetime.time]:
"""Create a Time from a string."""
Expand Down Expand Up @@ -486,13 +486,13 @@ class Choice(NonScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
discriminator_key: str,
) -> None:
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)
self.discriminator_key = discriminator_key

@classmethod
Expand Down Expand Up @@ -548,12 +548,12 @@ class Case(NonScalarType):
def __init__(
self,
children: Dict[str, Any],
format_attr: "FormatAttr",
validators_attr: "ValidatorsAttr",
optional: bool,
name: Optional[str],
description: Optional[str],
) -> None:
super().__init__(children, format_attr, optional, name, description)
super().__init__(children, validators_attr, optional, name, description)

def collect_validation(
self,
Expand Down
22 changes: 10 additions & 12 deletions guardrails/embedding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import islice
from typing import Callable, List, Optional

import openai
from guardrails.utils.openai_utils import OpenAIClient


class EmbeddingBase(ABC):
Expand Down Expand Up @@ -114,9 +113,9 @@ def output_dim(self) -> int:
class OpenAIEmbedding(EmbeddingBase):
def __init__(
self,
model: Optional[str] = "text-embedding-ada-002",
encoding_name: Optional[str] = "cl100k_base",
max_tokens: Optional[int] = 8191,
model: str = "text-embedding-ada-002",
encoding_name: str = "cl100k_base",
max_tokens: int = 8191,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
):
Expand All @@ -137,15 +136,14 @@ def embed_query(self, query: str) -> List[float]:
return resp[0]

def _get_embedding(self, texts: List[str]) -> List[List[float]]:
api_key = (
self.api_key
if self.api_key is not None
else os.environ.get("OPENAI_API_KEY")
client = OpenAIClient(
api_key=self.api_key,
api_base=self.api_base,
)
resp = openai.Embedding.create(
api_key=api_key, model=self._model, input=texts, api_base=self.api_base
return client.create_embedding(
model=self._model,
input=texts,
)
return [r["embedding"] for r in resp["data"]] # type: ignore

@property
def output_dim(self) -> int:
Expand Down
Loading

0 comments on commit 9cae16f

Please sign in to comment.