From da83cf5fe362746a240ca8992431ce8b0f540e7c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 21 Jun 2024 15:36:51 -0700 Subject: [PATCH] Bugfix: bugfix to #322 (#325) Some last commits for bugfix are missing for #322. --- include/flashinfer/attention/handler.cuh | 37 ++++++++++++------------ python/csrc/batch_prefill.cu | 13 ++++----- python/csrc/flashinfer_ops.h | 6 ++-- python/flashinfer/decode.py | 1 - python/flashinfer/prefill.py | 1 - src/bench_batch_decode.cu | 7 ++--- src/bench_cascade.cu | 6 ++-- src/test_batch_prefill.cu | 21 ++++++-------- src/test_cascade.cu | 16 +++++----- src/tvm_wrapper.cu | 8 ++--- 10 files changed, 51 insertions(+), 65 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 534f9852d..5be11caff 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -92,7 +92,7 @@ inline std::tuple PrefillBinarySearchKVChunkSize( const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) { int64_t low = min_kv_chunk_size, high = 0; int64_t batch_size = packed_qo_len_arr.size(); - int64_t max_kv_len; + int64_t max_kv_len = 0; for (const int64_t& kv_len : kv_len_arr) { max_kv_len = std::max(max_kv_len, kv_len); } @@ -174,9 +174,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( new_batch_size = batch_size; } else { // compute max_num_pages_per_batch and new_batch_size - std::vector page_indptr_h(batch_size + 1), num_pages(batch_size); + std::vector num_pages(batch_size); for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - num_pages[batch_idx] = page_indptr_h[batch_idx + 1] - page_indptr_h[batch_idx]; + num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; } std::tie(max_num_pages_per_batch, new_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, num_pages, @@ -517,14 +517,16 @@ class BatchDecodeHandler { }; template -cudaError_t PrefillSplitQOKVIndptr( - bool& split_kv, uint32_t& split_max_batch_size, uint32_t& total_num_tiles_q, - uint32_t& new_batch_size, WarpLayout& warp_layout, uint32_t& kv_chunk_size, - uint32_t& total_num_rows, std::vector& request_indices, - std::vector& qo_tile_indices, std::vector& kv_tile_indices, - std::vector& merge_indptr, std::vector& o_indptr, IdType* qo_indptr_h, - IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, cudaStream_t stream = nullptr) { +cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_size, + uint32_t& total_num_tiles_q, uint32_t& new_batch_size, + WarpLayout& warp_layout, uint32_t& kv_chunk_size, + uint32_t& total_num_rows, std::vector& request_indices, + std::vector& qo_tile_indices, + std::vector& kv_tile_indices, + std::vector& merge_indptr, std::vector& o_indptr, + IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size) { request_indices.clear(); qo_tile_indices.clear(); kv_tile_indices.clear(); @@ -536,8 +538,6 @@ cudaError_t PrefillSplitQOKVIndptr( const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; total_num_rows = qo_indptr_h[batch_size]; - bool has_kv_last_page_len = kv_last_page_len_h != nullptr; - // step 0: get the number of SMs int num_sm = 0; int dev_id = 0; @@ -570,7 +570,7 @@ cudaError_t PrefillSplitQOKVIndptr( // step 2: determine kv_chunk_size std::tie(split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize(max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr, - qo_chunk_size, /*min_kv_chunk_size=*/(128 / page_size)); + qo_chunk_size, /*min_kv_chunk_size=*/(512 / page_size)); // step 3: split qo_indptr and kv_indptr total_num_tiles_q = 0; @@ -656,9 +656,8 @@ class BatchPrefillHandler { template cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr_h, - IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size) { + IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -672,8 +671,8 @@ class BatchPrefillHandler { FLASHINFER_CUDA_CALL(PrefillSplitQOKVIndptr( split_kv, split_max_batch_size, total_num_tiles_q, new_batch_size, warp_layout_, kv_chunk_size, total_num_rows_, request_indices_vec, qo_tile_indices_vec, - kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h, - kv_last_page_len_h, batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, stream_)); + kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h, batch_size, + num_qo_heads, num_kv_heads, head_dim, page_size)); const uint32_t qo_tile_size = get_num_rows_per_cta(warp_layout_); if (IsCUDAGraphEnabled()) { diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index e5efd2a56..158f96b92 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -22,9 +22,8 @@ using namespace flashinfer; void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_last_page_len, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - torch::Tensor empty_q_data) { + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) { // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_CONTIGUOUS(workspace_buffer); @@ -33,7 +32,6 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( CHECK_DIM(1, workspace_buffer); qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32); paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32); - paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kCPU).to(torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); @@ -43,9 +41,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( cudaError_t status = handler_->BeginForward( static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr()), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); + static_cast(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); return true; @@ -285,7 +282,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( cudaError_t status = handler_->BeginForward( static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), - /*last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim, + batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 2a11ac49e..9d9f1f3de 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -112,9 +112,9 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { class BatchPrefillWithPagedKVCachePyTorchWrapper { public: void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor page_kv_indptr, torch::Tensor page_kv_last_page_len, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, unsigned page_size, torch::Tensor empty_q_data); + torch::Tensor page_kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + unsigned page_size, torch::Tensor empty_q_data); void EndForward(); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 429868a19..d5aad88ae 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -730,7 +730,6 @@ def begin_forward( self._workspace_buffer, self._qo_indptr_buf, indptr, - last_page_len, batch_size, num_qo_heads, num_kv_heads, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 3061d637f..1a6c3a1d2 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -773,7 +773,6 @@ def begin_forward( self._workspace_buffer, qo_indptr, paged_kv_indptr, - paged_kv_last_page_len, batch_size, num_qo_heads, num_kv_heads, diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 1af7793f9..eefeee41c 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -149,10 +149,9 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { size_t workspace_size_in_bytes = 128 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, - num_qo_heads, num_kv_heads, head_dim, page_size); + handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), + kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { cudaError_t status = diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 953c1464f..36db5d895 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -248,8 +248,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::device_vector buffer(workspace_size_in_bytes); cascade_handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); + kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = SinglePrefillWithKVCache( @@ -305,8 +304,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::device_vector buffer(workspace_size_in_bytes); baseline_handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); + kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index 3b29c2c72..76bea74de 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -104,8 +104,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), - kv_last_page_len.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper append_indptr_device(append_indptr); thrust::device_vector kv_indptr_device(kv_indptr); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), - /*kv_last_page_len=*/nullptr, batch_size, num_qo_heads, - num_kv_heads, head_dim, /*page_size=*/1); + handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), + kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); auto status = BatchPrefillWithRaggedKVCacheWrapper( &handler, thrust::raw_pointer_cast(queries_device.data()), @@ -321,8 +319,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), - kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size); + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( @@ -416,10 +413,10 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), - kv_indptr.data(), kv_last_page_len.data(), - /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), + workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), + /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, + page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( diff --git a/src/test_cascade.cu b/src/test_cascade.cu index ecddd7535..a7b1b9d2c 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -409,14 +409,14 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, thrust::device_vector buffer_baseline(workspace_size_in_bytes), buffer_cascade(workspace_size_in_bytes); - baseline_handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes, - qo_indptr_h.data(), kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - cascade_handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes, - qo_indptr_h.data(), kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, - num_qo_heads, num_kv_heads, head_dim, page_size); + baseline_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_baseline.data()), + workspace_size_in_bytes, qo_indptr_h.data(), + kv_indptr_combined_h.data(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size); + cascade_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_cascade.data()), + workspace_size_in_bytes, qo_indptr_h.data(), + kv_indptr_unique_h.data(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 430bff06b..d9318bb7d 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -272,8 +272,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, - DLTensor* kv_last_page_len, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim, int64_t page_size, TVMStreamHandle copy_stream) { + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, + int64_t page_size, TVMStreamHandle copy_stream) { CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers; @@ -290,8 +290,6 @@ void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( static_cast(workspace_buffer->data), workspace_size_in_bytes, static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), - static_cast(kv_last_page_len->data) + - kv_last_page_len->byte_offset / sizeof(dtype_idx), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer prefill BeginForward error " << cudaGetErrorString(status); @@ -568,7 +566,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward( static_cast(workspace_buffer->data), workspace_size_in_bytes, static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), - /*kv_last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim, + batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer PrefillWithRaggedKVCache BeginForward error "