diff --git a/README.md b/README.md index 4b167807c0..c76ca499fc 100644 --- a/README.md +++ b/README.md @@ -64,23 +64,26 @@ Mistal.rs supports several model categories: ## Description **Fast**: -- Quantized model support: 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit for faster inference and optimized memory usage. +- Apple silicon support with the Metal framework. +- CPU inference with `mkl`, `accelerate` support and optimized backend. +- CUDA support with flash attention and cuDNN. - Continuous batching and PagedAttention support. - Prefix caching. - [Device mapping](docs/DEVICE_MAPPING.md): load and run some layers on the device and the rest on the CPU. -**Accelerator support**: -- Apple silicon support with the Metal framework. -- CPU inference with `mkl`, `accelerate` support and optimized backend. -- CUDA support with flash attention and cuDNN. +**Quantization**: +- [Details](docs/QUANTS.md) +- GGML: 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit, with ISQ support. +- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit +- HQQ: 4-bit and 8 bit, with ISQ support +- [ISQ](docs/ISQ.md) (In situ quantization): run `.safetensors` models directly from Hugging Face Hub by quantizing them after loading instead of creating a GGUF file. + - This loads the ISQ-able weights on CPU before quantizing with ISQ and then moving to the device to avoid memory spikes. + - Extremely fast due to working in parallel **Easy**: - Lightweight OpenAI API compatible HTTP server. - Python API. - Grammar support with Regex and Yacc. -- [ISQ](docs/ISQ.md) (In situ quantization): run `.safetensors` models directly from Hugging Face Hub by quantizing them after loading instead of creating a GGUF file. - - This loads the ISQ-able weights on CPU before quantizing with ISQ and then moving to the device to avoid memory spikes. - - Extremely fast due to working in parallel **Powerful**: - Fast LoRA support with weight merging. @@ -98,7 +101,6 @@ Mistal.rs supports several model categories: - Please suggest more by raising an issue! - Tool calling: [docs](docs/TOOL_CALLING.md) - Prompt chunking (only without PagedAttention for now): handle larger prompts where the activation size would cause an OOM by sending chunks -- Various quantizations (GGUF, GPTQ, ISQ): [docs](docs/QUANTS.md) This is a demo of interactive mode with streaming running Phi 3 128k mini with quantization via ISQ to Q4K. diff --git a/docs/ISQ.md b/docs/ISQ.md index e6b5de5c7b..cf9ac47e0d 100644 --- a/docs/ISQ.md +++ b/docs/ISQ.md @@ -15,9 +15,12 @@ Possible values for ISQ quantization: - Q5K - Q6K - Q8K +- HQQ4 +- HQQ8 When using ISQ, it will automatically load ISQ-able weights into CPU memory before applying ISQ. The ISQ application process moves the weights to device memory. This process is implemented to avoid memory spikes from loading the model in full precision. +**Fallback rules for GGUF quantization** If a tensor cannot be quantized, the fallback process is as follows: 1) If using a `K` quant, fallback to a similar `Q` quant. 2) If that is not possible, use `F32` as the data type. diff --git a/docs/QUANTS.md b/docs/QUANTS.md index 5c5cb16664..7daa93a1c6 100644 --- a/docs/QUANTS.md +++ b/docs/QUANTS.md @@ -4,16 +4,22 @@ Mistral.rs supports the following quantization: - GGUF/GGML - Q, K type - Supported in GGUF/GGML and GGUF/GGML adapter models + - Supported in all plain and adapter models - I quants coming! - CPU, CUDA, Metal (all supported devices) + - 2, 3, 4, 5, 6, 8 bit - GPTQ - Supported in all plain and adapter models - CUDA only + - 2, 3, 4, 8 bit +- HQQ + - Supported in all plain and adapter models via ISQ + - CUDA and CPU only + - 4, 8 bit - ISQ - Q, K type GGUF quants - Supported in all plain and adapter models - - I quants coming! - - GPTQ quants coming! + - HQQ quants - CPU, CUDA, Metal (all supported devices) ## Using a GGUF quantized model diff --git a/mistralrs-core/src/pipeline/isq.rs b/mistralrs-core/src/pipeline/isq.rs index ae1d62fe8b..f41841016b 100644 --- a/mistralrs-core/src/pipeline/isq.rs +++ b/mistralrs-core/src/pipeline/isq.rs @@ -23,6 +23,11 @@ use crate::device_map::DeviceMapper; /// - `Q5K` /// - `Q6K` /// - `Q8K` +/// - `HQQ1` +/// - `HQQ2` +/// - `HQQ3` +/// - `HQQ4` +/// - `HQQ8` pub fn parse_isq_value(s: &str) -> Result { match s.to_lowercase().as_str() { "q4_0" => Ok(IsqType::Q4_0), @@ -37,7 +42,12 @@ pub fn parse_isq_value(s: &str) -> Result { "q5k" => Ok(IsqType::Q5K), "q6k" => Ok(IsqType::Q6K), "q8k" => Ok(IsqType::Q8K), - _ => Err(format!("GGML type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`.")), + "hqq8" => Ok(IsqType::HQQ8), + "hqq4" => Ok(IsqType::HQQ4), + // "hqq3" => Ok(IsqType::HQQ3), + // "hqq2" => Ok(IsqType::HQQ2), + // "hqq1" => Ok(IsqType::HQQ1), + _ => Err(format!("GGML type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`.")), } } diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index a043b38d2f..942c666a6b 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -7,7 +7,11 @@ fn main() { use std::{path::PathBuf, vec}; println!("cargo:rerun-if-changed=build.rs"); let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); - let lib_files = vec!["kernels/gptq/q_gemm.cu"]; + let lib_files = vec![ + "kernels/gptq/q_gemm.cu", + "kernels/hqq/hqq.cu", + "kernels/ops/ops.cu", + ]; for lib_file in lib_files.iter() { println!("cargo:rerun-if-changed={lib_file}"); } diff --git a/mistralrs-quant/kernels/hqq/hqq.cu b/mistralrs-quant/kernels/hqq/hqq.cu new file mode 100644 index 0000000000..8ce4ef7353 --- /dev/null +++ b/mistralrs-quant/kernels/hqq/hqq.cu @@ -0,0 +1,318 @@ +// https://github.com/mobiusml/hqq/blob/master/hqq/kernels/hqq_aten_cuda_kernel.cu + +#include +#include +#include + +//#if __CUDA_ARCH__ >= 630 +#include "cuda_fp16.h" +//#endif +//#if __CUDA_ARCH__ >= 800 +#include "cuda_bf16.h" +//#endif + +inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} +#define BLOCK_SIZE 256 //~256 +#define SHARED_SIZE 512 //~512 + +/*******************************************************************************************************************************************/ +/************* 8-bit *************/ +/*******************************************************************************************************************************************/ +//Simple +template +__global__ void dequantize_8bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + int n = h*w; + if(i>=n) return; + + int j = i % w; + W_r[i] = ((T)(Wq_packed[i]) - zero[j])*scale[j]; +} + +extern "C" void dequantize_8bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} + +//#if __CUDA_ARCH__ >= 630 +extern "C" void dequantize_8bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + +//#if __CUDA_ARCH__ >= 800 +extern "C" void dequantize_8bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + + +/*******************************************************************************************************************************************/ +/************* 4-bit *************/ +/*******************************************************************************************************************************************/ + +//Simple +/*__global__ void unpack_4bit_u8_kernel(unsigned char* Wq_packed, unsigned char* Wq_unpacked, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + Wq_unpacked[i] = (Wq_packed[i] & 0xF0) >> 4; //First chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x0F); //Second chunk +}*/ + +//Simple +template +__global__ void dequantize_4bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + int n = h*w; + if(i>=n) return; + + int j = i % w; + //W_r[i] = (T)((Wq_packed[i] & 0xF0) >> 4);//((T)((Wq_packed[i] & 0xF0) >> 4) - zero[j])*scale[j]; //First chunk + //W_r[i + n] = (T)((Wq_packed[i] & 0x0F)) + (T)10000;//((T)((Wq_packed[i] & 0x0F)) - zero[j])*scale[j]; //Second chunk + W_r[i] = ((T)((Wq_packed[i] & 0xF0) >> 4) - zero[j])*scale[j]; //First chunk + W_r[i + n] = ((T)((Wq_packed[i] & 0x0F)) - zero[j])*scale[j]; //Second chunk +} + +extern "C" void dequantize_4bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} + +//#if __CUDA_ARCH__ >= 630 +extern "C" void dequantize_4bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + +//#if __CUDA_ARCH__ >= 800 +extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + +/*******************************************************************************************************************************************/ +/************* 2-bit *************/ +/*******************************************************************************************************************************************/ + +//Simple +/*__global__ void unpack_2bit_u8_kernel(unsigned char* Wq_packed, unsigned char* Wq_unpacked, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + Wq_unpacked[i] = (Wq_packed[i] & 0xC0) >> 6; //1st chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x30) >> 4; //2nd chunk + Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x0C) >> 2; //3rd chunk + Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x03); //4th chunk +}*/ + + +//Simple +template +__global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + int n = h*w; + if(i>=n) return; + + int j = i % w; + W_r[i] = ((T)((Wq_packed[i] & 0xC0) >> 6) - zero[j])*scale[j]; //1st chunk + W_r[i + n] = ((T)((Wq_packed[i] & 0x30) >> 4) - zero[j])*scale[j]; //2nd chunk + W_r[i + n*2] = ((T)((Wq_packed[i] & 0x0C) >> 2) - zero[j])*scale[j]; //3rd chunk + W_r[i + n*3] = ((T)((Wq_packed[i] & 0x03)) - zero[j])*scale[j]; //4th chunk +} + +extern "C" void dequantize_2bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} + +//#if __CUDA_ARCH__ >= 630 +extern "C" void dequantize_2bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + +//#if __CUDA_ARCH__ >= 800 +extern "C" void dequantize_2bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + + +// //Shared +// template +// __global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) { +// int i = blockIdx.x*blockDim.x + threadIdx.x; +// int n = h*w; +// int s = threadIdx.x; + +// if(i>=n) return; + +// __shared__ unsigned char shared[BLOCK_SIZE]; +// __shared__ scalar_t shared_meta[BLOCK_SIZE][2]; + +// int j = i % w; +// shared[s] = Wq_packed[i]; +// shared_meta[s][0] = zero[j]; +// shared_meta[s][1] = scale[j]; +// __syncthreads(); + + +// W_r[i] = (scalar_t((shared[s] & 0xC0) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk +// W_r[i + n] = (scalar_t((shared[s] & 0x30) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk +// W_r[i + n*2] = (scalar_t((shared[s] & 0x0C) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk +// W_r[i + n*3] = (scalar_t((shared[s] & 0x03)) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk +// } + + + +/*******************************************************************************************************************************************/ +/************* 1-bit *************/ +/*******************************************************************************************************************************************/ + +//Simple +/*__global__ void unpack_1bit_u8_kernel(unsigned char* Wq_packed, unsigned char* Wq_unpacked, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + Wq_unpacked[i] = (Wq_packed[i] & 0x80) >> 7; //1st chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x40) >> 6; //2nd chunk + Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x20) >> 5; //3rd chunk + Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x10) >> 4; //4th chunk + Wq_unpacked[i + n*4] = (Wq_packed[i] & 0x08) >> 3; //5th chunk + Wq_unpacked[i + n*5] = (Wq_packed[i] & 0x04) >> 2; //6th chunk + Wq_unpacked[i + n*6] = (Wq_packed[i] & 0x02) >> 1; //7th chunk + Wq_unpacked[i + n*7] = (Wq_packed[i] & 0x01); //8th chunk +}*/ + +//Simple +template +__global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + int n = h*w; + if(i>=n) return; + + int j = i % w; + W_r[i] = ((T)((Wq_packed[i] & 0x80) >> 7) - zero[j])*scale[j]; //1st chunk + W_r[i + n] = ((T)((Wq_packed[i] & 0x40) >> 6) - zero[j])*scale[j]; //2nd chunk + W_r[i + n*2] = ((T)((Wq_packed[i] & 0x20) >> 5) - zero[j])*scale[j]; //3rd chunk + W_r[i + n*3] = ((T)((Wq_packed[i] & 0x10) >> 4) - zero[j])*scale[j]; //4th chunk + W_r[i + n*4] = ((T)((Wq_packed[i] & 0x08) >> 3) - zero[j])*scale[j]; //5th chunk + W_r[i + n*5] = ((T)((Wq_packed[i] & 0x04) >> 2) - zero[j])*scale[j]; //6th chunk + W_r[i + n*6] = ((T)((Wq_packed[i] & 0x02) >> 1) - zero[j])*scale[j]; //7th chunk + W_r[i + n*7] = ((T)((Wq_packed[i] & 0x01)) - zero[j])*scale[j]; //8th chunk +} + +extern "C" void dequantize_1bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} + +//#if __CUDA_ARCH__ >= 630 +extern "C" void dequantize_1bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + +//#if __CUDA_ARCH__ >= 800 +extern "C" void dequantize_1bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + +// //Shared +// template +// __global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) { +// int i = blockIdx.x*blockDim.x + threadIdx.x; +// int s = threadIdx.x; +// int n = h*w; +// if(i>=n) return; + +// __shared__ unsigned char shared[BLOCK_SIZE]; +// __shared__ scalar_t shared_meta[BLOCK_SIZE][2]; + +// int j = i % w; +// shared[s] = Wq_packed[i]; +// shared_meta[s][0] = zero[j]; +// shared_meta[s][1] = scale[j]; +// __syncthreads(); + +// W_r[i] = (scalar_t((shared[s] & 0x80) >> 7) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk +// W_r[i + n] = (scalar_t((shared[s] & 0x40) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk +// W_r[i + n*2] = (scalar_t((shared[s] & 0x20) >> 5) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk +// W_r[i + n*3] = (scalar_t((shared[s] & 0x10) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk +// W_r[i + n*4] = (scalar_t((shared[s] & 0x08) >> 3) - shared_meta[s][0])*shared_meta[s][1]; //5th chunk +// W_r[i + n*5] = (scalar_t((shared[s] & 0x04) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //6th chunk +// W_r[i + n*6] = (scalar_t((shared[s] & 0x02) >> 1) - shared_meta[s][0])*shared_meta[s][1]; //7th chunk +// W_r[i + n*7] = (scalar_t((shared[s] & 0x01)) - shared_meta[s][0])*shared_meta[s][1]; //8th chunk +// } + + +/*******************************************************************************************************************************************/ +/************* 3-bit *************/ +/*******************************************************************************************************************************************/ + +//Simple +/*__global__ void unpack_3bit_32_kernel(int32_t* Wq_packed, unsigned char* Wq_unpacked, int n) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + if(i>=n) return; + + Wq_unpacked[i] = (Wq_packed[i] & 0x38000000) >> 27; //1st chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x07000000) >> 24; //2nd chunk + Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x00E00000) >> 21; //3rd chunk + Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x001C0000) >> 18; //4th chunk + Wq_unpacked[i + n*4] = (Wq_packed[i] & 0x00038000) >> 15; //5th chunk + Wq_unpacked[i + n*5] = (Wq_packed[i] & 0x00007000) >> 12; //6th chunk + Wq_unpacked[i + n*6] = (Wq_packed[i] & 0x00000E00) >> 9; //7th chunk + Wq_unpacked[i + n*7] = (Wq_packed[i] & 0x000001C0) >> 6; //8th chunk + Wq_unpacked[i + n*8] = (Wq_packed[i] & 0x00000038) >> 3; //9th chunk + Wq_unpacked[i + n*9] = (Wq_packed[i] & 0x00000007); //10th chunk +}*/ + + +//Simple +template +__global__ void dequantize_3bit_32_kernel(int32_t* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { + int i = blockIdx.x*blockDim.x + threadIdx.x; + int n = h*w; + if(i>=n) return; + + int j = i % w; + W_r[i] = ((T)((Wq_packed[i] & 0x38000000) >> 27) - zero[j])*scale[j]; //1st chunk + W_r[i + n] = ((T)((Wq_packed[i] & 0x07000000) >> 24) - zero[j])*scale[j]; //2nd chunk + W_r[i + n*2] = ((T)((Wq_packed[i] & 0x00E00000) >> 21) - zero[j])*scale[j]; //3rd chunk + W_r[i + n*3] = ((T)((Wq_packed[i] & 0x001C0000) >> 18) - zero[j])*scale[j]; //4th chunk + W_r[i + n*4] = ((T)((Wq_packed[i] & 0x00038000) >> 15) - zero[j])*scale[j]; //5th chunk + W_r[i + n*5] = ((T)((Wq_packed[i] & 0x00007000) >> 12) - zero[j])*scale[j]; //6th chunk + W_r[i + n*6] = ((T)((Wq_packed[i] & 0x00000E00) >> 9) - zero[j])*scale[j]; //7th chunk + W_r[i + n*7] = ((T)((Wq_packed[i] & 0x000001C0) >> 6) - zero[j])*scale[j]; //8th chunk + W_r[i + n*8] = ((T)((Wq_packed[i] & 0x00000038) >> 3) - zero[j])*scale[j]; //9th chunk + W_r[i + n*9] = ((T)((Wq_packed[i] & 0x00000007)) - zero[j])*scale[j]; //10th chunk +} + +extern "C" void dequantize_3bit_32_kernel_f32(int32_t* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} + +//#if __CUDA_ARCH__ >= 630 +extern "C" void dequantize_3bit_32_kernel_f16(int32_t* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif + +//#if __CUDA_ARCH__ >= 800 +extern "C" void dequantize_3bit_32_kernel_bf16(int32_t* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { + int blocks = cdiv(h*w, BLOCK_SIZE); + dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +} +//#endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/ops/ops.cu b/mistralrs-quant/kernels/ops/ops.cu new file mode 100644 index 0000000000..b9d026790f --- /dev/null +++ b/mistralrs-quant/kernels/ops/ops.cu @@ -0,0 +1,76 @@ +// Get inspiration from +// https://github.com/pytorch/pytorch/blob/65aa16f968af2cd18ff8c25cc657e7abda594bfc/aten/src/ATen/native/cuda/Nonzero.cu +#include +#include + +int mq_next_power_of_2(const uint32_t num_nonzero) { + int result = 1; + while (result < num_nonzero) { + result <<= 1; + } + return result; +} + +template +__global__ void bitwise_or__kernel(const T *d_in1, const T *d_in2, T *d_out, + const uint32_t N) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + d_out[idx] = d_in1[idx] | d_in2[idx]; + } +} + + +template +void bitwise_or(const T *d_in1, const T *d_in2, T *d_out, int N) { + int nthreads = mq_next_power_of_2(N); + if (nthreads > 1024) { + nthreads = 1024; + } + const int nblocks = (N + nthreads - 1) / nthreads; + bitwise_or__kernel<<>>(d_in1, d_in2, d_out, N); + cudaDeviceSynchronize(); +} + +#define BITWISE_OP(TYPENAME, RUST_NAME) \ + extern "C" void mq_bitwise_or_##RUST_NAME(const TYPENAME *d_in1, \ + const TYPENAME *d_in2, \ + TYPENAME *d_out, uint32_t N) { \ + bitwise_or(d_in1, d_in2, d_out, N); \ + } + +BITWISE_OP(uint8_t, u8) +BITWISE_OP(uint32_t, u32) +BITWISE_OP(int64_t, i64) +BITWISE_OP(int32_t, i32) + +template +__global__ void leftshift_kernel(const T *d_in1, T *d_out, + const uint32_t N, const int32_t k) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + d_out[idx] = d_in1[idx] << k; + } +} + +template +void leftshift(const T *d_in1, T *d_out, int N, const int32_t k) { + int nthreads = mq_next_power_of_2(N); + if (nthreads > 1024) { + nthreads = 1024; + } + const int nblocks = (N + nthreads - 1) / nthreads; + leftshift_kernel<<>>(d_in1, d_out, N, k); + cudaDeviceSynchronize(); +} + +#define LEFTSHIFT_OP(TYPENAME, RUST_NAME) \ + extern "C" void mq_leftshift_##RUST_NAME(const TYPENAME *d_in1, \ + TYPENAME *d_out, uint32_t N, int32_t k) { \ + leftshift(d_in1, d_out, N, k); \ + } + +LEFTSHIFT_OP(uint8_t, u8) +LEFTSHIFT_OP(int32_t, i32) +LEFTSHIFT_OP(uint32_t, u32) +LEFTSHIFT_OP(int64_t, i64) diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index 1c74d8628d..1d2a7e6dac 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -27,7 +27,9 @@ impl QuantMethod for GgufMatMul { w: QMatMul::from_arc(q_weight)?, b, }), - QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Unquantized(_) => unreachable!(), + QuantMethodConfig::Gptq { .. } + | QuantMethodConfig::Unquantized(_) + | QuantMethodConfig::Hqq { .. } => unreachable!(), } } diff --git a/mistralrs-quant/src/gptq/gptq_cpu.rs b/mistralrs-quant/src/gptq/gptq_cpu.rs index b59403872e..b4324409a5 100644 --- a/mistralrs-quant/src/gptq/gptq_cpu.rs +++ b/mistralrs-quant/src/gptq/gptq_cpu.rs @@ -23,7 +23,9 @@ impl QuantMethod for GptqLayer { g_idx: _, bias: _, } => candle_core::bail!("GPTQ is only supported on CUDA."), - QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) => { + QuantMethodConfig::Gguf { .. } + | QuantMethodConfig::Unquantized(_) + | QuantMethodConfig::Hqq { .. } => { unreachable!() } } diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs index b9b26f1257..e89a7866cd 100644 --- a/mistralrs-quant/src/gptq/gptq_cuda.rs +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -237,7 +237,9 @@ impl QuantMethod for GptqLayer { bias, }) } - QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) => { + QuantMethodConfig::Gguf { .. } + | QuantMethodConfig::Unquantized(_) + | QuantMethodConfig::Hqq { .. } => { unreachable!() } } diff --git a/mistralrs-quant/src/hqq/ffi.rs b/mistralrs-quant/src/hqq/ffi.rs new file mode 100644 index 0000000000..d1f1ef335f --- /dev/null +++ b/mistralrs-quant/src/hqq/ffi.rs @@ -0,0 +1,74 @@ +macro_rules! dequant_kernel { + ($wq:ty, $scalar:ty, $postfix:tt) => { + paste! { + pub(crate) fn [< dequantize_ $postfix >]( + wq_packed: *const $wq, + scale: *const $scalar, + zero: *const $scalar, + out: *const $scalar, + h: i32, + w: i32 + ); + } + }; +} + +pub mod eight_bit { + use half::{bf16, f16}; + use paste::paste; + + #[allow(dead_code)] + extern "C" { + dequant_kernel!(u8, f32, 8bit_u8_kernel_f32); + dequant_kernel!(u8, f16, 8bit_u8_kernel_f16); + dequant_kernel!(u8, bf16, 8bit_u8_kernel_bf16); + } +} + +pub mod four_bit { + use half::{bf16, f16}; + use paste::paste; + + #[allow(dead_code)] + extern "C" { + dequant_kernel!(u8, f32, 4bit_u8_kernel_f32); + dequant_kernel!(u8, f16, 4bit_u8_kernel_f16); + dequant_kernel!(u8, bf16, 4bit_u8_kernel_bf16); + } +} + +pub mod three_bit { + use half::{bf16, f16}; + use paste::paste; + + #[allow(dead_code)] + extern "C" { + dequant_kernel!(i32, f32, 3bit_32_kernel_f32); + dequant_kernel!(i32, f16, 3bit_32_kernel_f16); + dequant_kernel!(i32, bf16, 3bit_32_kernel_bf16); + } +} + +pub mod two_bit { + use half::{bf16, f16}; + use paste::paste; + + #[allow(dead_code)] + extern "C" { + dequant_kernel!(u8, f32, 2bit_u8_kernel_f32); + dequant_kernel!(u8, f16, 2bit_u8_kernel_f16); + dequant_kernel!(u8, bf16, 2bit_u8_kernel_bf16); + } +} + +pub mod one_bit { + use half::{bf16, f16}; + use paste::paste; + + #[allow(dead_code)] + extern "C" { + dequant_kernel!(u8, f32, 1bit_u8_kernel_f32); + dequant_kernel!(u8, f16, 1bit_u8_kernel_f16); + dequant_kernel!(u8, bf16, 1bit_u8_kernel_bf16); + } +} diff --git a/mistralrs-quant/src/hqq/hqq_cpu.rs b/mistralrs-quant/src/hqq/hqq_cpu.rs new file mode 100644 index 0000000000..e524eff15e --- /dev/null +++ b/mistralrs-quant/src/hqq/hqq_cpu.rs @@ -0,0 +1,313 @@ +use candle_core::{CpuStorage, CustomOp3, Layout, Result, Shape, WithDType}; + +/* + 8 bit +*/ +pub(crate) struct Dequant8Bit { + pub(crate) h: usize, + pub(crate) w: usize, +} + +impl Dequant8Bit { + fn dequantize(&self, w: &[u8], s: &[T], z: &[T]) -> Vec { + let mut out = Vec::with_capacity(w.len()); + for (i, w) in w.iter().enumerate() { + let j = i % self.w; + out[i] = (T::from_f64(*w as f64) - z[j]) * s[j]; + } + out + } +} + +impl CustomOp3 for Dequant8Bit { + fn name(&self) -> &'static str { + "dequant-hqq-8bit" + } + fn cpu_fwd( + &self, + w: &CpuStorage, + l_w: &Layout, + s: &CpuStorage, + l_s: &Layout, + z: &CpuStorage, + l_z: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let CpuStorage::U8(w_slice) = w else { + candle_core::bail!("Weight must be u8, HQQ dequant 8-bit"); + }; + if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { + candle_core::bail!("All inputs must be contiguous"); + } + match (s, z) { + (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( + CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[self.h, self.w]), + )), + (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok(( + CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[self.h, self.w]), + )), + (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok(( + CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[self.h, self.w]), + )), + (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + } + } +} + +/* + 4 bit +*/ +pub(crate) struct Dequant4Bit { + pub(crate) h: usize, + pub(crate) w: usize, +} + +impl Dequant4Bit { + fn dequantize(&self, w: &[u8], s: &[T], z: &[T]) -> Vec { + let mut out = Vec::with_capacity(w.len()); + for (i, w) in w.iter().enumerate() { + let j = i % self.w; + let nrows = self.h * self.w; + out[i] = (T::from_f64(((*w & 0xF0) >> 4) as f64) - z[j]) * s[j]; + out[i + nrows] = (T::from_f64((*w & 0x0F) as f64) - z[j]) * s[j]; + } + out + } +} + +impl CustomOp3 for Dequant4Bit { + fn name(&self) -> &'static str { + "dequant-hqq-4bit" + } + fn cpu_fwd( + &self, + w: &CpuStorage, + l_w: &Layout, + s: &CpuStorage, + l_s: &Layout, + z: &CpuStorage, + l_z: &Layout, + ) -> Result<(CpuStorage, Shape)> { + const PACK_FACTOR: usize = 2; + + let CpuStorage::U8(w_slice) = w else { + candle_core::bail!("Weight must be u8, HQQ dequant 4-bit"); + }; + if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { + candle_core::bail!("All inputs must be contiguous"); + } + match (s, z) { + (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( + CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok(( + CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok(( + CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + } + } +} + +/* + 2 bit +*/ +pub(crate) struct Dequant2Bit { + pub(crate) h: usize, + pub(crate) w: usize, +} + +impl Dequant2Bit { + fn dequantize(&self, w: &[u8], s: &[T], z: &[T]) -> Vec { + let mut out = Vec::with_capacity(w.len()); + for (i, w) in w.iter().enumerate() { + let j = i % self.w; + let nrows = self.h * self.w; + out[i] = (T::from_f64(((*w & 0xC0) >> 6) as f64) - z[j]) * s[j]; + out[i + nrows] = (T::from_f64(((*w & 0x30) >> 4) as f64) - z[j]) * s[j]; + out[i + nrows * 2] = (T::from_f64(((*w & 0x0C) >> 2) as f64) - z[j]) * s[j]; + out[i + nrows * 3] = (T::from_f64((*w & 0x03) as f64) - z[j]) * s[j]; + } + out + } +} + +impl CustomOp3 for Dequant2Bit { + fn name(&self) -> &'static str { + "dequant-hqq-2bit" + } + fn cpu_fwd( + &self, + w: &CpuStorage, + l_w: &Layout, + s: &CpuStorage, + l_s: &Layout, + z: &CpuStorage, + l_z: &Layout, + ) -> Result<(CpuStorage, Shape)> { + const PACK_FACTOR: usize = 4; + + let CpuStorage::U8(w_slice) = w else { + candle_core::bail!("Weight must be u8, HQQ dequant 2-bit"); + }; + if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { + candle_core::bail!("All inputs must be contiguous"); + } + match (s, z) { + (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( + CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok(( + CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok(( + CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + } + } +} + +/* + 1 bit +*/ +pub(crate) struct Dequant1Bit { + pub(crate) h: usize, + pub(crate) w: usize, +} + +impl Dequant1Bit { + fn dequantize(&self, w: &[u8], s: &[T], z: &[T]) -> Vec { + let mut out = Vec::with_capacity(w.len()); + for (i, w) in w.iter().enumerate() { + let j = i % self.w; + let nrows = self.h * self.w; + out[i] = (T::from_f64(((*w & 0x80) >> 7) as f64) - z[j]) * s[j]; + out[i + nrows] = (T::from_f64(((*w & 0x40) >> 6) as f64) - z[j]) * s[j]; + out[i + nrows * 2] = (T::from_f64(((*w & 0x20) >> 5) as f64) - z[j]) * s[j]; + out[i + nrows * 3] = (T::from_f64(((*w & 0x10) >> 4) as f64) - z[j]) * s[j]; + out[i + nrows * 4] = (T::from_f64(((*w & 0x08) >> 3) as f64) - z[j]) * s[j]; + out[i + nrows * 5] = (T::from_f64(((*w & 0x04) >> 2) as f64) - z[j]) * s[j]; + out[i + nrows * 6] = (T::from_f64(((*w & 0x02) >> 1) as f64) - z[j]) * s[j]; + out[i + nrows * 7] = (T::from_f64((*w & 0x01) as f64) - z[j]) * s[j]; + } + out + } +} + +impl CustomOp3 for Dequant1Bit { + fn name(&self) -> &'static str { + "dequant-hqq-1bit" + } + fn cpu_fwd( + &self, + w: &CpuStorage, + l_w: &Layout, + s: &CpuStorage, + l_s: &Layout, + z: &CpuStorage, + l_z: &Layout, + ) -> Result<(CpuStorage, Shape)> { + const PACK_FACTOR: usize = 8; + + let CpuStorage::U8(w_slice) = w else { + candle_core::bail!("Weight must be u8, HQQ dequant 1-bit"); + }; + if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { + candle_core::bail!("All inputs must be contiguous"); + } + match (s, z) { + (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( + CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok(( + CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok(( + CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + } + } +} + +/* + 3 bit +*/ +pub(crate) struct Dequant3Bit { + pub(crate) h: usize, + pub(crate) w: usize, +} + +impl Dequant3Bit { + fn dequantize(&self, w: &[i32], s: &[T], z: &[T]) -> Vec { + let mut out = Vec::with_capacity(w.len()); + for (i, w) in w.iter().enumerate() { + let j = i % self.w; + let nrows = self.h * self.w; + out[i] = (T::from_f64(((*w & 0x38000000) >> 27) as f64) - z[j]) * s[j]; + out[i + nrows] = (T::from_f64(((*w & 0x07000000) >> 24) as f64) - z[j]) * s[j]; + out[i + nrows * 2] = (T::from_f64(((*w & 0x00E00000) >> 21) as f64) - z[j]) * s[j]; + out[i + nrows * 3] = (T::from_f64(((*w & 0x001C0000) >> 18) as f64) - z[j]) * s[j]; + out[i + nrows * 4] = (T::from_f64(((*w & 0x00038000) >> 15) as f64) - z[j]) * s[j]; + out[i + nrows * 5] = (T::from_f64(((*w & 0x00007000) >> 12) as f64) - z[j]) * s[j]; + out[i + nrows * 6] = (T::from_f64(((*w & 0x00000E00) >> 9) as f64) - z[j]) * s[j]; + out[i + nrows * 7] = (T::from_f64(((*w & 0x000001C0) >> 6) as f64) - z[j]) * s[j]; + out[i + nrows * 8] = (T::from_f64(((*w & 0x00000038) >> 3) as f64) - z[j]) * s[j]; + out[i + nrows * 9] = (T::from_f64((*w & 0x00000007) as f64) - z[j]) * s[j]; + } + out + } +} + +impl CustomOp3 for Dequant3Bit { + fn name(&self) -> &'static str { + "dequant-hqq-3bit" + } + fn cpu_fwd( + &self, + w: &CpuStorage, + l_w: &Layout, + s: &CpuStorage, + l_s: &Layout, + z: &CpuStorage, + l_z: &Layout, + ) -> Result<(CpuStorage, Shape)> { + const PACK_FACTOR: usize = 10; + + let CpuStorage::I32(w_slice) = w else { + candle_core::bail!("Weight must be i32, HQQ dequant 3-bit"); + }; + if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { + candle_core::bail!("All inputs must be contiguous"); + } + match (s, z) { + (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( + CpuStorage::F32(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::F16(s_slice), CpuStorage::F16(z_slice)) => Ok(( + CpuStorage::F16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (CpuStorage::BF16(s_slice), CpuStorage::BF16(z_slice)) => Ok(( + CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), + Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), + )), + (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + } + } +} diff --git a/mistralrs-quant/src/hqq/mod.rs b/mistralrs-quant/src/hqq/mod.rs new file mode 100644 index 0000000000..9af345184f --- /dev/null +++ b/mistralrs-quant/src/hqq/mod.rs @@ -0,0 +1,596 @@ +use candle_core::{DType, Device, Result, Shape, Tensor}; + +#[cfg(feature = "cuda")] +use candle_core::{ + cuda::{cudarc::driver::DevicePtr, CudaStorageSlice, WrapErr}, + from_storage_no_op, CudaStorage, Storage, +}; + +#[cfg(feature = "cuda")] +use half::{bf16, f16}; +use std::{ + num::NonZeroUsize, + sync::{atomic::AtomicUsize, Arc}, +}; + +use crate::{ + utils::{BitWiseOp, LeftshiftOp}, + IsqType, QuantMethod, QuantMethodConfig, +}; + +#[cfg(feature = "cuda")] +use crate::utils::{get_cuda_device, get_cuda_slice}; + +#[cfg(feature = "cuda")] +use ffi::{eight_bit, four_bit, one_bit, three_bit, two_bit}; + +#[cfg(feature = "cuda")] +mod ffi; + +#[cfg(not(feature = "cuda"))] +mod hqq_cpu; + +mod optimize; +mod quantize; + +pub(crate) const ISQ_HQQ_GROUP_SIZE: usize = 64; +pub(crate) const ISQ_HQQ_DEFAULT_OPT_STEPS: Option = Some(10); +pub(crate) const OPTIMIZER_HQQ_DEFAULT_STEPS: usize = 20; + +#[cfg(feature = "cuda")] +macro_rules! dequant_for_dtype { + ($this:expr, w=$wq_t:ty, sz=$scale_t:ty, $dtype:ident, pack=$pack:expr, $dev:expr, $bit_thing:ident, $postfix:tt) => {{ + paste::paste! { + let w_slice = get_cuda_slice::<$wq_t>(&$this.w_q)?; + let scale_slice = get_cuda_slice::<$scale_t>(&$this.scales)?; + let zero_slice = get_cuda_slice::<$scale_t>(&$this.zeros)?; + + let (h, w) = $this.w_q.dims2()?; + let num_packed_elems = $pack; + let out_shape = Shape::from_dims(&[num_packed_elems * h, w]); + + let out = unsafe { $dev.alloc::<$scale_t>(out_shape.elem_count()).w()? }; + let out_ptr = *out.device_ptr() as *mut $scale_t; + unsafe { + $bit_thing::[< dequantize_ $postfix >]( + w_slice, + scale_slice, + zero_slice, + out_ptr, + h as i32, + w as i32, + ); + } + + let storage = CudaStorage { + slice: CudaStorageSlice::$dtype(out), + device: $dev.clone(), + }; + let storage = Storage::Cuda(storage); + + from_storage_no_op(storage, out_shape, false) + } + }}; +} + +#[derive(Debug, Clone, Copy)] +pub enum HqqAxis { + Zero = 0, + One = 1, +} + +#[derive(Debug, Clone, Copy)] +pub enum HqqBits { + Eight = 8, + Four = 4, + Three = 3, + Two = 2, + One = 1, +} + +impl HqqBits { + // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/core/bitpack.py#L10 + pub(crate) fn bitpack_type(&self) -> impl Fn(Tensor) -> Result { + match self { + Self::Eight => |wq: Tensor| wq.to_dtype(DType::U8), + Self::Four => |wq: Tensor| { + let wq = wq.to_dtype(DType::U8)?; + let step = (wq.dims()[0] as f64 / 2.) as usize; + + let a = wq.narrow(0, 0, step)?; + let b = wq.narrow(0, step, step)?; + a.leftshift(4)?.bitwise_or(&b) + }, + Self::Two => |wq: Tensor| { + let wq = wq.to_dtype(DType::U8)?; + let step = (wq.dims()[0] as f64 / 4.) as usize; + + let a = wq.narrow(0, 0, step)?; + let b = wq.narrow(0, step, step)?; + let c = wq.narrow(0, step * 2, step)?; + let d = wq.narrow(0, step * 3, step)?; + + a.leftshift(6)? + .bitwise_or(&b.leftshift(4)?)? + .bitwise_or(&c.leftshift(2)?)? + .bitwise_or(&d) + }, + Self::Three => |wq_in: Tensor| { + let wq = Tensor::zeros( + ( + (10. * (wq_in.dims()[0] as f64 / 10.).ceil()) as usize, + wq_in.dims()[1], + ), + DType::I32, + wq_in.device(), + )?; + let wq = + wq.slice_assign(&[&(..wq_in.dims()[0]), &..], &wq_in.to_dtype(DType::I32)?)?; + let step = (wq.dims()[0] as f64 / 10.) as usize; + + let a = wq.narrow(0, 0, step)?; + let b = wq.narrow(0, step, step)?; + let c = wq.narrow(0, step * 2, step)?; + let d = wq.narrow(0, step * 3, step)?; + let e = wq.narrow(0, step * 4, step)?; + let f = wq.narrow(0, step * 5, step)?; + let g = wq.narrow(0, step * 6, step)?; + let h = wq.narrow(0, step * 7, step)?; + let i = wq.narrow(0, step * 8, step)?; + let j = wq.narrow(0, step * 9, step)?; + + a.leftshift(27)? + .bitwise_or(&b.leftshift(24)?)? + .bitwise_or(&c.leftshift(21)?)? + .bitwise_or(&d.leftshift(18)?)? + .bitwise_or(&e.leftshift(15)?)? + .bitwise_or(&f.leftshift(12)?)? + .bitwise_or(&g.leftshift(9)?)? + .bitwise_or(&h.leftshift(6)?)? + .bitwise_or(&i.leftshift(3)?)? + .bitwise_or(&j) + }, + Self::One => |wq: Tensor| { + let wq = wq.to_dtype(DType::U8)?; + let step = (wq.dims()[0] as f64 / 8.) as usize; + + let a = wq.narrow(0, 0, step)?; + let b = wq.narrow(0, step, step)?; + let c = wq.narrow(0, step * 2, step)?; + let d = wq.narrow(0, step * 3, step)?; + let e = wq.narrow(0, step * 4, step)?; + let f = wq.narrow(0, step * 5, step)?; + let g = wq.narrow(0, step * 6, step)?; + let h = wq.narrow(0, step * 7, step)?; + + a.leftshift(7)? + .bitwise_or(&b.leftshift(6)?)? + .bitwise_or(&c.leftshift(5)?)? + .bitwise_or(&d.leftshift(4)?)? + .bitwise_or(&e.leftshift(3)?)? + .bitwise_or(&f.leftshift(2)?)? + .bitwise_or(&g.leftshift(1)?)? + .bitwise_or(&h) + }, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct HqqConfig { + pub bits: HqqBits, + pub group_size: NonZeroUsize, + pub axis: HqqAxis, + pub optimization_steps: Option, + pub round_zeros: bool, // default false + pub channel_wise: bool, // default true +} + +#[derive(Debug)] +pub struct HqqLayer { + pub(crate) w_q: Tensor, + pub(crate) zeros: Tensor, + pub(crate) scales: Tensor, + pub(crate) bias: Option, + pub(crate) w_shape: Shape, + pub(crate) cfg: HqqConfig, +} + +impl HqqLayer { + /// Dequantize `self` into a tensor of shape `scales` or `zeros`. + #[cfg(not(feature = "cuda"))] + fn dequantize(&self) -> Result { + use crate::hqq::hqq_cpu::{ + Dequant1Bit, Dequant2Bit, Dequant3Bit, Dequant4Bit, Dequant8Bit, + }; + + match (self.scales.dtype(), self.zeros.dtype()) { + (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (), + (a, b) => { + candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).") + } + } + if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous()) + { + candle_core::bail!("All tensors must be contiguous!"); + } + if self.cfg.axis as usize != 0 { + candle_core::bail!( + "CPU HQQ dequantization requires axis == 0, got {}.", + self.cfg.axis as usize + ); + } + let (h, w) = self.w_q.dims2()?; + + match self.cfg.bits as usize { + 8 => self + .w_q + .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant8Bit { h, w })? + .reshape(&self.w_shape), + 4 => self + .w_q + .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant4Bit { h, w })? + .reshape(&self.w_shape), + 3 => self + .w_q + .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant3Bit { h, w })? + .reshape(&self.w_shape), + 2 => self + .w_q + .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant2Bit { h, w })? + .reshape(&self.w_shape), + 1 => self + .w_q + .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })? + .reshape(&self.w_shape), + b => candle_core::bail!("Unreachable bits {b}"), + } + } + + /// Dequantize `self` into a tensor of shape `scales` or `zeros`. + #[cfg(feature = "cuda")] + fn dequantize(&self) -> Result { + match (self.scales.dtype(), self.zeros.dtype()) { + (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (), + (a, b) => { + candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).") + } + } + if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous()) + { + candle_core::bail!("All tensors must be contiguous!"); + } + if self.cfg.axis as usize != 0 { + candle_core::bail!( + "CUDA HQQ dequantization requires axis == 0, got {}.", + self.cfg.axis as usize + ); + } + let dev = get_cuda_device(&self.w_q)?; + + let inner = match (self.cfg.bits as usize, self.scales.dtype()) { + // 8 bits + (8, DType::F32) => { + dequant_for_dtype!( + self, + w = u8, + sz = f32, + F32, + pack = 1, + dev, + eight_bit, + 8bit_u8_kernel_f32 + ) + } + (8, DType::F16) => { + dequant_for_dtype!( + self, + w = u8, + sz = f16, + F16, + pack = 1, + dev, + eight_bit, + 8bit_u8_kernel_f16 + ) + } + (8, DType::BF16) => { + dequant_for_dtype!( + self, + w = u8, + sz = bf16, + BF16, + pack = 1, + dev, + eight_bit, + 8bit_u8_kernel_bf16 + ) + } + + // 4 bits + (4, DType::F32) => { + dequant_for_dtype!( + self, + w = u8, + sz = f32, + F32, + pack = 2, + dev, + four_bit, + 4bit_u8_kernel_f32 + ) + } + (4, DType::F16) => { + dequant_for_dtype!( + self, + w = u8, + sz = f16, + F16, + pack = 2, + dev, + four_bit, + 4bit_u8_kernel_f16 + ) + } + (4, DType::BF16) => { + dequant_for_dtype!( + self, + w = u8, + sz = bf16, + BF16, + pack = 2, + dev, + four_bit, + 4bit_u8_kernel_bf16 + ) + } + + // 3 bits + // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/kernels/hqq_aten_cuda.cpp#L42-L45 + (3, DType::F32) => { + let res = dequant_for_dtype!( + self, + w = i32, + sz = f32, + F32, + pack = 10, + dev, + three_bit, + 3bit_32_kernel_f32 + ); + res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())? + } + (3, DType::F16) => { + let res = dequant_for_dtype!( + self, + w = i32, + sz = f16, + F16, + pack = 10, + dev, + three_bit, + 3bit_32_kernel_f16 + ); + res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())? + } + (3, DType::BF16) => { + let res = dequant_for_dtype!( + self, + w = i32, + sz = bf16, + BF16, + pack = 10, + dev, + three_bit, + 3bit_32_kernel_bf16 + ); + res.narrow(self.cfg.axis as usize, 0, self.cfg.group_size.into())? + } + + // 2 bits + (2, DType::F32) => { + dequant_for_dtype!( + self, + w = u8, + sz = f32, + F32, + pack = 4, + dev, + two_bit, + 2bit_u8_kernel_f32 + ) + } + (2, DType::F16) => { + dequant_for_dtype!( + self, + w = u8, + sz = f16, + F16, + pack = 4, + dev, + two_bit, + 2bit_u8_kernel_f16 + ) + } + (2, DType::BF16) => { + dequant_for_dtype!( + self, + w = u8, + sz = bf16, + BF16, + pack = 4, + dev, + two_bit, + 2bit_u8_kernel_bf16 + ) + } + + // 1 bit + (1, DType::F32) => { + dequant_for_dtype!( + self, + w = u8, + sz = f32, + F32, + pack = 8, + dev, + one_bit, + 1bit_u8_kernel_f32 + ) + } + (1, DType::F16) => { + dequant_for_dtype!( + self, + w = u8, + sz = f16, + F16, + pack = 8, + dev, + one_bit, + 1bit_u8_kernel_f16 + ) + } + (1, DType::BF16) => { + dequant_for_dtype!( + self, + w = u8, + sz = bf16, + BF16, + pack = 8, + dev, + one_bit, + 1bit_u8_kernel_bf16 + ) + } + (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"), + }; + inner.reshape(&self.w_shape) + } + + fn dequantize_matmul(&self, xs: &Tensor) -> Result { + let w = self.dequantize()?; + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + let res = xs.matmul(&w)?; + if let Some(ref bias) = self.bias { + res + bias + } else { + Ok(res) + } + } + + pub fn with_bias(mut self, bias: Tensor) -> Self { + self.bias = Some(bias); + self + } +} + +impl QuantMethod for HqqLayer { + fn new(method: QuantMethodConfig) -> Result + where + Self: Sized, + { + match method { + QuantMethodConfig::Gguf { .. } + | QuantMethodConfig::Unquantized(_) + | QuantMethodConfig::Gptq { .. } => { + unreachable!() + } + QuantMethodConfig::Hqq { + tensor, + bits, + group_size, + axis, + optimization_steps, + round_zeros, + channel_wise, + bias, + } => { + let cfg = HqqConfig { + bits, + group_size, + axis, + optimization_steps, + round_zeros: round_zeros.unwrap_or(false), + channel_wise: channel_wise.unwrap_or(true), + }; + + let this = Self::quantize(&tensor, tensor.device(), cfg)?; + if let Some(bias) = bias { + Ok(this.with_bias(bias)) + } else { + Ok(this) + } + } + } + } + + fn forward(&self, a: &Tensor) -> Result { + /* + if self.cfg.force_dequantize { + self.dequantize_matmul(a) + } else { + todo!() + } */ + self.dequantize_matmul(a) + } + + fn quantized_act_type(&self) -> Option { + Some(self.scales.dtype()) + } + + fn add_delta_w(&self, _delta: &Tensor) -> Result> { + candle_core::bail!("HQQ quantization does not support adding weight delta.") + } + + fn dtype_and_device(&self) -> (DType, Device) { + (self.scales.dtype(), self.scales.device().clone()) + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + self.bias.as_mut() + } + + fn apply_isq( + self: Arc, + dtype: IsqType, + device: Device, + n_quantized: &AtomicUsize, + ) -> Result> { + n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let bits = match dtype { + IsqType::HQQ8 => HqqBits::Eight, + IsqType::HQQ4 => HqqBits::Four, + // IsqType::HQQ3 => HqqBits::Three, + // IsqType::HQQ2 => HqqBits::Two, + // IsqType::HQQ1 => HqqBits::One, + _ => candle_core::bail!("Expected HQQ ISQ type."), + }; + let cfg = HqqConfig { + bits, + group_size: ISQ_HQQ_GROUP_SIZE.try_into()?, + axis: HqqAxis::Zero, + optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS, + round_zeros: false, + channel_wise: true, + }; + let dequant = self.dequantize()?; + let res = Self::quantize(&dequant, &device, cfg)?; + if let Some(ref bias) = self.bias { + let bias = bias + .to_device(&device)? + .to_dtype(res.dtype_and_device().0)?; + Ok(Arc::new(res.with_bias(bias))) + } else { + Ok(Arc::new(res)) + } + } + + fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { + // Use 1 because we quantize on the GPU + Some(1.try_into().unwrap()) + } +} diff --git a/mistralrs-quant/src/hqq/optimize.rs b/mistralrs-quant/src/hqq/optimize.rs new file mode 100644 index 0000000000..c3770e32f3 --- /dev/null +++ b/mistralrs-quant/src/hqq/optimize.rs @@ -0,0 +1,95 @@ +use candle_core::{DType, Result, Tensor}; + +use super::{HqqAxis, HqqLayer, OPTIMIZER_HQQ_DEFAULT_STEPS}; + +pub(crate) struct OptParams { + pub(crate) lp_norm: f64, + pub(crate) beta: f64, + pub(crate) kappa: f64, + pub(crate) iters: usize, +} + +impl OptParams { + pub(crate) fn default(optimization_steps: Option) -> Self { + Self { + lp_norm: 0.7, + beta: 1e1, + kappa: 1.01, + iters: optimization_steps.unwrap_or(OPTIMIZER_HQQ_DEFAULT_STEPS), + } + } +} + +pub(crate) struct OptResults { + pub(crate) wq: Tensor, + pub(crate) scale: Tensor, + pub(crate) zero: Tensor, +} + +fn shrink_lp_op(x: &Tensor, beta: f64, lp_norm: f64) -> Result { + if lp_norm == 1. { + x.sign()?.broadcast_mul(&(x.abs()? - 1. / beta)?.relu()?) + } else { + let pow_exp = Tensor::new(lp_norm - 1., x.device())? + .broadcast_as(x.shape().clone())? + .to_dtype(x.dtype())?; + x.sign()? + .broadcast_mul(&(x.abs()? - ((1. / beta) * x.abs()?.pow(&pow_exp)?))?.relu()?) + } +} + +impl HqqLayer { + // https://github.com/mobiusml/hqq/blob/306e30d9400629523c8e0af70101d8d7073cb3d5/hqq/core/optimize.py#L194 + pub(crate) fn optimize_weights_proximal_legacy( + tensor: &Tensor, + scale: &Tensor, + zero: Tensor, + min: f64, + max: f64, + axis: HqqAxis, + opt_params: OptParams, + ) -> Result { + let OptParams { + lp_norm, + mut beta, + kappa, + iters, + } = opt_params; + + let wf = tensor.clone(); + let scale = scale.to_dtype(wf.dtype())?; + let mut zero = zero.to_dtype(wf.dtype())?; + + let mut best_error = 1e4; + for _ in 0..iters { + let wq = wf + .broadcast_mul(&scale)? + .broadcast_add(&zero)? + .round()? + .clamp(min, max)?; + let wr = wq.broadcast_sub(&zero)?.broadcast_div(&scale)?; + let we = shrink_lp_op(&(&wf - &wr)?, beta, lp_norm)?; + + zero = (wq - (&wf - we)?.broadcast_mul(&scale)?)?.mean_keepdim(axis as usize)?; + beta *= kappa; + + let current_error = (&wf - wr)? + .abs()? + .mean_all()? + .to_dtype(DType::F32)? + .to_scalar::()?; + if current_error < best_error { + best_error = current_error; + } else { + break; + } + } + + let wq = tensor + .broadcast_mul(&scale)? + .broadcast_add(&zero)? + .round()? + .clamp(min, max)?; + Ok(OptResults { wq, scale, zero }) + } +} diff --git a/mistralrs-quant/src/hqq/quantize.rs b/mistralrs-quant/src/hqq/quantize.rs new file mode 100644 index 0000000000..0fd7ba5ff1 --- /dev/null +++ b/mistralrs-quant/src/hqq/quantize.rs @@ -0,0 +1,131 @@ +use candle_core::{Device, Result, Tensor}; + +use crate::hqq::optimize::OptResults; + +use super::{optimize::OptParams, HqqAxis, HqqConfig, HqqLayer}; + +impl HqqLayer { + /// Quantize the model into HQQ + pub fn quantize(input: &Tensor, device: &Device, cfg: HqqConfig) -> Result { + let group_size: usize = cfg.group_size.into(); + if input.elem_count() % group_size != 0 { + candle_core::bail!("`group_size` should be divisible by the tensor number of elements, which are {}, got a group size of {group_size}.", input.elem_count()); + } + + let mut w = input.clone(); + + // Reshape for grouping + w = if cfg.channel_wise { + match cfg.axis { + HqqAxis::One => w.reshape(((), group_size))?, + HqqAxis::Zero => w.reshape((group_size, ()))?, + } + } else { + w + }; + + // Get min and max valyes + let (min, max) = if !cfg.channel_wise { + // TODO we need min_all + let mut min = w.min(0)?; + let mut max = w.max(0)?; + while !min.dims().is_empty() { + min = min.min(0)?; + max = max.max(0)?; + } + (min, max) + } else { + ( + w.min_keepdim(cfg.axis as usize)?, + w.max_keepdim(cfg.axis as usize)?, + ) + }; + + let max_v = (2f64.powf(cfg.bits as usize as f64) - 1.).round(); + + // Note: here using the inverse of the scale to avoid division, quantize via W * scale + zero, scale is inverted later! + // Clamp to avoid half precision problems + let scale = (max_v / (max - &min)?)?.clamp(0., 2e4)?; + let mut zero = (min.neg()? * &scale)?; + + if cfg.round_zeros { + zero = zero.round()?; + } + + // We only support using optimization! + /*let (quant_w, scale, zero) = if let Some(optimization_steps) = cfg.optimization_steps { + let result = Self::optimize_weights_proximal_legacy( + &w, + &scale, + zero, + 0., + max_v, + cfg.axis, + OptParams::default(optimization_steps), + )?; + (result.wq, result.scale, result.zero) + } else { + ( + w.broadcast_mul(&scale)? + .broadcast_add(&zero)? + .clamp(0., max_v)?, + scale, + zero, + ) + };*/ + let OptResults { wq, scale, zero } = Self::optimize_weights_proximal_legacy( + &w, + &scale, + zero, + 0., + max_v, + cfg.axis, + OptParams::default(cfg.optimization_steps), + )?; + + let quant_w = cfg.bits.bitpack_type()(wq)?.to_device(device)?; + + let this = Self { + w_q: quant_w, + zeros: zero.to_device(device)?, + scales: (1.0 / scale)?.to_device(device)?, + bias: None, + w_shape: input.shape().clone(), + cfg, + }; + Ok(this) + } +} + +mod test { + #[cfg(all(feature = "cuda", test))] + use candle_core::{Device, Result, Tensor}; + + #[cfg(all(feature = "cuda", test))] + #[test] + fn test_quantize_hqq() -> Result<()> { + use candle_core::DType; + + use crate::{HqqAxis, HqqBits, HqqConfig, HqqLayer}; + + let dev = Device::new_cuda(0)?; + let data = Tensor::rand(0., 1., (3072, 3072), &dev)?.to_dtype(DType::F32)?; + let hqq = HqqLayer::quantize( + &data, + &dev, + HqqConfig { + bits: HqqBits::Two, + group_size: 64.try_into()?, + axis: HqqAxis::Zero, + optimization_steps: None, + round_zeros: false, + channel_wise: true, + }, + )?; + + let dequant = hqq.dequantize()?; + + dbg!(&(&dequant - &data)?.abs()?.mean_all()?); + Ok(()) + } +} diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 84bfa38bbb..d25aed363b 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -11,11 +11,13 @@ use candle_core::{ mod gguf; mod gptq; +mod hqq; mod unquantized; mod utils; pub use gguf::GgufMatMul; pub use gptq::GptqLayer; +pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer}; pub use unquantized::UnquantLinear; use candle_nn::{Linear, VarBuilder}; @@ -59,6 +61,16 @@ pub enum QuantMethodConfig { b: Option, }, Unquantized(Linear), + Hqq { + tensor: Tensor, + bits: HqqBits, + group_size: NonZeroUsize, + axis: HqqAxis, + optimization_steps: Option, + round_zeros: Option, + channel_wise: Option, + bias: Option, + }, } #[derive(Clone, Copy, Debug, PartialEq)] @@ -75,6 +87,11 @@ pub enum IsqType { Q5K, Q6K, Q8K, + HQQ8, + HQQ4, + // HQQ3, + // HQQ2, + // HQQ1, } impl TryFrom for GgmlDType { @@ -94,6 +111,7 @@ impl TryFrom for GgmlDType { IsqType::Q8K => Ok(Self::Q8K), IsqType::Q8_0 => Ok(Self::Q8_0), IsqType::Q8_1 => Ok(Self::Q8_1), + _ => candle_core::bail!("Expected valid GGML ISQ type."), } } } diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs index 7a155c7c94..c1c85d34e5 100644 --- a/mistralrs-quant/src/unquantized/mod.rs +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -6,7 +6,11 @@ use std::{ use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor}; use candle_nn::{Linear, Module}; -use crate::{generate_isq, GgufMatMul, IsqType, QuantMethod, QuantMethodConfig}; +use crate::{ + generate_isq, + hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE}, + GgufMatMul, IsqType, QuantMethod, QuantMethodConfig, +}; #[derive(Debug)] pub struct UnquantLinear(Linear); @@ -17,7 +21,9 @@ impl QuantMethod for UnquantLinear { Self: Sized, { match method { - QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Gptq { .. } => unreachable!(), + QuantMethodConfig::Gguf { .. } + | QuantMethodConfig::Gptq { .. } + | QuantMethodConfig::Hqq { .. } => unreachable!(), QuantMethodConfig::Unquantized(l) => Ok(Self(l)), } } @@ -52,6 +58,35 @@ impl QuantMethod for UnquantLinear { n_quantized: &AtomicUsize, ) -> Result> { match dtype { + /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */ + IsqType::HQQ4 | IsqType::HQQ8 => { + n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let bits = match dtype { + IsqType::HQQ8 => HqqBits::Eight, + IsqType::HQQ4 => HqqBits::Four, + // IsqType::HQQ3 => HqqBits::Three, + // IsqType::HQQ2 => HqqBits::Two, + // IsqType::HQQ1 => HqqBits::One, + _ => unreachable!(), + }; + let cfg = HqqConfig { + bits, + group_size: ISQ_HQQ_GROUP_SIZE.try_into()?, + axis: HqqAxis::Zero, + optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS, + round_zeros: false, + channel_wise: true, + }; + let res = HqqLayer::quantize(&self.0.weight().to_device(&device)?, &device, cfg)?; + if let Some(bias) = self.0.bias() { + let bias = bias + .to_device(&device)? + .to_dtype(res.dtype_and_device().0)?; + Ok(Arc::new(res.with_bias(bias))) + } else { + Ok(Arc::new(res)) + } + } IsqType::Q2K | IsqType::Q3K | IsqType::Q4K @@ -80,6 +115,11 @@ impl QuantMethod for UnquantLinear { fn get_max_isq_cpu_threads(&self, dtype: IsqType) -> Option { match dtype { + /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */ + IsqType::HQQ4 | IsqType::HQQ8 => { + // Use 1 because our HQQ quantizes on the GPU + Some(1.try_into().unwrap()) + } IsqType::Q2K | IsqType::Q3K | IsqType::Q4K diff --git a/mistralrs-quant/src/utils/mod.rs b/mistralrs-quant/src/utils/mod.rs index 152a05343f..f871d8cdc5 100644 --- a/mistralrs-quant/src/utils/mod.rs +++ b/mistralrs-quant/src/utils/mod.rs @@ -1,6 +1,9 @@ #[cfg(feature = "cuda")] mod ffi; pub(crate) mod isq; +mod ops; + +pub use ops::{BitWiseOp, LeftshiftOp}; #[cfg(feature = "cuda")] use candle_core::{ diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs new file mode 100644 index 0000000000..c01e7303d5 --- /dev/null +++ b/mistralrs-quant/src/utils/ops.rs @@ -0,0 +1,385 @@ +use candle_core::{ + backend::BackendStorage, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape, + Tensor, WithDType, +}; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + +use std::ops::{BitOr, Shl}; + +#[cfg(feature = "cuda")] +use crate::utils::ffi; +#[cfg(feature = "cuda")] +use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr}; +#[cfg(feature = "cuda")] +use std::ffi::c_void; + +struct BitWiseOr; + +impl BitWiseOr { + fn bitwise>(&self, vs1: &[T], vs2: &[T]) -> Vec { + vs1.into_par_iter() + .zip_eq(vs2) + .map(|(v1, v2)| *v1 | *v2) + .collect() + } +} + +impl CustomOp2 for BitWiseOr { + fn name(&self) -> &'static str { + "bitwise-or" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)> { + if l1.shape() != l2.shape() || l1.stride() != l2.stride() { + return Err(Error::ShapeMismatchBinaryOp { + lhs: l1.shape().clone(), + rhs: l2.shape().clone(), + op: "bitwise-or", + }); + } + if s1.dtype() != s2.dtype() { + return Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: "bitwise-or", + }); + } + match s1 { + CpuStorage::U8(vs1) => { + let vs2 = &s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::U8(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "bitwise-or")), + CpuStorage::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "bitwise-or")), + CpuStorage::I32(vs1) => { + let vs2 = &s2.as_slice::().unwrap(); + let result = self.bitwise(vs1, vs2); + let result = CpuStorage::I32(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")), + CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")), + CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")), + CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")), + } + } + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &CudaStorage, + l1: &Layout, + s2: &CudaStorage, + l2: &Layout, + ) -> Result<(CudaStorage, Shape)> { + if l1.shape() != l2.shape() || l1.stride() != l2.stride() { + return Err(Error::ShapeMismatchBinaryOp { + lhs: l1.shape().clone(), + rhs: l2.shape().clone(), + op: "bitwise-or", + }); + } + if s1.dtype() != s2.dtype() { + return Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: "bitwise-or", + }); + } + let dev = s1.device().clone(); + let (d_in1_ptr, d_in2_ptr, elem_count) = match s1.dtype() { + DType::U8 => { + let d_in1_ptr = *s1 + .as_cuda_slice::()? + .slice(l1.start_offset()..) + .device_ptr() as *const c_void; + let d_in2_ptr = *s2 + .as_cuda_slice::()? + .slice(l2.start_offset()..) + .device_ptr() as *const c_void; + let elem_count = l1.shape().elem_count(); + (d_in1_ptr, d_in2_ptr, elem_count) + } + DType::U32 => { + return Err(Error::UnsupportedDTypeForOp(DType::U32, "bitwise-or")); + } + DType::I64 => { + return Err(Error::UnsupportedDTypeForOp(DType::I64, "bitwise-or")); + } + DType::I32 => { + let d_in1_ptr = *s1 + .as_cuda_slice::()? + .slice(l1.start_offset()..) + .device_ptr() as *const c_void; + let d_in2_ptr = *s2 + .as_cuda_slice::()? + .slice(l2.start_offset()..) + .device_ptr() as *const c_void; + let elem_count = l1.shape().elem_count(); + (d_in1_ptr, d_in2_ptr, elem_count) + } + DType::BF16 => { + return Err(Error::UnsupportedDTypeForOp(DType::BF16, "bitwise-or")); + } + DType::F16 => { + return Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")); + } + DType::F32 => { + return Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")); + } + DType::F64 => { + return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")); + } + }; + let dst = match s1.dtype() { + DType::U8 => { + let d_out = unsafe { dev.alloc::(elem_count) }.w()?; + let d_out_ptr = *d_out.device_ptr() as *mut c_void; + unsafe { + ffi::mq_bitwise_or_u8( + d_in1_ptr, + d_in2_ptr, + d_out_ptr, + u32::try_from(elem_count)?, + ) + }; + CudaStorage::wrap_cuda_slice(d_out, dev) + } + DType::I32 => { + let d_out = unsafe { dev.alloc::(elem_count) }.w()?; + let d_out_ptr = *d_out.device_ptr() as *mut c_void; + unsafe { + ffi::mq_bitwise_or_i32( + d_in1_ptr, + d_in2_ptr, + d_out_ptr, + u32::try_from(elem_count)?, + ) + }; + CudaStorage::wrap_cuda_slice(d_out, dev) + } + _ => unreachable!(), + }; + Ok((dst, l1.shape().clone())) + } +} + +#[allow(dead_code)] +pub trait BitWiseOp { + fn bitwise_or(&self, rhs: &Tensor) -> Result; +} + +impl BitWiseOp for Tensor { + #[cfg(feature = "metal")] + fn bitwise_or(&self, rhs: &Tensor) -> Result { + let original_device = rhs.device(); + self.to_device(&candle_core::Device::Cpu)? + .apply_op2_no_bwd(&rhs.to_device(&candle_core::Device::Cpu)?, &BitWiseOr)? + .to_device(original_device) + } + #[cfg(not(feature = "metal"))] + fn bitwise_or(&self, rhs: &Tensor) -> Result { + self.apply_op2_no_bwd(rhs, &BitWiseOr) + } +} +struct Leftshift(usize); + +impl Leftshift { + fn leftshift>(&self, vs: &[T]) -> Vec { + let offset = T::from_f64(self.0 as f64); + vs.into_par_iter().map(|v| *v << offset).collect() + } +} + +impl CustomOp1 for Leftshift { + fn name(&self) -> &'static str { + "left" + } + + fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> { + match s1 { + CpuStorage::U8(vs1) => { + let result = self.leftshift(vs1); + let result = CpuStorage::U8(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshifr")), + CpuStorage::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshifr")), + CpuStorage::I32(vs1) => { + let result = self.leftshift(vs1); + let result = CpuStorage::I32(result); + Ok((result, l1.shape().clone())) + } + CpuStorage::BF16(_) => Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshifr")), + CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")), + CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")), + CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")), + } + } + #[cfg(feature = "cuda")] + fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> { + let dev = s1.device().clone(); + let (d_in1_ptr, elem_count) = match s1.dtype() { + DType::U8 => { + let d_in1_ptr = *s1 + .as_cuda_slice::()? + .slice(l1.start_offset()..) + .device_ptr() as *const c_void; + let elem_count = l1.shape().elem_count(); + (d_in1_ptr, elem_count) + } + DType::U32 => { + return Err(Error::UnsupportedDTypeForOp(DType::U32, "leftshift")); + } + DType::I64 => { + return Err(Error::UnsupportedDTypeForOp(DType::I64, "leftshift")); + } + DType::I32 => { + let d_in1_ptr = *s1 + .as_cuda_slice::()? + .slice(l1.start_offset()..) + .device_ptr() as *const c_void; + let elem_count = l1.shape().elem_count(); + (d_in1_ptr, elem_count) + } + DType::BF16 => { + return Err(Error::UnsupportedDTypeForOp(DType::BF16, "leftshift")); + } + DType::F16 => { + return Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshift")); + } + DType::F32 => { + return Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshift")); + } + DType::F64 => { + return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift")); + } + }; + let dst = match s1.dtype() { + DType::U8 => { + let d_out = unsafe { dev.alloc::(elem_count) }.w()?; + let d_out_ptr = *d_out.device_ptr() as *mut c_void; + unsafe { + ffi::mq_leftshift_u8( + d_in1_ptr, + d_out_ptr, + u32::try_from(elem_count)?, + self.0 as i32, + ) + }; + CudaStorage::wrap_cuda_slice(d_out, dev) + } + DType::I32 => { + let d_out = unsafe { dev.alloc::(elem_count) }.w()?; + let d_out_ptr = *d_out.device_ptr() as *mut c_void; + unsafe { + ffi::mq_leftshift_i32( + d_in1_ptr, + d_out_ptr, + u32::try_from(elem_count)?, + self.0 as i32, + ) + }; + CudaStorage::wrap_cuda_slice(d_out, dev) + } + _ => unreachable!(), + }; + Ok((dst, l1.shape().clone())) + } +} + +#[allow(dead_code)] +pub trait LeftshiftOp { + fn leftshift(&self, n: usize) -> Result; +} + +impl LeftshiftOp for Tensor { + #[cfg(feature = "metal")] + fn leftshift(&self, n: usize) -> Result { + let original_device = rhs.device(); + self.to_device(&candle_core::Device::Cpu)? + .apply_op2_no_bwd(&Leftshift(n))? + .to_device(original_device) + } + #[cfg(not(feature = "metal"))] + fn leftshift(&self, n: usize) -> Result { + self.apply_op1_no_bwd(&Leftshift(n)) + } +} + +mod tests { + #[test] + fn test_bitwise_or_cpu() { + use crate::utils::ops::BitWiseOp; + use candle_core::Tensor; + let device = candle_core::Device::Cpu; + let a = + Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); + let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); + let c = a.bitwise_or(&b).unwrap().to_vec2::().unwrap(); + assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]); + } + + #[cfg(feature = "cuda")] + #[test] + fn test_bitwise_or_cuda() { + use crate::utils::ops::BitWiseOp; + use candle_core::Tensor; + let device = candle_core::Device::new_cuda(0).unwrap(); + let a = + Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); + let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); + let c = a.bitwise_or(&b).unwrap().to_vec2::().unwrap(); + assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]); + } + + #[test] + fn test_leftshift_cpu() { + use crate::utils::ops::LeftshiftOp; + use candle_core::Tensor; + let device = candle_core::Device::Cpu; + let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap(); + let c = a.leftshift(2).unwrap().to_vec2::().unwrap(); + assert_eq!(c, [[4, 8], [12, 16], [20, 24]]); + } + + #[cfg(feature = "cuda")] + #[test] + fn test_leftshift_cuda() { + use crate::utils::ops::LeftshiftOp; + use candle_core::Tensor; + let device = candle_core::Device::new_cuda(0).unwrap(); + let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap(); + let c = a.leftshift(2).unwrap().to_vec2::().unwrap(); + assert_eq!(c, [[4, 8], [12, 16], [20, 24]]); + } + + #[cfg(feature = "cuda")] + #[test] + fn test_bitwise_or_and_leftshift_cuda() { + use crate::utils::{ops::BitWiseOp, LeftshiftOp}; + use candle_core::Tensor; + let device = candle_core::Device::new_cuda(0).unwrap(); + let a = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap(); + let b = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap(); + let c = a + .leftshift(4) + .unwrap() + .bitwise_or(&b) + .unwrap() + .to_vec1::() + .unwrap(); + let av = a.to_vec1::().unwrap(); + let bv = b.to_vec1::().unwrap(); + assert_eq!(av, [0b00001111]); + assert_eq!(bv, [0b00001111]); + assert_eq!(c, [0b11111111]); + } +}