Skip to content

Commit

Permalink
Use Vocabulary wrapper in Python code
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Sep 26, 2024
1 parent f6768c1 commit 9dd8a2d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
6 changes: 3 additions & 3 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[str, List[int]]],
vocabulary: Vocabulary,
vocabulary_transition_keys: Dict[str, List[int]],
start_state: int,
) -> Set[Tuple[int, int]]: ...
Expand All @@ -44,12 +44,12 @@ def get_token_transition_keys(
def get_vocabulary_transition_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, List[int]]],
vocabulary: Vocabulary,
frozen_tokens: Set[str],
) -> Dict[str, List[int]]: ...
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[str, List[int]]],
vocabulary: Vocabulary,
frozen_tokens: frozenset[str],
) -> Dict[int, Dict[int, int]]: ...

Expand Down
5 changes: 3 additions & 2 deletions python/outlines_core/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from .outlines_core_rs import ( # noqa: F401
FSMInfo,
Vocabulary,
_walk_fsm,
create_fsm_index_end_to_end,
get_token_transition_keys,
Expand Down Expand Up @@ -470,11 +471,11 @@ def create_fsm_index_tokenizer(
`fsm` needs to be deterministically ordered so that future caching makes sense.
"""
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
tokens, empty_token_ids = reduced_vocabulary(tokenizer)

states_to_token_subsets = create_fsm_index_end_to_end(
fsm.fsm_info,
list(vocabulary.items()),
Vocabulary.from_dict(tokens),
frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(),
)

Expand Down
18 changes: 7 additions & 11 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,15 @@ pub fn state_scan_tokens_py(
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: HashSet<State>,
vocabulary: Vec<(String, Vec<TokenId>)>,
vocabulary: &PyVocabulary,
vocabulary_transition_keys: HashMap<String, Vec<TransitionKey>>,
start_state: State,
) -> PyResult<HashSet<(TokenId, State)>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
&fsm_finals,
&vocabulary,
&vocabulary.0,
&vocabulary_transition_keys,
start_state,
))
Expand All @@ -129,14 +128,13 @@ pub fn get_token_transition_keys_py(
pub fn get_vocabulary_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: Vec<(String, Vec<TokenId>)>,
vocabulary: &PyVocabulary,
frozen_tokens: HashSet<String>,
) -> PyResult<HashMap<String, Vec<TransitionKey>>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
&vocabulary,
&vocabulary.0,
&frozen_tokens,
))
}
Expand All @@ -146,19 +144,17 @@ pub fn get_vocabulary_transition_keys_py(
pub fn create_fsm_index_end_to_end_py<'py>(
py: Python<'py>,
fsm_info: &FSMInfo,
vocabulary: Vec<(String, Vec<TokenId>)>,
vocabulary: &PyVocabulary,
frozen_tokens: HashSet<String>,
) -> PyResult<Bound<'py, PyDict>> {
let vocabulary = Vocabulary::from_iter(vocabulary);

let states_to_token_subsets = PyDict::new_bound(py);
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
&vocabulary,
&vocabulary.0,
&frozen_tokens,
);

Expand All @@ -170,7 +166,7 @@ pub fn create_fsm_index_end_to_end_py<'py>(
&fsm_info.transitions,
fsm_info.initial,
&fsm_info.finals,
&vocabulary,
&vocabulary.0,
&vocabulary_transition_keys,
start_state,
);
Expand Down
33 changes: 18 additions & 15 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import interegular
import numpy as np
import pytest
from outlines_core.fsm.outlines_core_rs import Vocabulary
from outlines_core.fsm.regex import (
BetterAlphabet,
BetterFSM,
Expand Down Expand Up @@ -182,23 +183,24 @@ def test_create_fsm_index_end_to_end():
regex_pattern = interegular.parse_pattern(regex_str)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())

vocabulary = {
tokens = {
"blah": [0],
"1a": [1],
"2": [2],
"0": [3],
"<EOS>": [4],
}

vocabulary_nb = []
for token_tuple, token_ids in vocabulary.items():
token = merge_symbols(token_tuple)
vocabulary = Vocabulary()
for token_tuple, token_ids in tokens.items():
token_tuple_np = merge_symbols(token_tuple)
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token, token_ids_np))
for token_id in token_ids_np:
vocabulary.insert(token_tuple_np, token_id)

res = create_fsm_index_end_to_end(
regex_fsm.fsm_info,
vocabulary_nb,
vocabulary,
frozenset(),
)

Expand All @@ -212,7 +214,7 @@ def test_create_fsm_index_end_to_end_multi_byte():
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True)

vocabulary = {
tokens = {
"blah": [0],
"😈a": [1],
"😇": [2],
Expand All @@ -224,15 +226,16 @@ def test_create_fsm_index_end_to_end_multi_byte():
"<EOS>": [8],
}

vocabulary_nb = []
for token_tuple, token_ids in vocabulary.items():
vocabulary = Vocabulary()
for token_tuple, token_ids in tokens.items():
token_tuple_np = merge_symbols(token_tuple)
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))
for token_id in token_ids_np:
vocabulary.insert(token_tuple_np, token_id)

res = create_fsm_index_end_to_end(
byte_fsm.fsm_info,
vocabulary_nb,
vocabulary,
frozenset(),
)

Expand Down Expand Up @@ -433,11 +436,11 @@ def convert_token_to_string(self, token):
regex_pattern = interegular.parse_pattern(pattern)
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
tokens, _ = reduced_vocabulary(tokenizer)
token_str_to_tranition_keys = get_vocabulary_transition_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
list(vocabulary.items()),
Vocabulary.from_dict(tokens),
frozenset(),
)

Expand Down Expand Up @@ -465,11 +468,11 @@ def convert_token_to_string(self, token):
regex_pattern = interegular.parse_pattern(pattern)
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
tokens, _ = reduced_vocabulary(tokenizer)
token_str_to_tranition_keys = get_vocabulary_transition_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
list(vocabulary.items()),
Vocabulary.from_dict(tokens),
frozenset(),
)

Expand Down

0 comments on commit 9dd8a2d

Please sign in to comment.