diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 47578cd3..964caa35 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -74,4 +74,4 @@ def time_json_schema_to_regex(self, schema_name): def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) - RegexGuide(regex, self.tokenizer) + RegexGuide.from_regex(regex, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 287d5f51..a47adae2 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -23,7 +23,7 @@ def setup(self, pattern_name): self.pattern = regex_samples[pattern_name] def time_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) + RegexGuide.from_regex(self.pattern, self.tokenizer) class MemoryRegexGuideBenchmark: @@ -34,4 +34,4 @@ def setup(self, pattern_name): self.pattern = regex_samples[pattern_name] def peakmem_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) + RegexGuide.from_regex(self.pattern, self.tokenizer) diff --git a/python/outlines_core/fsm/fsm.py b/python/outlines_core/fsm/fsm.py index 4daf3c86..6e68085d 100644 --- a/python/outlines_core/fsm/fsm.py +++ b/python/outlines_core/fsm/fsm.py @@ -37,7 +37,7 @@ def __init__(self, regex_string: str, tokenizer): "The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead." ) ) - super().__init__(regex_string, tokenizer) + self.from_regex(regex_string, tokenizer) def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: next_instruction = self.get_next_instruction(state) diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py index b7a1378c..966a18fb 100644 --- a/python/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -121,31 +121,31 @@ def create_states_mapping( tokenizer: "Tokenizer", regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, frozen_tokens: List[str] = [], -) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]: +) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: """Create the variables related to the mapping between states and tokens The parameters of the function are used for caching purpose. Parameters ---------- - regex_string: (`str`): + regex_string: The regular expression string to generate a states mapping for. - tokenizer: (`Tokenizer`): + tokenizer: The model's tokenizer. - regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*): + regex_parser: A function that parses a regex string into an `interegular` Pattern object. - frozen_tokens: (`List[str]`, *optional*): + frozen_tokens: A list of tokens that should be kept as-is when expanding the token-level FSM into a byte-level FSM. Defaults to an empty list. Returns ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): + states_to_token_maps: A mapping from states to a mapping from token ids originating from that state to the next state to transition to given that token. The structure is as follows: (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): + empty_token_ids: A set of token ids that correspond to empty strings. - final_states: (`set`): + final_states: A set of final states in the FSM. """ regex_pattern = regex_parser(regex_string) @@ -170,21 +170,72 @@ def create_states_mapping( return states_to_token_maps, empty_token_ids, regex_fsm.finals +def create_states_mapping_from_interegular_fsm( + fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" +) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: + byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids, regex_fsm.finals + + class RegexGuide(Guide): """Guide to generate text in the language of a regular expression.""" initial_state = 0 - def __init__(self, regex_string: str, tokenizer: "Tokenizer"): - ( - self.states_to_token_maps, - self.empty_token_ids, - fsm_finals, - ) = create_states_mapping(regex_string, tokenizer) - self.eos_token_id = tokenizer.eos_token_id + def __init__(self, states_to_token_maps, empty_token_ids, fsm_finals, eos_token_id): + self.states_to_token_maps = states_to_token_maps + self.empty_token_ids = empty_token_ids + self.eos_token_id = eos_token_id self.final_states = fsm_finals | {-1} self._cache_state_to_token_tensor() + @classmethod + def from_regex( + cls, + regex_string: str, + tokenizer: "Tokenizer", + _create_states_mapping=create_states_mapping, + ): + ( + states_to_token_maps, + empty_token_ids, + fsm_finals, + ) = _create_states_mapping(regex_string, tokenizer) + return cls( + states_to_token_maps, empty_token_ids, fsm_finals, tokenizer.eos_token_id + ) + + @classmethod + def from_interegular_fsm( + cls, + interegular_fsm: interegular.fsm.FSM, + tokenizer: "Tokenizer", + _create_states_mapping_from_interegular_fsm=create_states_mapping_from_interegular_fsm, + ): + ( + states_to_token_maps, + empty_token_ids, + fsm_finals, + ) = _create_states_mapping_from_interegular_fsm(interegular_fsm) + return cls( + states_to_token_maps, empty_token_ids, fsm_finals, tokenizer.eos_token_id + ) + def get_next_instruction(self, state: int) -> Instruction: """Return the next instruction for guided generation. @@ -242,45 +293,6 @@ def get_next_state(self, state: int, token_id: int) -> int: return next_state - @classmethod - def from_interegular_fsm( - cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" - ): - from_interegular_instance = cls.__new__(cls) - - def create_states_mapping_from_interegular_fsm( - fsm: interegular.fsm.FSM, - ) -> Tuple[dict, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids - - ( - from_interegular_instance.states_to_token_maps, - from_interegular_instance.empty_token_ids, - ) = create_states_mapping_from_interegular_fsm(interegular_fsm) - from_interegular_instance.eos_token_id = tokenizer.eos_token_id - from_interegular_instance._cache_state_to_token_tensor() - return from_interegular_instance - def _cache_state_to_token_tensor(self): """ cache state -> token int tensor diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 0bd28d4f..d5d5292b 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -42,7 +42,7 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" with pytest.raises(ValueError, match="The vocabulary"): - RegexGuide(regex_str, MockTokenizer()) + RegexGuide.from_regex(regex_str, MockTokenizer()) def test_regex(): @@ -56,7 +56,7 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == {0: {1: 1}} @@ -97,7 +97,7 @@ def convert_token_to_string(self, token): regex_str = "[😁-😎]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == { 0: {5: 1, 4: 2}, @@ -144,7 +144,7 @@ def convert_token_to_string(self, token): regex_str = " [😁-😎]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == { 0: {5: 1, 10: 2}, @@ -179,7 +179,7 @@ def convert_token_to_string(self, token): regex_str = r"`\n(\.\n)?`\n" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) state = fsm.get_next_state(state=4, token_id=103) assert state == 5