diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 8ba6e987..23eecba9 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -226,6 +226,52 @@ namespace refactor::kernel { auto attentionSize = info.maxAttSize(); auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize; + for (auto attLen = 0; attLen < 2048; ++attLen) { + MatrixDescriptor + q_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + k_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.headDim), + .cols = static_cast(attLen), + .majorStride = static_cast(info.headDim), + .order = COL_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + v_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(attLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + att_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(attLen), + .majorStride = static_cast(info.cacheLen), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.seqLen), + }); + tune(handle, d->mul, + q_, k_, att_, + DYNAMIC_WORKSPACE_SIZE); + tune(handle, d->mul, + att_, v_, q_, + DYNAMIC_WORKSPACE_SIZE); + } + auto routine = [d = std::move(d), info = this->info]// (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { auto handle = res.fetchOrStore()->handle;