Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KV Cache Quantization #971

Open
dinerburger opened this issue Dec 4, 2024 · 7 comments
Open

KV Cache Quantization #971

dinerburger opened this issue Dec 4, 2024 · 7 comments
Labels
new feature New feature or request

Comments

@dinerburger
Copy link

Both exllamav2 and llama.cpp support quantized KV cache to allow pretty large context lengths on consumer hardware. It would be a great addition to mistral.rs; I've been very interested in trying it, but I'm limited to 24GB of VRAM, necessitating, for example, sending KV cache to system RAM instead of keeping it on the card (only possible on llama.cpp to my knowledge).

@dinerburger dinerburger added the new feature New feature or request label Dec 4, 2024
@EricLBuehler
Copy link
Owner

Hi @dinerburger ! After some recent work in KV cache, I think we have the infrastructure now for this! I'll take a look again and will probably merge some initial support.

@dinerburger
Copy link
Author

Obviously there a number of ways to implement KV cache quant, but I'd be interested in knowing which implementation you're considering.

@EricLBuehler
Copy link
Owner

EricLBuehler commented Dec 7, 2024

I'm considering two options. The 8 bit cache using FP8 might be easier to implement.

  • 4-bit cache: something similar to what exllamav2 does here, where we apply a Hadamard transform to reduce the outliers (paper: https://arxiv.org/pdf/2404.00456), and then can use a Q4 cache
  • 8-bit cache: use FP8, which might be easier & quicker initially
    • We should probably consider using E5M2 rather than E4M3 because of this.

@sammcj I saw your recent PR merge to ollama supporting KV cache quantization - congrats! What method did you take (did you do anything special to quantize the K/V blocks)?

@dinerburger
Copy link
Author

dinerburger commented Dec 7, 2024

Perfect, yeah I was gonna recommend the Hadamard transform approach. It's easy and effective. I followed that PR pretty closely; @sammcj piggy-backed on llama.cpp's implementation, utilizing either q4_0 or q8_0 quant types provided by llama.cpp. Technically llama.cpp is capable of using [EDIT]many of their[/EDIT] quant types for KV cache quantization, and these types can be mixed assuming you build llama.cpp yourself with the GGML_CUDA_FA_ALL_QUANTS define. Depending on your implementation of the base llama quants this may be appropriate.

@sammcj
Copy link
Contributor

sammcj commented Dec 7, 2024

Thanks @EricLBuehler! Was simple compared to the efforts you'll be putting in I'm sure as llama.cpp does the heavy lifting of performing the quantisation. The changes to Ollama were mainly around the parameterisation of the Ollama components to make use of it, some memory management for their layer estimations/placement and a lot of shall we say 'soft skills' to get it across the line 😅

You can see the initial changes (bundled with FA support) in llama.cpp here: ggerganov/llama.cpp#7527

While 4bit works well for Exllamav2's KV, the quantisation that works well with llama.cpp/gguf is Q8_0, which is approximately 8.5bpw.

I've published a F16 vs Q8_0 KV perplexity measurement here (I might add q4_0 and another dataset variant as well in the next day or two).

Forgive my ignorance here - when you say int4/int8 - are you talking about quantising down to 4/8bit integers, or simply rounding to them?

The reason I ask this is I know that int4/8 models tend to be quite a bit lower quality than their quantised counterparts such as Q4_K_M/Q8_0.

@EricLBuehler
Copy link
Owner

@sammcj @dinerburger sorry for the late reply! I've begun work in #988.

Forgive my ignorance here - when you say int4/int8 - are you talking about quantising down to 4/8bit integers, or simply rounding to them?

I'll be using Q4_K_M and Q8_0 + a hadamard transform in the kernel for better distributions, not int4/int8.

Technically llama.cpp is capable of using [EDIT]many of their[/EDIT] quant types for KV cache quantization

Sounds like an interesting idea. I'm curious if we can do something similar after the initial KV cache quantization support is merged.

@dinerburger
Copy link
Author

Yeah, you can see the supported quant types here: https://github.com/ggerganov/llama.cpp/blob/26a8406ba9198eb6fdd8329fa717555b4f77f05f/common/common.cpp#L1018. A note however, if you want to experiment: compile llama.cpp with GGML_CUDA_FA_ALL_QUANTS or else you'll be limited to Q4_0 and Q8_0 with no mixing. With that flag, however, you can mix and match different K and V types, which can be nice since in my experience K cache is far more sensitive to quantization than V cache, especially for models leveraging GQA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants