From 0951b5b356f060b6512bda980ea83616193f3b19 Mon Sep 17 00:00:00 2001
From: YdrMaster <ydrml@hotmail.com>
Date: Mon, 13 Nov 2023 13:45:33 +0800
Subject: [PATCH] =?UTF-8?q?feat(kernel):=20=E5=AE=9E=E7=8E=B0=20slice=20?=
 =?UTF-8?q?=E7=9A=84=E9=87=8D=E6=95=B4=E5=92=8C=20cuda=20kernel=20?=
 =?UTF-8?q?=E8=B0=83=E7=94=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: YdrMaster <ydrml@hotmail.com>
---
 .../cuda/include/kernel/cuda/slice.cuh        | 20 ++++++++
 src/04kernel/cuda/src/slice.cu                | 28 +++++++++++
 .../include/kernel/attributes/slice_info.h    |  3 ++
 src/04kernel/src/attributes/slice_info.cc     | 47 +++++++++++++++++++
 src/04kernel/src/kernels/slice/cuda_kernel.cu | 21 +++++++--
 .../test/attributes/test_slice_info.cpp       | 14 ++++++
 src/04kernel/test/kernels/slice/test_cpu.cpp  | 20 ++++++--
 7 files changed, 146 insertions(+), 7 deletions(-)
 create mode 100644 src/04kernel/cuda/include/kernel/cuda/slice.cuh
 create mode 100644 src/04kernel/cuda/src/slice.cu

diff --git a/src/04kernel/cuda/include/kernel/cuda/slice.cuh b/src/04kernel/cuda/include/kernel/cuda/slice.cuh
new file mode 100644
index 00000000..235c5dc2
--- /dev/null
+++ b/src/04kernel/cuda/include/kernel/cuda/slice.cuh
@@ -0,0 +1,20 @@
+#ifndef KERNEL_CUDA_SLICE_CUH
+#define KERNEL_CUDA_SLICE_CUH
+
+#include "threads_distributer.cuh"
+
+namespace refactor::kernel::cuda {
+
+    struct DimInfo {
+        unsigned int countStride, sizeStart;
+        int sizeStride;
+    };
+
+    void launchSlice(
+        KernelLaunchParameters const &,
+        void const *src, DimInfo const *dims, void *output,
+        unsigned int blockSize);
+
+}// namespace refactor::kernel::cuda
+
+#endif// KERNEL_CUDA_SLICE_CUH
diff --git a/src/04kernel/cuda/src/slice.cu b/src/04kernel/cuda/src/slice.cu
new file mode 100644
index 00000000..1206bf62
--- /dev/null
+++ b/src/04kernel/cuda/src/slice.cu
@@ -0,0 +1,28 @@
+#include "kernel/cuda/slice.cuh"
+#include <cstdint>
+
+namespace refactor::kernel::cuda {
+
+    __global__ static void sliceKernel(
+        unsigned long long n,
+        uint8_t const *src, DimInfo const *dims, uint8_t *output,
+        unsigned int blockSize) {
+    }
+
+    void launchSlice(
+        KernelLaunchParameters const &params,
+        void const *src, DimInfo const *dims, void *output,
+        unsigned int blockSize) {
+        sliceKernel<<<
+            params.gridSize,
+            params.blockSize,
+            params.dynamicSharedBytes,
+            reinterpret_cast<cudaStream_t>(params.stream)>>>(
+            params.n,
+            reinterpret_cast<uint8_t const *>(src),
+            dims,
+            reinterpret_cast<uint8_t *>(output),
+            blockSize);
+    }
+
+}// namespace refactor::kernel::cuda
diff --git a/src/04kernel/include/kernel/attributes/slice_info.h b/src/04kernel/include/kernel/attributes/slice_info.h
index e8c919f3..b0fbb11e 100644
--- a/src/04kernel/include/kernel/attributes/slice_info.h
+++ b/src/04kernel/include/kernel/attributes/slice_info.h
@@ -24,7 +24,10 @@ namespace refactor::kernel {
         std::vector<Dim> dims;
         dim_t blockCount, blockSize, baseOffset;
 
+        SliceInfo(std::vector<Dim>, dim_t, dim_t, dim_t) noexcept;
         SliceInfo(Dimensions const &, Tensor const &) noexcept;
+        SliceInfo reform(dim_t maxblockSize) const noexcept;
+        void reformAssign(dim_t maxblockSize) noexcept;
     };
 
 }// namespace refactor::kernel
diff --git a/src/04kernel/src/attributes/slice_info.cc b/src/04kernel/src/attributes/slice_info.cc
index 83ed40c9..1c46302b 100644
--- a/src/04kernel/src/attributes/slice_info.cc
+++ b/src/04kernel/src/attributes/slice_info.cc
@@ -1,4 +1,5 @@
 #include "kernel/attributes/slice_info.h"
