From 6b5568aae1b14fdf0562f87cc1af02785873ea1e Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Mon, 3 Feb 2025 17:51:05 -0800 Subject: [PATCH] k_norm in rope for fp8 kv cache (#3633) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/709 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3633 Zero init of dequantized kv in prefill (D68574038), prefill FA3: GQA decode (USE_MQA_ATTN=1): Evaluation results on task nq.0_shot: em: 0.146814 | f1: 0.282832 Triton Split-k: Evaluation results on task nq.0_shot: em: 0.147368 | f1: 0.284228 BF16, prefill FA3 : Evaluation results on task nq.0_shot: em: 0.148753 | f1: 0.285131 Reviewed By: jianyuh Differential Revision: D68815109 fbshipit-source-id: d1da9baaf5fd34b705b793173c2a1e6c04d7901d --- .../gen_ai/src/kv_cache/kv_cache.cpp | 16 ++-- .../gen_ai/src/kv_cache/kv_cache.cu | 77 ++++++++++++------- .../include/fbgemm_gpu/utils/vec_quant.cuh | 11 +++ 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp index 5d97e0e994..14b5f81a59 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp @@ -72,7 +72,8 @@ at::Tensor rope_qkv_varseq_prefill( double hi_freq_factor, std::optional qparam_k, std::optional qparam_v, - bool write_k_back); + bool write_k_back, + bool k_rms_norm); at::Tensor rope_qkv_decoding( at::Tensor XQ, @@ -95,7 +96,8 @@ at::Tensor rope_qkv_decoding( double lo_freq_factor, double hi_freq_factor, std::optional qparam_k, - std::optional qparam_v); + std::optional qparam_v, + bool k_rms_norm); at::Tensor xpos_qkv_varseq_prefill( at::Tensor XQ, @@ -175,9 +177,9 @@ at::Tensor mqa_attn( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!) XK, Tensor XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192" - ", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False) -> Tensor"); + ", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_rms_norm=False) -> Tensor"); m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( - DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor"); + DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_rms_norm=False) -> Tensor"); m.def( "nope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING( DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None) -> Tensor"); @@ -237,7 +239,8 @@ at::Tensor rope_qkv_varseq_prefill_meta( double /* hi_freq_factor */, std::optional /* qparam_k */, std::optional /* qparam_v */, - bool /* write_k_back */ + bool /* write_k_back */, + bool /* k_rms_norm */ ) { return at::empty_like(XQ); } @@ -263,7 +266,8 @@ at::Tensor rope_qkv_decoding_meta( double /* lo_freq_factor */, double /* hi_freq_factor */, std::optional /* qparam_k */, - std::optional /* qparam_v */ + std::optional /* qparam_v */, + bool /* k_rms_norm */ ) { return at::empty_like(XQ); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index e282730ddc..cebd985712 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -210,8 +210,11 @@ __device__ void get_dst_row( enum class PositionEmbeddingMode { ROPE = 0, XPOS = 1 }; enum class QKV { Q, K, V }; -DEVICE_INLINE void -quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q, __half2* qparam = nullptr); +DEVICE_INLINE void quantize_fp8_kv( + fx4 dst, + uint8_t* dst_row_q, + __half2* qparam = nullptr, + bool do_rms_norm = false); __global__ void nope_qkv_varseq_prefill_kernel( at::PackedTensorAccessor32 @@ -711,7 +714,8 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) { old_context_len, \ scaling_factor, \ lo_freq_factor, \ - hi_freq_factor) \ + hi_freq_factor, \ + k_rms_norm) \ rope_xpos_qkv_varseq_prefill_kernel_ \ <<>>( \ XQ.packed_accessor32(), \ @@ -737,7 +741,8 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) { old_context_len, \ scaling_factor, \ lo_freq_factor, \ - hi_freq_factor); + hi_freq_factor, \ + k_rms_norm); #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) @@ -799,7 +804,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_( int64_t old_context_len = 8192, double scaling_factor = 16, double lo_freq_factor = 1, - double hi_freq_factor = 32) { + double hi_freq_factor = 32, + bool k_rms_norm = false) { // Launch b_t_(sum(h)) warps. auto b_t_hh = blockIdx.x * blockDim.y + threadIdx.y; // Block = [kThreadsPerWarp, kWarpsPerBlock] @@ -904,7 +910,7 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_( if (kCacheDtype == CacheLogicalDtype::FP8) { if (qparam_k_ptr == nullptr) { CUDA_KERNEL_ASSERT(D_H_q - D_H == 4); - quantize_fp8_kv(dst, dst_row_q); + quantize_fp8_kv(dst, dst_row_q, nullptr, (qkv == QKV::K && k_rms_norm)); } else { __half2* qparam_row = nullptr; auto T = cache_K.size(1); @@ -930,7 +936,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_( qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); } } - quantize_fp8_kv(dst, dst_row_q, qparam_row); + quantize_fp8_kv( + dst, dst_row_q, qparam_row, (qkv == QKV::K && k_rms_norm)); } } else if (kCacheDtype == CacheLogicalDtype::INT4) { @@ -1087,7 +1094,8 @@ at::Tensor rope_qkv_varseq_prefill( double hi_freq_factor = 32, std::optional qparam_k = {}, std::optional qparam_v = {}, - bool write_k_back = false) { + bool write_k_back = false, + bool k_rms_norm = false) { auto B_T = XQ.size(0); auto N_H = XQ.size(1); auto N_KVH = XK.size(1); @@ -1181,7 +1189,8 @@ at::Tensor rope_qkv_varseq_prefill( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + k_rms_norm); C10_CUDA_KERNEL_LAUNCH_CHECK(); #else throw std::runtime_error("CUDA version is older than 12.0"); @@ -1208,7 +1217,8 @@ at::Tensor rope_qkv_varseq_prefill( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + false); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1293,8 +1303,7 @@ at::Tensor xpos_qkv_varseq_prefill( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor, - false); + hi_freq_factor); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto num_groups_ = num_groups ? num_groups.value() : 1; @@ -1331,7 +1340,8 @@ at::Tensor xpos_qkv_varseq_prefill( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + false); C10_CUDA_KERNEL_LAUNCH_CHECK(); #else @@ -1359,7 +1369,8 @@ at::Tensor xpos_qkv_varseq_prefill( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + false); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1388,7 +1399,8 @@ at::Tensor rope_qkv_decoding( double lo_freq_factor = 1, double hi_freq_factor = 32, std::optional qparam_k = {}, - std::optional qparam_v = {}) { + std::optional qparam_v = {}, + bool k_rms_norm = false) { auto B = XQ.size(0); auto N_H = XQ.size(1); auto N_KVH = XK.size(1); @@ -1441,8 +1453,7 @@ at::Tensor rope_qkv_decoding( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor, - false); + hi_freq_factor); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto seqpos_ = @@ -1478,7 +1489,8 @@ at::Tensor rope_qkv_decoding( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + k_rms_norm); C10_CUDA_KERNEL_LAUNCH_CHECK(); #else @@ -1506,7 +1518,8 @@ at::Tensor rope_qkv_decoding( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + false); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1592,8 +1605,7 @@ at::Tensor xpos_qkv_decoding( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor, - false); + hi_freq_factor); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto num_groups_ = num_groups ? num_groups.value() : 1; @@ -1629,7 +1641,8 @@ at::Tensor xpos_qkv_decoding( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + false); C10_CUDA_KERNEL_LAUNCH_CHECK(); #else throw std::runtime_error("CUDA version is older than 12.0"); @@ -1656,7 +1669,8 @@ at::Tensor xpos_qkv_decoding( old_context_len, scaling_factor, lo_freq_factor, - hi_freq_factor); + hi_freq_factor, + false); C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -1931,8 +1945,18 @@ std::tuple dequantize_fp8_cache( return {cache_K_dq, cache_V_dq}; } -DEVICE_INLINE void -quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q, __half2* qparam) { +DEVICE_INLINE void quantize_fp8_kv( + fx4 dst, + uint8_t* dst_row_q, + __half2* qparam, + bool do_rms_norm) { + if (do_rms_norm) { + float sum = fx4_dot(dst, dst); + // Warp reduce sum + sum = warpReduceSum(sum); + float rsqr = rsqrtf(sum / 128); + dst = fx4_scale(dst, rsqr); + } auto thread_min = fminf(fminf(fminf(dst.x, dst.y), dst.z), dst.w); auto thread_max = fmaxf(fmaxf(fmaxf(dst.x, dst.y), dst.z), dst.w); @@ -1994,7 +2018,8 @@ quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q, __half2* qparam) { } #else DEVICE_INLINE void -quantize_fp8_kv(fx4 dst, uint8_t* dst_row_, __half2* qparam) {} +quantize_fp8_kv(fx4 dst, uint8_t* dst_row_, __half2* qparam, bool do_rms_norm) { +} std::vector quantize_fp8_per_tensor( at::Tensor input, std::optional bs, // batch size diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh index d61285b7ff..087bd8e426 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh @@ -180,6 +180,17 @@ DEVICE_INLINE fx4 fx4_acc(fx4 a, fx4 b) { a.w += b.w; return a; } +DEVICE_INLINE float fx4_dot(fx4 a, fx4 b) { + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +DEVICE_INLINE fx4 fx4_scale(fx4 a, float scale) { + a.x *= scale; + a.y *= scale; + a.z *= scale; + a.w *= scale; + return a; +} DEVICE_INLINE bfx4 fx4_to_bfx4(fx4 a) { bfx4 r;