Skip to content

Commit

Permalink
feat(kernel): 实现 expand cpu kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Nov 14, 2023
1 parent 0c1be80 commit bf3f8b0
Show file tree
Hide file tree
Showing 15 changed files with 401 additions and 9 deletions.
30 changes: 30 additions & 0 deletions src/04kernel/include/kernel/attributes/expand_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef KERNEL_EXPAND_INFO_H
#define KERNEL_EXPAND_INFO_H

#include "../tensor.h"
#include <vector>

namespace refactor::kernel {

/// @brief 优化用于计算的单向广播描述。
struct ExpandInfo {
struct Dim {
dim_t i, o;

bool operator==(Dim const &) const noexcept;
bool operator!=(Dim const &) const noexcept;
};

/// @brief 所有输入输出的各维度步长。
std::vector<Dim> strides;
dim_t blockCount, blockSize;

ExpandInfo(std::vector<Dim>, dim_t, dim_t) noexcept;
ExpandInfo(Tensor const &input, Tensor const &output) noexcept;
ExpandInfo reform(dim_t maxblockSize) const noexcept;
void reformAssign(dim_t maxblockSize) noexcept;
};

}// namespace refactor::kernel

#endif// KERNEL_EXPAND_INFO_H
91 changes: 91 additions & 0 deletions src/04kernel/src/attributes/expand_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "kernel/attributes/expand_info.h"
#include <numeric>

