From c01e1b81058a661da7bc7786065f7716514e8ec2 Mon Sep 17 00:00:00 2001 From: Nathan Hoos Date: Tue, 25 Feb 2025 00:39:51 +0000 Subject: [PATCH 1/4] Add write_into_mask method on Guide. Add kernels for torch and numpy --- Cargo.toml | 1 - python/outlines_core/__init__.py | 2 + python/outlines_core/kernels/__init__.py | 0 python/outlines_core/kernels/numpy.py | 78 ++++++++++++++++ python/outlines_core/kernels/torch.py | 106 ++++++++++++++++++++++ python/outlines_core/outlines_core_rs.pyi | 9 ++ src/index.rs | 12 +++ src/python_bindings/mod.rs | 39 ++++++++ src/vocabulary/mod.rs | 5 + tests/test_guide.py | 2 +- tests/test_index.py | 2 +- 11 files changed, 253 insertions(+), 3 deletions(-) create mode 100644 python/outlines_core/kernels/__init__.py create mode 100644 python/outlines_core/kernels/numpy.py create mode 100644 python/outlines_core/kernels/torch.py diff --git a/Cargo.toml b/Cargo.toml index 0ff68646..84e555f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ hf-hub = "=0.3.2" tokenizers = { version = "=0.20.3", features = ["http"] } rustc-hash = "2.1.0" regex-automata = "0.4.9" - [features] python-bindings = ["pyo3", "serde-pyobject"] diff --git a/python/outlines_core/__init__.py b/python/outlines_core/__init__.py index 951acdfb..793696d9 100644 --- a/python/outlines_core/__init__.py +++ b/python/outlines_core/__init__.py @@ -3,6 +3,8 @@ from .outlines_core_rs import Guide, Index, Vocabulary +from .kernels import torch + try: __version__ = version("outlines_core") except PackageNotFoundError: diff --git a/python/outlines_core/kernels/__init__.py b/python/outlines_core/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/outlines_core/kernels/numpy.py b/python/outlines_core/kernels/numpy.py new file mode 100644 index 00000000..90efb5df --- /dev/null +++ b/python/outlines_core/kernels/numpy.py @@ -0,0 +1,78 @@ +from outlines_core import Guide + +try: + import numpy as np + import numba +except ImportError as e: + missing_dep = "numba" if "numba" in str(e) else "numpy" + raise ImportError( + f"To use the kernels in `outlines_core.kernels.numpy`, `{missing_dep}` must be installed." + ) from e + +def allocate_token_bitmask(vocab_size: int) -> np.ndarray: + return np.full( + (1, (vocab_size + 31) // 32), + -1, + dtype=np.int32, + ) + +@numba.njit +def _apply_token_bitmask_kernel(logits, mask): + mask_len = mask.shape[1] + cutoff = 32 * mask_len + + if logits.shape[1] > cutoff: + logits[:, cutoff:] = -np.inf + logits = logits[:, :cutoff] + + n_rows, n_cols = logits.shape + + for i in range(n_rows): + for mi in range(mask_len): + mval = mask[i, mi] + base = mi * 32 + for bit in range(32): + j = base + bit + + if j >= n_cols: + break + + if ((mval >> bit) & 1) == 0: + logits[i, j] = -np.inf + +def apply_token_bitmask_inplace(logits: np.ndarray, mask: np.ndarray) -> None: + if logits.ndim == 1: + logits = np.expand_dims(logits, axis=0) + if mask.ndim == 1: + mask = np.expand_dims(mask, axis=0) + + if mask.dtype != np.int32: + raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.") + elif mask.ndim != 2: + raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.") + elif logits.ndim != 2: + raise ValueError(f"Invalid logits dimensions: Expected a 2D array, but got {mask.ndim}D.") + elif mask.shape[0] != logits.shape[0]: + raise ValueError( + f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})." + ) + _apply_token_bitmask_kernel(logits, mask) + +def fill_next_token_bitmask( + guide: Guide, mask: np.ndarray +) -> None: + # timing: all checks take roughly 0.5 microseconds. + if mask.dtype != np.int32: + raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.") + elif mask.ndim != 2: + raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.") + elif mask.shape[0] != 1: + raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.") + elif not mask.flags["C_CONTIGUOUS"]: + raise ValueError("Mask array must be contiguous in memory. Use `np.ascontiguousarray(mask)`.") + + return guide.write_mask_into( + mask.ctypes.data, + mask.size, + mask.itemsize + ) \ No newline at end of file diff --git a/python/outlines_core/kernels/torch.py b/python/outlines_core/kernels/torch.py new file mode 100644 index 00000000..9c7befd9 --- /dev/null +++ b/python/outlines_core/kernels/torch.py @@ -0,0 +1,106 @@ +# Provides kernels for masking a logits tensor, +# using the write_into_mask method on the `Guide` object and the bitmask +# which it writes into a tensor. +# +# Kernels inspired by https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/torch.py +from outlines_core import Guide + +try: + import torch +except Exception as e: + raise ModuleNotFoundError( + "`torch` is required to use the kernels from" + "`outlines_core.kernels.torch. You can install " + "`torch` using the official guide at https://pytorch.org/get-started/locally/" + ) + +def allocate_token_bitmask(vocab_size: int) -> torch.Tensor: + """ + Allocate a token bitmask for use with the `Guide.write_into_mask` API and logits masking, + based on the vocab_size. + + Arguments: + - vocab_size: int + Returns: + - torch.Tensor + """ + return torch.full( + (1, (vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + +# This takes roughly 23 microseconds per run, with a bitmask of +# 1k allowed tokens, and 128k logits tensor. +# Also compiles to one graph with no graph breaks +# Performance characteristics are: +# - Larger the logits array ( length ), the longer the kernel takes +# - Constant time for mask i.e. number of allowed tokens does not effect execution +# time +@torch.compile(dynamic=True) +def _apply_token_bitmask_kernel(logits, mask): + # This should not modify, so long as the mask + # is allocated at the correct size + logits = torch.where( + torch.ge( + torch.arange( + logits.shape[1], + device=logits.device + ), + 32 * mask.shape[1] + ), + -torch.inf, + logits + ) + + # Unpack each 32-bit mask value into 32 individual bits (as booleans) + bit_masks = ( + (torch.bitwise_right_shift( + mask.unsqueeze(-1), + torch.arange( + 32, + device=mask.device, + dtype=torch.int32 + )) & 1 + ) + .bool() + .view(mask.shape[0], -1) + .narrow(1, 0, logits.shape[1]) + ) + + # Possibly trim mask to match the logits width + bit_masks = bit_masks[:, :logits.shape[1]] + logits.masked_fill_(~bit_masks, -torch.inf) + + +def apply_token_bitmask_inplace(logits: torch.Tensor, mask: torch.Tensor) -> None: + if mask.dtype != torch.int32: + raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.") + elif mask.dim() != 2: + raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.") + elif logits.dim() != 2: + raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.") + elif mask.shape[0] != logits.shape[0]: + raise ValueError( + f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})." + ) + _apply_token_bitmask_kernel(logits, mask) + +def fill_next_token_bitmask(guide: Guide, mask: torch.Tensor) -> None: + if mask.dtype != torch.int32: + raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.") + elif mask.dim() != 2: + raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.") + elif mask.shape[0] != 1: + raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.") + elif not mask.is_contiguous(): + raise ValueError("Mask array must be contiguous in memory. Use `mask.contiguous()` to fix it.") + elif mask.device != torch.device("cpu"): + raise ValueError(f"Invalid device: Expected `mask` tensor to be on device `cpu`, but found it on `{mask.device}`.") + + guide.write_mask_into( + mask.data_ptr(), + mask.numel(), + mask.element_size() + ) diff --git a/python/outlines_core/outlines_core_rs.pyi b/python/outlines_core/outlines_core_rs.pyi index eb578b14..7f96bf6d 100644 --- a/python/outlines_core/outlines_core_rs.pyi +++ b/python/outlines_core/outlines_core_rs.pyi @@ -35,6 +35,15 @@ class Guide: def is_finished(self) -> bool: """Checks if the automaton is in a final state.""" ... + def write_mask_into(self, data_ptr: int, numel: int, element_size: int) -> None: + """Write the mask of allowed tokens into the memory specified by data_ptr. + Size of the memory to be written to is indicated by `numel`, and `element_size`. + `element_size` must be 4. + + `data_ptr` should be the data ptr to a `torch.tensor`, or `np.ndarray`, or other + continuous memory array""" + ... + def __repr__(self) -> str: """Gets the debug string representation of the guide.""" ... diff --git a/src/index.rs b/src/index.rs index 0b38e25e..75c57943 100644 --- a/src/index.rs +++ b/src/index.rs @@ -55,6 +55,8 @@ pub struct Index { transitions: HashMap>, /// The token ID reserved for the "end-of-sequence" token. eos_token_id: TokenId, + /// The size of th vocabulary used to build the index + vocab_size: usize } /// The `Index` structure is designed to efficiently map tokens from a given vocabulary /// to state transitions within a finite-state automaton. @@ -99,6 +101,7 @@ pub struct Index { impl Index { /// Builds an `Index` from regular expression and vocabulary tokens. pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result { + let vocab_size = vocabulary.len(); let eos_token_id = vocabulary.eos_token_id(); let dfa = DFA::new(regex).map_err(Box::new)?; let start_state = match dfa.universal_start_state(Anchored::Yes) { @@ -160,6 +163,7 @@ impl Index { final_states, transitions, eos_token_id, + vocab_size }) } @@ -190,6 +194,10 @@ impl Index { .map(|res| res.keys().cloned().collect()) } + pub fn allowed_tokens_iter(&self, state: &StateId) -> Option> { + self.transitions.get(state).map(|map| map.keys()) + } + /// Returns transition state for a given state and token id or `None` otherwise. pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option { if token_id == &self.eos_token_id { @@ -197,6 +205,10 @@ impl Index { } Some(*self.transitions.get(state)?.get(token_id)?) } + + pub fn vocab_size(&self) -> usize { + self.vocab_size + } } impl std::fmt::Display for Index { diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 2acab368..bacea078 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -70,6 +70,45 @@ impl PyGuide { self.index.is_final_state(self.state) } + fn write_mask_into( + &self, + data_ptr: usize, + numel: usize, + element_size: usize + ) -> PyResult<()> { + + if element_size != 4 { + return Err(PyErr::new::( + "The data type of the Tensor must be `torch.int32`", + )); + } else if data_ptr == 0 { + return Err(PyErr::new::( + "data_ptr cannot be null or nullptr", + )); + } else if data_ptr % 4 != 0 { + return Err(PyErr::new::( + "data_ptr is not aligned", + )); + } else if ((self.index.0.vocab_size() +31) / 32) != numel * 4 { + return Err(PyErr::new::( + "Invalid buffer size. Please ensure that the length of the mask tensor is equal to ((vocab_size + 31) / 32), and in `torch.int32` precision.", + )); + } + unsafe { + std::ptr::write_bytes(data_ptr as *mut u8, 0, numel * 4); + } + if let Some(tokens) = self.index.0.allowed_tokens_iter(&self.state) { + let slice = unsafe { std::slice::from_raw_parts_mut(data_ptr as *mut u32, numel) }; + for &token in tokens { + let bucket = (token as usize) / 32; + if bucket < slice.len() { + slice[bucket] |= 1 << ((token as usize) % 32); + } + } + } + Ok(()) + } + fn __repr__(&self) -> String { format!( "Guide object with the state={:#?} and {:#?}", diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 43bb6e7a..8bb09924 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -148,6 +148,11 @@ impl Vocabulary { self.tokens.remove(&token); } + pub fn len(&self) -> usize { + // +1 for `eos_token_id` + self.tokens.len() + 1 + } + /// Filters out `Prepend` kind of tokenizer's normalizers. fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) { // Main concern is prepend normalizers, for example https://github.com/google/sentencepiece diff --git a/tests/test_guide.py b/tests/test_guide.py index fa77a4ed..76d6f8d3 100644 --- a/tests/test_guide.py +++ b/tests/test_guide.py @@ -147,4 +147,4 @@ def test_equality(index): # progress one of the guides, confirm different state == different guide guide1.advance(guide1.get_tokens()[-1]) assert guide1 != guide2 - assert guide3 == guide2 + assert guide3 == guide2 \ No newline at end of file diff --git a/tests/test_index.py b/tests/test_index.py index 208a7cae..01129aa8 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -62,4 +62,4 @@ def test_deepcopy(index): is_deleted = not any(id(o) == index2_id for o in gc.get_objects()) assert is_deleted - assert copy_index2 == index + assert copy_index2 == index \ No newline at end of file From 31a48cca8e30903b30672fe1d00f78e0f66d3b33 Mon Sep 17 00:00:00 2001 From: Nathan Hoos Date: Tue, 25 Feb 2025 10:47:11 -0600 Subject: [PATCH 2/4] Update src/index.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rémi Louf --- src/index.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/index.rs b/src/index.rs index 75c57943..1411464c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -55,7 +55,7 @@ pub struct Index { transitions: HashMap>, /// The token ID reserved for the "end-of-sequence" token. eos_token_id: TokenId, - /// The size of th vocabulary used to build the index + /// The size of the vocabulary used to build the index vocab_size: usize } /// The `Index` structure is designed to efficiently map tokens from a given vocabulary From d874306b61404cb462e53005b17033880a26bd70 Mon Sep 17 00:00:00 2001 From: Nathan Hoos Date: Tue, 25 Feb 2025 10:47:22 -0600 Subject: [PATCH 3/4] Update python/outlines_core/outlines_core_rs.pyi MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rémi Louf --- python/outlines_core/outlines_core_rs.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/outlines_core/outlines_core_rs.pyi b/python/outlines_core/outlines_core_rs.pyi index 7f96bf6d..9ba47049 100644 --- a/python/outlines_core/outlines_core_rs.pyi +++ b/python/outlines_core/outlines_core_rs.pyi @@ -41,7 +41,7 @@ class Guide: `element_size` must be 4. `data_ptr` should be the data ptr to a `torch.tensor`, or `np.ndarray`, or other - continuous memory array""" + contiguous memory array""" ... def __repr__(self) -> str: From bb854f6b33cc29985ac0d540ead2f7854604e316 Mon Sep 17 00:00:00 2001 From: Nathan Hoos Date: Tue, 25 Feb 2025 10:47:36 -0600 Subject: [PATCH 4/4] Update python/outlines_core/kernels/torch.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rémi Louf --- python/outlines_core/kernels/torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/outlines_core/kernels/torch.py b/python/outlines_core/kernels/torch.py index 9c7befd9..cd2cc226 100644 --- a/python/outlines_core/kernels/torch.py +++ b/python/outlines_core/kernels/torch.py @@ -36,7 +36,7 @@ def allocate_token_bitmask(vocab_size: int) -> torch.Tensor: # Also compiles to one graph with no graph breaks # Performance characteristics are: # - Larger the logits array ( length ), the longer the kernel takes -# - Constant time for mask i.e. number of allowed tokens does not effect execution +# - Constant time for mask i.e. number of allowed tokens does not affect execution # time @torch.compile(dynamic=True) def _apply_token_bitmask_kernel(logits, mask):