-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add interface to Guide
object to update masks in place, and associated kernels.
#183
base: main
Are you sure you want to change the base?
Conversation
Benchmarks for
|
Awesome! Do you have some profiling results that show the time spend on each operation across the whole chain? |
|
||
# This takes roughly 23 microseconds per run, with a bitmask of | ||
# 1k allowed tokens, and 128k logits tensor. | ||
# Also compiles to one graph with no graph breaks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to access the CUDA code generated by PyTorch? It might be over-engineering for now, but I'd like to get an idea of how efficient that code is and if there are gains to be had there in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems possible - just have to find the temp directory where it dumps it: https://pytorch.org/tutorials/intermediate/inductor_debug_cpu.html
Co-authored-by: Rémi Louf <[email protected]>
Co-authored-by: Rémi Louf <[email protected]>
Co-authored-by: Rémi Louf <[email protected]>
@rlouf For Rust, Kernels, Or Both? |
towards resolving #178
On the guide object, I added the
write_mask_into
method. This method takes 3 arguments:data_ptr
: pointer to the start of the contiguous memory for the arraynumel
: number of elements in the arrayelement_size
: size in bytes of each element in the array. This is checked to be 4, since we only support u32 arrays. If it is not 4, a ValueError is thrown.In a mask array, each u32 represents the validity of 32 tokens ( one per bit ). Additionally, masks must also be stored in contiguous memory, in order for Rust to access and modify them.
Currently, kernels for both
torch
andnumpy
are implemented. Thenumpy
kernels require an additional dependency onnumba
in order to bring runtime down to around 40 microseconds ( 1 mask, 1 logits array ). runtime for the torch kernel with 1 mask and 1 logits array is half of numpy, at ~23 microseconds per run, mostly due totorch.compile
. The form of thenumpy
kernel is not final as of now; It will be updated to have better scaling and vectorized ops. If I can without hurting performance ( or if you all would like ), I will remove the dependency onnumba
.All kernels reside in the
outlines_core.kernels
submodule, with dependencies for each kernel imported dynamically in a try - except instead of being added to the package dependencies.TODO:
write_into_mask
method onGuide
numpy
kernelstorch
kernelsmlx
kernelsPlease feel free to critique any of this.