Skip to content

Commit

Permalink
refactor(kernel): 整理 cublas kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Nov 15, 2023
1 parent da47faf commit 321516c
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 74 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.PHONY : build install-python reconfig clean clean-log format test

TYPE ?= debug
TYPE ?= Debug
CUDA ?= OFF

FORMAT_ORIGIN ?=
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/include/kernel/attributes/expand_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace refactor::kernel {
std::vector<Dim> strides;
dim_t blockCount, blockSize;

ExpandInfo(std::vector<Dim>, dim_t, dim_t) noexcept;
ExpandInfo(DataType, slice_t<dim_t> input, slice_t<dim_t> output) noexcept;
ExpandInfo(Tensor const &input, Tensor const &output) noexcept;
ExpandInfo reform(dim_t maxblockSize) const noexcept;
void reformAssign(dim_t maxblockSize) noexcept;
Expand Down
44 changes: 19 additions & 25 deletions src/04kernel/src/attributes/expand_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,24 @@ namespace refactor::kernel {
}

ExpandInfo::ExpandInfo(
std::vector<Dim> strides_,
dim_t blockCount_,
dim_t blockSize_) noexcept
: strides(std::move(strides_)),
blockCount(blockCount_),
blockSize(blockSize_) {}

ExpandInfo::ExpandInfo(
Tensor const &input,
Tensor const &output) noexcept
DataType dataType,
slice_t<dim_t> input,
slice_t<dim_t> output) noexcept
: strides{{1, 1}},
blockCount(1),
blockSize(input.dataType.size()) {
ASSERT(input.rank() <= output.rank(), "Unreachable");
auto i = input.shape.rbegin(),
ei = input.shape.rend(),
o = output.shape.rbegin(),
eo = output.shape.rend();
blockSize(dataType.size()) {
ASSERT(input.size() <= output.size(), "Unreachable");
dim_t stride = 1;
while (o != eo) {
auto i_ = i == ei ? 1 : *i++,
o_ = *o++;
for (auto i = input.end_,
o = output.end_;
o != output.begin_;) {
auto i_ = i == input.begin_ ? 1 : *--i,
o_ = *--o;
if (o_ == 1) { continue; }
if (auto &it = strides.back(); i_ == 1) {
if (it.i != 0) {
strides.push_back({0, blockCount});
}
if (it.i) { strides.push_back({0, blockCount}); }
} else {
if (it.i == 0) {
strides.push_back({stride, blockCount});
}
if (!it.i) { strides.push_back({stride, blockCount}); }
stride *= i_;
}
blockCount *= o_;
Expand All @@ -67,6 +54,13 @@ namespace refactor::kernel {
}
}

ExpandInfo::ExpandInfo(
Tensor const &input,
Tensor const &output) noexcept
: ExpandInfo(input.dataType,
slice(input.shape.data(), input.rank()),
slice(output.shape.data(), output.rank())) {}

ExpandInfo ExpandInfo::reform(dim_t maxblockSize) const noexcept {
auto ans = *this;
ans.reformAssign(maxblockSize);
Expand Down
83 changes: 42 additions & 41 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,77 +2,78 @@
#include "cublas_kernel.hh"
#include <cublas_v2.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/tabulate.h>

namespace refactor::kernel {
using namespace runtime;
using namespace cublas;
using DT = DataType;

template<typename T>
template<class T>
struct MatMulBroadcastBiasFunctor {
T const *C;
T *Y;
size_t const n;
size_t const strideC0, strideC1;
T const *src;
size_t const n, strideC0, strideC1;

__device__ void operator()(size_t t) const noexcept {
size_t i = t / n;
size_t j = t % n;
memcpy(Y + t, C + i * strideC0 + j * strideC1, sizeof(T));
__device__ T operator()(size_t i) const noexcept {
return src[i / n * strideC0 + i % n * strideC1];
}
};

template<typename T>
template<class T>
struct MatMulCopyBiasFunctor {
T *dst;
T const *src;
size_t stride;
size_t blockSize;

__device__ void operator()(size_t i) const noexcept {
memcpy(dst + i * stride, src, stride * sizeof(T));
__device__ T operator()(size_t i) const noexcept {
return src[i % blockSize];
}
};

template<class T, cudaDataType_t cudaDataType>
Routine lowerTyped(MatMulInfo info, size_t strideC0, size_t strideC1) noexcept {
return [alpha = static_cast<T>(info.alpha),
template<class T>
Routine lowerTyped(cudaDataType_t cudaDataType, MatMulInfo info, size_t strideC0, size_t strideC1) noexcept {
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasType != BiasType::NoBias ? info.beta : 0.0f),
tA = info.transA ? CUBLAS_OP_T : CUBLAS_OP_N,
tB = info.transB ? CUBLAS_OP_T : CUBLAS_OP_N,
m = info.m, n = info.n, k = info.k, batch = info.batch(),
m = info.m, n = info.n, k = info.k,
strideY = info.m * info.n,
strideA = info.m * info.k,
strideB = info.k * info.n,
strideC0, strideC1,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
broadcaster = info.broadcaster](Resources &res, void const **inputs, void **outputs) {
auto A = reinterpret_cast<T const *>(inputs[0]);
auto B = reinterpret_cast<T const *>(inputs[1]);
auto Y = reinterpret_cast<T *>(outputs[0]);
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);

if (beta != T{}) {
auto C = reinterpret_cast<T const *>(inputs[2]);
if (beta != (T) 0) {
// Expand bias to 2D and store in final output Y
thrust::for_each_n(thrust::device,
thrust::counting_iterator<size_t>(0), strideY,
MatMulBroadcastBiasFunctor<T>{C, Y, n, strideC0, strideC1});
{
auto c = reinterpret_cast<T const *>(inputs[2]);
thrust::tabulate(
thrust::device,
y,
y + strideY,
MatMulBroadcastBiasFunctor<T>{c, n, strideC0, strideC1});
}
// Copy 2D bias to each batch
if (batch > 1) {
thrust::for_each_n(thrust::device,
thrust::counting_iterator<size_t>(1), batch,
MatMulCopyBiasFunctor<T>{Y, Y, strideY});
if (broadcaster.outputsCount > 1) {
thrust::tabulate(
thrust::device,
y + strideY,
y + strideY * broadcaster.outputsCount,
MatMulCopyBiasFunctor<T>{y, strideY});
}
}

auto handle = res.fetchOrStore<CublasContext>()->handle;
uint32_t offset[2];
for (size_t i = 0; i < batch; i++) {
for (auto i : range0_(broadcaster.outputsCount)) {
broadcaster.locate(i, offset);
auto stat = cublasGemmEx(
handle, tB, tA, n, m, k, &alpha, B + strideB * offset[1],
cudaDataType, ldb, A + strideA * offset[0], cudaDataType, lda, &beta, Y + strideY * i,
handle, tB, tA, n, m, k, &alpha, b + strideB * offset[1],
cudaDataType, ldb, a + strideA * offset[0], cudaDataType, lda, &beta, y + strideY * i,
cudaDataType, n, cudaDataType, CUBLAS_GEMM_DEFAULT);
}
};
Expand Down Expand Up @@ -100,12 +101,12 @@ namespace refactor::kernel {

res.fetchOrStore<CublasContext>();
switch (info.dataType) {
case DT::F32:
return lowerTyped<float, CUDA_R_32F>(info, strideC0, strideC1);
case DT::F64:
return lowerTyped<double, CUDA_R_64F>(info, strideC0, strideC1);
case DT::FP16:
return lowerTyped<fp16_t, CUDA_R_16F>(info, strideC0, strideC1);
case DataType::F32:
return lowerTyped<float>(CUDA_R_32F, info, strideC0, strideC1);
case DataType::F64:
return lowerTyped<double>(CUDA_R_64F, info, strideC0, strideC1);
case DataType::FP16:
return lowerTyped<half>(CUDA_R_16F, info, strideC0, strideC1);
default:
UNREACHABLE();
}
Expand Down
1 change: 0 additions & 1 deletion src/04kernel/src/kernels/simple_unary/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ namespace refactor::kernel {
CASE(FUNC, U32); \
CASE(FUNC, U64)


Routine K::lower(Resources &) const noexcept {
switch (opType) {
case Op::Abs:
Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/src/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ namespace refactor::kernel {
return std::memcpy(dst, src, bytes);
}
};
static Arc<mem_manager::MemManager> memPool = std::make_shared<mem_manager::MemPool>(4ul << 30, sizeof(uint64_t), BasicCpuMemManager::instance());
static Arc<mem_manager::MemManager> memPool = std::make_shared<mem_manager::MemPool>(5ul << 30, sizeof(uint64_t), BasicCpuMemManager::instance());
return memPool;
}
#ifdef USE_CUDA
case NvidiaGpu: {
static Arc<mem_manager::MemManager> memPool = std::make_shared<mem_manager::MemPool>(4ul << 30, 256, cuda::BasicCudaMemManager::instance());
static Arc<mem_manager::MemManager> memPool = std::make_shared<mem_manager::MemPool>(5ul << 30, 256, cuda::BasicCudaMemManager::instance());
return memPool;
}
#endif
Expand Down
4 changes: 1 addition & 3 deletions src/04kernel/test/attributes/test_expand_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ using namespace kernel;
TEST(kernel, ExpandInfo) {
auto input = Tensor::share(DataType::F32, {3, 4, 1, 6}),
output = Tensor::share(DataType::F32, {2, 3, 4, 5, 6});

ExpandInfo info(*input, *output);
for (auto s : info.strides) {
fmt::print("({} {}) ", s.i, s.o);
}
EXPECT_EQ(info.blockSize, 24);
EXPECT_EQ(info.blockCount, 120);
EXPECT_EQ(info.strides, (std::vector<ExpandInfo::Dim>{{0, 60}, {1, 5}, {0, 1}}));
Expand Down

0 comments on commit 321516c

Please sign in to comment.