Skip to content

Commit

Permalink
feat: 开始实现 attention
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 30, 2024
1 parent 6630866 commit 69a22b7
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 92 deletions.
16 changes: 16 additions & 0 deletions src/04kernel/include/kernel/attributes/attention_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef KERNEL_ATTENTION_INFO_H
#define KERNEL_ATTENTION_INFO_H

#include "../tensor.h"

namespace refactor::kernel {

struct AttentionInfo {
DataType dataType;
dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen;
bool concatCache, resetCache;
};

}// namespace refactor::kernel

#endif// KERNEL_ATTENTION_INFO_H
3 changes: 1 addition & 2 deletions src/04kernel/include/kernel/collectors/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
namespace refactor::kernel {

struct AttentionCollector final : public InfoCollector {
dim_t maxSeqLen;

AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept;
AttentionCollector(decltype(_target)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
Expand Down
55 changes: 37 additions & 18 deletions src/04kernel/src/collectors/attention.cc
Original file line number Diff line number Diff line change
@@ -1,38 +1,57 @@
#include "kernel/collectors/attention.h"
#include "kernel/attributes/attention_info.h"
// #include "../kernels/attention/cpu_kernel.hh"
#include "../kernels/attention/cuda_kernel.hh"

namespace refactor::kernel {

AttentionCollector::AttentionCollector(
decltype(_target) target,
decltype(maxSeqLen) maxSeqLen_) noexcept
: InfoCollector(target),
maxSeqLen(maxSeqLen_) {}
decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &query = inputs[0].get();
auto const &key = inputs[1].get();
auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get<int64_t>();
auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2];

std::vector<KernelBox> ans;
AttentionInfo info{
.dataType = query.dataType,
.batch = query.shape[0],
.nHead = query.shape[1],
.nKVHead = key.shape[1],
.seqLen = query.shape[2],
.headDim = query.shape[3],
.cacheLen = 0,
.concatCache = false,
.resetCache = false,
};
switch (outputs.size()) {
case 1:
// no kv cache
ASSERT(inputs.size() == 3, "");
break;
case 3:
switch (inputs.size()) {
case 6:
info.resetCache = true;
case 4:
info.concatCache = true;
case 3:
info.cacheLen = outputs[1].get().shape[2];
break;
default:
UNREACHABLE();
}
break;
default:
UNREACHABLE();
}

std ::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia: {
decltype(AttentionCuda::info) info{
.dataType = query.dataType,
.batch = query.shape[0],
.nHead = query.shape[1],
.nKVHead = key.shape[1],
.pastSeqLen = static_cast<dim_t>(pastSeqLen),
.seqLen = query.shape[2],
.cacheLen = cacheLen,
.headDim = query.shape[3],
.resetCache = false,
};
if (auto ptr = AttentionCuda::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
Expand Down
78 changes: 78 additions & 0 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "../../utilities/cuda/cublaslt_utils.cuh"
#include "cuda_kernel.hh"

namespace refactor::kernel {
using K = AttentionCuda;
using namespace cublas;

RoutineWorkspace K::lower(Resources &res) const {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;

auto computeType = CUBLAS_COMPUTE_32F;
auto dataType = CUDA_R_32F;
constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW;
constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL;

if (!info.cacheLen) {
if (info.nHead == info.nKVHead) {
auto batch = info.batch * info.nHead;
size_t workspaceSize = 0;
MatMulDescriptor
mulDesc(computeType, CUDA_R_32F);
MatrixDescriptor
qDesc(MatrixLayout{
.dataType = dataType,
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(batch),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
kDesc(MatrixLayout{
.dataType = dataType,
.rows = static_cast<uint64_t>(info.headDim),
.cols = static_cast<uint64_t>(info.seqLen),
.majorStride = static_cast<int64_t>(info.headDim),
.order = COL_MAJOR,
.batchCount = static_cast<int32_t>(batch),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
vDesc(MatrixLayout{
.dataType = dataType,
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(batch),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
attDesc(MatrixLayout{
.dataType = dataType,
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.seqLen),
.majorStride = static_cast<int64_t>(info.seqLen),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(batch),
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
});
workspaceSize += batch * info.seqLen * info.seqLen * info.dataType.size();
// qk: mulDesc(qDesc * kDesc) -> attDesc
// TODO inline mask && softmax
// av: mulDesc(attDesc * vDesc) -> qDesc
auto routine = [info = this->info]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
// auto handle = res.fetchOrStore<CublasLtContext>()->handle;
// auto q = inputs[0];
// auto k = inputs[1];
// auto v = inputs[2];
// auto o = outputs[0];
TODO("");
};
return {std::move(routine), workspaceSize};
}
}
TODO("");
}

}// namespace refactor::kernel
8 changes: 2 additions & 6 deletions src/04kernel/src/kernels/attention/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH
#define KERNEL_ATTENTION_CUDA_KERNEL_HH

#include "kernel/attributes/attention_info.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct AttentionCuda final : public Kernel {
struct {
DataType dataType;
dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim;
bool resetCache;
} info;
AttentionInfo info;

AttentionCuda(decltype(info)) noexcept;

Expand Down
33 changes: 0 additions & 33 deletions src/04kernel/src/utilities/cuda/cublaslt_context.cu

This file was deleted.

33 changes: 0 additions & 33 deletions src/04kernel/src/utilities/cuda/cublaslt_context.hh

This file was deleted.

75 changes: 75 additions & 0 deletions src/04kernel/src/utilities/cuda/cublaslt_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "common.h"
#include "cublaslt_utils.cuh"

namespace refactor::kernel::cublas {

CublasLtContext::CublasLtContext() : runtime::Resource() {
if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) {
RUNTIME_ERROR("Failed to create cublasLt handle");
}
}
CublasLtContext::~CublasLtContext() {
if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) {
fmt::println("Failed to destroy cublasLt handle");
abort();
}
}

auto CublasLtContext::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto CublasLtContext::build() noexcept -> runtime::ResourceBox {
return std::make_unique<CublasLtContext>();
}

auto CublasLtContext::resourceTypeId() const noexcept -> size_t {
return typeId();
}
auto CublasLtContext::description() const noexcept -> std::string_view {
return "CublasLtContext";
}

MatMulDescriptor::MatMulDescriptor(cublasComputeType_t compute, cudaDataType data)
: _internal(nullptr) {
CUBLASLT_ASSERT(cublasLtMatmulDescCreate(&_internal, compute, data));
}
MatMulDescriptor::~MatMulDescriptor() {
CUBLASLT_ASSERT(cublasLtMatmulDescDestroy(_internal));
}
cublasLtMatmulDesc_t MatMulDescriptor::get() const noexcept {
return _internal;
}

MatrixDescriptor::MatrixDescriptor(MatrixLayout layout)
: _internal(nullptr) {
CUBLASLT_ASSERT(cublasLtMatrixLayoutCreate(
&_internal,
layout.dataType,
layout.rows,
layout.cols,
layout.majorStride));
CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute(
_internal,
CUBLASLT_MATRIX_LAYOUT_ORDER,
&layout.order,
sizeof(layout.order)));
CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute(
_internal,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&layout.batchCount,
sizeof(layout.batchCount)));
CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute(
_internal,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&layout.batchStride,
sizeof(layout.batchStride)));
}
MatrixDescriptor::~MatrixDescriptor() {
CUBLASLT_ASSERT(cublasLtMatrixLayoutDestroy(_internal));
}
cublasLtMatrixLayout_t MatrixDescriptor::get() const noexcept {
return _internal;
}

}// namespace refactor::kernel::cublas
Loading

0 comments on commit 69a22b7

Please sign in to comment.