Skip to content

Commit

Permalink
Refactor RegexGuide constructor interface
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 1, 2024
1 parent bcf655a commit d5cac52
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 55 deletions.
2 changes: 1 addition & 1 deletion python/outlines_core/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, regex_string: str, tokenizer):
"The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead."
)
)
super().__init__(regex_string, tokenizer)
self.from_regex(regex_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
next_instruction = self.get_next_instruction(state)
Expand Down
120 changes: 66 additions & 54 deletions python/outlines_core/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,31 +121,31 @@ def create_states_mapping(
tokenizer: "Tokenizer",
regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern,
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]:
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose.
Parameters
----------
regex_string: (`str`):
regex_string:
The regular expression string to generate a states mapping for.
tokenizer: (`Tokenizer`):
tokenizer:
The model's tokenizer.
regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*):
regex_parser:
A function that parses a regex string into an `interegular` Pattern object.
frozen_tokens: (`List[str]`, *optional*):
frozen_tokens:
A list of tokens that should be kept as-is when expanding the token-level FSM
into a byte-level FSM. Defaults to an empty list.
Returns
-------
states_to_token_maps: (`Dict[int, Dict[int, int]]`):
states_to_token_maps:
A mapping from states to a mapping from token ids originating from that state
to the next state to transition to given that token. The structure is as follows:
(origin_state -> (token_id -> next_state))
empty_token_ids: (`Set[int]`):
empty_token_ids:
A set of token ids that correspond to empty strings.
final_states: (`set`):
final_states:
A set of final states in the FSM.
"""
regex_pattern = regex_parser(regex_string)
Expand All @@ -170,21 +170,72 @@ def create_states_mapping(
return states_to_token_maps, empty_token_ids, regex_fsm.finals


def create_states_mapping_from_interegular_fsm(
fsm: interegular.fsm.FSM, tokenizer: "Tokenizer"
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals


class RegexGuide(Guide):
"""Guide to generate text in the language of a regular expression."""

initial_state = 0

def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
def __init__(self, states_to_token_maps, empty_token_ids, fsm_finals, eos_token_id):
self.states_to_token_maps = states_to_token_maps
self.empty_token_ids = empty_token_ids
self.eos_token_id = eos_token_id
self.final_states = fsm_finals | {-1}
self._cache_state_to_token_tensor()

@classmethod
def from_regex(
cls,
regex_string: str,
tokenizer: "Tokenizer",
_create_states_mapping=create_states_mapping,
):
(
states_to_token_maps,
empty_token_ids,
fsm_finals,
) = _create_states_mapping(regex_string, tokenizer)
return cls(
states_to_token_maps, empty_token_ids, fsm_finals, tokenizer.eos_token_id
)

@classmethod
def from_interegular_fsm(
cls,
interegular_fsm: interegular.fsm.FSM,
tokenizer: "Tokenizer",
_create_states_mapping_from_interegular_fsm=create_states_mapping_from_interegular_fsm,
):
(
states_to_token_maps,
empty_token_ids,
fsm_finals,
) = _create_states_mapping_from_interegular_fsm(interegular_fsm)
return cls(
states_to_token_maps, empty_token_ids, fsm_finals, tokenizer.eos_token_id
)

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand Down Expand Up @@ -242,45 +293,6 @@ def get_next_state(self, state: int, token_id: int) -> int:

return next_state

@classmethod
def from_interegular_fsm(
cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer"
):
from_interegular_instance = cls.__new__(cls)

def create_states_mapping_from_interegular_fsm(
fsm: interegular.fsm.FSM,
) -> Tuple[dict, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids

(
from_interegular_instance.states_to_token_maps,
from_interegular_instance.empty_token_ids,
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
from_interegular_instance._cache_state_to_token_tensor()
return from_interegular_instance

def _cache_state_to_token_tensor(self):
"""
cache state -> token int tensor
Expand Down

0 comments on commit d5cac52

Please sign in to comment.