Skip to content

Commit

Permalink
refactor: reduce binary size by making kv_layout an argument instea…
Browse files Browse the repository at this point in the history
…d of template parameter (flashinfer-ai#370)

This PR reduces binary size by half, by moving `kv_layout` from template
parameter to input argument.

This PR also adds `stride_n` and `stride_h` fields to `tensor_info_t`
and `paged_kv_t`, thus making it possible to support non-contiguous
inputs (flashinfer-ai#311 ), however, I'll leave it for another PR.
  • Loading branch information
yzh119 authored Jul 12, 2024
1 parent c69cfab commit 024a79f
Show file tree
Hide file tree
Showing 35 changed files with 865 additions and 1,042 deletions.
205 changes: 96 additions & 109 deletions CMakeLists.txt

Large diffs are not rendered by default.

54 changes: 25 additions & 29 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
/*!
* \brief FlashAttention decoding cuda kernel with kv-cache for a single request
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam kv_layout The layout of k/v matrices (NHD or HND)
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
* \tparam vec_size A template integer indicates the vector size
Expand All @@ -207,14 +206,12 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
* of "theta" used in RoPE (Rotary Positional Embeddings)
* \param kv_chunk_size A integer indicates the kv-chunk size
*/
template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, bool partition_kv,
PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem, uint32_t tile_size_per_bdx,
uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeQ,
typename DTypeKV, typename DTypeOut>
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, typename DTypeQ, typename DTypeKV, typename DTypeOut>
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
float* __restrict__ lse,
tensor_info_t<kv_layout, bdx * vec_size> info,
float* __restrict__ lse, tensor_info_t info,
float logits_soft_cap, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta,
uint32_t kv_chunk_size) {
Expand Down Expand Up @@ -386,11 +383,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
*/
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, QKVLayout kv_layout,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
uint32_t bdy, uint32_t bdz, PageStorage page_storage, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
Expand Down Expand Up @@ -619,20 +616,19 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo
* \param num_kv_heads A integer indicates the number of heads of key and value
* \param seq_len A integer indicates the sequence length
* \param head_dim A integer indicates the head dimension
* \param kv_layout The layout of q/k/v matrices
* \param pos_encoding_mode The positional encoding mode
* \param rope_scale The scaling factor used in RoPE Interpolation
* \param rope_theta The theta used in RoPE
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, PosEncodingMode POS_ENCODING_MODE,
typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t seq_len,
float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta,
QKVLayout kv_layout, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
Expand All @@ -645,17 +641,17 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
constexpr uint32_t num_threads =
std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
tensor_info_t<KV_LAYOUT, HEAD_DIM> info(1, seq_len, num_qo_heads, num_kv_heads);
tensor_info_t info(1, seq_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U;
const uint32_t smem_size =
2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) +
2U * bdy * bdz * sizeof(float);
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand All @@ -677,9 +673,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
} else {
// use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand Down Expand Up @@ -722,10 +718,10 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
}

template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
Expand Down Expand Up @@ -754,8 +750,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, kv_layout,
DTypeQ, DTypeKV, DTypeOut, IdType>;
vec_size, bdx, bdy, bdz, page_storage, DTypeQ, DTypeKV,
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
Expand All @@ -775,8 +771,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
auto partition_kv_kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, kv_layout,
DTypeQ, DTypeKV, DTypeOut, IdType>;
vec_size, bdx, bdy, bdz, page_storage, DTypeQ, DTypeKV,
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
Expand Down
23 changes: 12 additions & 11 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ namespace flashinfer {

template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, QKVLayout kv_layout,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
uint32_t bdy, uint32_t bdz, PageStorage page_storage, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta);
Expand Down Expand Up @@ -138,8 +138,8 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
LogitsPostHook LOGITS_POST_HOOK, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ,
typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr_h, const uint32_t num_qo_heads,
Expand All @@ -161,7 +161,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK,
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
DTypeQ, DTypeKV, DTypeOut, IdType>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
Expand Down Expand Up @@ -308,8 +308,8 @@ class BatchDecodeHandler {
bool* GetBlockValidMask() const { return block_valid_mask_; }

template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ,
typename DTypeKV, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut,
typename IdType>
cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr_h,
IdType* last_page_len_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads,
Expand All @@ -318,9 +318,10 @@ class BatchDecodeHandler {
bool split_kv;
uint32_t max_grid_size, max_num_pages_per_batch, new_batch_size;
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE,
DTypeQ, DTypeKV, DTypeOut, IdType>;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
LOGITS_POST_HOOK, POS_ENCODING_MODE,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
work_estimation_func(split_kv, max_grid_size, max_num_pages_per_batch, new_batch_size,
batch_size, indptr_h, num_qo_heads, page_size,
Expand Down
Loading

0 comments on commit 024a79f

Please sign in to comment.