Skip to content

Commit

Permalink
feat: pad cuda test不跑gpu测试(slicecuda跑对的最后一版)
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Jan 26, 2024
1 parent 817bdfd commit a91cc98
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 3 deletions.
22 changes: 22 additions & 0 deletions src/04kernel/cuda/include/kernel/cuda/pad.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef KERNEL_CUDA_PAD_CUH
#define KERNEL_CUDA_PAD_CUH

#include "threads_distributer.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {

struct DimInfo {
unsigned int strideI, strideO, padS, dimI;
};

void launchPad(
KernelLaunchParameters const &,
uint8_t const *src, uint8_t const *src_const,
DimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize);

}// namespace refactor::kernel::cuda

#endif// KERNEL_CUDA_PAD_CUH
64 changes: 64 additions & 0 deletions src/04kernel/cuda/src/pad.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "kernel/cuda/pad.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {

__global__ static void padKernel(
unsigned long long n,
uint8_t const *__restrict__ src,
uint8_t const *__restrict__ src_const,
DimInfo const *__restrict__ dims,
uint8_t *__restrict__ dst,
unsigned int rank,
unsigned int blockSize) {
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
long rem = tid, j = 0;
bool flag = false;
for (auto i = 0; i < rank; ++i) {
auto strideO = __ldg(&(dims[i].strideO));
auto strideI = __ldg(&(dims[i].strideI));
auto padS = __ldg(&(dims[i].padS));
auto dimI = __ldg(&(dims[i].dimI));
auto pos = rem / strideO - padS;
if (pos < 0 || pos >= dimI) {
flag = true;
break;
}
j += pos * strideI;
rem %= strideO;
}
if (flag) {
optimizedMemcpy(dst + tid * blockSize, src_const, blockSize);
} else {
optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize);
}
}
}

void launchPad(
KernelLaunchParameters const &params,
uint8_t const *src, uint8_t const *src_const,
DimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize) {


padKernel<<<
params.gridSize,
params.blockSize,
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
src,
src_const,
dims,
reinterpret_cast<uint8_t *>(output),
rank,
blockSize);
}

}// namespace refactor::kernel::cuda
39 changes: 39 additions & 0 deletions src/04kernel/src/kernels/pad/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "cuda_kernel.hh"
#include "kernel/cuda/pad.cuh"
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

namespace refactor::kernel {
using namespace runtime;

auto PadCuda::lower(Resources &) const noexcept -> RoutineWorkspace {
thrust::host_vector<cuda::DimInfo> dims(info.dims.size());
std::transform(info.dims.begin(), info.dims.end(),
dims.begin(),
[](auto const &d) {
return cuda::DimInfo{
d.strideI,
d.strideO,
d.padS,
d.dimI,
};
});
return [dims = thrust::device_vector<cuda::DimInfo>(dims),
params = cuda::ThreadsDistributer()(info.blockCount),
blockSize = info.blockSize,
value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
thrust::device_vector<uint8_t> defaultValue(blockSize, 0);
if (value != 0) {
auto constValue = reinterpret_cast<uint8_t const *>(inputs[2]);
for (auto i : range0_(blockSize / value)) {
cudaMemcpy(defaultValue.data().get() + i * value, constValue, value, cudaMemcpyDeviceToDevice);
}
}
cuda::launchPad(params, src, defaultValue.data().get(), dims.data().get(), outputs[0],
dims.size(),
blockSize);
};
}

}// namespace refactor::kernel
6 changes: 3 additions & 3 deletions src/04kernel/src/kernels/pad/cuda_kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace refactor::kernel {

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
// #ifdef USE_CUDA
// RoutineWorkspace lower(Resources &) const noexcept final;
// #endif
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const noexcept final;
#endif
};

}// namespace refactor::kernel
Expand Down
71 changes: 71 additions & 0 deletions src/04kernel/test/kernels/pad/test_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifdef USE_CUDA

#include "../../../src/kernels/pad/cpu_kernel.hh"
#include "../../../src/kernels/pad/cuda_kernel.hh"
#include "hardware/device_manager.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;
using namespace hardware;

TEST(kernel, PadCuda) {
PadDimension dims{
{2, 4, 1},
{3, 5, 1},
{1, 1, 0},
{4, 8, 2},
};
// build routine
auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4});
auto t2Tensor = Tensor::share(DataType::I64, Shape{8});
auto t3Tensor = Tensor::share(DataType::F32, Shape{});
auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8});
PadType type = PadType::Constant;
auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor)));
// auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor)));
// ASSERT_TRUE(kernel && kCpu);
auto res = runtime::Resources();
// auto routine = kernel->lower(res).routine,
// rCpu = kCpu->lower(res).routine;
auto rCpu = kCpu->lower(res).routine;
// malloc
auto &dev = *device::init(Device::Type::Nvidia, 0, "");
// auto gpuIn = dev.malloc(t1Tensor->bytesSize()),
// gpuIn2 = dev.malloc(t2Tensor->bytesSize()),
// gpuIn3 = dev.malloc(t3Tensor->bytesSize()),
// gpuOut = dev.malloc(yTensor->bytesSize());
// put input data
std::vector<float> data(t1Tensor->elementsSize()),
constvalue(1, 1.2f),
cpuOut(yTensor->elementsSize());
std::vector<int64_t> pads{1, 1, 0, 2, 1, 1, 0, 2};


for (auto i : range0_(data.size())) { data[i] = i; }
// gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize());
// gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize());
// gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize());

// inference
// {
// void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3};
// void *outputs[]{*gpuOut};
// routine(res, nullptr, inputs, outputs);
// }
{
void const *inputs[]{data.data(), pads.data(), constvalue.data()};
void *outputs[]{cpuOut.data()};
rCpu(res, nullptr, inputs, outputs);
}
// take output data
std::vector<float> result(yTensor->elementsSize());
// gpuOut->copyToHost(result.data(), yTensor->bytesSize());
// check
for (auto i : range0_(cpuOut.size())) {
// fmt::println("i = {}, cpuout = {}, gpuout = {}", i, cpuOut[i], result[i]);
//EXPECT_FLOAT_EQ(cpuOut[i], result[i]);
}
}

#endif

0 comments on commit a91cc98

Please sign in to comment.