namespace refactor::kernel {

bool ExpandInfo::Dim::operator==(Dim const &rhs) const noexcept {
return i == rhs.i && o == rhs.o;
}
bool ExpandInfo::Dim::operator!=(Dim const &rhs) const noexcept {
return !operator==(rhs);
}

ExpandInfo::ExpandInfo(
std::vector<Dim> strides_,
dim_t blockCount_,
dim_t blockSize_) noexcept
: strides(std::move(strides_)),
blockCount(blockCount_),
blockSize(blockSize_) {}

ExpandInfo::ExpandInfo(
Tensor const &input,
Tensor const &output) noexcept
: strides{{1, 1}},
blockCount(1),
blockSize(input.dataType.size()) {
ASSERT(input.rank() <= output.rank(), "Unreachable");
auto i = input.shape.rbegin(),
ei = input.shape.rend(),
o = output.shape.rbegin(),
eo = output.shape.rend();
dim_t stride = 1;
while (o != eo) {
auto i_ = i == ei ? 1 : *i++,
o_ = *o++;
if (o_ == 1) { continue; }
if (auto &it = strides.back(); i_ == 1) {
if (it.i != 0) {
strides.push_back({0, blockCount});
}
} else {
if (it.i == 0) {
strides.push_back({stride, blockCount});
}
stride *= i_;
}
blockCount *= o_;
}
if (strides.size() == 1) {
// 没有发生广播
blockSize *= blockCount;
blockCount = 1;
strides = {};
return;
}
std::reverse(strides.begin(), strides.end());
strides.pop_back();

auto tail = strides.back();
ASSERT(tail.i == 0, "Unreachable");

blockSize *= tail.o;
blockCount /= tail.o;
for (auto &s : strides) {
s.i /= tail.o;
s.o /= tail.o;
}
}

ExpandInfo ExpandInfo::reform(dim_t maxblockSize) const noexcept {
auto ans = *this;
ans.reformAssign(maxblockSize);
return ans;
}
void ExpandInfo::reformAssign(dim_t maxblockSize) noexcept {
auto blockSize_ = std::gcd(blockSize, maxblockSize);
if (blockSize_ == blockSize) { return; }
auto times = blockSize / blockSize_;
blockCount *= times;
blockSize = blockSize_;
if (!strides.empty()) {
for (auto &s : strides) {
s.i *= times;
s.o *= times;
}
strides.resize(strides.size() + 1);
strides.back() = {1, 1};
}
}

}// namespace refactor::kernel
6 changes: 3 additions & 3 deletions src/04kernel/src/attributes/gather_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ namespace refactor::kernel {
: prefix(0), postfix(0), midSizeI(0), midSizeO(0), idxType(indices.dataType) {

auto axisIt = data.shape.begin() + axis;
prefix = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies<>());
prefix = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies());
midSizeI = *axisIt++;
postfix = std::accumulate(axisIt, data.shape.end(), data.dataType.size(), std::multiplies<>());
midSizeO = std::accumulate(indices.shape.begin(), indices.shape.end(), 1, std::multiplies<>());
postfix = std::accumulate(axisIt, data.shape.end(), data.dataType.size(), std::multiplies());
midSizeO = std::accumulate(indices.shape.begin(), indices.shape.end(), 1, std::multiplies());
}

}// namespace refactor::kernel
4 changes: 2 additions & 2 deletions src/04kernel/src/attributes/softmax_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace refactor::kernel {
: pre(0), mid(0), post(0), type(data.dataType) {

auto axisIt = data.shape.begin() + axis;
pre = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies<>());
pre = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies());
mid = *axisIt++;
post = std::accumulate(axisIt, data.shape.end(), 1, std::multiplies<>());
post = std::accumulate(axisIt, data.shape.end(), 1, std::multiplies());
};

}// namespace refactor::kernel
4 changes: 2 additions & 2 deletions src/04kernel/src/attributes/split_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace refactor::kernel {
auto eleSize = outputs[0].get().dataType.size();
auto const &shape = outputs[0].get().shape;
auto axisIt = shape.begin() + axis;
blockCount = std::accumulate(shape.begin(), axisIt, 1, std::multiplies<>());
auto postfix = std::accumulate(++axisIt, shape.end(), eleSize, std::multiplies<>());
blockCount = std::accumulate(shape.begin(), axisIt, 1, std::multiplies());
auto postfix = std::accumulate(++axisIt, shape.end(), eleSize, std::multiplies());
sum *= postfix;
std::transform(outputs.begin(), outputs.end(),
segments.begin(),
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/batch_normalization/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace refactor::kernel {

auto n = shape[0],
c = shape[1],
dims = std::accumulate(shape.begin() + 2, shape.end(), 1u, std::multiplies<>()),
dims = std::accumulate(shape.begin() + 2, shape.end(), 1u, std::multiplies()),
sn = c * dims,
sc = dims;
return [n, c, sn, sc, epsilon](Resources &, void const **inputs, void **outputs) {
Expand Down
48 changes: 48 additions & 0 deletions src/04kernel/src/kernels/expand/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "cpu_kernel.hh"
#include <execution>

namespace refactor::kernel {
using K = ExpandCpu;

K::ExpandCpu(ExpandInfo info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(ExpandInfo info) noexcept -> KernelBox {
return std::make_unique<K>(std::move(info));
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing expand operation on generic cpu";
}

Routine K::lower(Resources &) const noexcept {
using namespace runtime;
return [info = this->info](Resources &, void const **inputs, void **outputs) {
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
auto dst = reinterpret_cast<uint8_t *>(outputs[0]);
std::for_each_n(std::execution::par_unseq,
natural_t(0), info.blockCount,
[=, &info](auto i) {
long rem = i, j = 0;
for (auto const &s : info.strides) {
if (s.i) {
auto d = std::div(rem, s.o);
j += d.quot * s.i;
rem = d.rem;
} else {
rem %= s.o;
}
}
std::memcpy(dst + i * info.blockSize, src + j * info.blockSize, info.blockSize);
});
};
}

}// namespace refactor::kernel
24 changes: 24 additions & 0 deletions src/04kernel/src/kernels/expand/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef KERNEL_EXPAND_CPU_KERNEL_HH
#define KERNEL_EXPAND_CPU_KERNEL_HH

#include "kernel/attributes/expand_info.h"
#include "kernel/kernel.h"

namespace refactor::kernel {

struct ExpandCpu final : public Kernel {
ExpandInfo info;

explicit ExpandCpu(ExpandInfo) noexcept;

static KernelBox build(ExpandInfo) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
Routine lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_EXPAND_CPU_KERNEL_HH
27 changes: 27 additions & 0 deletions src/04kernel/src/kernels/expand/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "cuda_kernel.hh"

namespace refactor::kernel {
using K = ExpandCuda;

K::ExpandCuda(ExpandInfo info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(ExpandInfo info) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif
return std::make_unique<K>(std::move(info));
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing expand operation using CUDA";
}

}// namespace refactor::kernel
12 changes: 12 additions & 0 deletions src/04kernel/src/kernels/expand/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "cuda_kernel.hh"
#include <thrust/device_vector.h>

namespace refactor::kernel {
using namespace runtime;

Routine ExpandCuda::lower(Resources &) const noexcept {
return [](Resources &res, void const **inputs, void **outputs) {
};
}

}// namespace refactor::kernel
26 changes: 26 additions & 0 deletions src/04kernel/src/kernels/expand/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef KERNEL_EXPAND_CUDA_KERNEL_HH
#define KERNEL_EXPAND_CUDA_KERNEL_HH

#include "kernel/attributes/expand_info.h"
#include "kernel/kernel.h"

namespace refactor::kernel {

struct ExpandCuda final : public Kernel {
ExpandInfo info;

explicit ExpandCuda(ExpandInfo) noexcept;

static KernelBox build(ExpandInfo) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
Routine lower(Resources &) const noexcept final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_EXPAND_CUDA_KERNEL_HH
2 changes: 1 addition & 1 deletion src/04kernel/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace refactor::kernel {
}

int64_t Tensor::rank() const { return shape.size(); }
size_t Tensor::elementsSize() const { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); }
size_t Tensor::elementsSize() const { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); }
size_t Tensor::bytesSize() const { return dataType.size() * elementsSize(); }

Strides Tensor::strides() const {
Expand Down
22 changes: 22 additions & 0 deletions src/04kernel/test/attributes/test_expand_info.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "kernel/attributes/expand_info.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;

TEST(kernel, ExpandInfo) {
auto input = Tensor::share(DataType::F32, {3, 4, 1, 6}),
output = Tensor::share(DataType::F32, {2, 3, 4, 5, 6});
ExpandInfo info(*input, *output);
for (auto s : info.strides) {
fmt::print("({} {}) ", s.i, s.o);
}
EXPECT_EQ(info.blockSize, 24);
EXPECT_EQ(info.blockCount, 120);
EXPECT_EQ(info.strides, (std::vector<ExpandInfo::Dim>{{0, 60}, {1, 5}, {0, 1}}));

auto reformed = info.reform(16);
EXPECT_EQ(reformed.blockSize, 8);
EXPECT_EQ(reformed.blockCount, 360);
EXPECT_EQ(reformed.strides, (std::vector<ExpandInfo::Dim>{{0, 180}, {3, 15}, {0, 3}, {1, 1}}));
}
52 changes: 52 additions & 0 deletions src/04kernel/test/kernels/expand/test_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "../../../src/kernels/expand/cpu_kernel.hh"
#include <gtest/gtest.h>
#include <numeric>

using namespace refactor;
using namespace kernel;

TEST(kernel, ExpandCpu) {
// // build routine
auto input = Tensor::share(DataType::F32, {3, 4, 1, 6}),
output = Tensor::share(DataType::F32, {2, 3, 4, 5, 6});
auto kernel = ExpandCpu::build(ExpandInfo(*input, *output));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res);
// put input data
std::vector<float>
data(input->elementsSize()),
result(output->elementsSize());
std::iota(data.begin(), data.end(), 0);
// inference
{
void const *inputs[]{data.data()};
void *outputs[]{result.data()};
routine(res, inputs, outputs);
}
// check
{
auto idx = 0;
for (auto i : range0_(2)) {
for (auto j : range0_(12)) {
for (auto k : range0_(5)) {
for (auto m : range0_(6)) {
ASSERT_EQ(result[idx++], j * 6 + m);
}
}
}
}
}
// test reform
auto kernelReformed = ExpandCpu::build(ExpandInfo(*input, *output).reform(16));
ASSERT_TRUE(kernelReformed);
auto routineReformed = kernelReformed->lower(res);
std::vector<float> resultReformed(result.size());
{
void const *inputs[]{data.data()};
void *outputs[]{resultReformed.data()};
routineReformed(res, inputs, outputs);
}
// check
ASSERT_EQ(result, resultReformed);
}
Loading

0 comments on commit bf3f8b0

Please sign in to comment.