From 9e896c29908a8c56f4dc6979cc12860b872b7ff5 Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 26 Sep 2024 10:13:32 +0300 Subject: [PATCH] Use `Vocabulary` wrapper in Python code --- python/outlines_core/fsm/outlines_core_rs.pyi | 6 ++-- python/outlines_core/fsm/regex.py | 5 +-- src/python_bindings/mod.rs | 18 ++++------ tests/fsm/test_regex.py | 33 ++++++++++--------- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index 86dbc3ff..1780f2ca 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -32,7 +32,7 @@ def state_scan_tokens( fsm_transitions: Dict[Tuple[int, int], int], fsm_initial: int, fsm_finals: Set[int], - vocabulary: List[Tuple[str, List[int]]], + vocabulary: Vocabulary, vocabulary_transition_keys: Dict[str, List[int]], start_state: int, ) -> Set[Tuple[int, int]]: ... @@ -44,12 +44,12 @@ def get_token_transition_keys( def get_vocabulary_transition_keys( alphabet_symbol_mapping: Dict[str, int], alphabet_anything_value: int, - vocabulary: List[Tuple[str, List[int]]], + vocabulary: Vocabulary, frozen_tokens: Set[str], ) -> Dict[str, List[int]]: ... def create_fsm_index_end_to_end( fsm_info: FSMInfo, - vocabulary: List[Tuple[str, List[int]]], + vocabulary: Vocabulary, frozen_tokens: frozenset[str], ) -> Dict[int, Dict[int, int]]: ... diff --git a/python/outlines_core/fsm/regex.py b/python/outlines_core/fsm/regex.py index b0c6f1c8..078755f0 100644 --- a/python/outlines_core/fsm/regex.py +++ b/python/outlines_core/fsm/regex.py @@ -25,6 +25,7 @@ from .outlines_core_rs import ( # noqa: F401 FSMInfo, + Vocabulary, _walk_fsm, create_fsm_index_end_to_end, get_token_transition_keys, @@ -470,11 +471,11 @@ def create_fsm_index_tokenizer( `fsm` needs to be deterministically ordered so that future caching makes sense. """ - vocabulary, empty_token_ids = reduced_vocabulary(tokenizer) + tokens, empty_token_ids = reduced_vocabulary(tokenizer) states_to_token_subsets = create_fsm_index_end_to_end( fsm.fsm_info, - list(vocabulary.items()), + Vocabulary.from_dict(tokens), frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(), ) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 3338f38c..cb94e411 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -93,16 +93,15 @@ pub fn state_scan_tokens_py( fsm_transitions: HashMap<(State, TransitionKey), State>, fsm_initial: State, fsm_finals: HashSet, - vocabulary: Vec<(String, Vec)>, + vocabulary: &PyVocabulary, vocabulary_transition_keys: HashMap>, start_state: State, ) -> PyResult> { - let vocabulary = Vocabulary::from_iter(vocabulary); Ok(state_scan_tokens( &fsm_transitions, fsm_initial, &fsm_finals, - &vocabulary, + &vocabulary.0, &vocabulary_transition_keys, start_state, )) @@ -129,14 +128,13 @@ pub fn get_token_transition_keys_py( pub fn get_vocabulary_transition_keys_py( alphabet_symbol_mapping: HashMap, alphabet_anything_value: TransitionKey, - vocabulary: Vec<(String, Vec)>, + vocabulary: &PyVocabulary, frozen_tokens: HashSet, ) -> PyResult>> { - let vocabulary = Vocabulary::from_iter(vocabulary); Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, - &vocabulary, + &vocabulary.0, &frozen_tokens, )) } @@ -146,11 +144,9 @@ pub fn get_vocabulary_transition_keys_py( pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, fsm_info: &FSMInfo, - vocabulary: Vec<(String, Vec)>, + vocabulary: &PyVocabulary, frozen_tokens: HashSet, ) -> PyResult> { - let vocabulary = Vocabulary::from_iter(vocabulary); - let states_to_token_subsets = PyDict::new_bound(py); let mut seen: HashSet = HashSet::new(); let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); @@ -158,7 +154,7 @@ pub fn create_fsm_index_end_to_end_py<'py>( let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, fsm_info.alphabet_anything_value, - &vocabulary, + &vocabulary.0, &frozen_tokens, ); @@ -170,7 +166,7 @@ pub fn create_fsm_index_end_to_end_py<'py>( &fsm_info.transitions, fsm_info.initial, &fsm_info.finals, - &vocabulary, + &vocabulary.0, &vocabulary_transition_keys, start_state, ); diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 6e18e477..99d6a275 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -1,6 +1,7 @@ import interegular import numpy as np import pytest +from outlines_core.fsm.outlines_core_rs import Vocabulary from outlines_core.fsm.regex import ( BetterAlphabet, BetterFSM, @@ -182,7 +183,7 @@ def test_create_fsm_index_end_to_end(): regex_pattern = interegular.parse_pattern(regex_str) regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - vocabulary = { + tokens = { "blah": [0], "1a": [1], "2": [2], @@ -190,15 +191,16 @@ def test_create_fsm_index_end_to_end(): "": [4], } - vocabulary_nb = [] - for token_tuple, token_ids in vocabulary.items(): - token = merge_symbols(token_tuple) + vocabulary = Vocabulary() + for token_tuple, token_ids in tokens.items(): + token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token, token_ids_np)) + for token_id in token_ids_np: + vocabulary.insert(token_tuple_np, token_id) res = create_fsm_index_end_to_end( regex_fsm.fsm_info, - vocabulary_nb, + vocabulary, frozenset(), ) @@ -212,7 +214,7 @@ def test_create_fsm_index_end_to_end_multi_byte(): regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) - vocabulary = { + tokens = { "blah": [0], "😈a": [1], "😇": [2], @@ -224,15 +226,16 @@ def test_create_fsm_index_end_to_end_multi_byte(): "": [8], } - vocabulary_nb = [] - for token_tuple, token_ids in vocabulary.items(): + vocabulary = Vocabulary() + for token_tuple, token_ids in tokens.items(): token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + for token_id in token_ids_np: + vocabulary.insert(token_tuple_np, token_id) res = create_fsm_index_end_to_end( byte_fsm.fsm_info, - vocabulary_nb, + vocabulary, frozenset(), ) @@ -433,11 +436,11 @@ def convert_token_to_string(self, token): regex_pattern = interegular.parse_pattern(pattern) interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - vocabulary, _ = reduced_vocabulary(tokenizer) + tokens, _ = reduced_vocabulary(tokenizer) token_str_to_tranition_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, - list(vocabulary.items()), + Vocabulary.from_dict(tokens), frozenset(), ) @@ -465,11 +468,11 @@ def convert_token_to_string(self, token): regex_pattern = interegular.parse_pattern(pattern) interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) - vocabulary, _ = reduced_vocabulary(tokenizer) + tokens, _ = reduced_vocabulary(tokenizer) token_str_to_tranition_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, - list(vocabulary.items()), + Vocabulary.from_dict(tokens), frozenset(), )