Skip to content

Commit

Permalink
Add regex DSL and re-organize custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 17, 2025
1 parent 774fe56 commit 6c18c6e
Show file tree
Hide file tree
Showing 23 changed files with 1,158 additions and 228 deletions.
1 change: 0 additions & 1 deletion docs/reference/functions.md

This file was deleted.

229 changes: 229 additions & 0 deletions docs/reference/regex_dsl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# DSL to express constraints

This library provides a Domain-Specific Language (DSL) to construct regular expressions in a more intuitive and modular way. It allows you to create complex regexes using simple building blocks that represent literal strings, patterns, and various quantifiers. Additionally, these custom regex types can be used directly as types in [Pydantic](https://pydantic-docs.helpmanual.io/) schemas to enforce pattern constraints on your data.

---

## Why Use This DSL?

1. **Modularity & Readability**: Instead of writing cryptic regular expression strings, you compose a regex as a tree of objects.
2. **Enhanced Debugging**: Each expression can be visualized as an ASCII tree, making it easier to understand and debug complex regexes.
3. **Pydantic Integration**: Use your DSL-defined regex as types in Pydantic models. The DSL seamlessly converts to JSON Schema with proper pattern constraints.
4. **Extensibility**: Easily add or modify quantifiers and other regex components by extending the provided classes.

---

## Building Blocks


Every regex component in this DSL is a **Term**. Here are two primary types:

- **`String`**: Represents a literal string.
- **`Regex`**: Represents an existing regex pattern string.

```python
from outlines.types import String, Regex

# A literal string "hello"
literal = String("hello") # Internally represents "hello"

# A regex pattern to match one or more digits
digit = Regex(r"[0-9]+") # Internally represents the pattern [0-9]+

# Converting to standard regex strings:
from outlines.types.regex import to_regex

print(to_regex(literal)) # Output: hello
print(to_regex(digit)) # Output: [0-9]+
```

---

## Early Introduction to Quantifiers & Operators

The DSL supports common regex quantifiers as methods on every `Term`. These methods allow you to specify how many times a pattern should be matched. They include:

- **`times(count)`**: Matches the term exactly `count` times.
- **`optional()`**: Matches the term zero or one time.
- **`one_or_more()`**: Matches the term one or more times (Kleene Plus).
- **`zero_or_more()`**: Matches the term zero or more times (Kleene Star).
- **`repeat(min_count, max_count)`**: Matches the term between `min_count` and `max_count` times (or open-ended if one value is omitted).

Let’s see these quantifiers side by side with examples.

### Quantifiers in Action

#### `times(count)`

This method restricts the term to appear exactly `count` times.

```python
# Example: exactly 5 digits
five_digits = Regex(r"\d").times(5)
print(to_regex(five_digits)) # Output: (\d){5}
```

#### `optional()`

The `optional()` method makes a term optional, meaning it may occur zero or one time.

```python
# Example: an optional "s" at the end of a word
maybe_s = String("s").optional()
print(to_regex(maybe_s)) # Output: (s)?
```

#### `one_or_more()`

This method indicates that the term must appear at least once.

```python
# Example: one or more alphabetic characters
letters = Regex(r"[A-Za-z]").one_or_more()
print(to_regex(letters)) # Output: ([A-Za-z])+
```

#### `zero_or_more()`

This method means that the term can occur zero or more times.

```python
# Example: zero or more spaces
spaces = String(" ").zero_or_more()
print(to_regex(spaces)) # Output: ( )*
```

#### `repeat(min_count, max_count)`

The `repeat` method provides flexibility to set a lower and/or upper bound on the number of occurrences.

```python
# Example: Between 2 and 4 word characters
word_chars = Regex(r"\w").repeat(2, 4)
print(to_regex(word_chars)) # Output: (\w){2,4}

# Example: At least 3 digits (min specified, max left open)
at_least_three = Regex(r"\d").repeat(3, None)
print(to_regex(at_least_three)) # Output: (\d){3,}

# Example: Up to 2 punctuation marks (max specified, min omitted)
up_to_two = Regex(r"[,.]").repeat(None, 2)
print(to_regex(up_to_two)) # Output: ([,.]){,2}
```

---

## Combining Terms

The DSL allows you to combine basic terms into more complex patterns using concatenation and alternation.

### Concatenation (`+`)

The `+` operator (and its reflected variant) concatenates terms, meaning that the terms are matched in sequence.

```python
# Example: Match "hello world"
pattern = String("hello") + " " + String("world")
print(to_regex(pattern)) # Output: hello\ world
```

### Alternation (`|`)

The `|` operator creates alternatives, allowing a match for one of several patterns.

```python
# Example: Match either "cat" or "dog"
animal = String("cat") | "dog"
print(to_regex(animal)) # Output: (cat|dog)
```

*Note:* When using operators with plain strings (such as `"dog"`), the DSL automatically wraps them in a `String` object.

---

## Practical Examples

### Example 1: Matching a Custom ID Format

Suppose you want to create a regex that matches an ID format like "ID-12345", where:
- The literal "ID-" must be at the start.
- Followed by exactly 5 digits.

```python
id_pattern = "ID-" + Regex(r"\d").times(5)
print(to_regex(id_pattern)) # Output: ID-(\d){5}
```

### Example 2: Email Validation with Pydantic

You can define a regex for email validation and use it as a type in a Pydantic model.

```python
from pydantic import BaseModel, ValidationError

# Define an email regex term (this is a simplified version)
email_regex = Regex(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+")

class User(BaseModel):
name: str
email: email_regex # Use our DSL regex as a field type

# Valid input
user = User(name="Alice", email="[email protected]")
print(user)

# Invalid input (raises a ValidationError)
try:
User(name="Bob", email="not-an-email")
except ValidationError as e:
print(e)
```

When used in a Pydantic model, the email field is automatically validated against the regex pattern and its JSON Schema includes the `pattern` constraint.

### Example 3: Building a Complex Pattern

Consider a pattern to match a simple date format: `YYYY-MM-DD`.

```python
year = Regex(r"\d").times(4) # Four digits for the year
month = Regex(r"\d").times(2) # Two digits for the month
day = Regex(r"\d").times(2) # Two digits for the day

# Combine with literal hyphens
date_pattern = year + "-" + month + "-" + day
print(to_regex(date_pattern))
# Output: (\d){4}\-(\d){2}\-(\d){2}
```

---

## Visualizing Your Regex

One of the unique features of this DSL is that each term can print its underlying structure as an ASCII tree. This visualization can be particularly helpful when dealing with complex expressions.

```python
# A composite pattern using concatenation and quantifiers
pattern = "a" + String("b").one_or_more() + "c"
print(pattern)
```

*Expected Output:*

```
└── Sequence
├── String('a')
├── KleenePlus(+)
│ └── String('b')
└── String('c')
```

This tree representation makes it easy to see the hierarchy and order of operations in your regular expression.

---

## Final Words

This DSL is designed to simplify the creation and management of regular expressions—whether you're validating inputs in a web API, constraining the output of an LLM, or just experimenting with regex patterns. With intuitive methods for common quantifiers and operators, clear visual feedback, and built-in integration with Pydantic, you can build robust and maintainable regex-based validations with ease.

Feel free to explore the library further and adapt the examples to your use cases. Happy regexing!
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ nav:
- Generation:
- Overview: reference/generation/generation.md
- Chat templating: reference/chat_templating.md
- Regex DSL: reference/regex_dsl.md
- Text: reference/text.md
- Samplers: reference/samplers.md
- Structured generation:
Expand Down
40 changes: 16 additions & 24 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,25 @@
from enum import EnumMeta
from typing import Any, Protocol, Tuple, Type

from typing_extensions import _AnnotatedAlias, get_args

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?"
DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])"
TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])"
DATETIME = rf"({DATE})(\s)({TIME})"
from outlines.types import Regex, boolean, date
from outlines.types import datetime as datetime_type
from outlines.types import integer, number, time


