Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface to Guide object to update masks in place, and associated kernels. #183

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ hf-hub = "=0.3.2"
tokenizers = { version = "=0.20.3", features = ["http"] }
rustc-hash = "2.1.0"
regex-automata = "0.4.9"

[features]
python-bindings = ["pyo3", "serde-pyobject"]

Expand Down
2 changes: 2 additions & 0 deletions python/outlines_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from .outlines_core_rs import Guide, Index, Vocabulary

from .kernels import torch

try:
__version__ = version("outlines_core")
except PackageNotFoundError:
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions python/outlines_core/kernels/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from outlines_core import Guide

try:
import numpy as np
import numba
except ImportError as e:
missing_dep = "numba" if "numba" in str(e) else "numpy"
raise ImportError(
f"To use the kernels in `outlines_core.kernels.numpy`, `{missing_dep}` must be installed."
) from e

def allocate_token_bitmask(vocab_size: int) -> np.ndarray:
return np.full(
(1, (vocab_size + 31) // 32),
-1,
dtype=np.int32,
)

@numba.njit
def _apply_token_bitmask_kernel(logits, mask):
mask_len = mask.shape[1]
cutoff = 32 * mask_len

if logits.shape[1] > cutoff:
logits[:, cutoff:] = -np.inf
logits = logits[:, :cutoff]

n_rows, n_cols = logits.shape

for i in range(n_rows):
for mi in range(mask_len):
mval = mask[i, mi]
base = mi * 32
for bit in range(32):
j = base + bit

if j >= n_cols:
break

if ((mval >> bit) & 1) == 0:
logits[i, j] = -np.inf

def apply_token_bitmask_inplace(logits: np.ndarray, mask: np.ndarray) -> None:
if logits.ndim == 1:
logits = np.expand_dims(logits, axis=0)
if mask.ndim == 1:
mask = np.expand_dims(mask, axis=0)

if mask.dtype != np.int32:
raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.")
elif mask.ndim != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.")
elif logits.ndim != 2:
raise ValueError(f"Invalid logits dimensions: Expected a 2D array, but got {mask.ndim}D.")
elif mask.shape[0] != logits.shape[0]:
raise ValueError(
f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})."
)
_apply_token_bitmask_kernel(logits, mask)

def fill_next_token_bitmask(
guide: Guide, mask: np.ndarray
) -> None:
# timing: all checks take roughly 0.5 microseconds.
if mask.dtype != np.int32:
raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.")
elif mask.ndim != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.")
elif mask.shape[0] != 1:
raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.")
elif not mask.flags["C_CONTIGUOUS"]:
raise ValueError("Mask array must be contiguous in memory. Use `np.ascontiguousarray(mask)`.")

return guide.write_mask_into(
mask.ctypes.data,
mask.size,
mask.itemsize
)
106 changes: 106 additions & 0 deletions python/outlines_core/kernels/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Provides kernels for masking a logits tensor,
# using the write_into_mask method on the `Guide` object and the bitmask
# which it writes into a tensor.
#
# Kernels inspired by https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/torch.py
from outlines_core import Guide

try:
import torch
except Exception as e:
raise ModuleNotFoundError(
"`torch` is required to use the kernels from"
"`outlines_core.kernels.torch. You can install "
"`torch` using the official guide at https://pytorch.org/get-started/locally/"
)

def allocate_token_bitmask(vocab_size: int) -> torch.Tensor:
"""
Allocate a token bitmask for use with the `Guide.write_into_mask` API and logits masking,
based on the vocab_size.

Arguments:
- vocab_size: int
Returns:
- torch.Tensor
"""
return torch.full(
(1, (vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)

# This takes roughly 23 microseconds per run, with a bitmask of
# 1k allowed tokens, and 128k logits tensor.
# Also compiles to one graph with no graph breaks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to access the CUDA code generated by PyTorch? It might be over-engineering for now, but I'd like to get an idea of how efficient that code is and if there are gains to be had there in the future.

Copy link
Contributor Author

@unaidedelf8777 unaidedelf8777 Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems possible - just have to find the temp directory where it dumps it: https://pytorch.org/tutorials/intermediate/inductor_debug_cpu.html

# Performance characteristics are:
# - Larger the logits array ( length ), the longer the kernel takes
# - Constant time for mask i.e. number of allowed tokens does not effect execution
# time
@torch.compile(dynamic=True)
def _apply_token_bitmask_kernel(logits, mask):
# This should not modify, so long as the mask
# is allocated at the correct size
logits = torch.where(
torch.ge(
torch.arange(
logits.shape[1],
device=logits.device
),
32 * mask.shape[1]
),
-torch.inf,
logits
)

# Unpack each 32-bit mask value into 32 individual bits (as booleans)
bit_masks = (
(torch.bitwise_right_shift(
mask.unsqueeze(-1),
torch.arange(
32,
device=mask.device,
dtype=torch.int32
)) & 1
)
.bool()
.view(mask.shape[0], -1)
.narrow(1, 0, logits.shape[1])
)

# Possibly trim mask to match the logits width
bit_masks = bit_masks[:, :logits.shape[1]]
logits.masked_fill_(~bit_masks, -torch.inf)


def apply_token_bitmask_inplace(logits: torch.Tensor, mask: torch.Tensor) -> None:
if mask.dtype != torch.int32:
raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.")
elif mask.dim() != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.")
elif logits.dim() != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.")
elif mask.shape[0] != logits.shape[0]:
raise ValueError(
f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})."
)
_apply_token_bitmask_kernel(logits, mask)

