Skip to content

Commit

Permalink
Add regex DSL and re-organize custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 5, 2025
1 parent 9b586db commit fde47f5
Show file tree
Hide file tree
Showing 17 changed files with 781 additions and 197 deletions.
13 changes: 5 additions & 8 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import EnumMeta
from typing import Any, Protocol, Tuple, Type

from typing_extensions import _AnnotatedAlias, get_args
from outlines.types.dsl import Regex

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
Expand All @@ -19,14 +19,11 @@ def __call__(self, sequence: str) -> Any:

def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
# If it is a custom type
if isinstance(python_type, _AnnotatedAlias):
json_schema = get_args(python_type)[1].json_schema
type_class = get_args(python_type)[0]
if isinstance(python_type, Regex):
custom_regex_str = python_type.pattern

custom_regex_str = json_schema["pattern"]

def custom_format_fn(sequence: str) -> Any:
return type_class(sequence)
def custom_format_fn(sequence: str) -> str:
return str(sequence)

return custom_regex_str, custom_format_fn

Expand Down
11 changes: 9 additions & 2 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
)
from outlines.models import OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial
from outlines.types import Regex


@singledispatch
def regex(model, regex_str: str, sampler: Sampler = multinomial()):
def regex(model, regex_str: str | Regex, sampler: Sampler = multinomial()):
"""Generate structured text in the language of a regular expression.
Parameters
Expand All @@ -31,18 +32,24 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
"""
from outlines.processors import RegexLogitsProcessor

if isinstance(regex_str, Regex):
regex_str = regex_str.pattern

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersVision)
def regex_vision(
model,
regex_str: str,
regex_str: str | Regex,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

if isinstance(regex_str, Regex):
regex_str = regex_str.pattern

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)

Expand Down
16 changes: 12 additions & 4 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from . import airports, countries
from .email import Email
from .isbn import ISBN
from .locales import locale
from enum import Enum

from . import airports, countries, locale
from .dsl import Regex

email = Regex(
r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])"""
)

isbn = Regex(
r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]"
)
2 changes: 0 additions & 2 deletions outlines/types/airports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,4 @@
AIRPORT_IATA_LIST = [
(v["iata"], v["iata"]) for v in airportsdata.load().values() if v["iata"]
]


IATA = Enum("Airport", AIRPORT_IATA_LIST) # type:ignore
10 changes: 5 additions & 5 deletions outlines/types/countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import pycountry

ALPHA_2_CODE = [(country.alpha_2, country.alpha_2) for country in pycountry.countries]
Alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore
alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore

ALPHA_3_CODE = [(country.alpha_3, country.alpha_3) for country in pycountry.countries]
Alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore
alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore

NUMERIC_CODE = [(country.numeric, country.numeric) for country in pycountry.countries]
Numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore
numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore

NAME = [(country.name, country.name) for country in pycountry.countries]
Name = Enum("Name", NAME) # type:ignore
name = Enum("Name", NAME) # type:ignore

FLAG = [(country.flag, country.flag) for country in pycountry.countries]
Flag = Enum("Flag", FLAG) # type:ignore
flag = Enum("Flag", FLAG) # type:ignore
Loading

0 comments on commit fde47f5

Please sign in to comment.