class FormatFunction(Protocol):
def __call__(self, sequence: str) -> Any:
...


def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
def python_types_to_regex(python_type: Type) -> Tuple[Regex, FormatFunction]:
# If it is a custom type
if isinstance(python_type, _AnnotatedAlias):
json_schema = get_args(python_type)[1].json_schema
type_class = get_args(python_type)[0]

custom_regex_str = json_schema["pattern"]
if isinstance(python_type, Regex):
custom_regex_str = python_type.pattern

def custom_format_fn(sequence: str) -> Any:
return type_class(sequence)
def custom_format_fn(sequence: str) -> str:
return str(sequence)

return custom_regex_str, custom_format_fn
return Regex(custom_regex_str), custom_format_fn

if isinstance(python_type, EnumMeta):
values = python_type.__members__.keys()
Expand All @@ -37,44 +29,44 @@ def custom_format_fn(sequence: str) -> Any:
def enum_format_fn(sequence: str) -> str:
return str(sequence)

return enum_regex_str, enum_format_fn
return Regex(enum_regex_str), enum_format_fn

if python_type is float:

def float_format_fn(sequence: str) -> float:
return float(sequence)

return FLOAT, float_format_fn
return number, float_format_fn
elif python_type is int:

def int_format_fn(sequence: str) -> int:
return int(sequence)

return INTEGER, int_format_fn
return integer, int_format_fn
elif python_type is bool:

def bool_format_fn(sequence: str) -> bool:
return bool(sequence)

return BOOLEAN, bool_format_fn
return boolean, bool_format_fn
elif python_type == datetime.date:

def date_format_fn(sequence: str) -> datetime.date:
return datetime.datetime.strptime(sequence, "%Y-%m-%d").date()

return DATE, date_format_fn
return date, date_format_fn
elif python_type == datetime.time:

def time_format_fn(sequence: str) -> datetime.time:
return datetime.datetime.strptime(sequence, "%H:%M:%S").time()

return TIME, time_format_fn
return time, time_format_fn
elif python_type == datetime.datetime:

def datetime_format_fn(sequence: str) -> datetime.datetime:
return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")

return DATETIME, datetime_format_fn
return datetime_type, datetime_format_fn
else:
raise NotImplementedError(
f"The Python type {python_type} is not supported. Please open an issue."
Expand Down
1 change: 1 addition & 0 deletions outlines/generate/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def format(
"""
regex_str, format_fn = python_types_to_regex(python_type)
regex_str = regex_str.pattern
generator = regex(model, regex_str, sampler)
generator.format_sequence = format_fn

Expand Down
11 changes: 9 additions & 2 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
)
from outlines.models import OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial
from outlines.types import Regex


@singledispatch
def regex(model, regex_str: str, sampler: Sampler = multinomial()):
def regex(model, regex_str: str | Regex, sampler: Sampler = multinomial()):
"""Generate structured text in the language of a regular expression.
Parameters
Expand All @@ -31,18 +32,24 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
"""
from outlines.processors import RegexLogitsProcessor

if isinstance(regex_str, Regex):
regex_str = regex_str.pattern

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersVision)
def regex_vision(
model,
regex_str: str,
regex_str: str | Regex,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

if isinstance(regex_str, Regex):
regex_str = regex_str.pattern

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)

Expand Down
Loading

0 comments on commit 6c18c6e

Please sign in to comment.