diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index 4a913d40..a29b42c9 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -64,3 +64,56 @@ STRING_INNER: str TIME: str UUID: str WHITESPACE: str + +class Vocabulary: + def __init__(self): + """ + Creates an empty vocabulary. + """ + ... + @staticmethod + def from_dict(map: Dict[str, List[int]]) -> "Vocabulary": + """ + Creates a vocabulary from a dictionary of tokens to token ids. + """ + ... + def __contains__(self, token: str) -> bool: + """ + Checks if the vocabulary contains the token. + """ + ... + def __getitem__(self, token: str) -> List[int]: + """ + Gets the IDs of the token. + """ + ... + def __len__(self) -> int: + """ + Gets the number of tokens in the vocabulary. + """ + ... + def __repr__(self) -> str: + """ + Gets the debug string representation of the vocabulary. + """ + ... + def __str__(self) -> str: + """ + Gets the string representation of the vocabulary. + """ + ... + def insert(self, token: str, id: int): + """ + Inserts a token to the vocabulary. + """ + ... + def extend(self, tokens_and_ids: List[Tuple[str, List[int]]]): + """ + Extends the vocabulary with a list of tokens and their IDs. + """ + ... + def tokens(self) -> List[str]: + """ + Gets the list of tokens in the vocabulary. + """ + ... diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 5017b1de..3338f38c 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -197,6 +197,54 @@ pub fn create_fsm_index_end_to_end_py<'py>( Ok(states_to_token_subsets) } +#[pyclass(name = "Vocabulary")] +pub struct PyVocabulary(Vocabulary); + +#[pymethods] +impl PyVocabulary { + #[new] + fn new() -> PyVocabulary { + PyVocabulary(Vocabulary::new()) + } + + #[staticmethod] + fn from_dict(map: HashMap>) -> PyVocabulary { + PyVocabulary(Vocabulary::from_iter(map)) + } + + fn __contains__(&self, token: Token) -> bool { + self.0.contains_key(&token) + } + + fn __getitem__(&self, token: Token) -> Vec { + self.0.get(&token).cloned().unwrap_or_default() + } + + fn __len__(&self) -> usize { + self.0.len() + } + + fn __repr__(&self) -> String { + format!("{:#?}", self.0) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } + + fn insert(&mut self, token: Token, id: TokenId) { + self.0.insert_in_place(token, id); + } + + fn extend(&mut self, tokens_and_ids: Vec<(Token, Vec)>) { + self.0.extend_in_place(tokens_and_ids); + } + + fn tokens(&self) -> Vec { + self.0.keys().cloned().collect() + } +} + #[pymodule] fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; @@ -222,5 +270,7 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(build_regex_from_schema_py, m)?)?; m.add_function(wrap_pyfunction!(to_regex_py, m)?)?; + m.add_class::()?; + Ok(()) } diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py new file mode 100644 index 00000000..ddfbdf7e --- /dev/null +++ b/tests/fsm/test_vocabulary.py @@ -0,0 +1,54 @@ +from outlines_core.fsm.outlines_core_rs import Vocabulary + + +def test_vocabulary_insert(): + """ + Tests `vocabulary.insert(token, id)`. + """ + + vocabulary = Vocabulary() + + vocabulary.insert("foo", 0) + vocabulary.insert("bar", 1) + vocabulary.insert("baz", 2) + vocabulary.insert("foo", 3) + + assert "foo" in vocabulary + assert "bar" in vocabulary + assert "baz" in vocabulary + + assert len(vocabulary) == 3 + + assert sorted(vocabulary.tokens()) == sorted(["foo", "bar", "baz"]) + + assert vocabulary["foo"] == [0, 3] + assert vocabulary["bar"] == [1] + assert vocabulary["baz"] == [2] + + +def test_vocabulary_extend(): + """ + Tests `vocabulary.extend(tokens_and_ids)`. + """ + + vocabulary = Vocabulary() + + vocabulary.extend( + [ + ("foo", [0, 3]), + ("bar", [1]), + ("baz", [2]), + ] + ) + + assert "foo" in vocabulary + assert "bar" in vocabulary + assert "baz" in vocabulary + + assert len(vocabulary) == 3 + + assert sorted(vocabulary.tokens()) == sorted(["foo", "bar", "baz"]) + + assert vocabulary["foo"] == [0, 3] + assert vocabulary["bar"] == [1] + assert vocabulary["baz"] == [2]