Skip to content

Commit

Permalink
Use FxHash* as default Hash*
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 13, 2024
1 parent ead67fb commit ac33a29
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 47 deletions.
50 changes: 25 additions & 25 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use rustc_hash::{FxHashMap, FxHashSet};
use serde_json::Value;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};

#[pyclass(name = "FSMInfo")]
pub struct PyFSMInfo {
#[pyo3(get)]
initial: State,
#[pyo3(get)]
finals: FxHashSet<State>,
finals: HashSet<State>,
#[pyo3(get)]
transitions: FxHashMap<(State, TransitionKey), State>,
transitions: HashMap<(State, TransitionKey), State>,
#[pyo3(get)]
alphabet_anything_value: TransitionKey,
#[pyo3(get)]
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
}

impl From<FSMInfo> for PyFSMInfo {
Expand Down Expand Up @@ -57,10 +57,10 @@ impl PyFSMInfo {
#[new]
fn new(
initial: State,
finals: FxHashSet<State>,
transitions: FxHashMap<(State, TransitionKey), State>,
finals: HashSet<State>,
transitions: HashMap<(State, TransitionKey), State>,
alphabet_anything_value: TransitionKey,
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
) -> Self {
FSMInfo::new(
initial,
Expand All @@ -84,7 +84,7 @@ impl PyIndex {
fsm_info: &PyFSMInfo,
vocabulary: &PyVocabulary,
eos_token_id: u32,
frozen_tokens: FxHashSet<String>,
frozen_tokens: HashSet<String>,
) -> PyResult<Self> {
py.allow_threads(|| {
Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens)
Expand Down Expand Up @@ -135,11 +135,11 @@ impl PyIndex {
self.0.is_final(state)
}

fn final_states(&self) -> FxHashSet<State> {
fn final_states(&self) -> HashSet<State> {
self.0.final_states().clone()
}

fn get_transitions(&self) -> FxHashMap<u32, FxHashMap<u32, u32>> {
fn get_transitions(&self) -> HashMap<u32, HashMap<u32, u32>> {
self.0.transitions().clone()
}

Expand Down Expand Up @@ -171,9 +171,9 @@ pub fn to_regex_py(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyR
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)"
)]
pub fn walk_fsm_py(
fsm_transitions: FxHashMap<(State, TransitionKey), State>,
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: FxHashSet<State>,
fsm_finals: HashSet<State>,
token_transition_keys: Vec<TransitionKey>,
start_state: State,
full_match: bool,
Expand All @@ -193,13 +193,13 @@ pub fn walk_fsm_py(
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)"
)]
pub fn state_scan_tokens_py(
fsm_transitions: FxHashMap<(State, TransitionKey), State>,
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: FxHashSet<State>,
fsm_finals: HashSet<State>,
vocabulary: &PyVocabulary,
vocabulary_transition_keys: FxHashMap<String, Vec<TransitionKey>>,
vocabulary_transition_keys: HashMap<String, Vec<TransitionKey>>,
start_state: State,
) -> PyResult<FxHashSet<(TokenId, State)>> {
) -> PyResult<HashSet<(TokenId, State)>> {
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
Expand All @@ -213,7 +213,7 @@ pub fn state_scan_tokens_py(
#[pyfunction(name = "get_token_transition_keys")]
#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")]
pub fn get_token_transition_keys_py(
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: String,
) -> PyResult<Vec<TransitionKey>> {
Expand All @@ -229,11 +229,11 @@ pub fn get_token_transition_keys_py(
text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)"
)]
pub fn get_vocabulary_transition_keys_py(
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &PyVocabulary,
frozen_tokens: FxHashSet<String>,
) -> PyResult<FxHashMap<String, Vec<TransitionKey>>> {
frozen_tokens: HashSet<String>,
) -> PyResult<HashMap<String, Vec<TransitionKey>>> {
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -248,11 +248,11 @@ pub fn create_fsm_index_end_to_end_py<'py>(
py: Python<'py>,
fsm_info: &PyFSMInfo,
vocabulary: &PyVocabulary,
frozen_tokens: FxHashSet<String>,
frozen_tokens: HashSet<String>,
) -> PyResult<Bound<'py, PyDict>> {
let states_to_token_subsets = PyDict::new_bound(py);
let mut seen: FxHashSet<State> = FxHashSet::default();
let mut next_states: FxHashSet<State> = FxHashSet::from_iter(vec![fsm_info.initial]);
let mut seen: HashSet<State> = HashSet::default();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
Expand Down Expand Up @@ -300,13 +300,13 @@ pub struct PyVocabulary(Vocabulary);
#[pymethods]
impl PyVocabulary {
#[staticmethod]
fn from_dict(map: FxHashMap<Token, Vec<TokenId>>) -> PyVocabulary {
fn from_dict(map: HashMap<Token, Vec<TokenId>>) -> PyVocabulary {
PyVocabulary(Vocabulary::from(map))
}

#[staticmethod]
fn from_dict_with_eos_token_id(
map: FxHashMap<Token, Vec<TokenId>>,
map: HashMap<Token, Vec<TokenId>>,
eos_token_id: TokenId,
) -> PyVocabulary {
let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id));
Expand Down
26 changes: 13 additions & 13 deletions src/regex.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};

pub fn walk_fsm(
fsm_transitions: &FxHashMap<(State, TransitionKey), State>,
fsm_transitions: &HashMap<(State, TransitionKey), State>,
_fsm_initial: State,
fsm_finals: &FxHashSet<State>,
fsm_finals: &HashSet<State>,
token_transition_keys: &[TransitionKey],
start_state: State,
full_match: bool,
Expand Down Expand Up @@ -39,14 +39,14 @@ pub fn walk_fsm(
}

pub fn state_scan_tokens(
fsm_transitions: &FxHashMap<(State, TransitionKey), State>,
fsm_transitions: &HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: &FxHashSet<State>,
fsm_finals: &HashSet<State>,
vocabulary: &Vocabulary,
vocabulary_transition_keys: &FxHashMap<Token, Vec<TransitionKey>>,
vocabulary_transition_keys: &HashMap<Token, Vec<TransitionKey>>,
start_state: State,
) -> FxHashSet<(TokenId, State)> {
let mut res = FxHashSet::default();
) -> HashSet<(TokenId, State)> {
let mut res = HashSet::default();

for (token, token_ids) in vocabulary.iter() {
let token_transition_keys = &vocabulary_transition_keys[token];
Expand All @@ -72,7 +72,7 @@ pub fn state_scan_tokens(
}

pub fn get_token_transition_keys(
alphabet_symbol_mapping: &FxHashMap<String, TransitionKey>,
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: &str,
) -> Vec<TransitionKey> {
Expand Down Expand Up @@ -105,12 +105,12 @@ pub fn get_token_transition_keys(
}

pub fn get_vocabulary_transition_keys(
alphabet_symbol_mapping: &FxHashMap<String, TransitionKey>,
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &Vocabulary,
frozen_tokens: &FxHashSet<String>,
) -> FxHashMap<Token, Vec<TransitionKey>> {
let mut vocab_transition_keys = FxHashMap::default();
frozen_tokens: &HashSet<String>,
) -> HashMap<Token, Vec<TransitionKey>> {
let mut vocab_transition_keys = HashMap::default();

for item in vocabulary.iter() {
let token_str = item.0.clone();
Expand Down
18 changes: 9 additions & 9 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_hash::FxHashMap;
use rustc_hash::FxHashMap as HashMap;

use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};
Expand Down Expand Up @@ -29,15 +29,15 @@ mod processor;
pub struct Vocabulary {
// TODO: Option is temp for back compatibility
eos_token_id: Option<TokenId>,
tokens: FxHashMap<Token, Vec<TokenId>>,
tokens: HashMap<Token, Vec<TokenId>>,
}

impl Vocabulary {
/// Creates an empty vocabulary.
pub fn new(eos_token_id: Option<TokenId>) -> Self {
Self {
eos_token_id,
tokens: FxHashMap::default(),
tokens: HashMap::default(),
}
}

Expand Down Expand Up @@ -102,7 +102,7 @@ impl Vocabulary {
}

/// Returns all tokens with their token ids in vocabulary
pub fn tokens_to_ids(&self) -> &FxHashMap<Token, Vec<TokenId>> {
pub fn tokens_to_ids(&self) -> &HashMap<Token, Vec<TokenId>> {
&self.tokens
}

Expand Down Expand Up @@ -185,9 +185,9 @@ impl Vocabulary {
}

impl std::ops::Deref for Vocabulary {
type Target = FxHashMap<Token, Vec<TokenId>>;
type Target = HashMap<Token, Vec<TokenId>>;

fn deref(&self) -> &FxHashMap<Token, Vec<TokenId>> {
fn deref(&self) -> &HashMap<Token, Vec<TokenId>> {
&self.tokens
}
}
Expand All @@ -205,8 +205,8 @@ impl std::fmt::Display for Vocabulary {
}
}

impl From<FxHashMap<Token, Vec<TokenId>>> for Vocabulary {
fn from(tokens: FxHashMap<Token, Vec<TokenId>>) -> Vocabulary {
impl From<HashMap<Token, Vec<TokenId>>> for Vocabulary {
fn from(tokens: HashMap<Token, Vec<TokenId>>) -> Vocabulary {
Vocabulary {
eos_token_id: None,
tokens,
Expand Down Expand Up @@ -268,7 +268,7 @@ mod tests {

#[test]
fn new_empty_vocabulary_from_hashmap() {
let map = FxHashMap::default();
let map = HashMap::default();
let vocabulary = Vocabulary::from(map);
assert!(vocabulary.eos_token_id.is_none());
assert!(vocabulary.tokens.is_empty());
Expand Down

0 comments on commit ac33a29

Please sign in to comment.