diff --git a/docs/reference/functions.md b/docs/reference/functions.md deleted file mode 100644 index d29213a89..000000000 --- a/docs/reference/functions.md +++ /dev/null @@ -1 +0,0 @@ -# Outlines functions diff --git a/docs/reference/regex_dsl.md b/docs/reference/regex_dsl.md new file mode 100644 index 000000000..fc86262c7 --- /dev/null +++ b/docs/reference/regex_dsl.md @@ -0,0 +1,229 @@ +# DSL to express constraints + +This library provides a Domain-Specific Language (DSL) to construct regular expressions in a more intuitive and modular way. It allows you to create complex regexes using simple building blocks that represent literal strings, patterns, and various quantifiers. Additionally, these custom regex types can be used directly as types in [Pydantic](https://pydantic-docs.helpmanual.io/) schemas to enforce pattern constraints on your data. + +--- + +## Why Use This DSL? + +1. **Modularity & Readability**: Instead of writing cryptic regular expression strings, you compose a regex as a tree of objects. +2. **Enhanced Debugging**: Each expression can be visualized as an ASCII tree, making it easier to understand and debug complex regexes. +3. **Pydantic Integration**: Use your DSL-defined regex as types in Pydantic models. The DSL seamlessly converts to JSON Schema with proper pattern constraints. +4. **Extensibility**: Easily add or modify quantifiers and other regex components by extending the provided classes. + +--- + +## Building Blocks + + +Every regex component in this DSL is a **Term**. Here are two primary types: + +- **`String`**: Represents a literal string. +- **`Regex`**: Represents an existing regex pattern string. + +```python +from outlines.types import String, Regex + +# A literal string "hello" +literal = String("hello") # Internally represents "hello" + +# A regex pattern to match one or more digits +digit = Regex(r"[0-9]+") # Internally represents the pattern [0-9]+ + +# Converting to standard regex strings: +from outlines.types.regex import to_regex + +print(to_regex(literal)) # Output: hello +print(to_regex(digit)) # Output: [0-9]+ +``` + +--- + +## Early Introduction to Quantifiers & Operators + +The DSL supports common regex quantifiers as methods on every `Term`. These methods allow you to specify how many times a pattern should be matched. They include: + +- **`times(count)`**: Matches the term exactly `count` times. +- **`optional()`**: Matches the term zero or one time. +- **`one_or_more()`**: Matches the term one or more times (Kleene Plus). +- **`zero_or_more()`**: Matches the term zero or more times (Kleene Star). +- **`repeat(min_count, max_count)`**: Matches the term between `min_count` and `max_count` times (or open-ended if one value is omitted). + +Let’s see these quantifiers side by side with examples. + +### Quantifiers in Action + +#### `times(count)` + +This method restricts the term to appear exactly `count` times. + +```python +# Example: exactly 5 digits +five_digits = Regex(r"\d").times(5) +print(to_regex(five_digits)) # Output: (\d){5} +``` + +#### `optional()` + +The `optional()` method makes a term optional, meaning it may occur zero or one time. + +```python +# Example: an optional "s" at the end of a word +maybe_s = String("s").optional() +print(to_regex(maybe_s)) # Output: (s)? +``` + +#### `one_or_more()` + +This method indicates that the term must appear at least once. + +```python +# Example: one or more alphabetic characters +letters = Regex(r"[A-Za-z]").one_or_more() +print(to_regex(letters)) # Output: ([A-Za-z])+ +``` + +#### `zero_or_more()` + +This method means that the term can occur zero or more times. + +```python +# Example: zero or more spaces +spaces = String(" ").zero_or_more() +print(to_regex(spaces)) # Output: ( )* +``` + +#### `repeat(min_count, max_count)` + +The `repeat` method provides flexibility to set a lower and/or upper bound on the number of occurrences. + +```python +# Example: Between 2 and 4 word characters +word_chars = Regex(r"\w").repeat(2, 4) +print(to_regex(word_chars)) # Output: (\w){2,4} + +# Example: At least 3 digits (min specified, max left open) +at_least_three = Regex(r"\d").repeat(3, None) +print(to_regex(at_least_three)) # Output: (\d){3,} + +# Example: Up to 2 punctuation marks (max specified, min omitted) +up_to_two = Regex(r"[,.]").repeat(None, 2) +print(to_regex(up_to_two)) # Output: ([,.]){,2} +``` + +--- + +## Combining Terms + +The DSL allows you to combine basic terms into more complex patterns using concatenation and alternation. + +### Concatenation (`+`) + +The `+` operator (and its reflected variant) concatenates terms, meaning that the terms are matched in sequence. + +```python +# Example: Match "hello world" +pattern = String("hello") + " " + String("world") +print(to_regex(pattern)) # Output: hello\ world +``` + +### Alternation (`|`) + +The `|` operator creates alternatives, allowing a match for one of several patterns. + +```python +# Example: Match either "cat" or "dog" +animal = String("cat") | "dog" +print(to_regex(animal)) # Output: (cat|dog) +``` + +*Note:* When using operators with plain strings (such as `"dog"`), the DSL automatically wraps them in a `String` object. + +--- + +## Practical Examples + +### Example 1: Matching a Custom ID Format + +Suppose you want to create a regex that matches an ID format like "ID-12345", where: +- The literal "ID-" must be at the start. +- Followed by exactly 5 digits. + +```python +id_pattern = "ID-" + Regex(r"\d").times(5) +print(to_regex(id_pattern)) # Output: ID-(\d){5} +``` + +### Example 2: Email Validation with Pydantic + +You can define a regex for email validation and use it as a type in a Pydantic model. + +```python +from pydantic import BaseModel, ValidationError + +# Define an email regex term (this is a simplified version) +email_regex = Regex(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+") + +class User(BaseModel): + name: str + email: email_regex # Use our DSL regex as a field type + +# Valid input +user = User(name="Alice", email="alice@example.com") +print(user) + +# Invalid input (raises a ValidationError) +try: + User(name="Bob", email="not-an-email") +except ValidationError as e: + print(e) +``` + +When used in a Pydantic model, the email field is automatically validated against the regex pattern and its JSON Schema includes the `pattern` constraint. + +### Example 3: Building a Complex Pattern + +Consider a pattern to match a simple date format: `YYYY-MM-DD`. + +```python +year = Regex(r"\d").times(4) # Four digits for the year +month = Regex(r"\d").times(2) # Two digits for the month +day = Regex(r"\d").times(2) # Two digits for the day + +# Combine with literal hyphens +date_pattern = year + "-" + month + "-" + day +print(to_regex(date_pattern)) +# Output: (\d){4}\-(\d){2}\-(\d){2} +``` + +--- + +## Visualizing Your Regex + +One of the unique features of this DSL is that each term can print its underlying structure as an ASCII tree. This visualization can be particularly helpful when dealing with complex expressions. + +```python +# A composite pattern using concatenation and quantifiers +pattern = "a" + String("b").one_or_more() + "c" +print(pattern) +``` + +*Expected Output:* + +``` +└── Sequence + ├── String('a') + ├── KleenePlus(+) + │ └── String('b') + └── String('c') +``` + +This tree representation makes it easy to see the hierarchy and order of operations in your regular expression. + +--- + +## Final Words + +This DSL is designed to simplify the creation and management of regular expressions—whether you're validating inputs in a web API, constraining the output of an LLM, or just experimenting with regex patterns. With intuitive methods for common quantifiers and operators, clear visual feedback, and built-in integration with Pydantic, you can build robust and maintainable regex-based validations with ease. + +Feel free to explore the library further and adapt the examples to your use cases. Happy regexing! diff --git a/mkdocs.yml b/mkdocs.yml index 8d051d1af..2c31494da 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -133,6 +133,7 @@ nav: - Generation: - Overview: reference/generation/generation.md - Chat templating: reference/chat_templating.md + - Regex DSL: reference/regex_dsl.md - Text: reference/text.md - Samplers: reference/samplers.md - Structured generation: diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index cdb1d2115..d44a009ec 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -2,14 +2,9 @@ from enum import EnumMeta from typing import Any, Protocol, Tuple, Type -from typing_extensions import _AnnotatedAlias, get_args - -INTEGER = r"[+-]?(0|[1-9][0-9]*)" -BOOLEAN = "(True|False)" -FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?" -DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])" -TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])" -DATETIME = rf"({DATE})(\s)({TIME})" +from outlines.types import Regex, boolean, date +from outlines.types import datetime as datetime_type +from outlines.types import integer, number, time class FormatFunction(Protocol): @@ -17,18 +12,15 @@ def __call__(self, sequence: str) -> Any: ... -def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]: +def python_types_to_regex(python_type: Type) -> Tuple[Regex, FormatFunction]: # If it is a custom type - if isinstance(python_type, _AnnotatedAlias): - json_schema = get_args(python_type)[1].json_schema - type_class = get_args(python_type)[0] - - custom_regex_str = json_schema["pattern"] + if isinstance(python_type, Regex): + custom_regex_str = python_type.pattern - def custom_format_fn(sequence: str) -> Any: - return type_class(sequence) + def custom_format_fn(sequence: str) -> str: + return str(sequence) - return custom_regex_str, custom_format_fn + return Regex(custom_regex_str), custom_format_fn if isinstance(python_type, EnumMeta): values = python_type.__members__.keys() @@ -37,44 +29,44 @@ def custom_format_fn(sequence: str) -> Any: def enum_format_fn(sequence: str) -> str: return str(sequence) - return enum_regex_str, enum_format_fn + return Regex(enum_regex_str), enum_format_fn if python_type is float: def float_format_fn(sequence: str) -> float: return float(sequence) - return FLOAT, float_format_fn + return number, float_format_fn elif python_type is int: def int_format_fn(sequence: str) -> int: return int(sequence) - return INTEGER, int_format_fn + return integer, int_format_fn elif python_type is bool: def bool_format_fn(sequence: str) -> bool: return bool(sequence) - return BOOLEAN, bool_format_fn + return boolean, bool_format_fn elif python_type == datetime.date: def date_format_fn(sequence: str) -> datetime.date: return datetime.datetime.strptime(sequence, "%Y-%m-%d").date() - return DATE, date_format_fn + return date, date_format_fn elif python_type == datetime.time: def time_format_fn(sequence: str) -> datetime.time: return datetime.datetime.strptime(sequence, "%H:%M:%S").time() - return TIME, time_format_fn + return time, time_format_fn elif python_type == datetime.datetime: def datetime_format_fn(sequence: str) -> datetime.datetime: return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S") - return DATETIME, datetime_format_fn + return datetime_type, datetime_format_fn else: raise NotImplementedError( f"The Python type {python_type} is not supported. Please open an issue." diff --git a/outlines/generate/format.py b/outlines/generate/format.py index 88acec75f..56ed10ac1 100644 --- a/outlines/generate/format.py +++ b/outlines/generate/format.py @@ -33,6 +33,7 @@ def format( """ regex_str, format_fn = python_types_to_regex(python_type) + regex_str = regex_str.pattern generator = regex(model, regex_str, sampler) generator.format_sequence = format_fn diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 673880e49..326701f4b 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -6,10 +6,11 @@ ) from outlines.models import OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial +from outlines.types import Regex @singledispatch -def regex(model, regex_str: str, sampler: Sampler = multinomial()): +def regex(model, regex_str: str | Regex, sampler: Sampler = multinomial()): """Generate structured text in the language of a regular expression. Parameters @@ -31,6 +32,9 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): """ from outlines.processors import RegexLogitsProcessor + if isinstance(regex_str, Regex): + regex_str = regex_str.pattern + logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer) return SequenceGeneratorAdapter(model, logits_processor, sampler) @@ -38,11 +42,14 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): @regex.register(TransformersVision) def regex_vision( model, - regex_str: str, + regex_str: str | Regex, sampler: Sampler = multinomial(), ): from outlines.processors import RegexLogitsProcessor + if isinstance(regex_str, Regex): + regex_str = regex_str.pattern + logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer) return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index f4d2b8cd3..76885da0d 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,4 +1,39 @@ -from . import airports, countries -from .email import Email -from .isbn import ISBN -from .locales import locale +from enum import Enum + +from . import airports, countries, locale +from .dsl import Regex, json_schema, one_or_more, optional, regex, repeat, zero_or_more + +# Python types +integer = Regex(r"[+-]?(0|[1-9][0-9]*)") +boolean = Regex("(True|False)") +number = Regex(rf"{integer.pattern}(\.[0-9]+)?([eE][+-][0-9]+)?") +date = Regex(r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])") +time = Regex(r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])") +datetime = Regex(rf"({date.pattern})(\s)({time.pattern})") + +# Basic regex types +digit = Regex(r"\d") +char = Regex(r"\w") +newline = Regex(r"(\r\n|\r|\n)") # Matched new lines on Linux, Windows & MacOS +whitespace = Regex(r"\s") + +# Document-specific types +sentence = Regex(r"[A-Z].*\s*[.!?]") +paragraph = Regex(rf"{sentence.pattern}(?:\s+{sentence.pattern})*\n+") + + +# The following regex is FRC 5322 compliant and was found at: +# https://emailregex.com/ +email = Regex( + r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" +) + +# Matches any ISBN number. Note that this is not completely correct as not all +# 10 or 13 digits numbers are valid ISBNs. See https://en.wikipedia.org/wiki/ISBN +# Taken from O'Reilly's Regular Expression Cookbook: +# https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch04s13.html +# +# TODO: The check digit can only be computed by calling a function to compute it dynamically +isbn = Regex( + r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]" +) diff --git a/outlines/types/airports.py b/outlines/types/airports.py index 934ae1844..ec0ef72bd 100644 --- a/outlines/types/airports.py +++ b/outlines/types/airports.py @@ -6,6 +6,4 @@ AIRPORT_IATA_LIST = [ (v["iata"], v["iata"]) for v in airportsdata.load().values() if v["iata"] ] - - IATA = Enum("Airport", AIRPORT_IATA_LIST) # type:ignore diff --git a/outlines/types/dsl.py b/outlines/types/dsl.py new file mode 100644 index 000000000..c86136d28 --- /dev/null +++ b/outlines/types/dsl.py @@ -0,0 +1,359 @@ +import json as json +import re +from dataclasses import dataclass +from typing import Any, List, Union + +from pydantic import BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema as cs + + +class Term: + """Represents types defined with a regular expression. + + `Regex` instances can be used as a type in a Pydantic model definittion. + They will be translated to JSON Schema as a "string" field with the + "pattern" keyword set to the regular expression this class represents. The + class also handles validation. + + Examples + -------- + + >>> from outlines.types import Regex + >>> from pydantic import BaseModel + >>> + >>> age_type = Regex("[0-9]+") + >>> + >>> class User(BaseModel): + >>> name: str + >>> age: age_type + + """ + + def __add__(self: "Term", other: Union[str, "Term"]) -> "Sequence": + if isinstance(other, str): + other = String(other) + + return Sequence([self, other]) + + def __radd__(self: "Term", other: Union[str, "Term"]) -> "Sequence": + if isinstance(other, str): + other = String(other) + + return Sequence([other, self]) + + def __or__(self: "Term", other: Union[str, "Term"]) -> "Alternatives": + if isinstance(other, str): + other = String(other) + + return Alternatives([self, other]) + + def __ror__(self: "Term", other: Union[str, "Term"]) -> "Alternatives": + if isinstance(other, str): + other = String(other) + + return Alternatives([other, self]) + + def __get_validator__(self, _core_schema): + def validate(input_value): + return self.validate(input_value) + + return validate + + def __get_pydantic_core_schema__( + self, source_type: Any, handler: GetCoreSchemaHandler + ) -> cs.CoreSchema: + return cs.no_info_plain_validator_function(lambda value: self.validate(value)) + + def __get_pydantic_json_schema__( + self, core_schema: cs.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + from outlines.types.regex import to_regex + + return {"type": "string", "pattern": to_regex(self)} + + def validate(self, value: str) -> str: + from outlines.types.regex import to_regex + + pattern = to_regex(self) + compiled = re.compile(pattern) + if not compiled.fullmatch(str(value)): + raise ValueError( + f"Input should be in the language of the regular expression {pattern}" + ) + return value + + def matches(self, value: str) -> bool: + """Check that a given value is in the language defined by the Term. + + We make the assumption that the language defined by the term can + be defined with a regular expression. + + """ + from outlines.types.regex import to_regex + + pattern = to_regex(self) + compiled = re.compile(pattern) + if compiled.fullmatch(str(value)): + return True + return False + + def display_ascii_tree(self, indent="", is_last=True) -> str: + """Display the regex tree in ASCII format.""" + branch = "└── " if is_last else "├── " + result = indent + branch + self._display_node() + "\n" + + # Calculate the new indent for children + new_indent = indent + (" " if is_last else "│ ") + + # Let each subclass handle its children + result += self._display_children(new_indent) + return result + + def _display_node(self): + raise NotImplementedError + + def _display_children(self, indent: str) -> str: + """Display the children of this node. Override in subclasses with children.""" + return "" + + def __str__(self): + return self.display_ascii_tree() + + +@dataclass +class String(Term): + value: str + + def _display_node(self) -> str: + return f"String('{self.value}')" + + def __repr__(self): + return f"String(value='{self.value}')" + + +@dataclass +class Regex(Term): + pattern: str + + def _display_node(self) -> str: + return f"Regex('{self.pattern}')" + + def __repr__(self): + return f"Regex(pattern='{self.pattern}')" + + +class JsonSchema(Term): + def __init__(self, schema: Union[dict, str, type[BaseModel]]): + if isinstance(schema, dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + elif issubclass(schema, BaseModel): + schema_str = json.dumps(schema.model_json_schema()) + else: + raise ValueError( + f"Cannot parse schema {json_schema}. The schema must be either " + + "a Pydantic class, a dictionary or a string that contains the JSON " + + "schema specification" + ) + + self.schema = schema_str + + def _display_node(self) -> str: + return f"JsonSchema('{self.schema}')" + + def __repr__(self): + return f"JsonSchema(schema='{self.schema}')" + + +@dataclass +class KleeneStar(Term): + term: Term + + def _display_node(self) -> str: + return "KleeneStar(*)" + + def _display_children(self, indent: str) -> str: + return self.term.display_ascii_tree(indent, True) + + def __repr__(self): + return f"KleeneStar(term={repr(self.term)})" + + +@dataclass +class KleenePlus(Term): + term: Term + + def _display_node(self) -> str: + return "KleenePlus(+)" + + def _display_children(self, indent: str) -> str: + return self.term.display_ascii_tree(indent, True) + + def __repr__(self): + return f"KleenePlus(term={repr(self.term)})" + + +@dataclass +class Optional(Term): + term: Term + + def _display_node(self) -> str: + return "Optional(?)" + + def _display_children(self, indent: str) -> str: + return self.term.display_ascii_tree(indent, True) + + def __repr__(self): + return f"Optional(term={repr(self.term)})" + + +@dataclass +class Alternatives(Term): + terms: List[Term] + + def _display_node(self) -> str: + return "Alternatives(|)" + + def _display_children(self, indent: str) -> str: + return "".join( + term.display_ascii_tree(indent, i == len(self.terms) - 1) + for i, term in enumerate(self.terms) + ) + + def __repr__(self): + return f"Alternatives(terms={repr(self.terms)})" + + +@dataclass +class Sequence(Term): + terms: List[Term] + + def _display_node(self) -> str: + return "Sequence" + + def _display_children(self, indent: str) -> str: + return "".join( + term.display_ascii_tree(indent, i == len(self.terms) - 1) + for i, term in enumerate(self.terms) + ) + + def __repr__(self): + return f"Sequence(terms={repr(self.terms)})" + + +@dataclass +class QuantifyExact(Term): + term: Term + count: int + + def _display_node(self) -> str: + return f"Quantify({{{self.count}}})" + + def _display_children(self, indent: str) -> str: + return self.term.display_ascii_tree(indent, True) + + def __repr__(self): + return f"QuantifyExact(term={repr(self.term)}, count={repr(self.count)})" + + +@dataclass +class QuantifyMinimum(Term): + term: Term + min_count: int + + def _display_node(self) -> str: + return f"Quantify({{{self.min_count},}})" + + def _display_children(self, indent: str) -> str: + return self.term.display_ascii_tree(indent, True) + + def __repr__(self): + return ( + f"QuantifyMinimum(term={repr(self.term)}, min_count={repr(self.min_count)})" + ) + + +@dataclass +class QuantifyMaximum(Term): + term: Term + max_count: int + + def _display_node(self) -> str: + return f"Quantify({{,{self.max_count}}})" + + def _display_children(self, indent: str) -> str: + return self.term.display_ascii_tree(indent, True) + + def __repr__(self): + return ( + f"QuantifyMaximum(term={repr(self.term)}, max_count={repr(self.max_count)})" + ) + + +@dataclass +class QuantifyBetween(Term): + term: Term + min_count: int + max_count: int + + def __post_init__(self): + if self.min_count > self.max_count: + raise ValueError( + "QuantifyBetween: `max_count` must be greater than `min_count`." + ) + + def _display_node(self) -> str: + return f"Quantify({{{self.min_count},{self.max_count}}})" + + def _display_children(self, indent: str) -> str: + return self.term.display_ascii_tree(indent, True) + + def __repr__(self): + return f"QuantifyBetween(term={repr(self.term)}, min_count={repr(self.min_count)}, max_count={repr(self.max_count)})" + + +def optional(self: Term) -> Optional: + return Optional(self) + + +def one_or_more(self: Term) -> KleenePlus: + return KleenePlus(self) + + +def repeat(self: Term, min_count: int, max_count: int) -> QuantifyBetween: + match (min_count, max_count): + case (None, None): + raise ValueError( + "repeat: you must provide a value for at least `min_count` or `max_count`" + ) + case (_, None): + return QuantifyMinimum(self, min_count) + case (None, _): + return QuantifyMaximum(self, max_count) + case _: + return QuantifyBetween(self, min_count, max_count) + + +def times(self: Term, count: int = 0) -> QuantifyExact: + return QuantifyExact(self, count) + + +def zero_or_more(self: Term) -> KleeneStar: + return KleeneStar(self) + + +Term.one_or_more = one_or_more # type: ignore +Term.optional = optional # type: ignore +Term.repeat = repeat # type: ignore +Term.times = times # type: ignore +Term.zero_or_more = zero_or_more # type: ignore + + +def regex(pattern: str): + return Regex(pattern) + + +def json_schema(schema: Union[str, dict, type[BaseModel]]): + return JsonSchema(schema) diff --git a/outlines/types/email.py b/outlines/types/email.py deleted file mode 100644 index 45f8c4b2c..000000000 --- a/outlines/types/email.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Email Address types.""" -from pydantic import WithJsonSchema -from typing_extensions import Annotated - -# Taken from StackOverflow -# https://stackoverflow.com/a/201378/14773537 -EMAIL_REGEX = r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" -Email = Annotated[ - str, - WithJsonSchema({"type": "string", "pattern": EMAIL_REGEX}), -] diff --git a/outlines/types/isbn.py b/outlines/types/isbn.py deleted file mode 100644 index 5aebb067e..000000000 --- a/outlines/types/isbn.py +++ /dev/null @@ -1,12 +0,0 @@ -"""ISBN type""" -from pydantic import WithJsonSchema -from typing_extensions import Annotated - -# Matches any ISBN number. Note that this is not completely correct as not all -# 10 or 13 digits numbers are valid ISBNs. See https://en.wikipedia.org/wiki/ISBN -# Taken from O'Reilly's Regular Expression Cookbook: -# https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch04s13.html -# TODO: Can this be represented by a grammar or do we need semantic checks? -ISBN_REGEX = r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]" - -ISBN = Annotated[str, WithJsonSchema({"type": "string", "pattern": ISBN_REGEX})] diff --git a/outlines/types/locale/__init__.py b/outlines/types/locale/__init__.py new file mode 100644 index 000000000..511631d84 --- /dev/null +++ b/outlines/types/locale/__init__.py @@ -0,0 +1 @@ +from . import us diff --git a/outlines/types/locale/us.py b/outlines/types/locale/us.py new file mode 100644 index 000000000..0bda82d44 --- /dev/null +++ b/outlines/types/locale/us.py @@ -0,0 +1,4 @@ +from outlines.types.dsl import Regex + +zip_code = Regex(r"\d{5}(?:-\d{4})?") +phone_number = Regex(r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}") diff --git a/outlines/types/locales.py b/outlines/types/locales.py deleted file mode 100644 index c5d251bae..000000000 --- a/outlines/types/locales.py +++ /dev/null @@ -1,21 +0,0 @@ -from dataclasses import dataclass - -from outlines.types.phone_numbers import USPhoneNumber -from outlines.types.zip_codes import USZipCode - - -@dataclass -class US: - ZipCode = USZipCode - PhoneNumber = USPhoneNumber - - -def locale(locale_str: str): - locales = {"us": US} - - if locale_str not in locales: - raise NotImplementedError( - f"The locale {locale_str} is not supported yet. Please don't hesitate to create custom types for you locale and open a Pull Request." - ) - - return locales[locale_str] diff --git a/outlines/types/phone_numbers.py b/outlines/types/phone_numbers.py deleted file mode 100644 index 618687e75..000000000 --- a/outlines/types/phone_numbers.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Phone number types. - -We currently only support US phone numbers. We can however imagine having custom types -for each country, for instance leveraging the `phonenumbers` library. - -""" -from pydantic import WithJsonSchema -from typing_extensions import Annotated - -US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}" - - -USPhoneNumber = Annotated[ - str, - WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}), -] diff --git a/outlines/types/regex.py b/outlines/types/regex.py new file mode 100644 index 000000000..6f340118c --- /dev/null +++ b/outlines/types/regex.py @@ -0,0 +1,48 @@ +import re +from typing import Any, Union +import typing +import json + +from outlines_core.fsm.json_schema import build_regex_from_schema + +from .dsl import String, Regex, KleeneStar, KleenePlus, Optional, Alternatives, Sequence, QuantifyExact, QuantifyBetween, QuantifyMinimum, QuantifyMaximum, Term, JsonSchema + + +def to_regex(term: Term) -> str: + """Convert a term to a regular expression. + + We only consider self-contained terms that do not refer to another rule. + + """ + match term: + case String(): + return re.escape(term.value) + case Regex(): + return f"({term.pattern})" + case JsonSchema(): + regex_str = build_regex_from_schema(term.schema) + return f"({regex_str})" + case KleeneStar(): + return f"({to_regex(term.term)})*" + case KleenePlus(): + return f"({to_regex(term.term)})+" + case Optional(): + return f"({to_regex(term.term)})?" + case Alternatives(): + regexes = [to_regex(subterm) for subterm in term.terms] + return f"({'|'.join(regexes)})" + case Sequence(): + regexes = [to_regex(subterm) for subterm in term.terms] + return f"{''.join(regexes)}" + case QuantifyExact(): + return f"({to_regex(term.term)}){{{term.count}}}" + case QuantifyMinimum(): + return f"({to_regex(term.term)}){{{term.min_count},}}" + case QuantifyMaximum(): + return f"({to_regex(term.term)}){{,{term.max_count}}}" + case QuantifyBetween(): + return f"({to_regex(term.term)}){{{term.min_count},{term.max_count}}}" + case _: + raise TypeError( + f"Cannot convert object {repr(term)} to a regular expression." + ) diff --git a/outlines/types/zip_codes.py b/outlines/types/zip_codes.py deleted file mode 100644 index 67d994d5c..000000000 --- a/outlines/types/zip_codes.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Zip code types. - -We currently only support US Zip Codes. - -""" -from pydantic import WithJsonSchema -from typing_extensions import Annotated - -# This matches Zip and Zip+4 codes -US_ZIP_CODE = r"\d{5}(?:-\d{4})?" - - -USZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] diff --git a/pyproject.toml b/pyproject.toml index e76f4f7b9..e60205242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ module = [ "cloudpickle.*", "diskcache.*", "pydantic.*", + "pydantic_core.*", "pytest", "referencing.*", "torch.*", diff --git a/tests/fsm/test_types.py b/tests/fsm/test_types.py index d5450434c..4017a817c 100644 --- a/tests/fsm/test_types.py +++ b/tests/fsm/test_types.py @@ -1,29 +1,22 @@ -import datetime +import datetime as pydatetime import pytest -from outlines.fsm.types import ( - BOOLEAN, - DATE, - DATETIME, - FLOAT, - INTEGER, - TIME, - python_types_to_regex, -) +from outlines.fsm.types import python_types_to_regex +from outlines import types @pytest.mark.parametrize( - "python_type,regex", + "python_type,custom_type", [ - (int, INTEGER), - (float, FLOAT), - (bool, BOOLEAN), - (datetime.date, DATE), - (datetime.time, TIME), - (datetime.datetime, DATETIME), + (int, types.integer), + (float, types.number), + (bool, types.boolean), + (pydatetime.date, types.date), + (pydatetime.time, types.time), + (pydatetime.datetime, types.datetime), ], ) -def test_python_types(python_type, regex): +def test_python_types(python_type, custom_type): test_regex, _ = python_types_to_regex(python_type) - assert regex == test_regex + assert custom_type.pattern == test_regex.pattern diff --git a/tests/test_types.py b/tests/test_types.py deleted file mode 100644 index 5e60348b2..000000000 --- a/tests/test_types.py +++ /dev/null @@ -1,103 +0,0 @@ -import re - -import pytest -from pydantic import BaseModel - -from outlines import types -from outlines.fsm.types import python_types_to_regex - - -@pytest.mark.parametrize( - "custom_type,test_string,should_match", - [ - (types.phone_numbers.USPhoneNumber, "12", False), - (types.phone_numbers.USPhoneNumber, "(123) 123-1234", True), - (types.phone_numbers.USPhoneNumber, "123-123-1234", True), - (types.zip_codes.USZipCode, "12", False), - (types.zip_codes.USZipCode, "12345", True), - (types.zip_codes.USZipCode, "12345-1234", True), - (types.ISBN, "ISBN 0-1-2-3-4-5", False), - (types.ISBN, "ISBN 978-0-596-52068-7", True), - # (types.ISBN, "ISBN 978-0-596-52068-1", True), wrong check digit - (types.ISBN, "ISBN-13: 978-0-596-52068-7", True), - (types.ISBN, "978 0 596 52068 7", True), - (types.ISBN, "9780596520687", True), - (types.ISBN, "ISBN-10: 0-596-52068-9", True), - (types.ISBN, "0-596-52068-9", True), - (types.Email, "eitan@gmail.com", True), - (types.Email, "99@yahoo.com", True), - (types.Email, "eitan@.gmail.com", False), - (types.Email, "myemail", False), - (types.Email, "eitan@gmail", False), - (types.Email, "eitan@my.custom.domain", True), - ], -) -def test_type_regex(custom_type, test_string, should_match): - class Model(BaseModel): - attr: custom_type - - schema = Model.model_json_schema() - assert schema["properties"]["attr"]["type"] == "string" - regex_str = schema["properties"]["attr"]["pattern"] - does_match = re.match(regex_str, test_string) is not None - assert does_match is should_match - - regex_str, format_fn = python_types_to_regex(custom_type) - assert isinstance(format_fn(1), str) - does_match = re.match(regex_str, test_string) is not None - assert does_match is should_match - - -def test_locale_not_implemented(): - with pytest.raises(NotImplementedError): - types.locale("fr") - - -@pytest.mark.parametrize( - "locale_str,base_types,locale_types", - [ - ( - "us", - ["ZipCode", "PhoneNumber"], - [types.zip_codes.USZipCode, types.phone_numbers.USPhoneNumber], - ) - ], -) -def test_locale(locale_str, base_types, locale_types): - for base_type, locale_type in zip(base_types, locale_types): - type = getattr(types.locale(locale_str), base_type) - assert type == locale_type - - -@pytest.mark.parametrize( - "custom_type,test_string,should_match", - [ - (types.airports.IATA, "CDG", True), - (types.airports.IATA, "XXX", False), - (types.countries.Alpha2, "FR", True), - (types.countries.Alpha2, "XX", False), - (types.countries.Alpha3, "UKR", True), - (types.countries.Alpha3, "XXX", False), - (types.countries.Numeric, "004", True), - (types.countries.Numeric, "900", False), - (types.countries.Name, "Ukraine", True), - (types.countries.Name, "Wonderland", False), - (types.countries.Flag, "🇿🇼", True), - (types.countries.Flag, "🤗", False), - ], -) -def test_type_enum(custom_type, test_string, should_match): - type_name = custom_type.__name__ - - class Model(BaseModel): - attr: custom_type - - schema = Model.model_json_schema() - assert isinstance(schema["$defs"][type_name]["enum"], list) - does_match = test_string in schema["$defs"][type_name]["enum"] - assert does_match is should_match - - regex_str, format_fn = python_types_to_regex(custom_type) - assert isinstance(format_fn(1), str) - does_match = re.match(regex_str, test_string) is not None - assert does_match is should_match diff --git a/tests/types/test_custom_types.py b/tests/types/test_custom_types.py new file mode 100644 index 000000000..f2686f1bf --- /dev/null +++ b/tests/types/test_custom_types.py @@ -0,0 +1,110 @@ +import re + +import pytest +from pydantic import BaseModel + +from outlines import types +from outlines.fsm.types import python_types_to_regex + + +@pytest.mark.parametrize( + "custom_type,test_string,should_match", + [ + (types.locale.us.phone_number, "12", False), + (types.locale.us.phone_number, "(123) 123-1234", True), + (types.locale.us.phone_number, "123-123-1234", True), + (types.locale.us.zip_code, "12", False), + (types.locale.us.zip_code, "12345", True), + (types.locale.us.zip_code, "12345-1234", True), + (types.isbn, "ISBN 0-1-2-3-4-5", False), + (types.isbn, "ISBN 978-0-596-52068-7", True), + (types.isbn, "ISBN-13: 978-0-596-52068-7", True), + (types.isbn, "978 0 596 52068 7", True), + (types.isbn, "9780596520687", True), + (types.isbn, "ISBN-10: 0-596-52068-9", True), + (types.isbn, "0-596-52068-9", True), + (types.email, "eitan@gmail.com", True), + (types.email, "99@yahoo.com", True), + (types.email, "eitan@.gmail.com", False), + (types.email, "myemail", False), + (types.email, "eitan@gmail", False), + (types.email, "eitan@my.custom.domain", True), + (types.integer, "-19", True), + (types.integer, "19", True), + (types.integer, "019", False), + (types.integer, "1.9", False), + (types.integer, "a", False), + (types.boolean, "True", True), + (types.boolean, "False", True), + (types.boolean, "true", False), + (types.number, "10", True), + (types.number, "10.9", True), + (types.number, "10.9e+3", True), + (types.number, "10.9e-3", True), + (types.number, "a", False), + (types.date, "2022-03-23", True), + (types.date, "2022-03-32", False), + (types.date, "2022-13-23", False), + (types.date, "32-03-2022", False), + (types.time, "01:23:59", True), + (types.time, "01:23:61", False), + (types.time, "01:61:59", False), + (types.time, "24:23:59", False), + (types.sentence, "The temperature is 23.5 degrees !", True), + (types.sentence, "Did you earn $1,234.56 last month ?", True), + (types.sentence, "The #1 player scored 100 points .", True), + (types.sentence, "Hello @world, this is a test!", True), + (types.sentence, "invalid sentence.", False), + (types.sentence, "Invalid sentence", False), + (types.paragraph, "This is a paragraph!\n", True), + (types.paragraph, "Line1\nLine2", False), + (types.paragraph, "One sentence. Two sentences.\n\n", True), + (types.paragraph, "One sentence. invalid sentence.", False), + (types.paragraph, "One sentence. Invalid sentence\n", False), + ], +) +def test_type_regex(custom_type, test_string, should_match): + class Model(BaseModel): + attr: custom_type + + schema = Model.model_json_schema() + assert schema["properties"]["attr"]["type"] == "string" + regex_str = schema["properties"]["attr"]["pattern"] + does_match = re.fullmatch(regex_str, test_string) is not None + assert does_match is should_match + + regex_str = types.regex.to_regex(custom_type) + does_match = re.fullmatch(regex_str, test_string) is not None + assert does_match is should_match + + +@pytest.mark.parametrize( + "custom_type,test_string,should_match", + [ + (types.airports.IATA, "CDG", True), + (types.airports.IATA, "XXX", False), + (types.countries.Alpha2, "FR", True), + (types.countries.Alpha2, "XX", False), + (types.countries.Alpha3, "UKR", True), + (types.countries.Alpha3, "XXX", False), + (types.countries.Numeric, "004", True), + (types.countries.Numeric, "900", False), + (types.countries.Name, "Ukraine", True), + (types.countries.Name, "Wonderland", False), + (types.countries.Flag, "🇿🇼", True), + (types.countries.Flag, "🤗", False), + ], +) +def test_type_enum(custom_type, test_string, should_match): + type_name = custom_type.__name__ + + class Model(BaseModel): + attr: custom_type + + schema = Model.model_json_schema() + assert isinstance(schema["$defs"][type_name]["enum"], list) + does_match = test_string in schema["$defs"][type_name]["enum"] + assert does_match is should_match + + does_match = test_string in custom_type.__members__ + assert does_match is should_match diff --git a/tests/types/test_dsl.py b/tests/types/test_dsl.py new file mode 100644 index 000000000..dc2779ef1 --- /dev/null +++ b/tests/types/test_dsl.py @@ -0,0 +1,256 @@ +import pytest +from pydantic import BaseModel + +from outlines.types.dsl import ( + Alternatives, + JsonSchema, + KleenePlus, + KleeneStar, + Optional, + QuantifyBetween, + QuantifyExact, + QuantifyMaximum, + QuantifyMinimum, + Regex, + Sequence, + String, + Term, + one_or_more, + optional, + repeat, + times, + regex, + json_schema, + zero_or_more, +) + + +def test_dsl_init(): + string = String("test") + assert string.value == "test" + assert repr(string) == "String(value='test')" + + regex = Regex("[0-9]") + assert regex.pattern == "[0-9]" + assert repr(regex) == "Regex(pattern='[0-9]')" + + schema = JsonSchema('{ "type": "string" }') + assert schema.schema == '{ "type": "string" }' + assert repr(schema) == 'JsonSchema(schema=\'{ "type": "string" }\')' + + kleene_star = KleeneStar(string) + assert kleene_star.term == string + assert repr(kleene_star) == "KleeneStar(term=String(value='test'))" + + kleene_plus = KleenePlus(string) + assert kleene_plus.term == string + assert repr(kleene_plus) == "KleenePlus(term=String(value='test'))" + + optional = Optional(string) + assert optional.term == string + assert repr(optional) == "Optional(term=String(value='test'))" + + alternatives = Alternatives([string, regex]) + assert alternatives.terms[0] == string + assert alternatives.terms[1] == regex + assert ( + repr(alternatives) + == "Alternatives(terms=[String(value='test'), Regex(pattern='[0-9]')])" + ) + + sequence = Sequence([string, regex]) + assert sequence.terms[0] == string + assert sequence.terms[1] == regex + assert ( + repr(sequence) + == "Sequence(terms=[String(value='test'), Regex(pattern='[0-9]')])" + ) + + exact = QuantifyExact(string, 3) + assert exact.term == string + assert exact.count == 3 + assert repr(exact) == "QuantifyExact(term=String(value='test'), count=3)" + + minimum = QuantifyMinimum(string, 3) + assert minimum.term == string + assert minimum.min_count == 3 + assert repr(minimum) == "QuantifyMinimum(term=String(value='test'), min_count=3)" + + maximum = QuantifyMaximum(string, 3) + assert maximum.term == string + assert maximum.max_count == 3 + assert repr(maximum) == "QuantifyMaximum(term=String(value='test'), max_count=3)" + + between = QuantifyBetween(string, 1, 3) + assert between.term == string + assert between.min_count == 1 + assert between.max_count == 3 + assert ( + repr(between) + == "QuantifyBetween(term=String(value='test'), min_count=1, max_count=3)" + ) + + with pytest.raises( + ValueError, match="`max_count` must be greater than `min_count`" + ): + QuantifyBetween(string, 3, 1) + + +def test_dsl_operations(): + a = String("a") + b = String("b") + assert isinstance(a + b, Sequence) + assert (a + b).terms[0] == a + assert (a + b).terms[1] == b + + assert isinstance(a | b, Alternatives) + assert (a | b).terms[0] == a + assert (a | b).terms[1] == b + + +def test_dsl_operations_string_conversion(): + b = String("b") + sequence = "a" + b + assert isinstance(sequence, Sequence) + assert isinstance(sequence.terms[0], String) + assert sequence.terms[0].value == "a" + assert sequence.terms[1].value == "b" + + sequence = b + "a" + assert isinstance(sequence, Sequence) + assert isinstance(sequence.terms[0], String) + assert sequence.terms[0].value == "b" + assert sequence.terms[1].value == "a" + + alternative = "a" | b + assert isinstance(alternative, Alternatives) + assert isinstance(alternative.terms[0], String) + assert alternative.terms[0].value == "a" + assert alternative.terms[1].value == "b" + + alternative = b | "a" + assert isinstance(alternative, Alternatives) + assert isinstance(alternative.terms[0], String) + assert alternative.terms[0].value == "b" + assert alternative.terms[1].value == "a" + + +def test_dsl_aliases(): + test = regex("[0-9]") + assert isinstance(test, Regex) + + test = json_schema('{"type": "string"}') + assert isinstance(test, JsonSchema) + + test = String("test") + + assert isinstance(test.times(3), QuantifyExact) + assert test.times(3).count == 3 + assert test.times(3).term == test + + assert isinstance(times(test, 3), QuantifyExact) + assert times(test, 3).count == 3 + assert times(test, 3).term == test + + assert isinstance(test.one_or_more(), KleenePlus) + assert test.one_or_more().term == test + + assert isinstance(one_or_more(test), KleenePlus) + assert one_or_more(test).term == test + + assert isinstance(test.zero_or_more(), KleeneStar) + assert test.zero_or_more().term == test + + assert isinstance(zero_or_more(test), KleeneStar) + assert zero_or_more(test).term == test + + assert isinstance(test.optional(), Optional) + assert test.optional().term == test + + assert isinstance(optional(test), Optional) + assert optional(test).term == test + + rep_min = test.repeat(2, None) + assert isinstance(rep_min, QuantifyMinimum) + assert rep_min.min_count == 2 + + rep_min = repeat(test, 2, None) + assert isinstance(rep_min, QuantifyMinimum) + assert rep_min.min_count == 2 + + rep_max = test.repeat(None, 2) + assert isinstance(rep_max, QuantifyMaximum) + assert rep_max.max_count == 2 + + rep_max = repeat(test, None, 2) + assert isinstance(rep_max, QuantifyMaximum) + assert rep_max.max_count == 2 + + rep_between = test.repeat(1, 2) + assert isinstance(rep_between, QuantifyBetween) + assert rep_between.min_count == 1 + assert rep_between.max_count == 2 + + rep_between = repeat(test, 1, 2) + assert isinstance(rep_between, QuantifyBetween) + assert rep_between.min_count == 1 + assert rep_between.max_count == 2 + + with pytest.raises(ValueError, match="QuantifyBetween: `max_count` must be"): + test.repeat(2, 1) + + with pytest.raises(ValueError, match="QuantifyBetween: `max_count` must be"): + repeat(test, 2, 1) + + with pytest.raises(ValueError, match="repeat: you must provide"): + test.repeat(None, None) + + with pytest.raises(ValueError, match="repeat: you must provide"): + repeat(test, None, None) + + +def test_dsl_term_pydantic_simple(): + a = String("a") + + class Model(BaseModel): + field: a + + schema = Model.model_json_schema() + assert schema == { + "properties": {"field": {"pattern": "a", "title": "Field", "type": "string"}}, + "required": ["field"], + "title": "Model", + "type": "object", + } + + +def test_dsl_term_pydantic_combination(): + a = String("a") + b = String("b") + c = String("c") + + class Model(BaseModel): + field: (a + b) | c + + schema = Model.model_json_schema() + assert schema == { + "properties": { + "field": {"pattern": "(ab|c)", "title": "Field", "type": "string"} + }, + "required": ["field"], + "title": "Model", + "type": "object", + } + + +def test_dsl_display(): + a = String("a") + b = String("b") + c = Regex("[0-9]") + d = KleeneStar(a | b) + c + + tree = str(d) + assert ( + tree + == "└── Sequence\n ├── KleeneStar(*)\n │ └── Alternatives(|)\n │ ├── String('a')\n │ └── String('b')\n └── Regex('[0-9]')\n" + ) diff --git a/tests/types/test_to_regex.py b/tests/types/test_to_regex.py new file mode 100644 index 000000000..0dec02b63 --- /dev/null +++ b/tests/types/test_to_regex.py @@ -0,0 +1,72 @@ +import pytest + + +from outlines.types.regex import to_regex +from outlines.types.dsl import String, Regex, JsonSchema, KleeneStar, KleenePlus, QuantifyBetween, QuantifyExact, QuantifyMaximum, QuantifyMinimum, Sequence, Alternatives, Optional, Term + + +def test_to_regex_simple(): + a = String("a") + assert to_regex(a) == "a" + assert a.matches("a") is True + + a = Regex("[0-9]") + assert to_regex(a) == "([0-9])" + assert a.matches(0) is True + assert a.matches(10) is False + assert a.matches("a") is False + + a = JsonSchema({"type": "integer"}) + assert to_regex(a) == r"((-)?(0|[1-9][0-9]*))" + assert a.matches(1) is True + assert a.matches("1") is True + assert a.matches("a") is False + + a = Optional(String("a")) + assert to_regex(a) == "(a)?" + assert a.matches("") is True + assert a.matches("a") is True + + a = KleeneStar(String("a")) + assert to_regex(a) == "(a)*" + assert a.matches("") is True + assert a.matches("a") is True + assert a.matches("aaaaa") is True + + a = KleenePlus(String("a")) + assert to_regex(a) == "(a)+" + assert a.matches("") is False + assert a.matches("a") is True + assert a.matches("aaaaa") is True + + a = QuantifyExact(String("a"), 2) + assert to_regex(a) == "(a){2}" + assert a.matches("a") is False + assert a.matches("aa") is True + assert a.matches("aaa") is False + + a = QuantifyMinimum(String("a"), 2) + assert to_regex(a) == "(a){2,}" + assert a.matches("a") is False + assert a.matches("aa") is True + assert a.matches("aaa") is True + + a = QuantifyMaximum(String("a"), 2) + assert to_regex(a) == "(a){,2}" + assert a.matches("aa") is True + assert a.matches("aaa") is False + + a = QuantifyBetween(String("a"), 1, 2) + assert to_regex(a) == "(a){1,2}" + assert a.matches("") is False + assert a.matches("a") is True + assert a.matches("aa") is True + assert a.matches("aaa") is False + + with pytest.raises(TypeError, match="Cannot convert"): + to_regex(Term()) + + +def test_to_regex_combinations(): + a = Sequence([Regex("dog|cat"), String("fish")]) + assert to_regex(a) == "(dog|cat)fish"