Skip to content

Commit

Permalink
k_norm in rope for fp8 kv cache (#3633)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#709

Pull Request resolved: #3633

Zero init of dequantized kv in prefill (D68574038), prefill FA3:
GQA decode (USE_MQA_ATTN=1): Evaluation results on task nq.0_shot: em: 0.146814 | f1: 0.282832
Triton Split-k: Evaluation results on task nq.0_shot: em: 0.147368 | f1: 0.284228

BF16, prefill FA3 :
Evaluation results on task nq.0_shot: em: 0.148753 | f1: 0.285131

Reviewed By: jianyuh

Differential Revision: D68815109

fbshipit-source-id: d1da9baaf5fd34b705b793173c2a1e6c04d7901d
  • Loading branch information
ayaIbrah authored and facebook-github-bot committed Feb 4, 2025
1 parent f18ec58 commit 6b5568a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 32 deletions.
16 changes: 10 additions & 6 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ at::Tensor rope_qkv_varseq_prefill(
double hi_freq_factor,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v,
bool write_k_back);
bool write_k_back,
bool k_rms_norm);

at::Tensor rope_qkv_decoding(
at::Tensor XQ,
Expand All @@ -95,7 +96,8 @@ at::Tensor rope_qkv_decoding(
double lo_freq_factor,
double hi_freq_factor,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v);
std::optional<at::Tensor> qparam_v,
bool k_rms_norm);

at::Tensor xpos_qkv_varseq_prefill(
at::Tensor XQ,
Expand Down Expand Up @@ -175,9 +177,9 @@ at::Tensor mqa_attn(
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!) XK, Tensor XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False) -> Tensor");
", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_rms_norm=False) -> Tensor");
m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_rms_norm=False) -> Tensor");
m.def(
"nope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None) -> Tensor");
Expand Down Expand Up @@ -237,7 +239,8 @@ at::Tensor rope_qkv_varseq_prefill_meta(
double /* hi_freq_factor */,
std::optional<at::Tensor> /* qparam_k */,
std::optional<at::Tensor> /* qparam_v */,
bool /* write_k_back */
bool /* write_k_back */,
bool /* k_rms_norm */
) {
return at::empty_like(XQ);
}
Expand All @@ -263,7 +266,8 @@ at::Tensor rope_qkv_decoding_meta(
double /* lo_freq_factor */,
double /* hi_freq_factor */,
std::optional<at::Tensor> /* qparam_k */,
std::optional<at::Tensor> /* qparam_v */
std::optional<at::Tensor> /* qparam_v */,
bool /* k_rms_norm */
) {
return at::empty_like(XQ);
}
Expand Down
77 changes: 51 additions & 26 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ __device__ void get_dst_row(

enum class PositionEmbeddingMode { ROPE = 0, XPOS = 1 };
enum class QKV { Q, K, V };
DEVICE_INLINE void
quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q, __half2* qparam = nullptr);
DEVICE_INLINE void quantize_fp8_kv(
fx4 dst,
uint8_t* dst_row_q,
__half2* qparam = nullptr,
bool do_rms_norm = false);

__global__ void nope_qkv_varseq_prefill_kernel(
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
Expand Down Expand Up @@ -711,7 +714,8 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
old_context_len, \
scaling_factor, \
lo_freq_factor, \
hi_freq_factor) \
hi_freq_factor, \
k_rms_norm) \
rope_xpos_qkv_varseq_prefill_kernel_<EMB_MODE, DTYPE, NUM_GROUPS> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(), \
Expand All @@ -737,7 +741,8 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
old_context_len, \
scaling_factor, \
lo_freq_factor, \
hi_freq_factor);
hi_freq_factor, \
k_rms_norm);

