Skip to content

Commit

Permalink
feat(kernel): 实现 expand info 的构造和重整
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Nov 14, 2023
1 parent 0c1be80 commit e340329
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 9 deletions.
30 changes: 30 additions & 0 deletions src/04kernel/include/kernel/attributes/expand_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef KERNEL_EXPAND_INFO_H
#define KERNEL_EXPAND_INFO_H

#include "../tensor.h"
#include <vector>

namespace refactor::kernel {

/// @brief 优化用于计算的单向广播描述。
struct ExpandInfo {
struct Dim {
dim_t i, o;

bool operator==(Dim const &) const noexcept;
bool operator!=(Dim const &) const noexcept;
};

/// @brief 所有输入输出的各维度步长。
std::vector<Dim> strides;
dim_t blockCount, blockSize;

ExpandInfo(std::vector<Dim>, dim_t, dim_t) noexcept;
ExpandInfo(Tensor const &input, Tensor const &output) noexcept;
ExpandInfo reform(dim_t maxblockSize) const noexcept;
void reformAssign(dim_t maxblockSize) noexcept;
};

}// namespace refactor::kernel

#endif// KERNEL_EXPAND_INFO_H
96 changes: 96 additions & 0 deletions src/04kernel/src/attributes/expand_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#include "kernel/attributes/expand_info.h"
#include <numeric>

namespace refactor::kernel {

bool ExpandInfo::Dim::operator==(Dim const &rhs) const noexcept {
return i == rhs.i && o == rhs.o;
}
bool ExpandInfo::Dim::operator!=(Dim const &rhs) const noexcept {
return !operator==(rhs);
}

ExpandInfo::ExpandInfo(
std::vector<Dim> strides_,
dim_t blockCount_,
dim_t blockSize_) noexcept
: strides(std::move(strides_)),
blockCount(blockCount_),
blockSize(blockSize_) {}

ExpandInfo::ExpandInfo(
Tensor const &input,
Tensor const &output) noexcept
: strides{{1, 1}},
blockCount(1),
blockSize(input.dataType.size()) {
ASSERT(input.rank() <= output.rank(), "Unreachable");
auto i = input.shape.rbegin(),
end = input.shape.rend(),
o = output.shape.rbegin();
dim_t stride = 1;
while (i != end) {
auto i_ = *i++,
o_ = *o++;
if (o_ == 1) { continue; }
if (auto &it = strides.back(); i_ == 1) {
if (it.i == 0) {
it.o *= o_;
} else {
strides.push_back({0, it.o * o_});
}
} else {
stride *= i_;
if (it.i == 0) {
strides.push_back({stride, it.o * o_});
} else {
it.i = stride;
it.o *= o_;
}
}
}
ASSERT(strides[0].i == strides[0].o, "Unreachable");
auto elementCount = strides[0].i;

std::reverse(strides.begin(), strides.end());
strides.pop_back();
strides.shrink_to_fit();

blockSize *= elementCount;
for (auto &s : strides) {
s.i /= elementCount;
s.o /= elementCount;
}
blockCount = std::accumulate(o, output.shape.rend(), strides[0].o, std::multiplies());
}

ExpandInfo ExpandInfo::reform(dim_t maxblockSize) const noexcept {
auto blockSize_ = std::gcd(blockSize, maxblockSize);
if (blockSize_ == blockSize) { return *this; }
auto times = blockSize / blockSize_;
ExpandInfo ans(
std::vector<Dim>(strides.size() + 1),
blockCount * times,
blockSize_);
for (auto i : range0_(strides.size())) {
ans.strides[i].i *= times;
ans.strides[i].o *= times;
}
ans.strides.back() = {1, 1};
return ans;
}
void ExpandInfo::reformAssign(dim_t maxblockSize) noexcept {
auto blockSize_ = std::gcd(blockSize, maxblockSize);
if (blockSize_ == blockSize) { return; }
auto times = blockSize / blockSize_;
blockCount *= times;
blockSize = blockSize_;
for (auto &s : strides) {
s.i *= times;
s.o *= times;
}
strides.resize(strides.size() + 1);
strides.back() = {1, 1};
}

}// namespace refactor::kernel
6 changes: 3 additions & 3 deletions src/04kernel/src/attributes/gather_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ namespace refactor::kernel {
: prefix(0), postfix(0), midSizeI(0), midSizeO(0), idxType(indices.dataType) {

auto axisIt = data.shape.begin() + axis;
prefix = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies<>());
prefix = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies());
midSizeI = *axisIt++;
postfix = std::accumulate(axisIt, data.shape.end(), data.dataType.size(), std::multiplies<>());
midSizeO = std::accumulate(indices.shape.begin(), indices.shape.end(), 1, std::multiplies<>());
postfix = std::accumulate(axisIt, data.shape.end(), data.dataType.size(), std::multiplies());
midSizeO = std::accumulate(indices.shape.begin(), indices.shape.end(), 1, std::multiplies());
}

}// namespace refactor::kernel
4 changes: 2 additions & 2 deletions src/04kernel/src/attributes/softmax_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace refactor::kernel {
: pre(0), mid(0), post(0), type(data.dataType) {

auto axisIt = data.shape.begin() + axis;
pre = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies<>());
pre = std::accumulate(data.shape.begin(), axisIt, 1, std::multiplies());
mid = *axisIt++;
post = std::accumulate(axisIt, data.shape.end(), 1, std::multiplies<>());
post = std::accumulate(axisIt, data.shape.end(), 1, std::multiplies());
};

}// namespace refactor::kernel
4 changes: 2 additions & 2 deletions src/04kernel/src/attributes/split_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace refactor::kernel {
auto eleSize = outputs[0].get().dataType.size();
auto const &shape = outputs[0].get().shape;
auto axisIt = shape.begin() + axis;
blockCount = std::accumulate(shape.begin(), axisIt, 1, std::multiplies<>());
auto postfix = std::accumulate(++axisIt, shape.end(), eleSize, std::multiplies<>());
blockCount = std::accumulate(shape.begin(), axisIt, 1, std::multiplies());
auto postfix = std::accumulate(++axisIt, shape.end(), eleSize, std::multiplies());
sum *= postfix;
std::transform(outputs.begin(), outputs.end(),
segments.begin(),
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/batch_normalization/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace refactor::kernel {

auto n = shape[0],
c = shape[1],
dims = std::accumulate(shape.begin() + 2, shape.end(), 1u, std::multiplies<>()),
dims = std::accumulate(shape.begin() + 2, shape.end(), 1u, std::multiplies()),
sn = c * dims,
sc = dims;
return [n, c, sn, sc, epsilon](Resources &, void const **inputs, void **outputs) {
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace refactor::kernel {
}

int64_t Tensor::rank() const { return shape.size(); }
size_t Tensor::elementsSize() const { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); }
size_t Tensor::elementsSize() const { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); }
size_t Tensor::bytesSize() const { return dataType.size() * elementsSize(); }

Strides Tensor::strides() const {
Expand Down
14 changes: 14 additions & 0 deletions src/04kernel/test/attributes/test_expand_info.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "kernel/attributes/expand_info.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;

TEST(kernel, ExpandInfo) {
auto input = Tensor::share(DataType::F32, {3, 4, 1, 6}),
output = Tensor::share(DataType::F32, {2, 3, 4, 5, 6});
ExpandInfo info(*input, *output);
EXPECT_EQ(info.blockSize, 24);
EXPECT_EQ(info.blockCount, 120);
EXPECT_EQ(info.strides, (std::vector<ExpandInfo::Dim>{{12, 60}, {0, 5}}));
}

0 comments on commit e340329

Please sign in to comment.