Skip to content

Commit

Permalink
fix: disable other warp layout because of large binary size (flashinf…
Browse files Browse the repository at this point in the history
…er-ai#326)

Disable flashinfer-ai#322 for v0.0.6 release because binary size is too large.
v0.0.6 will only include bugfix at the moment.
  • Loading branch information
yzh119 authored Jun 21, 2024
1 parent da83cf5 commit c146e06
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 28 deletions.
6 changes: 1 addition & 5 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,7 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz
if (avg_packed_qo_len > 64 && head_dim < 256) {
warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2)
} else {
if (avg_packed_qo_len > 16) {
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
} else {
warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1)
}
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
}
const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout);

Expand Down
6 changes: 1 addition & 5 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1756,11 +1756,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
if (qo_len * group_size > 64 && HEAD_DIM < 256) {
warp_layout = WarpLayout::k4x1x2;
} else {
if (qo_len * group_size > 16) {
warp_layout = WarpLayout::k4x1x1;
} else {
warp_layout = WarpLayout::k1x4x1;
}
warp_layout = WarpLayout::k4x1x1;
}

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
Expand Down
29 changes: 13 additions & 16 deletions include/flashinfer/attention/warp_layout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace flashinfer {
enum class WarpLayout {
k4x1x2 = 0U,
k4x1x1 = 1U,
k1x4x1 = 2U,
// k1x4x1 = 2U,
};

template <WarpLayout warp_layout>
Expand All @@ -44,10 +44,10 @@ constexpr uint32_t get_num_warps_x<WarpLayout::k4x1x1>() {
return 4;
}

template <>
constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
return 1;
}
// template <>
// constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
// return 1;
// }

template <WarpLayout warp_layout>
constexpr uint32_t get_num_warps_z() {
Expand All @@ -64,10 +64,10 @@ constexpr uint32_t get_num_warps_z<WarpLayout::k4x1x1>() {
return 1;
}

template <>
constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
return 4;
}
// template <>
// constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
// return 4;
// }

template <WarpLayout warp_layout>
constexpr uint32_t get_num_frags_x() {
Expand All @@ -84,10 +84,10 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k4x1x1>() {
return 1;
}

template <>
constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
return 1;
}
// template <>
// constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
// return 1;
// }

#define DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, ...) \
if (warp_layout == WarpLayout::k4x1x2) { \
Expand All @@ -96,9 +96,6 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
} else if (warp_layout == WarpLayout::k4x1x1) { \
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k4x1x1; \
__VA_ARGS__ \
} else if (warp_layout == WarpLayout::k1x4x1) { \
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k1x4x1; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported warp layout: " << int(warp_layout); \
Expand Down
2 changes: 1 addition & 1 deletion python/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
warp_layout_choice = [0, 1, 2]
warp_layout_choice = [0, 1]
insts = "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<page_storage, {warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(
Expand Down
2 changes: 1 addition & 1 deletion python/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
warp_layout_choice = [0, 1, 2]
warp_layout_choice = [0, 1]
insts = "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(
Expand Down

0 comments on commit c146e06

Please sign in to comment.