Skip to content

Commit

Permalink
Introduce Python wrapper for Vocabulary
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Sep 26, 2024
1 parent 2d16373 commit f6768c1
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 0 deletions.
53 changes: 53 additions & 0 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,56 @@ STRING_INNER: str
TIME: str
UUID: str
WHITESPACE: str

class Vocabulary:
def __init__(self):
"""
Creates an empty vocabulary.
"""
...
@staticmethod
def from_dict(map: Dict[str, List[int]]) -> "Vocabulary":
"""
Creates a vocabulary from a dictionary of tokens to token ids.
"""
...
def __contains__(self, token: str) -> bool:
"""
Checks if the vocabulary contains the token.
"""
...
def __getitem__(self, token: str) -> List[int]:
"""
Gets the IDs of the token.
"""
...
def __len__(self) -> int:
"""
Gets the number of tokens in the vocabulary.
"""
...
def __repr__(self) -> str:
"""
Gets the debug string representation of the vocabulary.
"""
...
def __str__(self) -> str:
"""
Gets the string representation of the vocabulary.
"""
...
def insert(self, token: str, id: int):
"""
Inserts a token to the vocabulary.
"""
...
def extend(self, tokens_and_ids: List[Tuple[str, List[int]]]):
"""
Extends the vocabulary with a list of tokens and their IDs.
"""
...
def tokens(self) -> List[str]:
"""
Gets the list of tokens in the vocabulary.
"""
...
50 changes: 50 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,54 @@ pub fn create_fsm_index_end_to_end_py<'py>(
Ok(states_to_token_subsets)
}

#[pyclass(name = "Vocabulary")]
pub struct PyVocabulary(Vocabulary);

#[pymethods]
impl PyVocabulary {
#[new]
fn new() -> PyVocabulary {
PyVocabulary(Vocabulary::new())
}

#[staticmethod]
fn from_dict(map: HashMap<Token, Vec<TokenId>>) -> PyVocabulary {
PyVocabulary(Vocabulary::from_iter(map))
}

fn __contains__(&self, token: Token) -> bool {
self.0.contains_key(&token)
}

fn __getitem__(&self, token: Token) -> Vec<TokenId> {
self.0.get(&token).cloned().unwrap_or_default()
}

fn __len__(&self) -> usize {
self.0.len()
}

fn __repr__(&self) -> String {
format!("{:#?}", self.0)
}

fn __str__(&self) -> String {
format!("{}", self.0)
}

fn insert(&mut self, token: Token, id: TokenId) {
self.0.insert_in_place(token, id);
}

fn extend(&mut self, tokens_and_ids: Vec<(Token, Vec<TokenId>)>) {
self.0.extend_in_place(tokens_and_ids);
}

fn tokens(&self) -> Vec<Token> {
self.0.keys().cloned().collect()
}
}

#[pymodule]
fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?;
Expand All @@ -222,5 +270,7 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(build_regex_from_schema_py, m)?)?;
m.add_function(wrap_pyfunction!(to_regex_py, m)?)?;

m.add_class::<PyVocabulary>()?;

Ok(())
}
54 changes: 54 additions & 0 deletions tests/fsm/test_vocabulary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from outlines_core.fsm.outlines_core_rs import Vocabulary


def test_vocabulary_insert():
"""
Tests `vocabulary.insert(token, id)`.
"""

vocabulary = Vocabulary()

vocabulary.insert("foo", 0)
vocabulary.insert("bar", 1)
vocabulary.insert("baz", 2)
vocabulary.insert("foo", 3)

assert "foo" in vocabulary
assert "bar" in vocabulary
assert "baz" in vocabulary

assert len(vocabulary) == 3

assert sorted(vocabulary.tokens()) == sorted(["foo", "bar", "baz"])

assert vocabulary["foo"] == [0, 3]
assert vocabulary["bar"] == [1]
assert vocabulary["baz"] == [2]


def test_vocabulary_extend():
"""
Tests `vocabulary.extend(tokens_and_ids)`.
"""

vocabulary = Vocabulary()

vocabulary.extend(
[
("foo", [0, 3]),
("bar", [1]),
("baz", [2]),
]
)

assert "foo" in vocabulary
assert "bar" in vocabulary
assert "baz" in vocabulary

assert len(vocabulary) == 3

assert sorted(vocabulary.tokens()) == sorted(["foo", "bar", "baz"])

assert vocabulary["foo"] == [0, 3]
assert vocabulary["bar"] == [1]
assert vocabulary["baz"] == [2]

0 comments on commit f6768c1

Please sign in to comment.