Skip to content

Commit

Permalink
Add regex DSL and re-organize custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 7, 2025
1 parent 69418da commit b41b6e3
Show file tree
Hide file tree
Showing 20 changed files with 875 additions and 232 deletions.
44 changes: 18 additions & 26 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,25 @@
from enum import EnumMeta
from typing import Any, Protocol, Tuple, Type

from typing_extensions import _AnnotatedAlias, get_args

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?"
DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])"
TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])"
DATETIME = rf"({DATE})(\s)({TIME})"
from outlines.types import Regex, boolean, date
from outlines.types import datetime as datetime_type
from outlines.types import integer, number, time


class FormatFunction(Protocol):
def __call__(self, sequence: str) -> Any:
...


def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
def python_types_to_regex(python_type: Type) -> Tuple[Regex, FormatFunction]:
# If it is a custom type
if isinstance(python_type, _AnnotatedAlias):
json_schema = get_args(python_type)[1].json_schema
type_class = get_args(python_type)[0]

custom_regex_str = json_schema["pattern"]
if isinstance(python_type, Regex):
custom_regex_str = python_type.pattern

def custom_format_fn(sequence: str) -> Any:
return type_class(sequence)
def custom_format_fn(sequence: str) -> str:
return str(sequence)

return custom_regex_str, custom_format_fn
return Regex(custom_regex_str), custom_format_fn

if isinstance(python_type, EnumMeta):
values = python_type.__members__.keys()
Expand All @@ -37,44 +29,44 @@ def custom_format_fn(sequence: str) -> Any:
def enum_format_fn(sequence: str) -> str:
return str(sequence)

return enum_regex_str, enum_format_fn
return Regex(enum_regex_str), enum_format_fn

if python_type is float:

def float_format_fn(sequence: str) -> float:
return float(sequence)

return FLOAT, float_format_fn
elif python_type is int:
return number, float_format_fn
elif python_type == int:

def int_format_fn(sequence: str) -> int:
return int(sequence)

return INTEGER, int_format_fn
elif python_type is bool:
return integer, int_format_fn
elif python_type == bool:

def bool_format_fn(sequence: str) -> bool:
return bool(sequence)

return BOOLEAN, bool_format_fn
return boolean, bool_format_fn
elif python_type == datetime.date:

def date_format_fn(sequence: str) -> datetime.date:
return datetime.datetime.strptime(sequence, "%Y-%m-%d").date()

return DATE, date_format_fn
return date, date_format_fn
elif python_type == datetime.time:

def time_format_fn(sequence: str) -> datetime.time:
return datetime.datetime.strptime(sequence, "%H:%M:%S").time()

return TIME, time_format_fn
return time, time_format_fn
elif python_type == datetime.datetime:

def datetime_format_fn(sequence: str) -> datetime.datetime:
return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")

return DATETIME, datetime_format_fn
return datetime_type, datetime_format_fn
else:
raise NotImplementedError(
f"The Python type {python_type} is not supported. Please open an issue."
Expand Down
1 change: 1 addition & 0 deletions outlines/generate/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def format(
"""
regex_str, format_fn = python_types_to_regex(python_type)
regex_str = regex_str.pattern
generator = regex(model, regex_str, sampler)
generator.format_sequence = format_fn

Expand Down
11 changes: 9 additions & 2 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
)
from outlines.models import OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial
from outlines.types import Regex


@singledispatch
def regex(model, regex_str: str, sampler: Sampler = multinomial()):
def regex(model, regex_str: str | Regex, sampler: Sampler = multinomial()):
"""Generate structured text in the language of a regular expression.
Parameters
Expand All @@ -31,18 +32,24 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
"""
from outlines.processors import RegexLogitsProcessor

if isinstance(regex_str, Regex):
regex_str = regex_str.pattern

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersVision)
def regex_vision(
model,
regex_str: str,
regex_str: str | Regex,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

if isinstance(regex_str, Regex):
regex_str = regex_str.pattern

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)

Expand Down
35 changes: 31 additions & 4 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
from . import airports, countries
from .email import Email
from .isbn import ISBN
from .locales import locale
from enum import Enum

from . import airports, countries, locale
from .dsl import Regex, json_schema, one_or_more, optional, regex, repeat, zero_or_more

# Python types
integer = Regex(r"[+-]?(0|[1-9][0-9]*)")
boolean = Regex("(True|False)")
number = Regex(rf"{integer}(\.[0-9]+)?([eE][+-][0-9]+)?")
date = Regex(r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])")
time = Regex(r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])")
datetime = Regex(rf"({date})(\s)({time})")

# Basic regex types
digit = Regex(r"\d")
char = Regex(r"\w")
newline = Regex(r"(\r\n|\r|\n)") # Matched new lines on Linux, Windows & MacOS
whitespace = Regex(r"\s")

# Document-specific types
sentence = Regex(r"\\s+[A-Za-z,;'\"\\s]+[.?!]")
paragraph = Regex(r"r'(?s)((?:[^\n][\n]?)+)'")


# Custom types
email = Regex(
r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])"""
)
isbn = Regex(
r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]"
)
2 changes: 0 additions & 2 deletions outlines/types/airports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,4 @@
AIRPORT_IATA_LIST = [
(v["iata"], v["iata"]) for v in airportsdata.load().values() if v["iata"]
]


IATA = Enum("Airport", AIRPORT_IATA_LIST) # type:ignore
10 changes: 5 additions & 5 deletions outlines/types/countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import pycountry

ALPHA_2_CODE = [(country.alpha_2, country.alpha_2) for country in pycountry.countries]
Alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore
alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore

ALPHA_3_CODE = [(country.alpha_3, country.alpha_3) for country in pycountry.countries]
Alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore
alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore

NUMERIC_CODE = [(country.numeric, country.numeric) for country in pycountry.countries]
Numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore
numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore

NAME = [(country.name, country.name) for country in pycountry.countries]
Name = Enum("Name", NAME) # type:ignore
name = Enum("Name", NAME) # type:ignore

FLAG = [(country.flag, country.flag) for country in pycountry.countries]
Flag = Enum("Flag", FLAG) # type:ignore
flag = Enum("Flag", FLAG) # type:ignore
Loading

0 comments on commit b41b6e3

Please sign in to comment.