Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a trie for scanning during index construction #507

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 80 additions & 20 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,47 @@ def get_sub_fsms_from_seq(
)


VocabTrie = namedtuple(
"VocabTrie",
[
"parent_children_map",
"idx_to_token_str",
"token_str_to_idx",
],
)


@numba.njit(cache=True, nogil=True)
def create_vocab_trie(vocabulary: Dict[str, List[int]]) -> VocabTrie:
# Initialize the trie as a flat dictionary
trie = numba.typed.Dict.empty(numba.int64, nb_int_list_type)

vocab_list = numba.typed.List.empty_list(numba.types.string)
reverse_vocab = numba.typed.Dict.empty(numba.types.string, numba.int64)
for i, token in enumerate(vocabulary):
vocab_list.append(token)
reverse_vocab[token] = numba.int64(i)

for token, token_id in reverse_vocab.items():
for i in range(len(token), 0, -1):
prefix = token[:i]
if prefix == token:
continue
if prefix in vocabulary:
prefix_id = numba.int64(reverse_vocab[prefix])
trie.setdefault(
prefix_id, numba.typed.List.empty_list(numba.int64)
).append(numba.int64(token_id))
break
else:
root_prefix_id = -1
trie.setdefault(
root_prefix_id, numba.typed.List.empty_list(numba.int64)
).append(numba.int64(token_id))

return VocabTrie(trie, vocab_list, reverse_vocab)


@numba.njit(cache=True, nogil=True)
def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
Expand All @@ -466,11 +507,19 @@ def state_scan_tokens(
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: Dict[str, List[int]],
vocab_trie: VocabTrie,
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary.items():
# Initialize the stack with tokens having no prefixes
stack = numba.typed.List()
for next_token_idx in vocab_trie.parent_children_map[-1]:
stack.append(vocab_trie.idx_to_token_str[next_token_idx])

# Process the tokens using the stack
while len(stack) > 0:
token = stack.pop()
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
Expand All @@ -485,46 +534,52 @@ def state_scan_tokens(
if state_seq is not None and len(state_seq) < len(token):
continue

for token_id in token_ids:
for token_id in vocabulary[token]:
res.add((token_id, state_seq[-1]))

# Add successors to the stack
token_idx = vocab_trie.token_str_to_idx[token]
if token_idx in vocab_trie.parent_children_map:
for next_token_idx in vocab_trie.parent_children_map[token_idx]:
stack.append(vocab_trie.idx_to_token_str[next_token_idx])

return res


@numba.njit(cache=True, nogil=True)
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: Dict[str, List[int]],
) -> Dict[int, Set[Tuple[int, int]]]:
) -> List[List[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
seen: Set[int] = set()
next_states = {fsm_info.initial}
finals_list = numba.typed.List.empty_list(numba.int64)
for final in fsm_info.finals:
finals_list.append(final)

all_states = set(fsm_info.transitions.values())
all_states.add(fsm_info.initial)
num_states = max(all_states)
states_to_token_subsets = numba.typed.List(
[numba.typed.List.empty_list(nb_int_pair_type) for _ in range(num_states + 1)]
)

while next_states:
start_state = next_states.pop()
vocab_trie = create_vocab_trie(vocabulary)

for state_id in all_states:
token_ids_end_states = state_scan_tokens(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
finals_list,
vocabulary,
start_state,
vocab_trie,
state_id,
)

for token_id_and_end_state in token_ids_end_states:
states_to_token_subsets.setdefault(start_state, set()).add(
token_id_and_end_state
)
end_state = token_id_and_end_state[1]
if end_state not in seen:
next_states.add(end_state)

seen.add(start_state)
states_to_token_subsets[state_id].append(token_id_and_end_state)

return states_to_token_subsets

Expand Down Expand Up @@ -571,6 +626,11 @@ def create_fsm_index_tokenizer(
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)

states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {
state: set(tokens_id_end_state_list)
for state, tokens_id_end_state_list in enumerate(states_to_token_subsets)
if tokens_id_end_state_list
}

# Allow transitions to EOS from all terminals FSM states that are
# reachable
Expand Down
5 changes: 5 additions & 0 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def test_create_fsm_index_end_to_end():
vocabulary_nb.update(vocabulary)

res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb)
res = {
state: set(tokens_id_end_state_list)
for state, tokens_id_end_state_list in enumerate(res)
if tokens_id_end_state_list
}

assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}}

Expand Down