diff --git a/src/04kernel/cuda/include/kernel/cuda/pad.cuh b/src/04kernel/cuda/include/kernel/cuda/pad.cuh new file mode 100644 index 00000000..79d36cdd --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/pad.cuh @@ -0,0 +1,22 @@ +#ifndef KERNEL_CUDA_PAD_CUH +#define KERNEL_CUDA_PAD_CUH + +#include "threads_distributer.cuh" +#include + +namespace refactor::kernel::cuda { + + struct DimInfo { + unsigned int strideI, strideO, padS, dimI; + }; + + void launchPad( + KernelLaunchParameters const &, + uint8_t const *src, uint8_t const *src_const, + DimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize); + +}// namespace refactor::kernel::cuda + +#endif// KERNEL_CUDA_PAD_CUH diff --git a/src/04kernel/cuda/src/pad.cu b/src/04kernel/cuda/src/pad.cu new file mode 100644 index 00000000..f66d1479 --- /dev/null +++ b/src/04kernel/cuda/src/pad.cu @@ -0,0 +1,64 @@ +#include "kernel/cuda/pad.cuh" +#include "macro.cuh" +#include + +namespace refactor::kernel::cuda { + + __global__ static void padKernel( + unsigned long long n, + uint8_t const *__restrict__ src, + uint8_t const *__restrict__ src_const, + DimInfo const *__restrict__ dims, + uint8_t *__restrict__ dst, + unsigned int rank, + unsigned int blockSize) { + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + long rem = tid, j = 0; + bool flag = false; + for (auto i = 0; i < rank; ++i) { + auto strideO = __ldg(&(dims[i].strideO)); + auto strideI = __ldg(&(dims[i].strideI)); + auto padS = __ldg(&(dims[i].padS)); + auto dimI = __ldg(&(dims[i].dimI)); + auto pos = rem / strideO - padS; + if (pos < 0 || pos >= dimI) { + flag = true; + break; + } + j += pos * strideI; + rem %= strideO; + } + if (flag) { + optimizedMemcpy(dst + tid * blockSize, src_const, blockSize); + } else { + optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize); + } + } + } + + void launchPad( + KernelLaunchParameters const ¶ms, + uint8_t const *src, uint8_t const *src_const, + DimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize) { + + + padKernel<<< + params.gridSize, + params.blockSize, + 0, + reinterpret_cast(params.stream)>>>( + params.n, + src, + src_const, + dims, + reinterpret_cast(output), + rank, + blockSize); + } + +}// namespace refactor::kernel::cuda diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cu b/src/04kernel/src/kernels/pad/cuda_kernel.cu new file mode 100644 index 00000000..b246ab4f --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cu @@ -0,0 +1,39 @@ +#include "cuda_kernel.hh" +#include "kernel/cuda/pad.cuh" +#include +#include + +namespace refactor::kernel { + using namespace runtime; + + auto PadCuda::lower(Resources &) const noexcept -> RoutineWorkspace { + thrust::host_vector dims(info.dims.size()); + std::transform(info.dims.begin(), info.dims.end(), + dims.begin(), + [](auto const &d) { + return cuda::DimInfo{ + d.strideI, + d.strideO, + d.padS, + d.dimI, + }; + }); + return [dims = thrust::device_vector(dims), + params = cuda::ThreadsDistributer()(info.blockCount), + blockSize = info.blockSize, + value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + thrust::device_vector defaultValue(blockSize, 0); + if (value != 0) { + auto constValue = reinterpret_cast(inputs[2]); + for (auto i : range0_(blockSize / value)) { + cudaMemcpy(defaultValue.data().get() + i * value, constValue, value, cudaMemcpyDeviceToDevice); + } + } + cuda::launchPad(params, src, defaultValue.data().get(), dims.data().get(), outputs[0], + dims.size(), + blockSize); + }; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.hh b/src/04kernel/src/kernels/pad/cuda_kernel.hh index 9e8c48cf..b0f915a5 100644 --- a/src/04kernel/src/kernels/pad/cuda_kernel.hh +++ b/src/04kernel/src/kernels/pad/cuda_kernel.hh @@ -17,9 +17,9 @@ namespace refactor::kernel { size_t kernelTypeId() const noexcept final; std::string_view description() const noexcept final; - // #ifdef USE_CUDA - // RoutineWorkspace lower(Resources &) const noexcept final; - // #endif +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif }; }// namespace refactor::kernel diff --git a/src/04kernel/test/kernels/pad/test_cuda.cpp b/src/04kernel/test/kernels/pad/test_cuda.cpp new file mode 100644 index 00000000..4c490cd8 --- /dev/null +++ b/src/04kernel/test/kernels/pad/test_cuda.cpp @@ -0,0 +1,71 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/pad/cpu_kernel.hh" +#include "../../../src/kernels/pad/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, PadCuda) { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); + PadType type = PadType::Constant; + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + // auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + // ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + // auto routine = kernel->lower(res).routine, + // rCpu = kCpu->lower(res).routine; + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + // auto gpuIn = dev.malloc(t1Tensor->bytesSize()), + // gpuIn2 = dev.malloc(t2Tensor->bytesSize()), + // gpuIn3 = dev.malloc(t3Tensor->bytesSize()), + // gpuOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector data(t1Tensor->elementsSize()), + constvalue(1, 1.2f), + cpuOut(yTensor->elementsSize()); + std::vector pads{1, 1, 0, 2, 1, 1, 0, 2}; + + + for (auto i : range0_(data.size())) { data[i] = i; } + // gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + // gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + // gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + + // inference + // { + // void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; + // void *outputs[]{*gpuOut}; + // routine(res, nullptr, inputs, outputs); + // } + { + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(yTensor->elementsSize()); + // gpuOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(cpuOut.size())) { + // fmt::println("i = {}, cpuout = {}, gpuout = {}", i, cpuOut[i], result[i]); + //EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif