diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index a96fa8e24..35880e169 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -458,6 +458,47 @@ def get_sub_fsms_from_seq( ) +VocabTrie = namedtuple( + "VocabTrie", + [ + "parent_children_map", + "idx_to_token_str", + "token_str_to_idx", + ], +) + + +@numba.njit(cache=True, nogil=True) +def create_vocab_trie(vocabulary: Dict[str, List[int]]) -> VocabTrie: + # Initialize the trie as a flat dictionary + trie = numba.typed.Dict.empty(numba.int64, nb_int_list_type) + + vocab_list = numba.typed.List.empty_list(numba.types.string) + reverse_vocab = numba.typed.Dict.empty(numba.types.string, numba.int64) + for i, token in enumerate(vocabulary): + vocab_list.append(token) + reverse_vocab[token] = numba.int64(i) + + for token, token_id in reverse_vocab.items(): + for i in range(len(token), 0, -1): + prefix = token[:i] + if prefix == token: + continue + if prefix in vocabulary: + prefix_id = numba.int64(reverse_vocab[prefix]) + trie.setdefault( + prefix_id, numba.typed.List.empty_list(numba.int64) + ).append(numba.int64(token_id)) + break + else: + root_prefix_id = -1 + trie.setdefault( + root_prefix_id, numba.typed.List.empty_list(numba.int64) + ).append(numba.int64(token_id)) + + return VocabTrie(trie, vocab_list, reverse_vocab) + + @numba.njit(cache=True, nogil=True) def state_scan_tokens( fsm_transitions: Dict[Tuple[int, int], int], @@ -466,11 +507,19 @@ def state_scan_tokens( fsm_initial: int, fsm_finals: Set[int], vocabulary: Dict[str, List[int]], + vocab_trie: VocabTrie, start_state: int, ) -> Set[Tuple[int, int]]: res = set() - for token, token_ids in vocabulary.items(): + # Initialize the stack with tokens having no prefixes + stack = numba.typed.List() + for next_token_idx in vocab_trie.parent_children_map[-1]: + stack.append(vocab_trie.idx_to_token_str[next_token_idx]) + + # Process the tokens using the stack + while len(stack) > 0: + token = stack.pop() state_seq = _walk_fsm( fsm_transitions, alphabet_symbol_mapping, @@ -485,46 +534,52 @@ def state_scan_tokens( if state_seq is not None and len(state_seq) < len(token): continue - for token_id in token_ids: + for token_id in vocabulary[token]: res.add((token_id, state_seq[-1])) + # Add successors to the stack + token_idx = vocab_trie.token_str_to_idx[token] + if token_idx in vocab_trie.parent_children_map: + for next_token_idx in vocab_trie.parent_children_map[token_idx]: + stack.append(vocab_trie.idx_to_token_str[next_token_idx]) + return res +@numba.njit(cache=True, nogil=True) def create_fsm_index_end_to_end( fsm_info: FSMInfo, vocabulary: Dict[str, List[int]], -) -> Dict[int, Set[Tuple[int, int]]]: +) -> List[List[Tuple[int, int]]]: """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" - # TODO: Consider using a `List` of `Set`s instead; that way we can JIT this - # code, too. - states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {} - seen: Set[int] = set() - next_states = {fsm_info.initial} + finals_list = numba.typed.List.empty_list(numba.int64) + for final in fsm_info.finals: + finals_list.append(final) + + all_states = set(fsm_info.transitions.values()) + all_states.add(fsm_info.initial) + num_states = max(all_states) + states_to_token_subsets = numba.typed.List( + [numba.typed.List.empty_list(nb_int_pair_type) for _ in range(num_states + 1)] + ) - while next_states: - start_state = next_states.pop() + vocab_trie = create_vocab_trie(vocabulary) + for state_id in all_states: token_ids_end_states = state_scan_tokens( fsm_info.transitions, fsm_info.alphabet_symbol_mapping, fsm_info.alphabet_anything_value, fsm_info.initial, - fsm_info.finals, + finals_list, vocabulary, - start_state, + vocab_trie, + state_id, ) for token_id_and_end_state in token_ids_end_states: - states_to_token_subsets.setdefault(start_state, set()).add( - token_id_and_end_state - ) - end_state = token_id_and_end_state[1] - if end_state not in seen: - next_states.add(end_state) - - seen.add(start_state) + states_to_token_subsets[state_id].append(token_id_and_end_state) return states_to_token_subsets @@ -571,6 +626,11 @@ def create_fsm_index_tokenizer( vocabulary, empty_token_ids = reduced_vocabulary(tokenizer) states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary) + states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = { + state: set(tokens_id_end_state_list) + for state, tokens_id_end_state_list in enumerate(states_to_token_subsets) + if tokens_id_end_state_list + } # Allow transitions to EOS from all terminals FSM states that are # reachable diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 2d2dbd993..f64e45ddf 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -257,6 +257,11 @@ def test_create_fsm_index_end_to_end(): vocabulary_nb.update(vocabulary) res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) + res = { + state: set(tokens_id_end_state_list) + for state, tokens_id_end_state_list in enumerate(res) + if tokens_id_end_state_list + } assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}}