def fill_next_token_bitmask(guide: Guide, mask: torch.Tensor) -> None:
if mask.dtype != torch.int32:
raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.")
elif mask.dim() != 2:
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.")
elif mask.shape[0] != 1:
raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.")
elif not mask.is_contiguous():
raise ValueError("Mask array must be contiguous in memory. Use `mask.contiguous()` to fix it.")
elif mask.device != torch.device("cpu"):
raise ValueError(f"Invalid device: Expected `mask` tensor to be on device `cpu`, but found it on `{mask.device}`.")

guide.write_mask_into(
mask.data_ptr(),
mask.numel(),
mask.element_size()
)
9 changes: 9 additions & 0 deletions python/outlines_core/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ class Guide:
def is_finished(self) -> bool:
"""Checks if the automaton is in a final state."""
...
def write_mask_into(self, data_ptr: int, numel: int, element_size: int) -> None:
"""Write the mask of allowed tokens into the memory specified by data_ptr.
Size of the memory to be written to is indicated by `numel`, and `element_size`.
`element_size` must be 4.

`data_ptr` should be the data ptr to a `torch.tensor`, or `np.ndarray`, or other
continuous memory array"""
...

def __repr__(self) -> str:
"""Gets the debug string representation of the guide."""
...
Expand Down
12 changes: 12 additions & 0 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub struct Index {
transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
/// The token ID reserved for the "end-of-sequence" token.
eos_token_id: TokenId,
/// The size of th vocabulary used to build the index
vocab_size: usize
}
/// The `Index` structure is designed to efficiently map tokens from a given vocabulary
/// to state transitions within a finite-state automaton.
Expand Down Expand Up @@ -99,6 +101,7 @@ pub struct Index {
impl Index {
/// Builds an `Index` from regular expression and vocabulary tokens.
pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let vocab_size = vocabulary.len();
let eos_token_id = vocabulary.eos_token_id();
let dfa = DFA::new(regex).map_err(Box::new)?;
let start_state = match dfa.universal_start_state(Anchored::Yes) {
Expand Down Expand Up @@ -160,6 +163,7 @@ impl Index {
final_states,
transitions,
eos_token_id,
vocab_size
})
}

Expand Down Expand Up @@ -190,13 +194,21 @@ impl Index {
.map(|res| res.keys().cloned().collect())
}

pub fn allowed_tokens_iter(&self, state: &StateId) -> Option<impl Iterator<Item = &TokenId>> {
self.transitions.get(state).map(|map| map.keys())
}

/// Returns transition state for a given state and token id or `None` otherwise.
pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
if token_id == &self.eos_token_id {
return None;
}
Some(*self.transitions.get(state)?.get(token_id)?)
}

pub fn vocab_size(&self) -> usize {
self.vocab_size
}
}

impl std::fmt::Display for Index {
Expand Down
39 changes: 39 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,45 @@ impl PyGuide {
self.index.is_final_state(self.state)
}

fn write_mask_into(
&self,
data_ptr: usize,
numel: usize,
element_size: usize
) -> PyResult<()> {

if element_size != 4 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"The data type of the Tensor must be `torch.int32`",
));
} else if data_ptr == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"data_ptr cannot be null or nullptr",
));
} else if data_ptr % 4 != 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"data_ptr is not aligned",
));
} else if ((self.index.0.vocab_size() +31) / 32) != numel * 4 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid buffer size. Please ensure that the length of the mask tensor is equal to ((vocab_size + 31) / 32), and in `torch.int32` precision.",
));
}
unsafe {
std::ptr::write_bytes(data_ptr as *mut u8, 0, numel * 4);
}
if let Some(tokens) = self.index.0.allowed_tokens_iter(&self.state) {
let slice = unsafe { std::slice::from_raw_parts_mut(data_ptr as *mut u32, numel) };
for &token in tokens {
let bucket = (token as usize) / 32;
if bucket < slice.len() {
slice[bucket] |= 1 << ((token as usize) % 32);
}
}
}
Ok(())
}

fn __repr__(&self) -> String {
format!(
"Guide object with the state={:#?} and {:#?}",
Expand Down
5 changes: 5 additions & 0 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ impl Vocabulary {
self.tokens.remove(&token);
}

pub fn len(&self) -> usize {
// +1 for `eos_token_id`
self.tokens.len() + 1
}

/// Filters out `Prepend` kind of tokenizer's normalizers.
fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) {
// Main concern is prepend normalizers, for example https://github.com/google/sentencepiece
Expand Down
2 changes: 1 addition & 1 deletion tests/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,4 @@ def test_equality(index):
# progress one of the guides, confirm different state == different guide
guide1.advance(guide1.get_tokens()[-1])
assert guide1 != guide2
assert guide3 == guide2
assert guide3 == guide2
2 changes: 1 addition & 1 deletion tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def test_deepcopy(index):
is_deleted = not any(id(o) == index2_id for o in gc.get_objects())
assert is_deleted

assert copy_index2 == index
assert copy_index2 == index
Loading