diff --git a/nvidia/common/src/paged_attention.cuh b/nvidia/common/src/paged_attention.cuh new file mode 100644 index 00000000..ab062ccf --- /dev/null +++ b/nvidia/common/src/paged_attention.cuh @@ -0,0 +1,176 @@ +#pragma once + +#include +#include + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { dst = float_to_half(src); } + +inline __device__ void from_float(uint32_t& dst, float2 src) { dst = float2_to_half2(src); } + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return fma(h0_h0(a), b, c); } + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return fma(h0_h0(a), b, fc); } + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; diff --git a/nvidia/common/src/paged_attention.rs b/nvidia/common/src/paged_attention.rs new file mode 100644 index 00000000..14e4d349 --- /dev/null +++ b/nvidia/common/src/paged_attention.rs @@ -0,0 +1,366 @@ +use cuda::{ + bindings::CUdeviceptr, AsRaw, ContextGuard, ContextResource, ContextSpore, CudaDataType, + DevSlice, ModuleSpore, Ptx, Stream, +}; +use std::{ + ffi::{c_uint, c_void, CString}, + ops::{Deref, DerefMut}, +}; +use tensor::{udim, Tensor}; + +pub struct PagedAttention { + module: ModuleSpore, + block_size: c_uint, +} + +impl PagedAttention { + pub fn new( + data_type: CudaDataType, + head_size: usize, + block_size: usize, + num_threads: usize, + scalar_t: CudaDataType, + cache_t: CudaDataType, + ctx: &ContextGuard, + ) -> Self { + let ty_arg = data_type.name(); + let items_per_thread = (max_item_size + block_size - 1) / block_size; + + const PAGED_ATTENTION: &str = include_str!("paged_attention.cuh"); + let code = format!( + r#"{PAGED_ATTENTION} + + #define MAX(a, b) ((a) > (b) ? (a) : (b)) + #define MIN(a, b) ((a) < (b) ? (a) : (b)) + #define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) + + __device__ void paged_attention_kernel(const {scalar_t}* __restrict__ out, const {scalar_t}* __restrict__ q, + const {scalar_t}* __restrict__ k_cache, const {scalar_t}* __restrict__ v_cache, + const int num_kv_heads, const float scale, const int* __restrict__ block_tables, + const int* __restrict__ past_seq_lens, const int max_num_blocks_per_seq, + const int q_stride, const int kv_block_stride, const int kv_head_stride) {{ + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + const int past_seq_len = past_seq_lens[seq_idx]; + + const int num_seq_blocks = DIVIDE_ROUND_UP(past_seq_len, {block_size}); + + const int start_block_idx = 0; + const int end_block_idx = num_seq_blocks; + const int num_tokens = end_token_idx - start_block_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / {block_size}, 1); + constexpr int NUM_THREAD_GROUPS = {num_threads} / THREAD_GROUP_SIZE; + assert({num_threads} % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP({block_size}, WARP_SIZE); + constexpr int NUM_WARPS = {num_threads} / WARP_SIZE; + const int thread_idx = thread_idx / WARP_SIZE; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof({scalar_t})), 1); + using K_vec = typename Vec<{scalar_t}, VEC_SIZE>::Type; + using Q_vec = typename Vec<{scalar_t}, VEC_SIZE>::Type; + + constexpr int NUM_ELEMS_PER_THREAD = {head_size} / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + const {scalar_t}* q_ptr = q + seq_idx * q_stride + head_idx * {head_size}; + + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; + + #pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {{ + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + }} + __syncthreads(); + + extern __shared__ char shared_mem[]; + float* logits = reinterpret_cast(shared_mem); + __shared__ float red_smem[2 * NUM_WARPS]; + + constexpr int x = 16 / sizeof({cache_t}); + float qk_max = -FLT_MEX; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {{ + const int64_t physical_block_number = static_cast(block_table[block_idx]); + + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {{ + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % {block_size}; + const int token_idx = block_idx * {block_size} + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + + #pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {{ + const {scalar_t}* k_ptr = k_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * {block_size} * x + offset2); + }} + + // QK + float qk = scale * Qk_dot<{scalar_t}, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); + + // Softmax + if (thread_group_offset == 0) {{ + const bool mask = token_idx >= past_seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + }} + }} + }} + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {{ + qk_max = fmax(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + }} + if (lane == 0) {{ + red_smem[warp_idx] = qk_max; + }} + __syncthreads(); + + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + #pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {{ + qk_max = fmax(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + }} + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += {num_threads}) {{ + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + }} + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += {num_threads}) {{ + logits[i] *= inv_sum; + }} + __syncthreads(); + + // Value + constexpr int V_VEC_SIZE = MIN(16 / sizeof({scalar_t}), {block_size}); + using V_vec = typename Vec<{scalar_t}, V_VEC_SIZE>::Type; + using L_vec = typename Vec<{scalar_t}, V_VEC_SIZE>::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = {block_size} / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP({head_size}, NUM_ROWS_PER_ITER); + + float accs[NUM_ROWS_PER_THREAD]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {{ + accs[i] = 0.f; + }} + + {scalar_t} zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {{ + const int64_t physical_block_number = static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * {block_size} + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + + const {cache_t}* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; + + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {{ + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < {head_size}) {{ + const int offset = row_idx * {block_size} + physical_block_offset; + V_vec v_vec; + v_vec = *reinterpret_cast(v_ptr, offset); + if (block_idx = num_seq_blocks - 1) {{ + {scalar_t}* v_vec_ptr = reinterpret_cast<{scalar_t}*>(&v_vec); + #pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) {{ + v_vec_ptr[j] = token_idx + j < past_seq_len ? v_vec_ptr[j] : zero_value; + }} + }} + accs[i] += dot(logits_vec, v_vec); + }} + }} + }} + + // LV + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {{ + float acc = accs[i]; + #pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {{ + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + }} + accs[i] = acc; + }} + + __syncthreads(); + + float* out_smem = reinterpret_cast(shared_mem); + #pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) {{ + int mid = i / 2; + if (warp_idx >= mid && warp_idx < i) {{ + float* dst = &out_smem[(warp_idx - mid) * {head_size}]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {{ + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < {head_size} && lane % NUM_V_VECS_PER_ROW == 0) {{ + dst[row_idx] = accs[i]; + }} + }} + }} + __syncthreads(); + + if (warp_idx < mid) {{ + const float* src = &out_smem[warp_idx * {head_size}]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {{ + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < {head_size} && lane % NUM_V_VECS_PER_ROW == 0) {{ + accs[i] += src[row_idx]; + }} + }} + }} + __syncthreads(); + }} + + // Output + if (warp_idx == 0) {{ + {scalar_t}* out_ptr = out + seq_idx * num_heads * max_num_partitions * {head_size} + + head_idx * max_num_partitions * {head_size} + partition_idx * {head_size}; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {{ + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < {head_size} && lane % NUM_V_VECS_PER_ROW == 0) {{ + from_float(*(out_ptr + row_idx), accs[i]); + }} + }} + }} + }} +"# + ); + + let (ptx, log) = Ptx::compile(code); + if !log.is_empty() { + warn!("{log}"); + } + Self { + module: ctx.load(&ptx.unwrap()).sporulate(), + block_size: block_size as _, + items_per_thread: items_per_thread as _, + } + } + + pub fn launch( + &self, + out: &mut Tensor, + query: &Tensor, + key_cache: &Tensor, + value_cache: &Tensor, + num_kv_heads: int, + scale: f32, + block_tables: Tensor, + seq_lens: Tensor, + block_size: int, + max_seq_len: int, + stream: &Stream, + ) where + OUT_T: DerefMut, + Q_T: Deref, + K_T: Deref, + V_T: Deref, + BLOCK_TABLES_T: Deref, + SEQ_LENS_T: Deref, + { + let num_seqs = query.shape()[0]; + let num_heads = query.shape()[1]; + let head_size = query.shape()[2]; + let max_num_blocks_per_seq = block_tables.shape()[1]; + let q_stride = query.stride()[0]; + let kv_block_stride = key_cache.stride()[0]; + let kv_head_stride = key_cache.stride()[1]; + + let thread_group_size = WARP_SIZE / block_size.max(1); + assert_eq!(head_size % thread_group_size, 0); + + let mut out_ptr = (out.physical().as_ptr() as isize + out.bytes_offset()) as CUdeviceptr; + let query_ptr = (query.physical().as_ptr() as isize + query.bytes_offset()) as CUdeviceptr; + let key_cache_ptr = (key_cache.physical().as_ptr() as isize + key_cache.bytes_offset()) as CUdeviceptr; + let value_cache_ptr = (value_cache.physical().as_ptr() as isize + value_cache.bytes_offset()) as CUdeviceptr; + let block_tables_ptr = (block_tables.physical().as_ptr() as isize + block_tables.bytes_offset()) as CUdeviceptr; + let seq_lens_ptr = (seq_lens.physical().as_ptr() as isize + seq_lens.bytes_offset()) as CUdeviceptr; + + const NUM_WARPS: usize = num_threads / WARP_SIZE; + let padded_max_seq_len = (max_seq_len + block_size - 1) / block_size * block_size; + let logits_size = padded_max_seq_len * std::mem::size_of::(); + let outputs_size = NUM_WARPS / 2 * head_size * std::mem::size_of::(); + let shared_mem_size = logits_size.max(outputs_size); + + let grid = (num_heads, num_seqs, 1).into(); + let block = (num_threads, 1, 1).into(); + let _device_guard = device_guard(device_of(&query)); + let stream = at::cuda::getCurrentCUDAStream(); + + let params: [*const c_void; 11] = [ + (&out_ptr) as *const _ as _, + (&query_ptr) as *const _ as _, + (&key_cache_ptr) as *const _ as _, + (&value_cache_ptr) as *const _ as _, + (&num_kv_heads) as *const _ as _, + (&scale) as *const _ as _, + (&block_tables_ptr) as *const _ as _, + (&seq_lens_ptr) as *const _ as _, + (&max_num_blocks_per_seq) as *const _ as _, + (&kv_block_stride) as *const _ as _, + (&kv_head_stride) as *const _ as _, + ]; + + // 如果想声明超过48KB的共享内存,必须使用 dynamic shared memory 且要设置 cudaFuncAttributeMaxDynamicSharedMemorySize + cuFuncSetAttribute(paged_attention_v1_kernel, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_mem_size); + + let module = unsafe { self.module.sprout(stream.ctx()) }; + if items_len <= self.block_size { + let kernel = module.get_kernel(&self.padding); + kernel.launch(row, items_len, params.as_ptr(), 0, Some(stream)); + } else { + let block_size = (items_len + self.items_per_thread - 1) / self.items_per_thread; + let kernel = module.get_kernel(&self.folding); + kernel.launch(row, block_size, params.as_ptr(), 0, Some(stream)); + } + + // match head_size { + // 64 => LAUNCH_PAGED_ATTENTION_V1(64), + // 80 => LAUNCH_PAGED_ATTENTION_V1(80), + // 96 => LAUNCH_PAGED_ATTENTION_V1(96), + // 112 => LAUNCH_PAGED_ATTENTION_V1(112), + // 128 => LAUNCH_PAGED_ATTENTION_V1(128), + // 256 => LAUNCH_PAGED_ATTENTION_V1(256), + // _ => { + // panic!("Unsupported head size: {}", head_size); + // } + // } + } + + #[inline] + pub fn kill(&mut self, ctx: &ContextGuard) { + unsafe { self.module.kill(ctx) }; + } +}