#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
(defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
Expand Down Expand Up @@ -799,7 +804,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
int64_t old_context_len = 8192,
double scaling_factor = 16,
double lo_freq_factor = 1,
double hi_freq_factor = 32) {
double hi_freq_factor = 32,
bool k_rms_norm = false) {
// Launch b_t_(sum(h)) warps.
auto b_t_hh = blockIdx.x * blockDim.y +
threadIdx.y; // Block = [kThreadsPerWarp, kWarpsPerBlock]
Expand Down Expand Up @@ -904,7 +910,7 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
if (kCacheDtype == CacheLogicalDtype::FP8) {
if (qparam_k_ptr == nullptr) {
CUDA_KERNEL_ASSERT(D_H_q - D_H == 4);
quantize_fp8_kv(dst, dst_row_q);
quantize_fp8_kv(dst, dst_row_q, nullptr, (qkv == QKV::K && k_rms_norm));
} else {
__half2* qparam_row = nullptr;
auto T = cache_K.size(1);
Expand All @@ -930,7 +936,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
}
}
quantize_fp8_kv(dst, dst_row_q, qparam_row);
quantize_fp8_kv(
dst, dst_row_q, qparam_row, (qkv == QKV::K && k_rms_norm));
}

} else if (kCacheDtype == CacheLogicalDtype::INT4) {
Expand Down Expand Up @@ -1087,7 +1094,8 @@ at::Tensor rope_qkv_varseq_prefill(
double hi_freq_factor = 32,
std::optional<at::Tensor> qparam_k = {},
std::optional<at::Tensor> qparam_v = {},
bool write_k_back = false) {
bool write_k_back = false,
bool k_rms_norm = false) {
auto B_T = XQ.size(0);
auto N_H = XQ.size(1);
auto N_KVH = XK.size(1);
Expand Down Expand Up @@ -1181,7 +1189,8 @@ at::Tensor rope_qkv_varseq_prefill(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
k_rms_norm);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
throw std::runtime_error("CUDA version is older than 12.0");
Expand All @@ -1208,7 +1217,8 @@ at::Tensor rope_qkv_varseq_prefill(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -1293,8 +1303,7 @@ at::Tensor xpos_qkv_varseq_prefill(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false);
hi_freq_factor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto num_groups_ = num_groups ? num_groups.value() : 1;
Expand Down Expand Up @@ -1331,7 +1340,8 @@ at::Tensor xpos_qkv_varseq_prefill(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
Expand Down Expand Up @@ -1359,7 +1369,8 @@ at::Tensor xpos_qkv_varseq_prefill(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -1388,7 +1399,8 @@ at::Tensor rope_qkv_decoding(
double lo_freq_factor = 1,
double hi_freq_factor = 32,
std::optional<at::Tensor> qparam_k = {},
std::optional<at::Tensor> qparam_v = {}) {
std::optional<at::Tensor> qparam_v = {},
bool k_rms_norm = false) {
auto B = XQ.size(0);
auto N_H = XQ.size(1);
auto N_KVH = XK.size(1);
Expand Down Expand Up @@ -1441,8 +1453,7 @@ at::Tensor rope_qkv_decoding(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false);
hi_freq_factor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto seqpos_ =
Expand Down Expand Up @@ -1478,7 +1489,8 @@ at::Tensor rope_qkv_decoding(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
k_rms_norm);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
Expand Down Expand Up @@ -1506,7 +1518,8 @@ at::Tensor rope_qkv_decoding(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -1592,8 +1605,7 @@ at::Tensor xpos_qkv_decoding(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false);
hi_freq_factor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
auto num_groups_ = num_groups ? num_groups.value() : 1;
Expand Down Expand Up @@ -1629,7 +1641,8 @@ at::Tensor xpos_qkv_decoding(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
throw std::runtime_error("CUDA version is older than 12.0");
Expand All @@ -1656,7 +1669,8 @@ at::Tensor xpos_qkv_decoding(
old_context_len,
scaling_factor,
lo_freq_factor,
hi_freq_factor);
hi_freq_factor,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
Expand Down Expand Up @@ -1931,8 +1945,18 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
return {cache_K_dq, cache_V_dq};
}
DEVICE_INLINE void
quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q, __half2* qparam) {
DEVICE_INLINE void quantize_fp8_kv(
fx4 dst,
uint8_t* dst_row_q,
__half2* qparam,
bool do_rms_norm) {
if (do_rms_norm) {
float sum = fx4_dot(dst, dst);
// Warp reduce sum
sum = warpReduceSum(sum);
float rsqr = rsqrtf(sum / 128);
dst = fx4_scale(dst, rsqr);
}
auto thread_min = fminf(fminf(fminf(dst.x, dst.y), dst.z), dst.w);
auto thread_max = fmaxf(fmaxf(fmaxf(dst.x, dst.y), dst.z), dst.w);
Expand Down Expand Up @@ -1994,7 +2018,8 @@ quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q, __half2* qparam) {
}
#else
DEVICE_INLINE void
quantize_fp8_kv(fx4 dst, uint8_t* dst_row_, __half2* qparam) {}
quantize_fp8_kv(fx4 dst, uint8_t* dst_row_, __half2* qparam, bool do_rms_norm) {
}
std::vector<at::Tensor> quantize_fp8_per_tensor(
at::Tensor input,
std::optional<at::Tensor> bs, // batch size
Expand Down
11 changes: 11 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ DEVICE_INLINE fx4 fx4_acc(fx4 a, fx4 b) {
a.w += b.w;
return a;
}
DEVICE_INLINE float fx4_dot(fx4 a, fx4 b) {
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}

DEVICE_INLINE fx4 fx4_scale(fx4 a, float scale) {
a.x *= scale;
a.y *= scale;
a.z *= scale;
a.w *= scale;
return a;
}

DEVICE_INLINE bfx4 fx4_to_bfx4(fx4 a) {
bfx4 r;
Expand Down

0 comments on commit 6b5568a

Please sign in to comment.