Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 21, 2024
1 parent bfa8e9f commit a686d2f
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,52 @@ namespace refactor::kernel {
auto attentionSize = info.maxAttSize();
auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize;
for (auto attLen = 0; attLen < 2048; ++attLen) {
MatrixDescriptor
q_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
k_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.headDim),
.cols = static_cast<uint64_t>(attLen),
.majorStride = static_cast<int64_t>(info.headDim),
.order = COL_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
}),
v_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(attLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
}),
att_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(attLen),
.majorStride = static_cast<int64_t>(info.cacheLen),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.seqLen),
});
tune(handle, d->mul,
q_, k_, att_,
DYNAMIC_WORKSPACE_SIZE);
tune(handle, d->mul,
att_, v_, q_,
DYNAMIC_WORKSPACE_SIZE);
}
auto routine = [d = std::move(d), info = this->info]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
Expand Down

0 comments on commit a686d2f

Please sign in to comment.