Skip to content

Commit

Permalink
feat(kernel): 使用等号替换 memcpy 以触发指令级优化
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Nov 16, 2023
1 parent 3f4d976 commit 784bea6
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 65 deletions.
3 changes: 2 additions & 1 deletion src/04kernel/cuda/src/concat.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/cuda/concat.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {
Expand All @@ -15,7 +16,7 @@ namespace refactor::kernel::cuda {
tid += step) {
auto i = tid % sum, j = i * sub, k = 0u;
while (j >= segments[k]) { j -= segments[k++]; }
memcpy(output + tid * sub, inputs[k] + (tid / sum) * segments[k] + j, sub);
MEMCPY(output + tid * sub, inputs[k] + (tid / sum) * segments[k] + j, sub);
}
}

Expand Down
17 changes: 10 additions & 7 deletions src/04kernel/cuda/src/expand.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/cuda/expand.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {
Expand All @@ -8,20 +9,22 @@ namespace refactor::kernel::cuda {
uint8_t const *data, expand::DimStride const *strides, uint8_t *output,
unsigned int rank,
unsigned int eleSize) {
extern __shared__ expand::DimStride shared[];
for (auto i = threadIdx.x; i < rank; i += blockDim.x) {
shared[i] = strides[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
long rem = tid, i = 0;
for (auto j = 0; j < rank; ++j) {
auto const &s = strides[j];
if (s.i) {
i += rem / s.o * s.i;
}
auto s = shared[j];
i += rem / s.o * s.i;
rem %= s.o;
}

memcpy(output + tid * eleSize, data + i * eleSize, eleSize);
MEMCPY(output + tid * eleSize, data + i * eleSize, eleSize);
}
}

Expand All @@ -33,7 +36,7 @@ namespace refactor::kernel::cuda {
expandKernel<<<
params.gridSize,
params.blockSize,
0,
rank * sizeof(expand::DimStride),
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(data),
Expand Down
3 changes: 2 additions & 1 deletion src/04kernel/cuda/src/gather.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/cuda/gather.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {
Expand All @@ -24,7 +25,7 @@ namespace refactor::kernel::cuda {
tid += step) {
auto i = tid / batch,
j = tid % batch;
memcpy(unit * tid + output,
MEMCPY(unit * tid + output,
unit * (batch * (i / midSizeO * midSizeI + shared[i % midSizeO]) + j) + data,
unit);
}
Expand Down
22 changes: 22 additions & 0 deletions src/04kernel/cuda/src/macro.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,26 @@
cudaGetErrorString(status))); \
}

#define MEMCPY(DST, SRC, ELE_SIZE) \
switch (ELE_SIZE) { \
case 1: \
*reinterpret_cast<uint8_t *>(DST) = *reinterpret_cast<uint8_t const *>(SRC); \
break; \
case 2: \
*reinterpret_cast<uint16_t *>(DST) = *reinterpret_cast<uint16_t const *>(SRC); \
break; \
case 4: \
*reinterpret_cast<float *>(DST) = *reinterpret_cast<float const *>(SRC); \
break; \
case 8: \
*reinterpret_cast<float2 *>(DST) = *reinterpret_cast<float2 const *>(SRC); \
break; \
case 16: \
*reinterpret_cast<float4 *>(DST) = *reinterpret_cast<float4 const *>(SRC); \
break; \
default: \
memcpy((DST), (SRC), (ELE_SIZE)); \
break; \
}

#endif// KERNEL_CUDA_MACRO_CUH
4 changes: 2 additions & 2 deletions src/04kernel/cuda/src/slice.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "kernel/cuda/slice.cuh"
#include "macro.cuh"
#include <cstdint>
#include <cstdio>

namespace refactor::kernel::cuda {

Expand All @@ -26,7 +26,7 @@ namespace refactor::kernel::cuda {
src_ += rem / dim.countStride * dim.sizeStride + dim.sizeStart;
rem %= dim.countStride;
}
memcpy(dst_, src_, blockSize);
MEMCPY(dst_, src_, blockSize);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/04kernel/cuda/src/split.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/cuda/split.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {
Expand All @@ -20,7 +21,7 @@ namespace refactor::kernel::cuda {
tid += step) {
auto i = tid % sum, j = i * sub, k = 0u;
while (j >= shared[k]) { j -= shared[k++]; }
memcpy(outputs[k] + (tid / sum) * shared[k] + j, data + tid * sub, sub);
MEMCPY(outputs[k] + (tid / sum) * shared[k] + j, data + tid * sub, sub);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/04kernel/cuda/src/transpose.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/cuda/transpose.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {
Expand All @@ -19,7 +20,7 @@ namespace refactor::kernel::cuda {
rem %= d.o;
}

memcpy(output + tid * eleSize, data + j * eleSize, eleSize);
MEMCPY(output + tid * eleSize, data + j * eleSize, eleSize);
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/04kernel/cuda/src/where.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/cuda/where.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {
Expand All @@ -25,7 +26,8 @@ namespace refactor::kernel::cuda {
ix += quot * dim[1];
iy += quot * dim[2];
}
memcpy(output + tid * eleSize,

MEMCPY(output + tid * eleSize,
c[ic]
? x + ix * eleSize
: y + iy * eleSize,
Expand Down
5 changes: 3 additions & 2 deletions src/04kernel/src/attributes/expand_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ namespace refactor::kernel {
s.i *= times;
s.o *= times;
}
strides.resize(strides.size() + 1);
strides.back() = {1, 1};
strides.resize(strides.size() + 2);
strides.rbegin()[1] = {times, times};
strides.rbegin()[0] = {1, 1};
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/src/kernels/expand/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ namespace refactor::kernel {
using K = ExpandCuda;

K::ExpandCuda(ExpandInfo info_) noexcept
: Kernel(), info(std::move(info_)) {}
: Kernel(), info(info_.reform(16)) {}

auto K::build(ExpandInfo info) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif

return std::make_unique<K>(info.reform(16));
return std::make_unique<K>(std::move(info));
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
Expand Down
44 changes: 41 additions & 3 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ namespace refactor::kernel {
using K = MatMulCublas;
using DT = DataType;

K::MatMulCublas(decltype(info) info_) noexcept
: Kernel(), info(std::move(info_)) {}
K::MatMulCublas(decltype(info) info_, decltype(biasExpand) biasExpand_) noexcept
: Kernel(), info(std::move(info_)), biasExpand(std::move(biasExpand_)) {}

auto K::build(Tensor const &a, Tensor const &b, Tensor const &y, MatMulInfo info) noexcept -> KernelBox {
static const std::unordered_set<decltype(DT::internal)> TYPE{DT::F32, DT::F64, DT::FP16};
Expand All @@ -20,7 +20,45 @@ namespace refactor::kernel {
return nullptr;
}

return std::make_unique<K>(std::move(info));
dim_t inputs[2];
switch (info.biasType) {
case BiasType::NoBias:
return std::make_unique<K>(std::move(info), std::nullopt);
case BiasType::Scalar:
inputs[0] = 1;
inputs[1] = 1;
break;
case BiasType::RowVector:
inputs[0] = 1;
inputs[1] = info.n;
break;
case BiasType::ColVector:
inputs[0] = info.m;
inputs[1] = 1;
break;
case BiasType::Matrix:
inputs[0] = info.m;
inputs[1] = info.n;
break;
default:
break;
}

std::vector<dim_t> outputShape(std::max(a.rank(), b.rank()));
for (auto i : range0_(outputShape.size() - 2)) {
auto a_ = i < a.rank() ? a.shape[i] : 1;
auto b_ = i < b.rank() ? b.shape[i] : 1;
outputShape[i] = std::max(a_, b_);
}
outputShape.rbegin()[1] = info.m;
outputShape.rbegin()[0] = info.n;

return std::make_unique<K>(
std::move(info),
std::make_optional(ExpandInfo(
dataType,
slice(inputs, 2),
slice(outputShape.data(), outputShape.size()))));
}

auto K::typeId() noexcept -> size_t {
Expand Down
53 changes: 12 additions & 41 deletions src/04kernel/src/kernels/mat_mul/cublas_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "../../utilities/cuda/cublas_context.hh"
#include "../expand/cuda_kernel.hh"
#include "cublas_kernel.hh"
#include <cublas_v2.h>
#include <thrust/execution_policy.h>
Expand All @@ -9,27 +10,10 @@ namespace refactor::kernel {
using namespace cublas;

template<class T>
struct MatMulBroadcastBiasFunctor {
T const *src;
size_t const n, strideC0, strideC1;

__device__ T operator()(size_t i) const noexcept {
return src[i / n * strideC0 + i % n * strideC1];
}
};

template<class T>
struct MatMulCopyBiasFunctor {
T const *src;
size_t blockSize;

__device__ T operator()(size_t i) const noexcept {
return src[i % blockSize];
}
};

template<class T>
Routine lowerTyped(cudaDataType_t cudaDataType, MatMulInfo info, size_t strideC0, size_t strideC1) noexcept {
Routine lowerTyped(cudaDataType_t cudaDataType, MatMulInfo info, Resources &res, std::optional<ExpandInfo> biasExpand, size_t strideC0, size_t strideC1) noexcept {
auto biasEx = biasExpand
? std::make_optional(ExpandCuda(*biasExpand).lower(res))
: std::nullopt;
return [cudaDataType,
alpha = static_cast<T>(info.alpha),
beta = static_cast<T>(info.biasType != BiasType::NoBias ? info.beta : 0.0f),
Expand All @@ -42,29 +26,16 @@ namespace refactor::kernel {
strideC0, strideC1,
lda = info.transA ? info.m : info.k,
ldb = info.transB ? info.k : info.n,
biasEx,
broadcaster = info.broadcaster](Resources &res, void const **inputs, void **outputs) {
auto a = reinterpret_cast<T const *>(inputs[0]);
auto b = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);

if (beta != (T) 0) {
// Expand bias to 2D and store in final output Y
{
auto c = reinterpret_cast<T const *>(inputs[2]);
thrust::tabulate(
thrust::device,
y,
y + strideY,
MatMulBroadcastBiasFunctor<T>{c, n, strideC0, strideC1});
}
// Copy 2D bias to each batch
if (broadcaster.outputsCount > 1) {
thrust::tabulate(
thrust::device,
y + strideY,
y + strideY * broadcaster.outputsCount,
MatMulCopyBiasFunctor<T>{y, strideY});
}
void const *inputs_[]{inputs[2]};
void *outputs_[]{outputs[0]};
(*biasEx)(res, inputs_, outputs_);
}

auto handle = res.fetchOrStore<CublasContext>()->handle;
Expand Down Expand Up @@ -102,11 +73,11 @@ namespace refactor::kernel {
res.fetchOrStore<CublasContext>();
switch (info.dataType) {
case DataType::F32:
return lowerTyped<float>(CUDA_R_32F, info, strideC0, strideC1);
return lowerTyped<float>(CUDA_R_32F, info, res, biasExpand, strideC0, strideC1);
case DataType::F64:
return lowerTyped<double>(CUDA_R_64F, info, strideC0, strideC1);
return lowerTyped<double>(CUDA_R_64F, info, res, biasExpand, strideC0, strideC1);
case DataType::FP16:
return lowerTyped<half>(CUDA_R_16F, info, strideC0, strideC1);
return lowerTyped<half>(CUDA_R_16F, info, res, biasExpand, strideC0, strideC1);
default:
UNREACHABLE();
}
Expand Down
5 changes: 4 additions & 1 deletion src/04kernel/src/kernels/mat_mul/cublas_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
#ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH
#define KERNEL_MATMUL_CUBLAS_KERNEL_HH

#include "kernel/attributes/expand_info.h"
#include "kernel/attributes/matmul_info.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"
#include <optional>

namespace refactor::kernel {

struct MatMulCublas final : public Kernel {
MatMulInfo info;
std::optional<ExpandInfo> biasExpand;

explicit MatMulCublas(MatMulInfo) noexcept;
explicit MatMulCublas(MatMulInfo, std::optional<ExpandInfo>) noexcept;

static KernelBox build(Tensor const &, Tensor const &, Tensor const &, MatMulInfo) noexcept;
static size_t typeId() noexcept;
Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/src/kernels/slice/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ namespace refactor::kernel {
using K = SliceCuda;

K::SliceCuda(SliceInfo info_) noexcept
: Kernel(), info(std::move(info_)) {}
: Kernel(), info(info_.reform(16)) {}

auto K::build(SliceInfo info) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif

return std::make_unique<K>(info.reform(16));
return std::make_unique<K>(std::move(info));
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
Expand Down

0 comments on commit 784bea6

Please sign in to comment.