Skip to content

Commit

Permalink
Refactor RegexGuide constructor interface
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 1, 2024
1 parent 6ef8820 commit 97a8a4c
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 77 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ def time_json_schema_to_regex(self, schema_name):

def time_json_schema_to_fsm(self, schema_name):
regex = build_regex_from_schema(self.schema)
RegexGuide(regex, self.tokenizer)
RegexGuide.from_regex(regex, self.tokenizer)
4 changes: 2 additions & 2 deletions benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def setup(self, pattern_name):
self.pattern = regex_samples[pattern_name]

def time_regex_to_guide(self, pattern_name):
RegexGuide(self.pattern, self.tokenizer)
RegexGuide.from_regex(self.pattern, self.tokenizer)


class MemoryRegexGuideBenchmark:
Expand All @@ -34,4 +34,4 @@ def setup(self, pattern_name):
self.pattern = regex_samples[pattern_name]

def peakmem_regex_to_guide(self, pattern_name):
RegexGuide(self.pattern, self.tokenizer)
RegexGuide.from_regex(self.pattern, self.tokenizer)
185 changes: 117 additions & 68 deletions python/outlines_core/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,40 +121,74 @@ def create_states_mapping(
tokenizer: "Tokenizer",
regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern,
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]:
"""Create the variables related to the mapping between states and tokens
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
"""Create the variables related to the mapping between states and tokens from a regex string.
The parameters of the function are used for caching purpose.
Parameters
----------
regex_string: (`str`):
regex_string:
The regular expression string to generate a states mapping for.
tokenizer: (`Tokenizer`):
tokenizer:
The model's tokenizer.
regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*):
regex_parser:
A function that parses a regex string into an `interegular` Pattern object.
frozen_tokens: (`List[str]`, *optional*):
frozen_tokens:
A list of tokens that should be kept as-is when expanding the token-level FSM
into a byte-level FSM. Defaults to an empty list.
Returns
-------
states_to_token_maps:
A mapping from states to a mapping from token ids originating from that state
to the next state to transition to given that token. The structure is as follows:
(origin_state -> (token_id -> next_state))
empty_token_ids:
A set of token ids that correspond to empty strings.
final_states:
A set of final states in the FSM.
"""
regex_fsm = regex_parser(regex_string).to_fsm()
return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens)


def create_states_mapping_from_fsm(
fsm: interegular.fsm.FSM,
tokenizer: "Tokenizer",
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
"""Create the variables related to the mapping between states and tokens from an FSM.
The parameters of the function are used for caching purpose.
Parameters
----------
fsm:
An FSM for the regular expression.
tokenizer:
The model's tokenizer.
frozen_tokens:
A list of tokens that should be kept as-is when expanding the token-level FSM
into a byte-level FSM. Defaults to an empty list.
Returns
-------
states_to_token_maps: (`Dict[int, Dict[int, int]]`):
states_to_token_maps:
A mapping from states to a mapping from token ids originating from that state
to the next state to transition to given that token. The structure is as follows:
(origin_state -> (token_id -> next_state))
empty_token_ids: (`Set[int]`):
empty_token_ids:
A set of token ids that correspond to empty strings.
final_states: (`set`):
final_states:
A set of final states in the FSM.
"""
regex_pattern = regex_parser(regex_string)
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True, frozen_tokens=frozen_tokens
fsm.reduce(), keep_utf8=True, frozen_tokens=frozen_tokens
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer, frozen_tokens=frozen_tokens
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
Expand All @@ -175,15 +209,79 @@ class RegexGuide(Guide):

initial_state = 0

def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
def __init__(
self,
states_to_token_maps,
empty_token_ids,
fsm_finals,
eos_token_id,
states_to_token_mask,
):
self.states_to_token_maps = states_to_token_maps
self.empty_token_ids = empty_token_ids
self.eos_token_id = eos_token_id
self.final_states = fsm_finals | {-1}
self.states_to_token_mask = states_to_token_mask

@classmethod
def from_regex(
cls,
regex_string: str,
tokenizer: "Tokenizer",
_create_states_mapping=create_states_mapping,
device=None,
regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern,
frozen_tokens: List[str] = [],
):
(
states_to_token_maps,
empty_token_ids,
fsm_finals,
) = _create_states_mapping(
regex_string,
tokenizer,
regex_parser=regex_parser,
frozen_tokens=frozen_tokens,
)
states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device)
for state, next_tokens_to_end_states in states_to_token_maps.items()
}
return cls(
states_to_token_maps,
empty_token_ids,
fsm_finals,
tokenizer.eos_token_id,
states_to_token_mask,
)

@classmethod
def from_interegular_fsm(
cls,
interegular_fsm: interegular.fsm.FSM,
tokenizer: "Tokenizer",
_create_states_mapping_from_fsm=create_states_mapping_from_fsm,
device=None,
frozen_tokens: List[str] = [],
):
(
self.states_to_token_maps,
self.empty_token_ids,
states_to_token_maps,
empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}
self._cache_state_to_token_tensor()
) = _create_states_mapping_from_fsm(
interegular_fsm, tokenizer, frozen_tokens=frozen_tokens
)
states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device)
for state, next_tokens_to_end_states in states_to_token_maps.items()
}
return cls(
states_to_token_maps,
empty_token_ids,
fsm_finals,
tokenizer.eos_token_id,
states_to_token_mask,
)

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand Down Expand Up @@ -242,55 +340,6 @@ def get_next_state(self, state: int, token_id: int) -> int:

return next_state

@classmethod
def from_interegular_fsm(
cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer"
):
from_interegular_instance = cls.__new__(cls)

def create_states_mapping_from_interegular_fsm(
fsm: interegular.fsm.FSM,
) -> Tuple[dict, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids

(
from_interegular_instance.states_to_token_maps,
from_interegular_instance.empty_token_ids,
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
from_interegular_instance._cache_state_to_token_tensor()
return from_interegular_instance

def _cache_state_to_token_tensor(self):
"""
cache state -> token int tensor
this increases performance of mask construction substantially
"""
self.states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()))
for state, next_tokens_to_end_states in self.states_to_token_maps.items()
}

def is_final_state(self, state: int) -> bool:
"""Determine whether the current state of the guide is a final state."""
return state in self.final_states
Expand Down
43 changes: 37 additions & 6 deletions tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import interegular
import pytest
from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write

Expand Down Expand Up @@ -42,10 +43,10 @@ def convert_token_to_string(self, token):
regex_str = "[1-9]"

with pytest.raises(ValueError, match="The vocabulary"):
RegexGuide(regex_str, MockTokenizer())
RegexGuide.from_regex(regex_str, MockTokenizer())


def test_regex():
def test_from_regex():
class MockTokenizer:
vocabulary = {"1": 1, "a": 2, "eos": 3}
special_tokens = {"eos"}
Expand All @@ -56,7 +57,37 @@ def convert_token_to_string(self, token):

regex_str = "[1-9]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

assert fsm.states_to_token_maps == {0: {1: 1}}

instruction = fsm.get_next_instruction(0)
assert isinstance(instruction, Generate)
assert_expected_tensor_ids(instruction.tokens, [1])

assert fsm.get_next_state(state=0, token_id=1) == 1
assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_from_fsm():
class MockTokenizer:
vocabulary = {"1": 1, "a": 2, "eos": 3}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
return token

regex_str = "[1-9]"
tokenizer = MockTokenizer()
fsm = RegexGuide.from_interegular_fsm(
interegular.parse_pattern(regex_str).to_fsm(), tokenizer
)

assert fsm.states_to_token_maps == {0: {1: 1}}

Expand Down Expand Up @@ -97,7 +128,7 @@ def convert_token_to_string(self, token):

regex_str = "[😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 4: 2},
Expand Down Expand Up @@ -144,7 +175,7 @@ def convert_token_to_string(self, token):

regex_str = " [😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 10: 2},
Expand Down Expand Up @@ -179,7 +210,7 @@ def convert_token_to_string(self, token):

regex_str = r"`\n(\.\n)?`\n"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)
fsm = RegexGuide.from_regex(regex_str, tokenizer)

state = fsm.get_next_state(state=4, token_id=103)
assert state == 5
Expand Down

0 comments on commit 97a8a4c

Please sign in to comment.