-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add interface to Guide
object to update masks in place, and associated kernels.
#183
Open
unaidedelf8777
wants to merge
4
commits into
dottxt-ai:main
Choose a base branch
from
unaidedelf8777:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
c01e1b8
Add write_into_mask method on Guide. Add kernels for torch and numpy
unaidedelf8777 31a48cc
Update src/index.rs
unaidedelf8777 d874306
Update python/outlines_core/outlines_core_rs.pyi
unaidedelf8777 bb854f6
Update python/outlines_core/kernels/torch.py
unaidedelf8777 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from outlines_core import Guide | ||
|
||
try: | ||
import numpy as np | ||
import numba | ||
except ImportError as e: | ||
missing_dep = "numba" if "numba" in str(e) else "numpy" | ||
raise ImportError( | ||
f"To use the kernels in `outlines_core.kernels.numpy`, `{missing_dep}` must be installed." | ||
) from e | ||
|
||
def allocate_token_bitmask(vocab_size: int) -> np.ndarray: | ||
return np.full( | ||
(1, (vocab_size + 31) // 32), | ||
-1, | ||
dtype=np.int32, | ||
) | ||
|
||
@numba.njit | ||
def _apply_token_bitmask_kernel(logits, mask): | ||
mask_len = mask.shape[1] | ||
cutoff = 32 * mask_len | ||
|
||
if logits.shape[1] > cutoff: | ||
logits[:, cutoff:] = -np.inf | ||
logits = logits[:, :cutoff] | ||
|
||
n_rows, n_cols = logits.shape | ||
|
||
for i in range(n_rows): | ||
for mi in range(mask_len): | ||
mval = mask[i, mi] | ||
base = mi * 32 | ||
for bit in range(32): | ||
j = base + bit | ||
|
||
if j >= n_cols: | ||
break | ||
|
||
if ((mval >> bit) & 1) == 0: | ||
logits[i, j] = -np.inf | ||
|
||
def apply_token_bitmask_inplace(logits: np.ndarray, mask: np.ndarray) -> None: | ||
if logits.ndim == 1: | ||
logits = np.expand_dims(logits, axis=0) | ||
if mask.ndim == 1: | ||
mask = np.expand_dims(mask, axis=0) | ||
|
||
if mask.dtype != np.int32: | ||
raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.") | ||
elif mask.ndim != 2: | ||
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.") | ||
elif logits.ndim != 2: | ||
raise ValueError(f"Invalid logits dimensions: Expected a 2D array, but got {mask.ndim}D.") | ||
elif mask.shape[0] != logits.shape[0]: | ||
raise ValueError( | ||
f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})." | ||
) | ||
_apply_token_bitmask_kernel(logits, mask) | ||
|
||
def fill_next_token_bitmask( | ||
guide: Guide, mask: np.ndarray | ||
) -> None: | ||
# timing: all checks take roughly 0.5 microseconds. | ||
if mask.dtype != np.int32: | ||
raise ValueError(f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`.") | ||
elif mask.ndim != 2: | ||
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D.") | ||
elif mask.shape[0] != 1: | ||
raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.") | ||
elif not mask.flags["C_CONTIGUOUS"]: | ||
raise ValueError("Mask array must be contiguous in memory. Use `np.ascontiguousarray(mask)`.") | ||
|
||
return guide.write_mask_into( | ||
mask.ctypes.data, | ||
mask.size, | ||
mask.itemsize | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Provides kernels for masking a logits tensor, | ||
# using the write_into_mask method on the `Guide` object and the bitmask | ||
# which it writes into a tensor. | ||
# | ||
# Kernels inspired by https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/torch.py | ||
from outlines_core import Guide | ||
|
||
try: | ||
import torch | ||
except Exception as e: | ||
raise ModuleNotFoundError( | ||
"`torch` is required to use the kernels from" | ||
"`outlines_core.kernels.torch. You can install " | ||
"`torch` using the official guide at https://pytorch.org/get-started/locally/" | ||
) | ||
|
||
def allocate_token_bitmask(vocab_size: int) -> torch.Tensor: | ||
""" | ||
Allocate a token bitmask for use with the `Guide.write_into_mask` API and logits masking, | ||
based on the vocab_size. | ||
|
||
Arguments: | ||
- vocab_size: int | ||
Returns: | ||
- torch.Tensor | ||
""" | ||
return torch.full( | ||
(1, (vocab_size + 31) // 32), | ||
-1, | ||
dtype=torch.int32, | ||
pin_memory=torch.cuda.is_available(), | ||
) | ||
|
||
# This takes roughly 23 microseconds per run, with a bitmask of | ||
# 1k allowed tokens, and 128k logits tensor. | ||
# Also compiles to one graph with no graph breaks | ||
# Performance characteristics are: | ||
# - Larger the logits array ( length ), the longer the kernel takes | ||
# - Constant time for mask i.e. number of allowed tokens does not affect execution | ||
# time | ||
@torch.compile(dynamic=True) | ||
def _apply_token_bitmask_kernel(logits, mask): | ||
# This should not modify, so long as the mask | ||
# is allocated at the correct size | ||
logits = torch.where( | ||
torch.ge( | ||
torch.arange( | ||
logits.shape[1], | ||
device=logits.device | ||
), | ||
32 * mask.shape[1] | ||
), | ||
-torch.inf, | ||
logits | ||
) | ||
|
||
# Unpack each 32-bit mask value into 32 individual bits (as booleans) | ||
bit_masks = ( | ||
(torch.bitwise_right_shift( | ||
mask.unsqueeze(-1), | ||
torch.arange( | ||
32, | ||
device=mask.device, | ||
dtype=torch.int32 | ||
)) & 1 | ||
) | ||
.bool() | ||
.view(mask.shape[0], -1) | ||
.narrow(1, 0, logits.shape[1]) | ||
) | ||
|
||
# Possibly trim mask to match the logits width | ||
bit_masks = bit_masks[:, :logits.shape[1]] | ||
logits.masked_fill_(~bit_masks, -torch.inf) | ||
|
||
|
||
def apply_token_bitmask_inplace(logits: torch.Tensor, mask: torch.Tensor) -> None: | ||
if mask.dtype != torch.int32: | ||
raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.") | ||
elif mask.dim() != 2: | ||
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.") | ||
elif logits.dim() != 2: | ||
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.") | ||
elif mask.shape[0] != logits.shape[0]: | ||
raise ValueError( | ||
f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})." | ||
) | ||
_apply_token_bitmask_kernel(logits, mask) | ||
|
||
def fill_next_token_bitmask(guide: Guide, mask: torch.Tensor) -> None: | ||
if mask.dtype != torch.int32: | ||
raise ValueError(f"Invalid mask dtype: Expected `torch.int32`, but got `{mask.dtype}`.") | ||
elif mask.dim() != 2: | ||
raise ValueError(f"Invalid mask dimensions: Expected a 2D array, but got {mask.dim()}D.") | ||
elif mask.shape[0] != 1: | ||
raise ValueError(f"Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}.") | ||
elif not mask.is_contiguous(): | ||
raise ValueError("Mask array must be contiguous in memory. Use `mask.contiguous()` to fix it.") | ||
elif mask.device != torch.device("cpu"): | ||
raise ValueError(f"Invalid device: Expected `mask` tensor to be on device `cpu`, but found it on `{mask.device}`.") | ||
|
||
guide.write_mask_into( | ||
mask.data_ptr(), | ||
mask.numel(), | ||
mask.element_size() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to access the CUDA code generated by PyTorch? It might be over-engineering for now, but I'd like to get an idea of how efficient that code is and if there are gains to be had there in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems possible - just have to find the temp directory where it dumps it: https://pytorch.org/tutorials/intermediate/inductor_debug_cpu.html