Skip to content

Commit

Permalink
feat(kernel): 尝试把公共信息放入 shared memory
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Nov 15, 2023
1 parent 321516c commit 435e5c9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
11 changes: 8 additions & 3 deletions src/04kernel/cuda/src/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@ namespace refactor::kernel::cuda {
unsigned int unit,
unsigned int midSizeI,
unsigned int midSizeO) {
extern __shared__ uint32_t shared[];
for (auto i = threadIdx.x; i < midSizeO; i += blockDim.x) {
shared[i] = indices[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
auto i = tid / batch,
j = tid % batch;
memcpy(unit * tid + output,
unit * (batch * (i / midSizeO * midSizeI + indices[i % midSizeO]) + j) + data,
unit * (batch * (i / midSizeO * midSizeI + shared[i % midSizeO]) + j) + data,
unit);
}
}
Expand All @@ -37,7 +42,7 @@ namespace refactor::kernel::cuda {
gatherKernel<<<
params.gridSize,
params.blockSize,
params.dynamicSharedBytes,
midSizeO * sizeof(uint32_t),
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(data),
Expand All @@ -51,7 +56,7 @@ namespace refactor::kernel::cuda {
gatherKernel<<<
params.gridSize,
params.blockSize,
params.dynamicSharedBytes,
midSizeO * sizeof(uint32_t),
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(data),
Expand Down
9 changes: 7 additions & 2 deletions src/04kernel/cuda/src/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ namespace refactor::kernel::cuda {
uint8_t const *src, DimInfo const *dims, uint8_t *output,
unsigned int rank,
unsigned int blockSize) {
extern __shared__ DimInfo dimInfo[];
for (auto i = threadIdx.x; i < rank; i += blockDim.x) {
dimInfo[i] = dims[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
Expand All @@ -17,7 +22,7 @@ namespace refactor::kernel::cuda {
auto src_ = src;
auto dst_ = output + rem * blockSize;
for (auto i = 0; i < rank; ++i) {
auto const &dim = dims[i];
auto const &dim = dimInfo[i];
src_ += rem / dim.countStride * dim.sizeStride + dim.sizeStart;
rem %= dim.countStride;
}
Expand All @@ -33,7 +38,7 @@ namespace refactor::kernel::cuda {
sliceKernel<<<
params.gridSize,
params.blockSize,
params.dynamicSharedBytes,
rank * sizeof(DimInfo),
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(src),
Expand Down

0 comments on commit 435e5c9

Please sign in to comment.