From ac33a296ff3923f331eac3d98a0cc7648bb2b982 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 13 Dec 2024 21:00:51 +0000 Subject: [PATCH] Use FxHash* as default Hash* --- src/python_bindings/mod.rs | 50 +++++++++++++++++++------------------- src/regex.rs | 26 ++++++++++---------- src/vocabulary/mod.rs | 18 +++++++------- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 944ca150..a3ed0d11 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -10,21 +10,21 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3::wrap_pyfunction; -use rustc_hash::{FxHashMap, FxHashSet}; use serde_json::Value; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[pyclass(name = "FSMInfo")] pub struct PyFSMInfo { #[pyo3(get)] initial: State, #[pyo3(get)] - finals: FxHashSet, + finals: HashSet, #[pyo3(get)] - transitions: FxHashMap<(State, TransitionKey), State>, + transitions: HashMap<(State, TransitionKey), State>, #[pyo3(get)] alphabet_anything_value: TransitionKey, #[pyo3(get)] - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, } impl From for PyFSMInfo { @@ -57,10 +57,10 @@ impl PyFSMInfo { #[new] fn new( initial: State, - finals: FxHashSet, - transitions: FxHashMap<(State, TransitionKey), State>, + finals: HashSet, + transitions: HashMap<(State, TransitionKey), State>, alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, ) -> Self { FSMInfo::new( initial, @@ -84,7 +84,7 @@ impl PyIndex { fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, eos_token_id: u32, - frozen_tokens: FxHashSet, + frozen_tokens: HashSet, ) -> PyResult { py.allow_threads(|| { Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) @@ -135,11 +135,11 @@ impl PyIndex { self.0.is_final(state) } - fn final_states(&self) -> FxHashSet { + fn final_states(&self) -> HashSet { self.0.final_states().clone() } - fn get_transitions(&self) -> FxHashMap> { + fn get_transitions(&self) -> HashMap> { self.0.transitions().clone() } @@ -171,9 +171,9 @@ pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyR text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" )] pub fn walk_fsm_py( - fsm_transitions: FxHashMap<(State, TransitionKey), State>, + fsm_transitions: HashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: FxHashSet, + fsm_finals: HashSet, token_transition_keys: Vec, start_state: State, full_match: bool, @@ -193,13 +193,13 @@ pub fn walk_fsm_py( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" )] pub fn state_scan_tokens_py( - fsm_transitions: FxHashMap<(State, TransitionKey), State>, + fsm_transitions: HashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: FxHashSet, + fsm_finals: HashSet, vocabulary: &PyVocabulary, - vocabulary_transition_keys: FxHashMap>, + vocabulary_transition_keys: HashMap>, start_state: State, -) -> PyResult> { +) -> PyResult> { Ok(state_scan_tokens( &fsm_transitions, fsm_initial, @@ -213,7 +213,7 @@ pub fn state_scan_tokens_py( #[pyfunction(name = "get_token_transition_keys")] #[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] pub fn get_token_transition_keys_py( - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, alphabet_anything_value: TransitionKey, token_str: String, ) -> PyResult> { @@ -229,11 +229,11 @@ pub fn get_token_transition_keys_py( text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" )] pub fn get_vocabulary_transition_keys_py( - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, alphabet_anything_value: TransitionKey, vocabulary: &PyVocabulary, - frozen_tokens: FxHashSet, -) -> PyResult>> { + frozen_tokens: HashSet, +) -> PyResult>> { Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, @@ -248,11 +248,11 @@ pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, - frozen_tokens: FxHashSet, + frozen_tokens: HashSet, ) -> PyResult> { let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter(vec![fsm_info.initial]); + let mut seen: HashSet = HashSet::default(); + let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, @@ -300,13 +300,13 @@ pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { #[staticmethod] - fn from_dict(map: FxHashMap>) -> PyVocabulary { + fn from_dict(map: HashMap>) -> PyVocabulary { PyVocabulary(Vocabulary::from(map)) } #[staticmethod] fn from_dict_with_eos_token_id( - map: FxHashMap>, + map: HashMap>, eos_token_id: TokenId, ) -> PyVocabulary { let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id)); diff --git a/src/regex.rs b/src/regex.rs index 24687f1e..c9270b69 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,10 +1,10 @@ use crate::prelude::*; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; pub fn walk_fsm( - fsm_transitions: &FxHashMap<(State, TransitionKey), State>, + fsm_transitions: &HashMap<(State, TransitionKey), State>, _fsm_initial: State, - fsm_finals: &FxHashSet, + fsm_finals: &HashSet, token_transition_keys: &[TransitionKey], start_state: State, full_match: bool, @@ -39,14 +39,14 @@ pub fn walk_fsm( } pub fn state_scan_tokens( - fsm_transitions: &FxHashMap<(State, TransitionKey), State>, + fsm_transitions: &HashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: &FxHashSet, + fsm_finals: &HashSet, vocabulary: &Vocabulary, - vocabulary_transition_keys: &FxHashMap>, + vocabulary_transition_keys: &HashMap>, start_state: State, -) -> FxHashSet<(TokenId, State)> { - let mut res = FxHashSet::default(); +) -> HashSet<(TokenId, State)> { + let mut res = HashSet::default(); for (token, token_ids) in vocabulary.iter() { let token_transition_keys = &vocabulary_transition_keys[token]; @@ -72,7 +72,7 @@ pub fn state_scan_tokens( } pub fn get_token_transition_keys( - alphabet_symbol_mapping: &FxHashMap, + alphabet_symbol_mapping: &HashMap, alphabet_anything_value: TransitionKey, token_str: &str, ) -> Vec { @@ -105,12 +105,12 @@ pub fn get_token_transition_keys( } pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: &FxHashMap, + alphabet_symbol_mapping: &HashMap, alphabet_anything_value: TransitionKey, vocabulary: &Vocabulary, - frozen_tokens: &FxHashSet, -) -> FxHashMap> { - let mut vocab_transition_keys = FxHashMap::default(); + frozen_tokens: &HashSet, +) -> HashMap> { + let mut vocab_transition_keys = HashMap::default(); for item in vocabulary.iter() { let token_str = item.0.clone(); diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index c95ed99f..13156ade 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap; +use rustc_hash::FxHashMap as HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -29,7 +29,7 @@ mod processor; pub struct Vocabulary { // TODO: Option is temp for back compatibility eos_token_id: Option, - tokens: FxHashMap>, + tokens: HashMap>, } impl Vocabulary { @@ -37,7 +37,7 @@ impl Vocabulary { pub fn new(eos_token_id: Option) -> Self { Self { eos_token_id, - tokens: FxHashMap::default(), + tokens: HashMap::default(), } } @@ -102,7 +102,7 @@ impl Vocabulary { } /// Returns all tokens with their token ids in vocabulary - pub fn tokens_to_ids(&self) -> &FxHashMap> { + pub fn tokens_to_ids(&self) -> &HashMap> { &self.tokens } @@ -185,9 +185,9 @@ impl Vocabulary { } impl std::ops::Deref for Vocabulary { - type Target = FxHashMap>; + type Target = HashMap>; - fn deref(&self) -> &FxHashMap> { + fn deref(&self) -> &HashMap> { &self.tokens } } @@ -205,8 +205,8 @@ impl std::fmt::Display for Vocabulary { } } -impl From>> for Vocabulary { - fn from(tokens: FxHashMap>) -> Vocabulary { +impl From>> for Vocabulary { + fn from(tokens: HashMap>) -> Vocabulary { Vocabulary { eos_token_id: None, tokens, @@ -268,7 +268,7 @@ mod tests { #[test] fn new_empty_vocabulary_from_hashmap() { - let map = FxHashMap::default(); + let map = HashMap::default(); let vocabulary = Vocabulary::from(map); assert!(vocabulary.eos_token_id.is_none()); assert!(vocabulary.tokens.is_empty());