diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 5be11caff..2108207c3 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -559,11 +559,7 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz if (avg_packed_qo_len > 64 && head_dim < 256) { warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2) } else { - if (avg_packed_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) - } else { - warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) - } + warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) } const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 0d301fbcd..599f32d09 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1756,11 +1756,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (qo_len * group_size > 64 && HEAD_DIM < 256) { warp_layout = WarpLayout::k4x1x2; } else { - if (qo_len * group_size > 16) { - warp_layout = WarpLayout::k4x1x1; - } else { - warp_layout = WarpLayout::k1x4x1; - } + warp_layout = WarpLayout::k4x1x1; } DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { diff --git a/include/flashinfer/attention/warp_layout.cuh b/include/flashinfer/attention/warp_layout.cuh index bb0103ad3..c4b467b98 100644 --- a/include/flashinfer/attention/warp_layout.cuh +++ b/include/flashinfer/attention/warp_layout.cuh @@ -26,7 +26,7 @@ namespace flashinfer { enum class WarpLayout { k4x1x2 = 0U, k4x1x1 = 1U, - k1x4x1 = 2U, + // k1x4x1 = 2U, }; template @@ -44,10 +44,10 @@ constexpr uint32_t get_num_warps_x() { return 4; } -template <> -constexpr uint32_t get_num_warps_x() { - return 1; -} +// template <> +// constexpr uint32_t get_num_warps_x() { +// return 1; +// } template constexpr uint32_t get_num_warps_z() { @@ -64,10 +64,10 @@ constexpr uint32_t get_num_warps_z() { return 1; } -template <> -constexpr uint32_t get_num_warps_z() { - return 4; -} +// template <> +// constexpr uint32_t get_num_warps_z() { +// return 4; +// } template constexpr uint32_t get_num_frags_x() { @@ -84,10 +84,10 @@ constexpr uint32_t get_num_frags_x() { return 1; } -template <> -constexpr uint32_t get_num_frags_x() { - return 1; -} +// template <> +// constexpr uint32_t get_num_frags_x() { +// return 1; +// } #define DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, ...) \ if (warp_layout == WarpLayout::k4x1x2) { \ @@ -96,9 +96,6 @@ constexpr uint32_t get_num_frags_x() { } else if (warp_layout == WarpLayout::k4x1x1) { \ constexpr WarpLayout WARP_LAYOUT = WarpLayout::k4x1x1; \ __VA_ARGS__ \ - } else if (warp_layout == WarpLayout::k1x4x1) { \ - constexpr WarpLayout WARP_LAYOUT = WarpLayout::k1x4x1; \ - __VA_ARGS__ \ } else { \ std::ostringstream err_msg; \ err_msg << "Unsupported warp layout: " << int(warp_layout); \ diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index 0cf117a9a..45d848006 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -40,7 +40,7 @@ def get_cu_file_str( dtype_out, idtype, ): - warp_layout_choice = [0, 1, 2] + warp_layout_choice = [0, 1] insts = "\n".join( [ """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index 4f65db6b5..ebe4d8948 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -39,7 +39,7 @@ def get_cu_file_str( dtype_out, idtype, ): - warp_layout_choice = [0, 1, 2] + warp_layout_choice = [0, 1] insts = "\n".join( [ """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(