From d5cac52cef58cfda75e10ab08bf4fc5296225418 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 1 Oct 2024 11:15:28 -0500 Subject: [PATCH] Refactor RegexGuide constructor interface --- python/outlines_core/fsm/fsm.py | 2 +- python/outlines_core/fsm/guide.py | 120 ++++++++++++++++-------------- 2 files changed, 67 insertions(+), 55 deletions(-) diff --git a/python/outlines_core/fsm/fsm.py b/python/outlines_core/fsm/fsm.py index 4daf3c86..6e68085d 100644 --- a/python/outlines_core/fsm/fsm.py +++ b/python/outlines_core/fsm/fsm.py @@ -37,7 +37,7 @@ def __init__(self, regex_string: str, tokenizer): "The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead." ) ) - super().__init__(regex_string, tokenizer) + self.from_regex(regex_string, tokenizer) def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: next_instruction = self.get_next_instruction(state) diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py index b7a1378c..966a18fb 100644 --- a/python/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -121,31 +121,31 @@ def create_states_mapping( tokenizer: "Tokenizer", regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, frozen_tokens: List[str] = [], -) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]: +) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: """Create the variables related to the mapping between states and tokens The parameters of the function are used for caching purpose. Parameters ---------- - regex_string: (`str`): + regex_string: The regular expression string to generate a states mapping for. - tokenizer: (`Tokenizer`): + tokenizer: The model's tokenizer. - regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*): + regex_parser: A function that parses a regex string into an `interegular` Pattern object. - frozen_tokens: (`List[str]`, *optional*): + frozen_tokens: A list of tokens that should be kept as-is when expanding the token-level FSM into a byte-level FSM. Defaults to an empty list. Returns ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): + states_to_token_maps: A mapping from states to a mapping from token ids originating from that state to the next state to transition to given that token. The structure is as follows: (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): + empty_token_ids: A set of token ids that correspond to empty strings. - final_states: (`set`): + final_states: A set of final states in the FSM. """ regex_pattern = regex_parser(regex_string) @@ -170,21 +170,72 @@ def create_states_mapping( return states_to_token_maps, empty_token_ids, regex_fsm.finals +def create_states_mapping_from_interegular_fsm( + fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" +) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: + byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids, regex_fsm.finals + + class RegexGuide(Guide): """Guide to generate text in the language of a regular expression.""" initial_state = 0 - def __init__(self, regex_string: str, tokenizer: "Tokenizer"): - ( - self.states_to_token_maps, - self.empty_token_ids, - fsm_finals, - ) = create_states_mapping(regex_string, tokenizer) - self.eos_token_id = tokenizer.eos_token_id + def __init__(self, states_to_token_maps, empty_token_ids, fsm_finals, eos_token_id): + self.states_to_token_maps = states_to_token_maps + self.empty_token_ids = empty_token_ids + self.eos_token_id = eos_token_id self.final_states = fsm_finals | {-1} self._cache_state_to_token_tensor() + @classmethod + def from_regex( + cls, + regex_string: str, + tokenizer: "Tokenizer", + _create_states_mapping=create_states_mapping, + ): + ( + states_to_token_maps, + empty_token_ids, + fsm_finals, + ) = _create_states_mapping(regex_string, tokenizer) + return cls( + states_to_token_maps, empty_token_ids, fsm_finals, tokenizer.eos_token_id + ) + + @classmethod + def from_interegular_fsm( + cls, + interegular_fsm: interegular.fsm.FSM, + tokenizer: "Tokenizer", + _create_states_mapping_from_interegular_fsm=create_states_mapping_from_interegular_fsm, + ): + ( + states_to_token_maps, + empty_token_ids, + fsm_finals, + ) = _create_states_mapping_from_interegular_fsm(interegular_fsm) + return cls( + states_to_token_maps, empty_token_ids, fsm_finals, tokenizer.eos_token_id + ) + def get_next_instruction(self, state: int) -> Instruction: """Return the next instruction for guided generation. @@ -242,45 +293,6 @@ def get_next_state(self, state: int, token_id: int) -> int: return next_state - @classmethod - def from_interegular_fsm( - cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" - ): - from_interegular_instance = cls.__new__(cls) - - def create_states_mapping_from_interegular_fsm( - fsm: interegular.fsm.FSM, - ) -> Tuple[dict, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids - - ( - from_interegular_instance.states_to_token_maps, - from_interegular_instance.empty_token_ids, - ) = create_states_mapping_from_interegular_fsm(interegular_fsm) - from_interegular_instance.eos_token_id = tokenizer.eos_token_id - from_interegular_instance._cache_state_to_token_tensor() - return from_interegular_instance - def _cache_state_to_token_tensor(self): """ cache state -> token int tensor