diff --git a/guardrails/validators.py b/guardrails/validators.py
new file mode 100644
index 000000000..0ae1bcdf2
--- /dev/null
+++ b/guardrails/validators.py
@@ -0,0 +1,2542 @@
+"""This module contains the validators for the Guardrails framework.
+
+The name with which a validator is registered is the name that is used
+in the `RAIL` spec to specify formatters.
+"""
+import ast
+import contextvars
+import inspect
+import itertools
+import logging
+import os
+import re
+import string
+import warnings
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
+
+import rstr
+from tenacity import retry, stop_after_attempt, wait_random_exponential
+
+from guardrails.utils.casting_utils import to_int
+from guardrails.utils.docs_utils import get_chunks_from_text, sentence_split
+from guardrails.utils.openai_utils import (
+ OpenAIClient,
+ get_static_openai_chat_create_func,
+)
+from guardrails.utils.sql_utils import SQLDriver, create_sql_driver
+from guardrails.utils.validator_utils import PROVENANCE_V1_PROMPT
+from guardrails.validator_base import (
+ FailResult,
+ PassResult,
+ ValidationResult,
+ Validator,
+ register_validator,
+)
+
+try:
+ import numpy as np
+except ImportError:
+ _HAS_NUMPY = False
+else:
+ _HAS_NUMPY = True
+
+try:
+ import detect_secrets # type: ignore
+except ImportError:
+ detect_secrets = None
+
+try:
+ from presidio_analyzer import AnalyzerEngine
+ from presidio_anonymizer import AnonymizerEngine
+except ImportError:
+ AnalyzerEngine = None
+ AnonymizerEngine = None
+
+try:
+ import nltk # type: ignore
+except ImportError:
+ nltk = None # type: ignore
+
+if nltk is not None:
+ try:
+ nltk.data.find("tokenizers/punkt")
+ except LookupError:
+ nltk.download("punkt")
+
+try:
+ import spacy
+except ImportError:
+ spacy = None
+
+
+logger = logging.getLogger(__name__)
+
+
+# @register_validator('required', 'all')
+# class Required(Validator):
+# """Validates that a value is not None."""
+
+# def validate(self, key: str, value: Any, schema: Union[Dict, List]) -> bool:
+# """Validates that a value is not None."""
+
+# return value is not None
+
+
+# @register_validator('description', 'all')
+# class Description(Validator):
+# """Validates that a value is not None."""
+
+# def validate(self, key: str, value: Any, schema: Union[Dict, List]) -> bool:
+# """Validates that a value is not None."""
+
+# return value is not None
+
+
+@register_validator(name="pydantic_field_validator", data_type="all")
+class PydanticFieldValidator(Validator):
+ """Validates a specific field in a Pydantic model with the specified
+ validator method.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `pydantic_field_validator` |
+ | Supported data types | `Any` |
+ | Programmatic fix | Override with return value from `field_validator`. |
+
+ Parameters: Arguments
+
+ field_validator (Callable): A validator for a specific field in a Pydantic model.
+ """ # noqa
+
+ override_value_on_pass = True
+
+ def __init__(
+ self,
+ field_validator: Callable,
+ on_fail: Optional[Callable[..., Any]] = None,
+ **kwargs,
+ ):
+ self.field_validator = field_validator
+ super().__init__(on_fail, **kwargs)
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ try:
+ validated_field = self.field_validator(value)
+ except Exception as e:
+ return FailResult(
+ error_message=str(e),
+ fix_value=None,
+ )
+ return PassResult(
+ value_override=validated_field,
+ )
+
+ def to_prompt(self, with_keywords: bool = True) -> str:
+ return self.field_validator.__func__.__name__
+
+
+@register_validator(name="valid-range", data_type=["integer", "float", "percentage"])
+class ValidRange(Validator):
+ """Validates that a value is within a range.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `valid-range` |
+ | Supported data types | `integer`, `float`, `percentage` |
+ | Programmatic fix | Closest value within the range. |
+
+ Parameters: Arguments
+ min: The inclusive minimum value of the range.
+ max: The inclusive maximum value of the range.
+ """
+
+ def __init__(
+ self,
+ min: Optional[int] = None,
+ max: Optional[int] = None,
+ on_fail: Optional[Callable] = None,
+ ):
+ super().__init__(on_fail=on_fail, min=min, max=max)
+
+ self._min = min
+ self._max = max
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ """Validates that a value is within a range."""
+ logger.debug(f"Validating {value} is in range {self._min} - {self._max}...")
+
+ val_type = type(value)
+
+ if self._min is not None and value < val_type(self._min):
+ return FailResult(
+ error_message=f"Value {value} is less than {self._min}.",
+ fix_value=self._min,
+ )
+
+ if self._max is not None and value > val_type(self._max):
+ return FailResult(
+ error_message=f"Value {value} is greater than {self._max}.",
+ fix_value=self._max,
+ )
+
+ return PassResult()
+
+
+@register_validator(name="valid-choices", data_type="all")
+class ValidChoices(Validator):
+ """Validates that a value is within the acceptable choices.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `valid-choices` |
+ | Supported data types | `all` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+ choices: The list of valid choices.
+ """
+
+ def __init__(self, choices: List[Any], on_fail: Optional[Callable] = None):
+ super().__init__(on_fail=on_fail, choices=choices)
+ self._choices = choices
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ """Validates that a value is within a range."""
+ logger.debug(f"Validating {value} is in choices {self._choices}...")
+
+ if value not in self._choices:
+ return FailResult(
+ error_message=f"Value {value} is not in choices {self._choices}.",
+ )
+
+ return PassResult()
+
+
+@register_validator(name="lower-case", data_type="string")
+class LowerCase(Validator):
+ """Validates that a value is lower case.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `lower-case` |
+ | Supported data types | `string` |
+ | Programmatic fix | Convert to lower case. |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is lower case...")
+
+ if value.lower() != value:
+ return FailResult(
+ error_message=f"Value {value} is not lower case.",
+ fix_value=value.lower(),
+ )
+
+ return PassResult()
+
+
+@register_validator(name="upper-case", data_type="string")
+class UpperCase(Validator):
+ """Validates that a value is upper case.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `upper-case` |
+ | Supported data types | `string` |
+ | Programmatic fix | Convert to upper case. |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is upper case...")
+
+ if value.upper() != value:
+ return FailResult(
+ error_message=f"Value {value} is not upper case.",
+ fix_value=value.upper(),
+ )
+
+ return PassResult()
+
+
+@register_validator(name="length", data_type=["string", "list"])
+class ValidLength(Validator):
+ """Validates that the length of value is within the expected range.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `length` |
+ | Supported data types | `string`, `list`, `object` |
+ | Programmatic fix | If shorter than the minimum, pad with empty last elements. If longer than the maximum, truncate. |
+
+ Parameters: Arguments
+ min: The inclusive minimum length.
+ max: The inclusive maximum length.
+ """ # noqa
+
+ def __init__(
+ self,
+ min: Optional[int] = None,
+ max: Optional[int] = None,
+ on_fail: Optional[Callable] = None,
+ ):
+ super().__init__(on_fail=on_fail, min=min, max=max)
+ self._min = to_int(min)
+ self._max = to_int(max)
+
+ def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult:
+ """Validates that the length of value is within the expected range."""
+ logger.debug(
+ f"Validating {value} is in length range {self._min} - {self._max}..."
+ )
+
+ if self._min is not None and len(value) < self._min:
+ logger.debug(f"Value {value} is less than {self._min}.")
+
+ # Repeat the last character to make the value the correct length.
+ if isinstance(value, str):
+ if not value:
+ last_val = rstr.rstr(string.ascii_lowercase, 1)
+ else:
+ last_val = value[-1]
+ corrected_value = value + last_val * (self._min - len(value))
+ else:
+ if not value:
+ last_val = [rstr.rstr(string.ascii_lowercase, 1)]
+ else:
+ last_val = [value[-1]]
+ # extend value by padding it out with last_val
+ corrected_value = value.extend([last_val] * (self._min - len(value)))
+
+ return FailResult(
+ error_message=f"Value has length less than {self._min}. "
+ f"Please return a longer output, "
+ f"that is shorter than {self._max} characters.",
+ fix_value=corrected_value,
+ )
+
+ if self._max is not None and len(value) > self._max:
+ logger.debug(f"Value {value} is greater than {self._max}.")
+ return FailResult(
+ error_message=f"Value has length greater than {self._max}. "
+ f"Please return a shorter output, "
+ f"that is shorter than {self._max} characters.",
+ fix_value=value[: self._max],
+ )
+
+ return PassResult()
+
+
+@register_validator(name="regex_match", data_type="string")
+class RegexMatch(Validator):
+ """Validates that a value matches a regular expression.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `regex_match` |
+ | Supported data types | `string` |
+ | Programmatic fix | Generate a string that matches the regular expression |
+
+ Parameters: Arguments
+ regex: Str regex pattern
+ match_type: Str in {"search", "fullmatch"} for a regex search or full-match option
+ """ # noqa
+
+ def __init__(
+ self,
+ regex: str,
+ match_type: Optional[str] = None,
+ on_fail: Optional[Callable] = None,
+ ):
+ # todo -> something forces this to be passed as kwargs and therefore xml-ized.
+ # match_types = ["fullmatch", "search"]
+
+ if match_type is None:
+ match_type = "fullmatch"
+ assert match_type in [
+ "fullmatch",
+ "search",
+ ], 'match_type must be in ["fullmatch", "search"]'
+
+ super().__init__(on_fail=on_fail, match_type=match_type, regex=regex)
+ self._regex = regex
+ self._match_type = match_type
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ p = re.compile(self._regex)
+ """Validates that value matches the provided regular expression."""
+ # Pad matching string on either side for fix
+ # example if we are performing a regex search
+ str_padding = (
+ "" if self._match_type == "fullmatch" else rstr.rstr(string.ascii_lowercase)
+ )
+ self._fix_str = str_padding + rstr.xeger(self._regex) + str_padding
+
+ if not getattr(p, self._match_type)(value):
+ return FailResult(
+ error_message=f"Result must match {self._regex}",
+ fix_value=self._fix_str,
+ )
+ return PassResult()
+
+ def to_prompt(self, with_keywords: bool = True) -> str:
+ return "results should match " + self._regex
+
+
+@register_validator(name="two-words", data_type="string")
+class TwoWords(Validator):
+ """Validates that a value is two words.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `two-words` |
+ | Supported data types | `string` |
+ | Programmatic fix | Pick the first two words. |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is two words...")
+
+ if len(value.split()) != 2:
+ return FailResult(
+ error_message="must be exactly two words",
+ fix_value=" ".join(value.split()[:2]),
+ )
+
+ return PassResult()
+
+
+@register_validator(name="one-line", data_type="string")
+class OneLine(Validator):
+ """Validates that a value is a single line or sentence.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `one-line` |
+ | Supported data types | `string` |
+ | Programmatic fix | Pick the first line. |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is a single line...")
+
+ if len(value.splitlines()) > 1:
+ return FailResult(
+ error_message=f"Value {value} is not a single line.",
+ fix_value=value.splitlines()[0],
+ )
+
+ return PassResult()
+
+
+@register_validator(name="valid-url", data_type=["string"])
+class ValidURL(Validator):
+ """Validates that a value is a valid URL.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `valid-url` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is a valid URL...")
+
+ from urllib.parse import urlparse
+
+ # Check that the URL is valid
+ try:
+ result = urlparse(value)
+ # Check that the URL has a scheme and network location
+ if not result.scheme or not result.netloc:
+ return FailResult(
+ error_message=f"URL {value} is not valid.",
+ )
+ except ValueError:
+ return FailResult(
+ error_message=f"URL {value} is not valid.",
+ )
+
+ return PassResult()
+
+
+@register_validator(name="is-reachable", data_type=["string"])
+class EndpointIsReachable(Validator):
+ """Validates that a value is a reachable URL.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `is-reachable` |
+ | Supported data types | `string`, |
+ | Programmatic fix | None |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is a valid URL...")
+
+ import requests
+
+ # Check that the URL exists and can be reached
+ try:
+ response = requests.get(value)
+ if response.status_code != 200:
+ return FailResult(
+ error_message=f"URL {value} returned "
+ f"status code {response.status_code}",
+ )
+ except requests.exceptions.ConnectionError:
+ return FailResult(
+ error_message=f"URL {value} could not be reached",
+ )
+ except requests.exceptions.InvalidSchema:
+ return FailResult(
+ error_message=f"URL {value} does not specify "
+ f"a valid connection adapter",
+ )
+ except requests.exceptions.MissingSchema:
+ return FailResult(
+ error_message=f"URL {value} does not contain " f"a http schema",
+ )
+
+ return PassResult()
+
+
+@register_validator(name="bug-free-python", data_type="string")
+class BugFreePython(Validator):
+ """Validates that there are no Python syntactic bugs in the generated code.
+
+ This validator checks for syntax errors by running `ast.parse(code)`,
+ and will raise an exception if there are any.
+ Only the packages in the `python` environment are available to the code snippet.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `bug-free-python` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is not a bug...")
+
+ # The value is a Python code snippet. We need to check for syntax errors.
+ try:
+ ast.parse(value)
+ except SyntaxError as e:
+ return FailResult(
+ error_message=f"Syntax error: {e.msg}",
+ )
+
+ return PassResult()
+
+
+@register_validator(name="bug-free-sql", data_type=["string"])
+class BugFreeSQL(Validator):
+ """Validates that there are no SQL syntactic bugs in the generated code.
+
+ This is a very minimal implementation that uses the Pypi `sqlvalidator` package
+ to check if the SQL query is valid. You can implement a custom SQL validator
+ that uses a database connection to check if the query is valid.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `bug-free-sql` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+ """
+
+ def __init__(
+ self,
+ conn: Optional[str] = None,
+ schema_file: Optional[str] = None,
+ on_fail: Optional[Callable] = None,
+ ):
+ super().__init__(on_fail=on_fail)
+ self._driver: SQLDriver = create_sql_driver(schema_file=schema_file, conn=conn)
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ errors = self._driver.validate_sql(value)
+ if len(errors) > 0:
+ return FailResult(
+ error_message=". ".join(errors),
+ )
+
+ return PassResult()
+
+
+@register_validator(name="sql-column-presence", data_type="string")
+class SqlColumnPresence(Validator):
+ """Validates that all columns in the SQL query are present in the schema.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `sql-column-presence` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+ cols: The list of valid columns.
+ """
+
+ def __init__(self, cols: List[str], on_fail: Optional[Callable] = None):
+ super().__init__(on_fail=on_fail, cols=cols)
+ self._cols = set(cols)
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ from sqlglot import exp, parse
+
+ expressions = parse(value)
+ cols = set()
+ for expression in expressions:
+ if expression is None:
+ continue
+ for col in expression.find_all(exp.Column):
+ cols.add(col.alias_or_name)
+
+ diff = cols.difference(self._cols)
+ if len(diff) > 0:
+ return FailResult(
+ error_message=f"Columns [{', '.join(diff)}] "
+ f"not in [{', '.join(self._cols)}]",
+ )
+
+ return PassResult()
+
+
+@register_validator(name="exclude-sql-predicates", data_type="string")
+class ExcludeSqlPredicates(Validator):
+ """Validates that the SQL query does not contain certain predicates.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `exclude-sql-predicates` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+ predicates: The list of predicates to avoid.
+ """
+
+ def __init__(self, predicates: List[str], on_fail: Optional[Callable] = None):
+ super().__init__(on_fail=on_fail, predicates=predicates)
+ self._predicates = set(predicates)
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ from sqlglot import exp, parse
+
+ expressions = parse(value)
+ for expression in expressions:
+ if expression is None:
+ continue
+ for pred in self._predicates:
+ try:
+ getattr(exp, pred)
+ except AttributeError:
+ raise ValueError(f"Predicate {pred} does not exist")
+ if len(list(expression.find_all(getattr(exp, pred)))):
+ return FailResult(
+ error_message=f"SQL query contains predicate {pred}",
+ fix_value="",
+ )
+
+ return PassResult()
+
+
+@register_validator(name="similar-to-document", data_type="string")
+class SimilarToDocument(Validator):
+ """Validates that a value is similar to the document.
+
+ This validator checks if the value is similar to the document by checking
+ the cosine similarity between the value and the document, using an
+ embedding.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `similar-to-document` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+ document: The document to use for the similarity check.
+ threshold: The minimum cosine similarity to be considered similar. Defaults to 0.7.
+ model: The embedding model to use. Defaults to text-embedding-ada-002.
+ """ # noqa
+
+ def __init__(
+ self,
+ document: str,
+ threshold: float = 0.7,
+ model: str = "text-embedding-ada-002",
+ on_fail: Optional[Callable] = None,
+ ):
+ super().__init__(
+ on_fail=on_fail, document=document, threshold=threshold, model=model
+ )
+ if not _HAS_NUMPY:
+ raise ImportError(
+ f"The {self.__class__.__name__} validator requires the numpy package.\n"
+ "`poetry add numpy` to install it."
+ )
+
+ self.client = OpenAIClient()
+
+ self._document = document
+ embedding_response = self.client.create_embedding(input=[document], model=model)
+ embedding = embedding_response[0] # type: ignore
+ self._document_embedding = np.array(embedding)
+ self._model = model
+ self._threshold = float(threshold)
+
+ @staticmethod
+ def cosine_similarity(a: "np.ndarray", b: "np.ndarray") -> float:
+ """Calculate the cosine similarity between two vectors.
+
+ Args:
+ a: The first vector.
+ b: The second vector.
+
+ Returns:
+ float: The cosine similarity between the two vectors.
+ """
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} is similar to document...")
+
+ embedding_response = self.client.create_embedding(
+ input=[value], model=self._model
+ )
+
+ value_embedding = np.array(embedding_response[0]) # type: ignore
+
+ similarity = self.cosine_similarity(
+ self._document_embedding,
+ value_embedding,
+ )
+ if similarity < self._threshold:
+ return FailResult(
+ error_message=f"Value {value} is not similar enough "
+ f"to document {self._document}.",
+ )
+
+ return PassResult()
+
+ def to_prompt(self, with_keywords: bool = True) -> str:
+ return ""
+
+
+@register_validator(name="is-profanity-free", data_type="string")
+class IsProfanityFree(Validator):
+ """Validates that a translated text does not contain profanity language.
+
+ This validator uses the `alt-profanity-check` package to check if a string
+ contains profanity language.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `is-profanity-free` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+ """
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ try:
+ from profanity_check import predict # type: ignore
+ except ImportError:
+ raise ImportError(
+ "`is-profanity-free` validator requires the `alt-profanity-check`"
+ "package. Please install it with `poetry add profanity-check`."
+ )
+
+ prediction = predict([value])
+ if prediction[0] == 1:
+ return FailResult(
+ error_message=f"{value} contains profanity. "
+ f"Please return a profanity-free output.",
+ fix_value="",
+ )
+ return PassResult()
+
+
+@register_validator(name="is-high-quality-translation", data_type="string")
+class IsHighQualityTranslation(Validator):
+ """Using inpiredco.critique to check if a translation is high quality.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `is-high-quality-translation` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Other parameters: Metadata
+ translation_source (str): The source of the translation.
+ """
+
+ required_metadata_keys = ["translation_source"]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ try:
+ from inspiredco.critique import Critique # type: ignore
+
+ inspiredco_api_key = os.environ.get("INSPIREDCO_API_KEY")
+ if not inspiredco_api_key:
+ raise ValueError(
+ "The INSPIREDCO_API_KEY environment variable must be set"
+ "in order to use the is-high-quality-translation validator!"
+ )
+
+ self._critique = Critique(api_key=inspiredco_api_key)
+
+ except ImportError:
+ raise ImportError(
+ "`is-high-quality-translation` validator requires the `inspiredco`"
+ "package. Please install it with `poetry add inspiredco`."
+ )
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ if "translation_source" not in metadata:
+ raise RuntimeError(
+ "is-high-quality-translation validator expects "
+ "`translation_source` key in metadata"
+ )
+ src = metadata["translation_source"]
+ prediction = self._critique.evaluate(
+ metric="comet",
+ config={"model": "unbabel_comet/wmt21-comet-qe-da"},
+ dataset=[{"source": src, "target": value}],
+ )
+ quality = prediction["examples"][0]["value"]
+ if quality < -0.1:
+ return FailResult(
+ error_message=f"{value} is a low quality translation."
+ "Please return a higher quality output.",
+ fix_value="",
+ )
+ return PassResult()
+
+
+@register_validator(name="ends-with", data_type="list")
+class EndsWith(Validator):
+ """Validates that a list ends with a given value.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `ends-with` |
+ | Supported data types | `list` |
+ | Programmatic fix | Append the given value to the list. |
+
+ Parameters: Arguments
+ end: The required last element.
+ """
+
+ def __init__(self, end: str, on_fail: str = "fix"):
+ super().__init__(on_fail=on_fail, end=end)
+ self._end = end
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(f"Validating {value} ends with {self._end}...")
+
+ if not value[-1] == self._end:
+ return FailResult(
+ error_message=f"{value} must end with {self._end}",
+ fix_value=value + [self._end],
+ )
+
+ return PassResult()
+
+
+@register_validator(name="extracted-summary-sentences-match", data_type="string")
+class ExtractedSummarySentencesMatch(Validator):
+ """Validates that the extracted summary sentences match the original text
+ by performing a cosine similarity in the embedding space.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `extracted-summary-sentences-match` |
+ | Supported data types | `string` |
+ | Programmatic fix | Remove any sentences that can not be verified. |
+
+ Parameters: Arguments
+
+ threshold: The minimum cosine similarity to be considered similar. Default to 0.7.
+
+ Other parameters: Metadata
+
+ filepaths (List[str]): A list of strings that specifies the filepaths for any documents that should be used for asserting the summary's similarity.
+ document_store (DocumentStoreBase, optional): The document store to use during validation. Defaults to EphemeralDocumentStore.
+ vector_db (VectorDBBase, optional): A vector database to use for embeddings. Defaults to Faiss.
+ embedding_model (EmbeddingBase, optional): The embeddig model to use. Defaults to OpenAIEmbedding.
+ """ # noqa
+
+ required_metadata_keys = ["filepaths"]
+
+ def __init__(
+ self,
+ threshold: float = 0.7,
+ on_fail: Optional[Callable] = None,
+ **kwargs: Optional[Dict[str, Any]],
+ ):
+ super().__init__(on_fail, **kwargs)
+ # TODO(shreya): Pass embedding_model, vector_db, document_store from spec
+
+ self._threshold = float(threshold)
+
+ @staticmethod
+ def _instantiate_store(
+ metadata, api_key: Optional[str] = None, api_base: Optional[str] = None
+ ):
+ if "document_store" in metadata:
+ return metadata["document_store"]
+
+ from guardrails.document_store import EphemeralDocumentStore
+
+ if "vector_db" in metadata:
+ vector_db = metadata["vector_db"]
+ else:
+ from guardrails.vectordb import Faiss
+
+ if "embedding_model" in metadata:
+ embedding_model = metadata["embedding_model"]
+ else:
+ from guardrails.embedding import OpenAIEmbedding
+
+ embedding_model = OpenAIEmbedding(api_key=api_key, api_base=api_base)
+
+ vector_db = Faiss.new_flat_ip_index(
+ embedding_model.output_dim, embedder=embedding_model
+ )
+
+ return EphemeralDocumentStore(vector_db)
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ if "filepaths" not in metadata:
+ raise RuntimeError(
+ "extracted-sentences-summary-match validator expects "
+ "`filepaths` key in metadata"
+ )
+ filepaths = metadata["filepaths"]
+
+ kwargs = {}
+ context_copy = contextvars.copy_context()
+ for key, context_var in context_copy.items():
+ if key.name == "kwargs" and isinstance(kwargs, dict):
+ kwargs = context_var
+ break
+
+ api_key = kwargs.get("api_key")
+ api_base = kwargs.get("api_base")
+
+ store = self._instantiate_store(metadata, api_key, api_base)
+
+ sources = []
+ for filepath in filepaths:
+ with open(filepath) as f:
+ doc = f.read()
+ store.add_text(doc, {"path": filepath})
+ sources.append(filepath)
+
+ # Split the value into sentences.
+ sentences = re.split(r"(?<=[.!?]) +", value)
+
+ # Check if any of the sentences in the value match any of the sentences
+ # in the documents.
+ unverified = []
+ verified = []
+ citations = {}
+ for id_, sentence in enumerate(sentences):
+ page = store.search_with_threshold(sentence, self._threshold)
+ if not page or page[0].metadata["path"] not in sources:
+ unverified.append(sentence)
+ else:
+ sentence_id = id_ + 1
+ citation_path = page[0].metadata["path"]
+ citation_id = sources.index(citation_path) + 1
+
+ citations[sentence_id] = citation_id
+ verified.append(sentence + f" [{citation_id}]")
+
+ fixed_summary = (
+ " ".join(verified)
+ + "\n\n"
+ + "\n".join(f"[{i + 1}] {s}" for i, s in enumerate(sources))
+ )
+ metadata["summary_with_citations"] = fixed_summary
+ metadata["citations"] = citations
+
+ if unverified:
+ unverified_sentences = "\n".join(unverified)
+ return FailResult(
+ metadata=metadata,
+ error_message=(
+ f"The summary \nSummary: {value}\n has sentences\n"
+ f"{unverified_sentences}\n that are not similar to any document."
+ ),
+ fix_value=fixed_summary,
+ )
+
+ return PassResult(metadata=metadata)
+
+ def to_prompt(self, with_keywords: bool = True) -> str:
+ return ""
+
+
+@register_validator(name="reading-time", data_type="string")
+class ReadingTime(Validator):
+ """Validates that the a string can be read in less than a certain amount of
+ time.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `reading-time` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+
+ reading_time: The maximum reading time.
+ """
+
+ def __init__(self, reading_time: int, on_fail: str = "fix"):
+ super().__init__(on_fail=on_fail, reading_time=reading_time)
+ self._max_time = reading_time
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ logger.debug(
+ f"Validating {value} can be read in less than {self._max_time} seconds..."
+ )
+
+ # Estimate the reading time of the string
+ reading_time = len(value.split()) / 200 * 60
+ logger.debug(f"Estimated reading time {reading_time} seconds...")
+
+ if abs(reading_time - self._max_time) > 1:
+ logger.error(f"{value} took {reading_time} to read")
+ return FailResult(
+ error_message=f"String should be readable "
+ f"within {self._max_time} minutes.",
+ fix_value=value,
+ )
+
+ return PassResult()
+
+
+@register_validator(name="extractive-summary", data_type="string")
+class ExtractiveSummary(Validator):
+ """Validates that a string is a valid extractive summary of a given
+ document.
+
+ This validator does a fuzzy match between the sentences in the
+ summary and the sentences in the document. Each sentence in the
+ summary must be similar to at least one sentence in the document.
+ After the validation, the summary is updated to include the
+ sentences from the document that were matched, and the citations for
+ those sentences are added to the end of the summary.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `extractive-summary` |
+ | Supported data types | `string` |
+ | Programmatic fix | Remove any sentences that can not be verified. |
+
+ Parameters: Arguments
+
+ threshold: The minimum fuzz ratio to be considered summarized. Defaults to 85.
+
+ Other parameters: Metadata
+
+ filepaths (List[str]): A list of strings that specifies the filepaths for any documents that should be used for asserting the summary's similarity.
+ """ # noqa
+
+ required_metadata_keys = ["filepaths"]
+
+ def __init__(
+ self,
+ threshold: int = 85,
+ on_fail: Optional[Callable] = None,
+ **kwargs,
+ ):
+ super().__init__(on_fail, **kwargs)
+
+ self._threshold = threshold
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ """Make sure each sentence was precisely copied from the document."""
+
+ if "filepaths" not in metadata:
+ raise RuntimeError(
+ "extractive-summary validator expects " "`filepaths` key in metadata"
+ )
+
+ filepaths = metadata["filepaths"]
+
+ # Load documents
+ store = {}
+ for filepath in filepaths:
+ with open(filepath) as f:
+ doc = f.read()
+ store[filepath] = sentence_split(doc)
+
+ try:
+ from thefuzz import fuzz # type: ignore
+ except ImportError:
+ raise ImportError(
+ "`thefuzz` library is required for `extractive-summary` validator. "
+ "Please install it with `poetry add thefuzz`."
+ )
+
+ # Split the value into sentences.
+ sentences = sentence_split(value)
+
+ # Check if any of the sentences in the value match any of the sentences
+ # # in the documents.
+ unverified = []
+ verified = []
+ citations = {}
+
+ for id_, sentence in enumerate(sentences):
+ highest_ratio = 0
+ highest_ratio_doc = None
+
+ # Check fuzzy match against all sentences in all documents
+ for doc_path, doc_sentences in store.items():
+ for doc_sentence in doc_sentences:
+ ratio = fuzz.ratio(sentence, doc_sentence)
+ if ratio > highest_ratio:
+ highest_ratio = ratio
+ highest_ratio_doc = doc_path
+
+ if highest_ratio < self._threshold:
+ unverified.append(sentence)
+ else:
+ sentence_id = id_ + 1
+ citation_id = list(store).index(highest_ratio_doc) + 1
+
+ citations[sentence_id] = citation_id
+ verified.append(sentence + f" [{citation_id}]")
+
+ verified_sentences = (
+ " ".join(verified)
+ + "\n\n"
+ + "\n".join(f"[{i + 1}] {s}" for i, s in enumerate(store))
+ )
+
+ metadata["summary_with_citations"] = verified_sentences
+ metadata["citations"] = citations
+
+ if len(unverified):
+ unverified_sentences = "\n".join(
+ "- " + s for i, s in enumerate(sentences) if i in unverified
+ )
+ return FailResult(
+ metadata=metadata,
+ error_message=(
+ f"The summary \nSummary: {value}\n has sentences\n"
+ f"{unverified_sentences}\n that are not similar to any document."
+ ),
+ fix_value="\n".join(verified_sentences),
+ )
+
+ return PassResult(
+ metadata=metadata,
+ )
+
+
+@register_validator(name="remove-redundant-sentences", data_type="string")
+class RemoveRedundantSentences(Validator):
+ """Removes redundant sentences from a string.
+
+ This validator removes sentences from a string that are similar to
+ other sentences in the string. This is useful for removing
+ repetitive sentences from a string.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `remove-redundant-sentences` |
+ | Supported data types | `string` |
+ | Programmatic fix | Remove any redundant sentences. |
+
+ Parameters: Arguments
+
+ threshold: The minimum fuzz ratio to be considered redundant. Defaults to 70.
+ """
+
+ def __init__(
+ self, threshold: int = 70, on_fail: Optional[Callable] = None, **kwargs
+ ):
+ super().__init__(on_fail, **kwargs)
+ self._threshold = threshold
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ """Remove redundant sentences from a string."""
+
+ try:
+ from thefuzz import fuzz # type: ignore
+ except ImportError:
+ raise ImportError(
+ "`thefuzz` library is required for `remove-redundant-sentences` "
+ "validator. Please install it with `poetry add thefuzz`."
+ )
+
+ # Split the value into sentences.
+ sentences = sentence_split(value)
+ filtered_sentences = []
+ redundant_sentences = []
+
+ sentence = sentences[0]
+ other_sentences = sentences[1:]
+ while len(other_sentences):
+ # Check fuzzy match against all other sentences
+ filtered_sentences.append(sentence)
+ unique_sentences = []
+ for other_sentence in other_sentences:
+ ratio = fuzz.ratio(sentence, other_sentence)
+ if ratio > self._threshold:
+ redundant_sentences.append(other_sentence)
+ else:
+ unique_sentences.append(other_sentence)
+ if len(unique_sentences) == 0:
+ break
+ sentence = unique_sentences[0]
+ other_sentences = unique_sentences[1:]
+
+ filtered_summary = " ".join(filtered_sentences)
+
+ if len(redundant_sentences):
+ redundant_sentences = "\n".join(redundant_sentences)
+ return FailResult(
+ error_message=(
+ f"The summary \nSummary: {value}\n has sentences\n"
+ f"{redundant_sentences}\n that are similar to other sentences."
+ ),
+ fix_value=filtered_summary,
+ )
+
+ return PassResult()
+
+
+@register_validator(name="saliency-check", data_type="string")
+class SaliencyCheck(Validator):
+ """Checks that the summary covers the list of topics present in the
+ document.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `saliency-check` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+
+ docs_dir: Path to the directory containing the documents.
+ threshold: Threshold for overlap between topics in document and summary. Defaults to 0.25
+ """ # noqa
+
+ def __init__(
+ self,
+ docs_dir: str,
+ llm_callable: Optional[Callable] = None,
+ on_fail: Optional[Callable] = None,
+ threshold: float = 0.25,
+ **kwargs,
+ ):
+ """Initialize the SalienceCheck validator.
+
+ Args:
+ docs_dir: Path to the directory containing the documents.
+ on_fail: Function to call when validation fails.
+ threshold: Threshold for overlap between topics in document and summary.
+ """
+
+ super().__init__(on_fail, **kwargs)
+
+ if llm_callable is not None and inspect.iscoroutinefunction(llm_callable):
+ raise ValueError(
+ "SaliencyCheck validator does not support async LLM callables."
+ )
+
+ self.llm_callable = (
+ llm_callable if llm_callable else get_static_openai_chat_create_func()
+ )
+
+ self._threshold = threshold
+
+ # Load documents
+ self._document_store = {}
+ for doc_path in os.listdir(docs_dir):
+ with open(os.path.join(docs_dir, doc_path)) as f:
+ text = f.read()
+ # Precompute topics for each document
+ self._document_store[doc_path] = self._get_topics(text)
+
+ @property
+ def _topics(self) -> List[str]:
+ """Return a list of topics that can be used in the validator."""
+ # Merge topics from all documents
+ topics = set()
+ for doc_topics in self._document_store.values():
+ topics.update(doc_topics)
+ return list(topics)
+
+ def _get_topics(self, text: str, topics: Optional[List[str]] = None) -> List[str]:
+ """Extract topics from a string."""
+
+ from guardrails import Guard
+
+ topics_seed = ""
+ if topics is not None:
+ topics_seed = (
+ "Here's a seed list of topics, select topics from this list"
+ " if they are covered in the doc:\n\n" + ", ".join(topics)
+ )
+
+ spec = f"""
+
+
+
+
+Extract a list of topics from the following text:
+
+{text}
+
+{topics_seed}
+
+Return the output as a JSON with a single key "topics" containing a list of topics.
+
+Make sure that topics are relevant to text, and topics are not too specific or general.
+
+
+ """
+
+ guard = Guard.from_rail_string(spec)
+ _, validated_output = guard(llm_api=self.llm_callable) # type: ignore
+ return validated_output["topics"]
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ topics_in_summary = self._get_topics(value, topics=self._topics)
+
+ # Compute overlap between topics in document and summary
+ intersection = set(topics_in_summary).intersection(set(self._topics))
+ overlap = len(intersection) / len(self._topics)
+
+ if overlap < self._threshold:
+ return FailResult(
+ error_message=(
+ f"The summary \nSummary: {value}\n does not cover these topics:\n"
+ f"{set(self._topics).difference(intersection)}"
+ ),
+ fix_value="",
+ )
+
+ return PassResult()
+
+
+@register_validator(name="qa-relevance-llm-eval", data_type="string")
+class QARelevanceLLMEval(Validator):
+ """Validates that an answer is relevant to the question asked by asking the
+ LLM to self evaluate.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `qa-relevance-llm-eval` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Other parameters: Metadata
+ question (str): The original question the llm was given to answer.
+ """
+
+ required_metadata_keys = ["question"]
+
+ def __init__(
+ self,
+ llm_callable: Optional[Callable] = None,
+ on_fail: Optional[Callable] = None,
+ **kwargs,
+ ):
+ super().__init__(on_fail, **kwargs)
+
+ if llm_callable is not None and inspect.iscoroutinefunction(llm_callable):
+ raise ValueError(
+ "QARelevanceLLMEval validator does not support async LLM callables."
+ )
+
+ self.llm_callable = (
+ llm_callable if llm_callable else get_static_openai_chat_create_func()
+ )
+
+ def _selfeval(self, question: str, answer: str):
+ from guardrails import Guard
+
+ spec = """
+
+
+
+
+Is the answer below relevant to the question asked?
+Question: {question}
+Answer: {answer}
+
+Relevant (as a JSON with a single boolean key, "relevant"):\
+
+
+ """.format(
+ question=question,
+ answer=answer,
+ )
+ guard = Guard.from_rail_string(spec)
+
+ _, validated_output = guard(
+ self.llm_callable, # type: ignore
+ max_tokens=10,
+ temperature=0.1,
+ )
+ return validated_output
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ if "question" not in metadata:
+ raise RuntimeError(
+ "qa-relevance-llm-eval validator expects " "`question` key in metadata"
+ )
+
+ question = metadata["question"]
+
+ relevant = self._selfeval(question, value)["relevant"]
+ if relevant:
+ return PassResult()
+
+ fixed_answer = "No relevant answer found."
+ return FailResult(
+ error_message=f"The answer {value} is not relevant "
+ f"to the question {question}.",
+ fix_value=fixed_answer,
+ )
+
+ def to_prompt(self, with_keywords: bool = True) -> str:
+ return ""
+
+
+@register_validator(name="provenance-v0", data_type="string")
+class ProvenanceV0(Validator):
+ """Validates that LLM-generated text matches some source text based on
+ distance in embedding space.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `provenance-v0` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+ threshold: The minimum cosine similarity between the generated text and
+ the source text. Defaults to 0.8.
+ validation_method: Whether to validate at the sentence level or over the full text. Must be one of `sentence` or `full`. Defaults to `sentence`
+
+ Other parameters: Metadata
+ query_function (Callable, optional): A callable that takes a string and returns a list of (chunk, score) tuples.
+ sources (List[str], optional): The source text.
+ embed_function (Callable, optional): A callable that creates embeddings for the sources. Must accept a list of strings and return an np.array of floats.
+
+ In order to use this validator, you must provide either a `query_function` or
+ `sources` with an `embed_function` in the metadata.
+
+ If providing query_function, it should take a string as input and return a list of
+ (chunk, score) tuples. The chunk is a string and the score is a float representing
+ the cosine distance between the chunk and the input string. The list should be
+ sorted in ascending order by score.
+
+ Note: The score should represent distance in embedding space, not similarity. I.e.,
+ lower is better and the score should be 0 if the chunk is identical to the input
+ string.
+
+ Example:
+ ```py
+ def query_function(text: str, k: int) -> List[Tuple[str, float]]:
+ return [("This is a chunk", 0.9), ("This is another chunk", 0.8)]
+
+ guard = Guard.from_rail(...)
+ guard(
+ openai.ChatCompletion.create(...),
+ prompt_params={...},
+ temperature=0.0,
+ metadata={"query_function": query_function},
+ )
+ ```
+
+
+ If providing sources, it should be a list of strings. The embed_function should
+ take a string or a list of strings as input and return a np array of floats.
+ The vector should be normalized to unit length.
+
+ Example:
+ ```py
+ def embed_function(text: Union[str, List[str]]) -> np.ndarray:
+ return np.array([[0.1, 0.2, 0.3]])
+
+ guard = Guard.from_rail(...)
+ guard(
+ openai.ChatCompletion.create(...),
+ prompt_params={...},
+ temperature=0.0,
+ metadata={
+ "sources": ["This is a source text"],
+ "embed_function": embed_function
+ },
+ )
+ ```
+ """ # noqa
+
+ def __init__(
+ self,
+ threshold: float = 0.8,
+ validation_method: str = "sentence",
+ on_fail: Optional[Callable] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ on_fail, threshold=threshold, validation_method=validation_method, **kwargs
+ )
+ self._threshold = float(threshold)
+ if validation_method not in ["sentence", "full"]:
+ raise ValueError("validation_method must be 'sentence' or 'full'.")
+ self._validation_method = validation_method
+
+ def get_query_function(self, metadata: Dict[str, Any]) -> Callable:
+ query_fn = metadata.get("query_function", None)
+ sources = metadata.get("sources", None)
+
+ # Check that query_fn or sources are provided
+ if query_fn is not None:
+ if sources is not None:
+ warnings.warn(
+ "Both `query_function` and `sources` are provided in metadata. "
+ "`query_function` will be used."
+ )
+ return query_fn
+
+ if sources is None:
+ raise ValueError(
+ "You must provide either `query_function` or `sources` in metadata."
+ )
+
+ # Check chunking strategy
+ chunk_strategy = metadata.get("chunk_strategy", "sentence")
+ if chunk_strategy not in ["sentence", "word", "char", "token"]:
+ raise ValueError(
+ "`chunk_strategy` must be one of 'sentence', 'word', 'char', "
+ "or 'token'."
+ )
+ chunk_size = metadata.get("chunk_size", 5)
+ chunk_overlap = metadata.get("chunk_overlap", 2)
+
+ # Check distance metric
+ distance_metric = metadata.get("distance_metric", "cosine")
+ if distance_metric not in ["cosine", "euclidean"]:
+ raise ValueError(
+ "`distance_metric` must be one of 'cosine' or 'euclidean'."
+ )
+
+ # Check embed model
+ embed_function = metadata.get("embed_function", None)
+ if embed_function is None:
+ raise ValueError(
+ "You must provide `embed_function` in metadata in order to "
+ "use the default query function."
+ )
+ return partial(
+ self.query_vector_collection,
+ sources=metadata["sources"],
+ chunk_strategy=chunk_strategy,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ distance_metric=distance_metric,
+ embed_function=embed_function,
+ )
+
+ def validate_each_sentence(
+ self, value: Any, query_function: Callable, metadata: Dict[str, Any]
+ ) -> ValidationResult:
+ if nltk is None:
+ raise ImportError(
+ "`nltk` library is required for `provenance-v0` validator. "
+ "Please install it with `poetry add nltk`."
+ )
+ # Split the value into sentences using nltk sentence tokenizer.
+ sentences = nltk.sent_tokenize(value)
+
+ unsupported_sentences = []
+ supported_sentences = []
+ for sentence in sentences:
+ most_similar_chunks = query_function(text=sentence, k=1)
+ if most_similar_chunks is None:
+ unsupported_sentences.append(sentence)
+ continue
+ most_similar_chunk = most_similar_chunks[0]
+ if most_similar_chunk[1] < self._threshold:
+ supported_sentences.append((sentence, most_similar_chunk[0]))
+ else:
+ unsupported_sentences.append(sentence)
+
+ metadata["unsupported_sentences"] = "- " + "\n- ".join(unsupported_sentences)
+ metadata["supported_sentences"] = supported_sentences
+ if unsupported_sentences:
+ unsupported_sentences = "- " + "\n- ".join(unsupported_sentences)
+ return FailResult(
+ metadata=metadata,
+ error_message=(
+ f"None of the following sentences in your response are supported "
+ "by provided context:"
+ f"\n{metadata['unsupported_sentences']}"
+ ),
+ fix_value="\n".join(s[0] for s in supported_sentences),
+ )
+ return PassResult(metadata=metadata)
+
+ def validate_full_text(
+ self, value: Any, query_function: Callable, metadata: Dict[str, Any]
+ ) -> ValidationResult:
+ most_similar_chunks = query_function(text=value, k=1)
+ if most_similar_chunks is None:
+ metadata["unsupported_text"] = value
+ metadata["supported_text_citations"] = {}
+ return FailResult(
+ metadata=metadata,
+ error_message=(
+ "The following text in your response is not supported by the "
+ "supported by the provided context:\n" + value
+ ),
+ )
+ most_similar_chunk = most_similar_chunks[0]
+ if most_similar_chunk[1] > self._threshold:
+ metadata["unsupported_text"] = value
+ metadata["supported_text_citations"] = {}
+ return FailResult(
+ metadata=metadata,
+ error_message=(
+ "The following text in your response is not supported by the "
+ "supported by the provided context:\n" + value
+ ),
+ )
+
+ metadata["unsupported_text"] = ""
+ metadata["supported_text_citations"] = {
+ value: most_similar_chunk[0],
+ }
+ return PassResult(metadata=metadata)
+
+ def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult:
+ query_function = self.get_query_function(metadata)
+
+ if self._validation_method == "sentence":
+ return self.validate_each_sentence(value, query_function, metadata)
+ elif self._validation_method == "full":
+ return self.validate_full_text(value, query_function, metadata)
+ else:
+ raise ValueError("validation_method must be 'sentence' or 'full'.")
+
+ @staticmethod
+ def query_vector_collection(
+ text: str,
+ k: int,
+ sources: List[str],
+ embed_function: Callable,
+ chunk_strategy: str = "sentence",
+ chunk_size: int = 5,
+ chunk_overlap: int = 2,
+ distance_metric: str = "cosine",
+ ) -> List[Tuple[str, float]]:
+ chunks = [
+ get_chunks_from_text(source, chunk_strategy, chunk_size, chunk_overlap)
+ for source in sources
+ ]
+ chunks = list(itertools.chain.from_iterable(chunks))
+
+ # Create embeddings
+ source_embeddings = np.array(embed_function(chunks)).squeeze()
+ query_embedding = embed_function(text).squeeze()
+
+ # Compute distances
+ if distance_metric == "cosine":
+ if not _HAS_NUMPY:
+ raise ValueError(
+ "You must install numpy in order to use the cosine distance "
+ "metric."
+ )
+
+ cos_sim = 1 - (
+ np.dot(source_embeddings, query_embedding)
+ / (
+ np.linalg.norm(source_embeddings, axis=1)
+ * np.linalg.norm(query_embedding)
+ )
+ )
+ top_indices = np.argsort(cos_sim)[:k]
+ top_similarities = [cos_sim[j] for j in top_indices]
+ top_chunks = [chunks[j] for j in top_indices]
+ else:
+ raise ValueError("distance_metric must be 'cosine'.")
+
+ return list(zip(top_chunks, top_similarities))
+
+ def to_prompt(self, with_keywords: bool = True) -> str:
+ return ""
+
+
+@register_validator(name="provenance-v1", data_type="string")
+class ProvenanceV1(Validator):
+ """Validates that the LLM-generated text is supported by the provided
+ contexts.
+
+ This validator uses an LLM callable to evaluate the generated text against the
+ provided contexts (LLM-ception).
+
+ In order to use this validator, you must provide either:
+ 1. a 'query_function' in the metadata. That function should take a string as input
+ (the LLM-generated text) and return a list of relevant
+ chunks. The list should be sorted in ascending order by the distance between the
+ chunk and the LLM-generated text.
+
+ Example using str callable:
+ >>> def query_function(text: str, k: int) -> List[str]:
+ ... return ["This is a chunk", "This is another chunk"]
+
+ >>> guard = Guard.from_string(validators=[
+ ProvenanceV1(llm_callable="gpt-3.5-turbo", ...)
+ ]
+ )
+ >>> guard.parse(
+ ... llm_output=...,
+ ... metadata={"query_function": query_function}
+ ... )
+
+ Example using a custom llm callable:
+ >>> def query_function(text: str, k: int) -> List[str]:
+ ... return ["This is a chunk", "This is another chunk"]
+
+ >>> guard = Guard.from_string(validators=[
+ ProvenanceV1(llm_callable=your_custom_callable, ...)
+ ]
+ )
+ >>> guard.parse(
+ ... llm_output=...,
+ ... metadata={"query_function": query_function}
+ ... )
+
+ OR
+
+ 2. `sources` with an `embed_function` in the metadata. The embed_function should
+ take a string or a list of strings as input and return a np array of floats.
+ The vector should be normalized to unit length.
+
+ Example:
+ ```py
+ def embed_function(text: Union[str, List[str]]) -> np.ndarray:
+ return np.array([[0.1, 0.2, 0.3]])
+
+ guard = Guard.from_rail(...)
+ guard(
+ openai.ChatCompletion.create(...),
+ prompt_params={...},
+ temperature=0.0,
+ metadata={
+ "sources": ["This is a source text"],
+ "embed_function": embed_function
+ },
+ )
+ """
+
+ def __init__(
+ self,
+ validation_method: str = "sentence",
+ llm_callable: Union[str, Callable] = "gpt-3.5-turbo",
+ top_k: int = 3,
+ max_tokens: int = 2,
+ on_fail: Optional[Callable] = None,
+ **kwargs,
+ ):
+ """
+ args:
+ validation_method (str): Whether to validate at the sentence level or over
+ the full text. One of `sentence` or `full`. Defaults to `sentence`
+ llm_callable (Union[str, Callable]): Either the name of the OpenAI model,
+ or a callable that takes a prompt and returns a response.
+ top_k (int): The number of chunks to return from the query function.
+ Defaults to 3.
+ max_tokens (int): The maximum number of tokens to send to the LLM.
+ Defaults to 2.
+
+ Other args: Metadata
+ query_function (Callable): A callable that takes a string and returns a
+ list of chunks.
+ sources (List[str], optional): The source text.
+ embed_function (Callable, optional): A callable that creates embeddings for
+ the sources. Must accept a list of strings and returns float np.array.
+ """
+ super().__init__(
+ on_fail,
+ validation_method=validation_method,
+ llm_callable=llm_callable,
+ top_k=top_k,
+ max_tokens=max_tokens,
+ **kwargs,
+ )
+ if validation_method not in ["sentence", "full"]:
+ raise ValueError("validation_method must be 'sentence' or 'full'.")
+ self._validation_method = validation_method
+ self.set_callable(llm_callable)
+ self._top_k = int(top_k)
+ self._max_tokens = int(max_tokens)
+
+ self.client = OpenAIClient()
+
+ def set_callable(self, llm_callable: Union[str, Callable]) -> None:
+ """Set the LLM callable.
+
+ Args:
+ llm_callable: Either the name of the OpenAI model, or a callable that takes
+ a prompt and returns a response.
+ """
+ if isinstance(llm_callable, str):
+ if llm_callable not in ["gpt-3.5-turbo", "gpt-4"]:
+ raise ValueError(
+ "llm_callable must be one of 'gpt-3.5-turbo' or 'gpt-4'."
+ "If you want to use a custom LLM, please provide a callable."
+ "Check out ProvenanceV1 documentation for an example."
+ )
+
+ def openai_callable(prompt: str) -> str:
+ response = self.client.create_chat_completion(
+ model=llm_callable,
+ messages=[
+ {"role": "user", "content": prompt},
+ ],
+ max_tokens=self._max_tokens,
+ )
+ return response.output
+
+ self._llm_callable = openai_callable
+ elif isinstance(llm_callable, Callable):
+ self._llm_callable = llm_callable
+ else:
+ raise ValueError(
+ "llm_callable must be either a string or a callable that takes a string"
+ " and returns a string."
+ )
+
+ def get_query_function(self, metadata: Dict[str, Any]) -> Callable:
+ # Exact same as ProvenanceV0
+
+ query_fn = metadata.get("query_function", None)
+ sources = metadata.get("sources", None)
+
+ # Check that query_fn or sources are provided
+ if query_fn is not None:
+ if sources is not None:
+ warnings.warn(
+ "Both `query_function` and `sources` are provided in metadata. "
+ "`query_function` will be used."
+ )
+ return query_fn
+
+ if sources is None:
+ raise ValueError(
+ "You must provide either `query_function` or `sources` in metadata."
+ )
+
+ # Check chunking strategy
+ chunk_strategy = metadata.get("chunk_strategy", "sentence")
+ if chunk_strategy not in ["sentence", "word", "char", "token"]:
+ raise ValueError(
+ "`chunk_strategy` must be one of 'sentence', 'word', 'char', "
+ "or 'token'."
+ )
+ chunk_size = metadata.get("chunk_size", 5)
+ chunk_overlap = metadata.get("chunk_overlap", 2)
+
+ # Check distance metric
+ distance_metric = metadata.get("distance_metric", "cosine")
+ if distance_metric not in ["cosine", "euclidean"]:
+ raise ValueError(
+ "`distance_metric` must be one of 'cosine' or 'euclidean'."
+ )
+
+ # Check embed model
+ embed_function = metadata.get("embed_function", None)
+ if embed_function is None:
+ raise ValueError(
+ "You must provide `embed_function` in metadata in order to "
+ "use the default query function."
+ )
+ return partial(
+ self.query_vector_collection,
+ sources=metadata["sources"],
+ chunk_strategy=chunk_strategy,
+ chunk_size=chunk_size,
+ chunk_overlap=chunk_overlap,
+ distance_metric=distance_metric,
+ embed_function=embed_function,
+ )
+
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
+ def call_llm(self, prompt: str) -> str:
+ """Call the LLM with the given prompt.
+
+ Expects a function that takes a string and returns a string.
+
+ Args:
+ prompt (str): The prompt to send to the LLM.
+
+ Returns:
+ response (str): String representing the LLM response.
+ """
+ return self._llm_callable(prompt)
+
+ def evaluate_with_llm(self, text: str, query_function: Callable) -> bool:
+ """Validate that the LLM-generated text is supported by the provided
+ contexts.
+
+ Args:
+ value (Any): The LLM-generated text.
+ query_function (Callable): The query function.
+
+ Returns:
+ self_eval: The self-evaluation boolean
+ """
+ # Get the relevant chunks using the query function
+ relevant_chunks = query_function(text=text, k=self._top_k)
+
+ # Create the prompt to ask the LLM
+ prompt = PROVENANCE_V1_PROMPT.format(text, "\n".join(relevant_chunks))
+
+ # Get self-evaluation
+ self_eval = self.call_llm(prompt)
+ self_eval = True if self_eval == "Yes" else False
+ return self_eval
+
+ def validate_each_sentence(
+ self, value: Any, query_function: Callable, metadata: Dict[str, Any]
+ ) -> ValidationResult:
+ if nltk is None:
+ raise ImportError(
+ "`nltk` library is required for `provenance-v0` validator. "
+ "Please install it with `poetry add nltk`."
+ )
+ # Split the value into sentences using nltk sentence tokenizer.
+ sentences = nltk.sent_tokenize(value)
+
+ unsupported_sentences = []
+ supported_sentences = []
+ for sentence in sentences:
+ self_eval = self.evaluate_with_llm(sentence, query_function)
+ if not self_eval:
+ unsupported_sentences.append(sentence)
+ else:
+ supported_sentences.append(sentence)
+
+ if unsupported_sentences:
+ unsupported_sentences = "- " + "\n- ".join(unsupported_sentences)
+ return FailResult(
+ metadata=metadata,
+ error_message=(
+ f"None of the following sentences in your response are supported "
+ "by provided context:"
+ f"\n{unsupported_sentences}"
+ ),
+ fix_value="\n".join(supported_sentences),
+ )
+ return PassResult(metadata=metadata)
+
+ def validate_full_text(
+ self, value: Any, query_function: Callable, metadata: Dict[str, Any]
+ ) -> ValidationResult:
+ # Self-evaluate LLM with entire text
+ self_eval = self.evaluate_with_llm(value, query_function)
+ if not self_eval:
+ # if false
+ return FailResult(
+ metadata=metadata,
+ error_message=(
+ "The following text in your response is not supported by the "
+ "supported by the provided context:\n" + value
+ ),
+ )
+ return PassResult(metadata=metadata)
+
+ def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult:
+ kwargs = {}
+ context_copy = contextvars.copy_context()
+ for key, context_var in context_copy.items():
+ if key.name == "kwargs" and isinstance(kwargs, dict):
+ kwargs = context_var
+ break
+
+ api_key = kwargs.get("api_key")
+ api_base = kwargs.get("api_base")
+
+ # Set the OpenAI API key
+ if os.getenv("OPENAI_API_KEY"): # Check if set in environment
+ self.client.api_key = os.getenv("OPENAI_API_KEY")
+ elif api_key: # Check if set when calling guard() or parse()
+ self.client.api_key = api_key
+
+ # Set the OpenAI API base if specified
+ if api_base:
+ self.client.api_base = api_base
+
+ query_function = self.get_query_function(metadata)
+ if self._validation_method == "sentence":
+ return self.validate_each_sentence(value, query_function, metadata)
+ elif self._validation_method == "full":
+ return self.validate_full_text(value, query_function, metadata)
+ else:
+ raise ValueError("validation_method must be 'sentence' or 'full'.")
+
+ @staticmethod
+ def query_vector_collection(
+ text: str,
+ k: int,
+ sources: List[str],
+ embed_function: Callable,
+ chunk_strategy: str = "sentence",
+ chunk_size: int = 5,
+ chunk_overlap: int = 2,
+ distance_metric: str = "cosine",
+ ) -> List[Tuple[str, float]]:
+ chunks = [
+ get_chunks_from_text(source, chunk_strategy, chunk_size, chunk_overlap)
+ for source in sources
+ ]
+ chunks = list(itertools.chain.from_iterable(chunks))
+
+ # Create embeddings
+ source_embeddings = np.array(embed_function(chunks)).squeeze()
+ query_embedding = embed_function(text).squeeze()
+
+ # Compute distances
+ if distance_metric == "cosine":
+ if not _HAS_NUMPY:
+ raise ValueError(
+ "You must install numpy in order to use the cosine distance "
+ "metric."
+ )
+
+ cos_sim = 1 - (
+ np.dot(source_embeddings, query_embedding)
+ / (
+ np.linalg.norm(source_embeddings, axis=1)
+ * np.linalg.norm(query_embedding)
+ )
+ )
+ top_indices = np.argsort(cos_sim)[:k]
+ top_chunks = [chunks[j] for j in top_indices]
+ else:
+ raise ValueError("distance_metric must be 'cosine'.")
+
+ return top_chunks
+
+
+@register_validator(name="pii", data_type="string")
+class PIIFilter(Validator):
+ """Validates that any text does not contain any PII.
+
+ This validator uses Microsoft's Presidio (https://github.com/microsoft/presidio)
+ to detect PII in the text. If PII is detected, the validator will fail with a
+ programmatic fix that anonymizes the text. Otherwise, the validator will pass.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | ----------------------------------- |
+ | Name for `format` attribute | `pii` |
+ | Supported data types | `string` |
+ | Programmatic fix | Anonymized text with PII filtered |
+
+ Parameters: Arguments
+ pii_entities (str | List[str], optional): The PII entities to filter. Must be
+ one of `pii` or `spi`. Defaults to None. Can also be set in metadata.
+ """
+
+ PII_ENTITIES_MAP = {
+ "pii": [
+ "EMAIL_ADDRESS",
+ "PHONE_NUMBER",
+ "DOMAIN_NAME",
+ "IP_ADDRESS",
+ "DATE_TIME",
+ "LOCATION",
+ "PERSON",
+ "URL",
+ ],
+ "spi": [
+ "CREDIT_CARD",
+ "CRYPTO",
+ "IBAN_CODE",
+ "NRP",
+ "MEDICAL_LICENSE",
+ "US_BANK_NUMBER",
+ "US_DRIVER_LICENSE",
+ "US_ITIN",
+ "US_PASSPORT",
+ "US_SSN",
+ ],
+ }
+
+ def __init__(
+ self,
+ pii_entities: Union[str, List[str], None] = None,
+ on_fail: Union[Callable[..., Any], None] = None,
+ **kwargs,
+ ):
+ if AnalyzerEngine is None or AnonymizerEngine is None:
+ raise ImportError(
+ "You must install the `presidio-analyzer`, `presidio-anonymizer`"
+ "and a spaCy language model to use the PII validator."
+ "Refer to https://microsoft.github.io/presidio/installation/"
+ )
+
+ super().__init__(on_fail, pii_entities=pii_entities, **kwargs)
+ self.pii_entities = pii_entities
+ self.pii_analyzer = AnalyzerEngine()
+ self.pii_anonymizer = AnonymizerEngine()
+
+ def get_anonymized_text(self, text: str, entities: List[str]) -> str:
+ """Analyze and anonymize the text for PII.
+
+ Args:
+ text (str): The text to analyze.
+ pii_entities (List[str]): The PII entities to filter.
+
+ Returns:
+ anonymized_text (str): The anonymized text.
+ """
+ results = self.pii_analyzer.analyze(text=text, entities=entities, language="en")
+ results = cast(List[Any], results)
+ anonymized_text = self.pii_anonymizer.anonymize(
+ text=text, analyzer_results=results
+ ).text
+ return anonymized_text
+
+ def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult:
+ # Entities to filter passed through metadata take precedence
+ pii_entities = metadata.get("pii_entities", self.pii_entities)
+ if pii_entities is None:
+ raise ValueError(
+ "`pii_entities` must be set in order to use the `PIIFilter` validator."
+ "Add this: `pii_entities=['PERSON', 'PHONE_NUMBER']`"
+ "OR pii_entities='pii' or 'spi'"
+ "in init or metadata."
+ )
+
+ # Check that pii_entities is a string OR list of strings
+ if isinstance(pii_entities, str):
+ # A key to the PII_ENTITIES_MAP
+ entities_to_filter = self.PII_ENTITIES_MAP.get(pii_entities, None)
+ if entities_to_filter is None:
+ raise ValueError(
+ f"`pii_entities` must be one of {self.PII_ENTITIES_MAP.keys()}"
+ )
+ elif isinstance(pii_entities, list):
+ entities_to_filter = pii_entities
+ else:
+ raise ValueError(
+ f"`pii_entities` must be one of {self.PII_ENTITIES_MAP.keys()}"
+ " or a list of strings."
+ )
+
+ # Analyze the text, and anonymize it if there is PII
+ anonymized_text = self.get_anonymized_text(
+ text=value, entities=entities_to_filter
+ )
+
+ # If anonymized value text is different from original value, then there is PII
+ if anonymized_text != value:
+ return FailResult(
+ error_message=(
+ f"The following text in your response contains PII:\n{value}"
+ ),
+ fix_value=anonymized_text,
+ )
+ return PassResult()
+
+
+@register_validator(name="similar-to-list", data_type="string")
+class SimilarToList(Validator):
+ """Validates that a value is similar to a list of previously known values.
+
+ **Key Properties**
+
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `similar-to-list` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+ standard_deviations (int): The number of standard deviations from the mean to check.
+ threshold (float): The threshold for the average semantic similarity for strings.
+
+ For integer values, this validator checks whether the value lies
+ within 'k' standard deviations of the mean of the previous values.
+ (Assumes that the previous values are normally distributed.) For
+ string values, this validator checks whether the average semantic
+ similarity between the generated value and the previous values is
+ less than a threshold.
+ """ # noqa
+
+ def __init__(
+ self,
+ standard_deviations: int = 3,
+ threshold: float = 0.1,
+ on_fail: Optional[Callable] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ on_fail,
+ standard_deviations=standard_deviations,
+ threshold=threshold,
+ **kwargs,
+ )
+ self._standard_deviations = int(standard_deviations)
+ self._threshold = float(threshold)
+
+ def get_semantic_similarity(
+ self, text1: str, text2: str, embed_function: Callable
+ ) -> float:
+ """Get the semantic similarity between two strings.
+
+ Args:
+ text1 (str): The first string.
+ text2 (str): The second string.
+ embed_function (Callable): The embedding function.
+ Returns:
+ similarity (float): The semantic similarity between the two strings.
+ """
+ text1_embedding = embed_function(text1)
+ text2_embedding = embed_function(text2)
+ similarity = 1 - (
+ np.dot(text1_embedding, text2_embedding)
+ / (np.linalg.norm(text1_embedding) * np.linalg.norm(text2_embedding))
+ )
+ return similarity
+
+ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
+ prev_values = metadata.get("prev_values", [])
+ if not prev_values:
+ raise ValueError("You must provide a list of previous values in metadata.")
+
+ # Check if np is installed
+ if not _HAS_NUMPY:
+ raise ValueError(
+ "You must install numpy in order to "
+ "use the distribution check validator."
+ )
+ try:
+ value = int(value)
+ is_int = True
+ except ValueError:
+ is_int = False
+
+ if is_int:
+ # Check whether prev_values are also all integers
+ if not all(isinstance(prev_value, int) for prev_value in prev_values):
+ raise ValueError(
+ "Both given value and all the previous values must be "
+ "integers in order to use the distribution check validator."
+ )
+
+ # Check whether the value lies in a similar distribution as the prev_values
+ # Get mean and std of prev_values
+ prev_values = np.array(prev_values)
+ prev_mean = np.mean(prev_values) # type: ignore
+ prev_std = np.std(prev_values)
+
+ # Check whether the value lies outside specified stds of the mean
+ if value < prev_mean - (
+ self._standard_deviations * prev_std
+ ) or value > prev_mean + (self._standard_deviations * prev_std):
+ return FailResult(
+ error_message=(
+ f"The value {value} lies outside of the expected distribution "
+ f"of {prev_mean} +/- {self._standard_deviations * prev_std}."
+ ),
+ )
+ return PassResult()
+ else:
+ # Check whether prev_values are also all strings
+ if not all(isinstance(prev_value, str) for prev_value in prev_values):
+ raise ValueError(
+ "Both given value and all the previous values must be "
+ "strings in order to use the distribution check validator."
+ )
+
+ # Check embed model
+ embed_function = metadata.get("embed_function", None)
+ if embed_function is None:
+ raise ValueError(
+ "You must provide `embed_function` in metadata in order to "
+ "check the semantic similarity of the generated string."
+ )
+
+ # Check whether the value is semantically similar to the prev_values
+ # Get average semantic similarity
+ # Lesser the average semantic similarity, more similar the strings are
+ avg_semantic_similarity = np.mean(
+ np.array(
+ [
+ self.get_semantic_similarity(value, prev_value, embed_function)
+ for prev_value in prev_values
+ ]
+ )
+ )
+
+ # If average semantic similarity is above the threshold,
+ # then the value is not semantically similar to the prev_values
+ if avg_semantic_similarity > self._threshold:
+ return FailResult(
+ error_message=(
+ f"The value {value} is not semantically similar to the "
+ f"previous values. The average semantic similarity is "
+ f"{avg_semantic_similarity} which is below the threshold of "
+ f"{self._threshold}."
+ ),
+ )
+ return PassResult()
+
+
+@register_validator(name="detect-secrets", data_type="string")
+class DetectSecrets(Validator):
+ """Validates whether the generated code snippet contains any secrets.
+
+ **Key Properties**
+ | Property | Description |
+ | ----------------------------- | --------------------------------- |
+ | Name for `format` attribute | `detect-secrets` |
+ | Supported data types | `string` |
+ | Programmatic fix | None |
+
+ Parameters: Arguments
+ None
+
+ This validator uses the detect-secrets library to check whether the generated code
+ snippet contains any secrets. If any secrets are detected, the validator fails and
+ returns the generated code snippet with the secrets replaced with asterisks.
+ Else the validator returns the generated code snippet.
+
+ Following are some caveats:
+ - Multiple secrets on the same line may not be caught. e.g.
+ - Minified code
+ - One-line lists/dictionaries
+ - Multi-variable assignments
+ - Multi-line secrets may not be caught. e.g.
+ - RSA/SSH keys
+
+ Example:
+ ```py
+
+ guard = Guard.from_string(validators=[
+ DetectSecrets(on_fail="fix")
+ ])
+ guard.parse(
+ llm_output=code_snippet,
+ )
+ """
+
+ def __init__(self, on_fail: Union[Callable[..., Any], None] = None, **kwargs):
+ super().__init__(on_fail, **kwargs)
+
+ # Check if detect-secrets is installed
+ if detect_secrets is None:
+ raise ValueError(
+ "You must install detect-secrets in order to "
+ "use the DetectSecrets validator."
+ )
+ self.temp_file_name = "temp.txt"
+ self.mask = "********"
+
+ def get_unique_secrets(self, value: str) -> Tuple[Dict[str, Any], List[str]]:
+ """Get unique secrets from the value.
+
+ Args:
+ value (str): The generated code snippet.
+
+ Returns:
+ unique_secrets (Dict[str, Any]): A dictionary of unique secrets and their
+ line numbers.
+ lines (List[str]): The lines of the generated code snippet.
+ """
+ try:
+ # Write each line of value to a new file
+ with open(self.temp_file_name, "w") as f:
+ f.writelines(value)
+ except Exception as e:
+ raise OSError(
+ "Problems creating or deleting the temporary file. "
+ "Please check the permissions of the current directory."
+ ) from e
+
+ try:
+ # Create a new secrets collection
+ from detect_secrets import settings
+ from detect_secrets.core.secrets_collection import SecretsCollection
+
+ secrets = SecretsCollection()
+
+ # Scan the file for secrets
+ with settings.default_settings():
+ secrets.scan_file(self.temp_file_name)
+ except ImportError:
+ raise ValueError(
+ "You must install detect-secrets in order to "
+ "use the DetectSecrets validator."
+ )
+ except Exception as e:
+ raise RuntimeError(
+ "Problems with creating a SecretsCollection or "
+ "scanning the file for secrets."
+ ) from e
+
+ # Get unique secrets from these secrets
+ unique_secrets = {}
+ for secret in secrets:
+ _, potential_secret = secret
+ actual_secret = potential_secret.secret_value
+ line_number = potential_secret.line_number
+ if actual_secret not in unique_secrets:
+ unique_secrets[actual_secret] = [line_number]
+ else:
+ # if secret already exists, avoid duplicate line numbers
+ if line_number not in unique_secrets[actual_secret]:
+ unique_secrets[actual_secret].append(line_number)
+
+ try:
+ # File no longer needed, read the lines from the file
+ with open(self.temp_file_name, "r") as f:
+ lines = f.readlines()
+ except Exception as e:
+ raise OSError(
+ "Problems reading the temporary file. "
+ "Please check the permissions of the current directory."
+ ) from e
+
+ try:
+ # Delete the file
+ os.remove(self.temp_file_name)
+ except Exception as e:
+ raise OSError(
+ "Problems deleting the temporary file. "
+ "Please check the permissions of the current directory."
+ ) from e
+ return unique_secrets, lines
+
+ def get_modified_value(
+ self, unique_secrets: Dict[str, Any], lines: List[str]
+ ) -> str:
+ """Replace the secrets on the lines with asterisks.
+
+ Args:
+ unique_secrets (Dict[str, Any]): A dictionary of unique secrets and their
+ line numbers.
+ lines (List[str]): The lines of the generated code snippet.
+
+ Returns:
+ modified_value (str): The generated code snippet with secrets replaced with
+ asterisks.
+ """
+ # Replace the secrets on the lines with asterisks
+ for secret, line_numbers in unique_secrets.items():
+ for line_number in line_numbers:
+ lines[line_number - 1] = lines[line_number - 1].replace(
+ secret, self.mask
+ )
+
+ # Convert lines to a multiline string
+ modified_value = "".join(lines)
+ return modified_value
+
+ def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
+ # Check if value is a multiline string
+ if "\n" not in value:
+ # Raise warning if value is not a multiline string
+ warnings.warn(
+ "The DetectSecrets validator works best with "
+ "multiline code snippets. "
+ "Refer validator docs for more details."
+ )
+
+ # Add a newline to value
+ value += "\n"
+
+ # Get unique secrets from the value
+ unique_secrets, lines = self.get_unique_secrets(value)
+
+ if unique_secrets:
+ # Replace the secrets on the lines with asterisks
+ modified_value = self.get_modified_value(unique_secrets, lines)
+
+ return FailResult(
+ error_message=(
+ "The following secrets were detected in your response:\n"
+ + "\n".join(unique_secrets.keys())
+ ),
+ fix_value=modified_value,
+ )
+ return PassResult()
diff --git a/guardrails/validators/__init__.py b/guardrails/validators/__init__.py
index 81bac20c1..b42eea66f 100644
--- a/guardrails/validators/__init__.py
+++ b/guardrails/validators/__init__.py
@@ -13,6 +13,7 @@
)
from guardrails.validators.bug_free_python import BugFreePython
from guardrails.validators.bug_free_sql import BugFreeSQL
+from guardrails.validators.competitor_check import CompetitorCheck
from guardrails.validators.detect_secrets import DetectSecrets, detect_secrets
from guardrails.validators.endpoint_is_reachable import EndpointIsReachable
from guardrails.validators.ends_with import EndsWith
@@ -75,6 +76,7 @@
"PIIFilter",
"SimilarToList",
"DetectSecrets",
+ "CompetitorCheck",
# Validator helpers
"detect_secrets",
"AnalyzerEngine",
diff --git a/guardrails/validators/competitor_check.py b/guardrails/validators/competitor_check.py
new file mode 100644
index 000000000..714918527
--- /dev/null
+++ b/guardrails/validators/competitor_check.py
@@ -0,0 +1,174 @@
+import re
+from typing import Any, Callable, Dict, List, Optional
+
+from guardrails.logger import logger
+from guardrails.validators import (
+ FailResult,
+ PassResult,
+ ValidationResult,
+ Validator,
+ register_validator,
+)
+
+try:
+ import nltk # type: ignore
+except ImportError:
+ nltk = None # type: ignore
+
+if nltk is not None:
+ try:
+ nltk.data.find("tokenizers/punkt")
+ except LookupError:
+ nltk.download("punkt")
+
+try:
+ import spacy
+except ImportError:
+ spacy = None
+
+
+@register_validator(name="competitor-check", data_type="string")
+class CompetitorCheck(Validator):
+ """Validates that LLM-generated text is not naming any competitors from a
+ given list.
+
+ In order to use this validator you need to provide an extensive list of the
+ competitors you want to avoid naming including all common variations.
+
+ Args:
+ competitors (List[str]): List of competitors you want to avoid naming
+ """
+
+ def __init__(
+ self,
+ competitors: List[str],
+ on_fail: Optional[Callable] = None,
+ ):
+ super().__init__(competitors=competitors, on_fail=on_fail)
+ self._competitors = competitors
+ model = "en_core_web_trf"
+ if spacy is None:
+ raise ImportError(
+ "You must install spacy in order to use the CompetitorCheck validator."
+ )
+
+ if not spacy.util.is_package(model):
+ logger.info(
+ f"Spacy model {model} not installed. "
+ "Download should start now and take a few minutes."
+ )
+ spacy.cli.download(model) # type: ignore
+
+ self.nlp = spacy.load(model)
+
+ def exact_match(self, text: str, competitors: List[str]) -> List[str]:
+ """Performs exact match to find competitors from a list in a given
+ text.
+
+ Args:
+ text (str): The text to search for competitors.
+ competitors (list): A list of competitor entities to match.
+
+ Returns:
+ list: A list of matched entities.
+ """
+
+ found_entities = []
+ for entity in competitors:
+ pattern = rf"\b{re.escape(entity)}\b"
+ match = re.search(pattern.lower(), text.lower())
+ if match:
+ found_entities.append(entity)
+ return found_entities
+
+ def perform_ner(self, text: str, nlp) -> List[str]:
+ """Performs named entity recognition on text using a provided NLP
+ model.
+
+ Args:
+ text (str): The text to perform named entity recognition on.
+ nlp: The NLP model to use for entity recognition.
+
+ Returns:
+ entities: A list of entities found.
+ """
+
+ doc = nlp(text)
+ entities = []
+ for ent in doc.ents:
+ entities.append(ent.text)
+ return entities
+
+ def is_entity_in_list(self, entities: List[str], competitors: List[str]) -> List:
+ """Checks if any entity from a list is present in a given list of
+ competitors.
+
+ Args:
+ entities (list): A list of entities to check
+ competitors (list): A list of competitor names to match
+
+ Returns:
+ List: List of found competitors
+ """
+
+ found_competitors = []
+ for entity in entities:
+ for item in competitors:
+ pattern = rf"\b{re.escape(item)}\b"
+ match = re.search(pattern.lower(), entity.lower())
+ if match:
+ found_competitors.append(item)
+ return found_competitors
+
+ def validate(self, value: str, metadata=Dict) -> ValidationResult:
+ """Checks a text to find competitors' names in it.
+
+ While running, store sentences naming competitors and generate a fixed output
+ filtering out all flagged sentences.
+
+ Args:
+ value (str): The value to be validated.
+ metadata (Dict, optional): Additional metadata. Defaults to empty dict.
+
+ Returns:
+ ValidationResult: The validation result.
+ """
+
+ if nltk is None:
+ raise ImportError(
+ "`nltk` library is required for `competitors-check` validator. "
+ "Please install it with `poetry add nltk`."
+ )
+ sentences = nltk.sent_tokenize(value)
+ flagged_sentences = []
+ filtered_sentences = []
+ list_of_competitors_found = []
+
+ for sentence in sentences:
+ entities = self.exact_match(sentence, self._competitors)
+ if entities:
+ ner_entities = self.perform_ner(sentence, self.nlp)
+ found_competitors = self.is_entity_in_list(ner_entities, entities)
+
+ if found_competitors:
+ flagged_sentences.append((found_competitors, sentence))
+ list_of_competitors_found.append(found_competitors)
+ logger.debug(f"Found: {found_competitors} named in '{sentence}'")
+ else:
+ filtered_sentences.append(sentence)
+
+ else:
+ filtered_sentences.append(sentence)
+
+ filtered_output = " ".join(filtered_sentences)
+
+ if len(flagged_sentences):
+ return FailResult(
+ error_message=(
+ f"Found the following competitors: {list_of_competitors_found}. "
+ "Please avoid naming those competitors next time"
+ ),
+ fix_value=filtered_output,
+ )
+ else:
+ return PassResult()