From 889e30123b00594142666b2dbe0f8e58b3d4b392 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 1 Oct 2024 11:15:28 -0500 Subject: [PATCH] Refactor RegexGuide constructor interface --- benchmarks/bench_json_schema.py | 2 +- benchmarks/bench_regex_guide.py | 4 +- python/outlines_core/fsm/guide.py | 185 +++++++++++++++++++----------- tests/fsm/test_guide.py | 43 ++++++- 4 files changed, 157 insertions(+), 77 deletions(-) diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 47578cd3..964caa35 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -74,4 +74,4 @@ def time_json_schema_to_regex(self, schema_name): def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) - RegexGuide(regex, self.tokenizer) + RegexGuide.from_regex(regex, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 287d5f51..a47adae2 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -23,7 +23,7 @@ def setup(self, pattern_name): self.pattern = regex_samples[pattern_name] def time_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) + RegexGuide.from_regex(self.pattern, self.tokenizer) class MemoryRegexGuideBenchmark: @@ -34,4 +34,4 @@ def setup(self, pattern_name): self.pattern = regex_samples[pattern_name] def peakmem_regex_to_guide(self, pattern_name): - RegexGuide(self.pattern, self.tokenizer) + RegexGuide.from_regex(self.pattern, self.tokenizer) diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py index b7a1378c..b2de683a 100644 --- a/python/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -121,40 +121,74 @@ def create_states_mapping( tokenizer: "Tokenizer", regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, frozen_tokens: List[str] = [], -) -> Tuple[Dict[int, Dict[int, int]], Set[int], set]: - """Create the variables related to the mapping between states and tokens +) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: + """Create the variables related to the mapping between states and tokens from a regex string. + The parameters of the function are used for caching purpose. Parameters ---------- - regex_string: (`str`): + regex_string: The regular expression string to generate a states mapping for. - tokenizer: (`Tokenizer`): + tokenizer: The model's tokenizer. - regex_parser: (`Callable[[str], interegular.Pattern]`, *optional*): + regex_parser: A function that parses a regex string into an `interegular` Pattern object. - frozen_tokens: (`List[str]`, *optional*): + frozen_tokens: + A list of tokens that should be kept as-is when expanding the token-level FSM + into a byte-level FSM. Defaults to an empty list. + + Returns + ------- + states_to_token_maps: + A mapping from states to a mapping from token ids originating from that state + to the next state to transition to given that token. The structure is as follows: + (origin_state -> (token_id -> next_state)) + empty_token_ids: + A set of token ids that correspond to empty strings. + final_states: + A set of final states in the FSM. + """ + regex_fsm = regex_parser(regex_string).to_fsm() + return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens) + + +def create_states_mapping_from_fsm( + fsm: interegular.fsm.FSM, + tokenizer: "Tokenizer", + frozen_tokens: List[str] = [], +) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: + """Create the variables related to the mapping between states and tokens from an FSM. + + The parameters of the function are used for caching purpose. + + Parameters + ---------- + fsm: + An FSM for the regular expression. + tokenizer: + The model's tokenizer. + frozen_tokens: A list of tokens that should be kept as-is when expanding the token-level FSM into a byte-level FSM. Defaults to an empty list. Returns ------- - states_to_token_maps: (`Dict[int, Dict[int, int]]`): + states_to_token_maps: A mapping from states to a mapping from token ids originating from that state to the next state to transition to given that token. The structure is as follows: (origin_state -> (token_id -> next_state)) - empty_token_ids: (`Set[int]`): + empty_token_ids: A set of token ids that correspond to empty strings. - final_states: (`set`): + final_states: A set of final states in the FSM. """ - regex_pattern = regex_parser(regex_string) byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True, frozen_tokens=frozen_tokens + fsm.reduce(), keep_utf8=True, frozen_tokens=frozen_tokens ) regex_fsm, _ = make_deterministic_fsm(byte_fsm) states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer, frozen_tokens=frozen_tokens + regex_fsm, tokenizer ) # We make sure that it is possible to generate strings in the language @@ -175,15 +209,79 @@ class RegexGuide(Guide): initial_state = 0 - def __init__(self, regex_string: str, tokenizer: "Tokenizer"): + def __init__( + self, + states_to_token_maps, + empty_token_ids, + fsm_finals, + eos_token_id, + states_to_token_mask, + ): + self.states_to_token_maps = states_to_token_maps + self.empty_token_ids = empty_token_ids + self.eos_token_id = eos_token_id + self.final_states = fsm_finals | {-1} + self.states_to_token_mask = states_to_token_mask + + @classmethod + def from_regex( + cls, + regex_string: str, + tokenizer: "Tokenizer", + _create_states_mapping=create_states_mapping, + device=None, + regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, + frozen_tokens: List[str] = [], + ): + ( + states_to_token_maps, + empty_token_ids, + fsm_finals, + ) = _create_states_mapping( + regex_string, + tokenizer, + regex_parser=regex_parser, + frozen_tokens=frozen_tokens, + ) + states_to_token_mask = { + state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device) + for state, next_tokens_to_end_states in states_to_token_maps.items() + } + return cls( + states_to_token_maps, + empty_token_ids, + fsm_finals, + tokenizer.eos_token_id, + states_to_token_mask, + ) + + @classmethod + def from_interegular_fsm( + cls, + interegular_fsm: interegular.fsm.FSM, + tokenizer: "Tokenizer", + _create_states_mapping_from_fsm=create_states_mapping_from_fsm, + device=None, + frozen_tokens: List[str] = [], + ): ( - self.states_to_token_maps, - self.empty_token_ids, + states_to_token_maps, + empty_token_ids, fsm_finals, - ) = create_states_mapping(regex_string, tokenizer) - self.eos_token_id = tokenizer.eos_token_id - self.final_states = fsm_finals | {-1} - self._cache_state_to_token_tensor() + ) = _create_states_mapping_from_fsm( + interegular_fsm, tokenizer, frozen_tokens=frozen_tokens + ) + states_to_token_mask = { + state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device) + for state, next_tokens_to_end_states in states_to_token_maps.items() + } + return cls( + states_to_token_maps, + empty_token_ids, + fsm_finals, + tokenizer.eos_token_id, + states_to_token_mask, + ) def get_next_instruction(self, state: int) -> Instruction: """Return the next instruction for guided generation. @@ -242,55 +340,6 @@ def get_next_state(self, state: int, token_id: int) -> int: return next_state - @classmethod - def from_interegular_fsm( - cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" - ): - from_interegular_instance = cls.__new__(cls) - - def create_states_mapping_from_interegular_fsm( - fsm: interegular.fsm.FSM, - ) -> Tuple[dict, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids - - ( - from_interegular_instance.states_to_token_maps, - from_interegular_instance.empty_token_ids, - ) = create_states_mapping_from_interegular_fsm(interegular_fsm) - from_interegular_instance.eos_token_id = tokenizer.eos_token_id - from_interegular_instance._cache_state_to_token_tensor() - return from_interegular_instance - - def _cache_state_to_token_tensor(self): - """ - cache state -> token int tensor - this increases performance of mask construction substantially - """ - self.states_to_token_mask = { - state: torch.tensor(list(next_tokens_to_end_states.keys())) - for state, next_tokens_to_end_states in self.states_to_token_maps.items() - } - def is_final_state(self, state: int) -> bool: """Determine whether the current state of the guide is a final state.""" return state in self.final_states diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 0bd28d4f..174063b0 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,3 +1,4 @@ +import interegular import pytest from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write @@ -42,10 +43,10 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" with pytest.raises(ValueError, match="The vocabulary"): - RegexGuide(regex_str, MockTokenizer()) + RegexGuide.from_regex(regex_str, MockTokenizer()) -def test_regex(): +def test_from_regex(): class MockTokenizer: vocabulary = {"1": 1, "a": 2, "eos": 3} special_tokens = {"eos"} @@ -56,7 +57,37 @@ def convert_token_to_string(self, token): regex_str = "[1-9]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) + + assert fsm.states_to_token_maps == {0: {1: 1}} + + instruction = fsm.get_next_instruction(0) + assert isinstance(instruction, Generate) + assert_expected_tensor_ids(instruction.tokens, [1]) + + assert fsm.get_next_state(state=0, token_id=1) == 1 + assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + +def test_from_fsm(): + class MockTokenizer: + vocabulary = {"1": 1, "a": 2, "eos": 3} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + tokenizer = MockTokenizer() + fsm = RegexGuide.from_interegular_fsm( + interegular.parse_pattern(regex_str).to_fsm(), tokenizer + ) assert fsm.states_to_token_maps == {0: {1: 1}} @@ -97,7 +128,7 @@ def convert_token_to_string(self, token): regex_str = "[😁-😎]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == { 0: {5: 1, 4: 2}, @@ -144,7 +175,7 @@ def convert_token_to_string(self, token): regex_str = " [😁-😎]" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) assert fsm.states_to_token_maps == { 0: {5: 1, 10: 2}, @@ -179,7 +210,7 @@ def convert_token_to_string(self, token): regex_str = r"`\n(\.\n)?`\n" tokenizer = MockTokenizer() - fsm = RegexGuide(regex_str, tokenizer) + fsm = RegexGuide.from_regex(regex_str, tokenizer) state = fsm.get_next_state(state=4, token_id=103) assert state == 5