+#include <numeric>
 
 namespace refactor::kernel {
 
@@ -11,6 +12,16 @@ namespace refactor::kernel {
         return !operator==(rhs);
     }
 
+    SliceInfo::SliceInfo(
+        std::vector<Dim> dims_,
+        dim_t blockCount_,
+        dim_t blockSize_,
+        dim_t baseOffset_) noexcept
+        : blockCount(blockCount_),
+          blockSize(blockSize_),
+          baseOffset(baseOffset_),
+          dims(std::move(dims_)) {}
+
     SliceInfo::SliceInfo(Dimensions const &dims_, Tensor const &input) noexcept
         : blockCount(1),
           blockSize(input.dataType.size()),
@@ -53,4 +64,40 @@ namespace refactor::kernel {
         dims.shrink_to_fit();
     }
 
+    SliceInfo SliceInfo::reform(dim_t maxblockSize) const noexcept {
+        auto blockSize_ = std::gcd(blockSize, maxblockSize);
+        if (blockSize_ == blockSize) { return *this; }
+        auto times = blockSize / blockSize_;
+        SliceInfo ans{
+            std::vector<Dim>(dims.size() + 1),
+            blockCount * times,
+            blockSize_,
+            baseOffset,
+        };
+        for (auto i : range0_(dims.size())) {
+            auto const &d = dims[i];
+            ans.dims[i] = {
+                d.countStride * times,
+                d.sizeStart,
+                d.sizeStride,
+            };
+        }
+        ans.dims.back() = {1, 0, static_cast<sdim_t>(blockSize_)};
+        return ans;
+    }
+
+    void SliceInfo::reformAssign(dim_t maxblockSize) noexcept {
+        auto blockSize_ = std::gcd(blockSize, maxblockSize);
+        if (blockSize_ == blockSize) { return; }
+        auto times = blockSize / blockSize_;
+        blockCount *= times;
+        blockSize = blockSize_;
+        for (auto &d : dims) {
+            d.countStride *= times;
+        }
+        dims.resize(dims.size() + 1);
+        dims.back() = {1, 0, static_cast<sdim_t>(blockSize_)};
+    }
+
+
 }// namespace refactor::kernel
diff --git a/src/04kernel/src/kernels/slice/cuda_kernel.cu b/src/04kernel/src/kernels/slice/cuda_kernel.cu
index 27b6fff9..b5d4628d 100644
--- a/src/04kernel/src/kernels/slice/cuda_kernel.cu
+++ b/src/04kernel/src/kernels/slice/cuda_kernel.cu
@@ -1,14 +1,27 @@
 #include "cuda_kernel.hh"
-#include "kernel/cuda/split.cuh"
-#include "mem_manager/foreign_blob.hh"
-#include "runtime/mem_manager.hh"
+#include "kernel/cuda/slice.cuh"
 #include <thrust/device_vector.h>
+#include <thrust/host_vector.h>
 
 namespace refactor::kernel {
     using namespace runtime;
 
     Routine SliceCuda::lower(Resources &) const noexcept {
-        return [](Resources &, void const **inputs, void **outputs) {
+        auto reformed = info.reform(16);
+        thrust::host_vector<cuda::DimInfo> dims(info.dims.size());
+        std::transform(info.dims.begin(), info.dims.end(),
+                       dims.begin(),
+                       [](auto const &d) { return cuda::DimInfo{
+                                               d.countStride,
+                                               d.sizeStart,
+                                               d.sizeStride,
+                                           }; });
+        return [dims = thrust::device_vector<cuda::DimInfo>(dims),
+                params = cuda::ThreadsDistributer()(reformed.blockCount),
+                blockSize = reformed.blockSize,
+                baseOffset = reformed.baseOffset](Resources &, void const **inputs, void **outputs) {
+            auto src = reinterpret_cast<uint8_t const *>(inputs[0]) + baseOffset;
+            cuda::launchSlice(params, src, dims.data().get(), outputs[0], blockSize);
         };
     }
 
diff --git a/src/04kernel/test/attributes/test_slice_info.cpp b/src/04kernel/test/attributes/test_slice_info.cpp
index 2621d8dd..39f224eb 100644
--- a/src/04kernel/test/attributes/test_slice_info.cpp
+++ b/src/04kernel/test/attributes/test_slice_info.cpp
@@ -26,4 +26,18 @@ TEST(kernel, SliceInfo) {
               })
               // clang-format on
     );
+
+    auto reformed = info.reform(16);
+    EXPECT_EQ(reformed.blockCount, 36);
+    EXPECT_EQ(reformed.blockSize, 16);
+    EXPECT_EQ(reformed.baseOffset, 24);
+    EXPECT_EQ(reformed.dims,
+              // clang-format off
+              (decltype(reformed.dims){
+                  {48 / 24 * 6, 900 * 4, -360 * 4},
+                  {24 / 24 * 6,  60 * 4,   90 * 4},
+                  {          1,       0,       16},
+              })
+              // clang-format on
+    );
 }
diff --git a/src/04kernel/test/kernels/slice/test_cpu.cpp b/src/04kernel/test/kernels/slice/test_cpu.cpp
index e554c16f..9d4d471d 100644
--- a/src/04kernel/test/kernels/slice/test_cpu.cpp
+++ b/src/04kernel/test/kernels/slice/test_cpu.cpp
@@ -27,9 +27,11 @@ TEST(kernel, SliceCpu) {
         result(output->elementsSize());
     std::iota(data.begin(), data.end(), 0);
     // inference
-    void const *inputs[]{data.data()};
-    void *outputs[]{result.data()};
-    routine(res, inputs, outputs);
+    {
+        void const *inputs[]{data.data()};
+        void *outputs[]{result.data()};
+        routine(res, inputs, outputs);
+    }
     // check
     dim_t
         di[]{5, 3, 1},
@@ -49,4 +51,16 @@ TEST(kernel, SliceCpu) {
             }
         }
     }
+    // test reform
+    auto kernelReformed = SliceCpu::build(SliceInfo(dims, *input).reform(16));
+    ASSERT_TRUE(kernelReformed);
+    auto routineReformed = kernelReformed->lower(res);
+    std::vector<float> resultReformed(result.size());
+    {
+        void const *inputs[]{data.data()};
+        void *outputs[]{resultReformed.data()};
+        routineReformed(res, inputs, outputs);
+    }
+    // check
+    ASSERT_EQ(result, resultReformed);
 }