Skip to content

Commit

Permalink
Add topk operator and cpu kernel.
Browse files Browse the repository at this point in the history
wendy12022 committed Mar 9, 2024
1 parent a813939 commit de18a1b
Showing 12 changed files with 528 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/04kernel/include/kernel/attributes/topk_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef KERNEL_TOPK_INFO_H
#define KERNEL_TOPK_INFO_H

#include "../tensor.h"

namespace refactor::kernel {

struct TopKInfo {

uint8_t topk;
uint8_t axis;
size_t in_stride, in_stride_pre_axis, out_stride_pre_axis;
size_t elem_size, axis_elem_size;

TopKInfo(uint8_t topk, uint8_t axis, Tensor const &input);
size_t getElementSize() const {return elem_size;}
size_t getAxisElementSize()const { return axis_elem_size;}
size_t getInStride()const{return in_stride;}
size_t getInStridePreAxis()const{return in_stride_pre_axis;}
size_t getOutStridePreAxis()const {return out_stride_pre_axis;}
};

}// namespace refactor::kernel

#endif// KERNEL_SPLIT_INFO_H
20 changes: 20 additions & 0 deletions src/04kernel/include/kernel/collectors/topk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef KERNEL_TOPK_H
#define KERNEL_TOPK_H

#include "../collector.h"

namespace refactor::kernel {

struct TopKCollector final : public InfoCollector {
uint32_t topk, axis;

constexpr TopKCollector(decltype(_target) target, uint32_t topk, uint32_t axis_) noexcept
: InfoCollector(target), topk(topk), axis(axis_) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_SPLIT_H
14 changes: 14 additions & 0 deletions src/04kernel/src/attributes/topk_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "kernel/attributes/topk_info.h"
#include <numeric>

namespace refactor::kernel {

TopKInfo::TopKInfo(uint8_t topk, uint8_t axis, Tensor const &input):topk(topk),
axis(axis),
in_stride(input.strides()[axis]),
in_stride_pre_axis(axis == 0 ? 0 : input.strides()[axis - 1]),
out_stride_pre_axis(in_stride_pre_axis/input.shape[axis]*topk),
elem_size(input.elementsSize()),
axis_elem_size(input.shape[axis]){}

}
30 changes: 30 additions & 0 deletions src/04kernel/src/collectors/topk.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "kernel/collectors/topk.h"
#include "../kernels/topk/cpu_kernel.hh"
#include "kernel/attributes/topk_info.h"
//#include "../kernels/topk/cuda_kernel.hh"

namespace refactor::kernel {

std::vector<KernelBox>
TopKCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
TopKInfo info(topk, axis, inputs[0]);
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = TopKCpu::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
//todo :暂时用cpu的实现
case decltype(_target)::Nvidia:
if (auto ptr = TopKCpu::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
61 changes: 61 additions & 0 deletions src/04kernel/src/kernels/topk/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "cpu_kernel.hh"
#include <execution>
#include <list>

namespace refactor::kernel {
using K = TopKCpu;

K::TopKCpu(TopKInfo info) noexcept
: Kernel(), info(std::move(info)) {}

auto K::build(TopKInfo info) noexcept -> KernelBox {
return std::make_unique<K>(std::move(info));
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing topk operation on generic cpu";
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
using namespace runtime;
return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto src = reinterpret_cast<float const *>(inputs[0]);

auto dstVal = reinterpret_cast<float*>(outputs[0]);//T
auto dstIndex = reinterpret_cast<uint32_t*>(outputs[1]);


size_t M = info.getElementSize() / info.getAxisElementSize();
size_t N = info.getAxisElementSize();
auto inStride1 = info.getInStridePreAxis();
auto inStride2 = info.getInStride();
auto outStride1 = info.getOutStridePreAxis();
auto outStride2 = inStride2;

for(size_t m = 0; m < M; m ++){
using PairType = std::pair<float, uint8_t>;
std::list<PairType> list;
for(size_t n = 0; n < N; n++){
auto srcIdx = m /inStride2 * inStride1 + m % inStride2 + n * inStride2;
list.push_back({src[srcIdx],n});
}
list.sort([](const PairType &a, const PairType &b)->bool{return a.first > b.first;});

size_t offset = m /inStride2 * outStride1 + m % inStride2;
std::for_each_n(list.begin(), (uint32_t)info.topk,
[&](auto &elem) {
dstVal[offset] = elem.first;
dstIndex[offset] = elem.second;
offset += outStride2;
});
}
};
}
}// namespace refactor::kernel
23 changes: 23 additions & 0 deletions src/04kernel/src/kernels/topk/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef KERNEL_TOPK_CPU_KERNEL_HH
#define KERNEL_TOPK_CPU_KERNEL_HH

#include "kernel/attributes/topk_info.h"
#include "kernel/kernel.h"

namespace refactor::kernel {

struct TopKCpu final : public Kernel {
TopKInfo info;
explicit TopKCpu(TopKInfo info) noexcept;

static KernelBox build(TopKInfo info) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_SPLIT_CPU_KERNEL_HH
238 changes: 238 additions & 0 deletions src/04kernel/test/kernels/topk/test_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#include "../../../src/kernels/topk/cpu_kernel.hh"
#include <gtest/gtest.h>
#include <numeric>

using namespace refactor;
using namespace kernel;

TEST(kernel, TopKCpu) {
// build routine
auto inputTensor = Tensor::share(DataType::F32, Shape{3, 4});
auto outputTensor0 = Tensor::share(DataType::F32, Shape{3, 3});
auto outputTensor1 = Tensor::share(DataType::U32, Shape{3, 3});

auto kernel = TopKCpu::build(TopKInfo(3,1, *inputTensor));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> ins(inputTensor->elementsSize());
std::vector<float> out0(outputTensor0->elementsSize());
std::vector<uint32_t> out1(outputTensor1->elementsSize());

std::iota(ins.begin(), ins.end(), 0);
// inference
void const *inputs[]{ins.data()};
void *outputs[]{out0.data(), out1.data()};
routine(res, nullptr, inputs, outputs);

// check
std::vector<float> expectVal = {3,2,1,7,6,5,11,10,9};
std::vector<uint32_t> expectIdx = {3,2,1,3,2,1,3,2,1};
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";});

for(size_t i=0;i< expectVal.size(); ++i){
EXPECT_EQ(expectVal[i], out0[i]);
EXPECT_EQ(expectIdx[i], out1[i]);
}
}

TEST(kernel, TopKCpu1) {
// build routine
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 4, 2});
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 2});
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 2});

auto kernel = TopKCpu::build(TopKInfo(3,1, *inputTensor));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> ins(inputTensor->elementsSize());
std::vector<float> out0(outputTensor0->elementsSize());
std::vector<uint32_t> out1(outputTensor1->elementsSize());

std::iota(ins.begin(), ins.end(), 0);
// inference
void const *inputs[]{ins.data()};
void *outputs[]{out0.data(), out1.data()};
routine(res, nullptr, inputs, outputs);
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";});

// check
std::vector<float> expectVal = {6,7,4,5,2,3,14,15,12,13,10,11};
std::vector<uint32_t> expectIdx = {3,3,2,2,1,1,3,3,2,2,1,1};


for(size_t i=0;i< expectVal.size(); ++i){
EXPECT_EQ(expectVal[i], out0[i]);
EXPECT_EQ(expectIdx[i], out1[i]);
}
}

TEST(kernel, TopKCpu2) {
// build routine
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 4, 2});
auto outputTensor0 = Tensor::share(DataType::F32, Shape{1, 4, 2});
auto outputTensor1 = Tensor::share(DataType::U32, Shape{1, 4, 2});

auto kernel = TopKCpu::build(TopKInfo(1,0, *inputTensor));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> ins(inputTensor->elementsSize());
std::vector<float> out0(outputTensor0->elementsSize());
std::vector<uint32_t> out1(outputTensor1->elementsSize());

std::iota(ins.begin(), ins.end(), 0);
// inference
void const *inputs[]{ins.data()};
void *outputs[]{out0.data(), out1.data()};
routine(res, nullptr, inputs, outputs);
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";});

// check
std::vector<float> expectVal = {8,9,10,11,12,13,14,15};
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1};


for(size_t i=0;i< expectVal.size(); ++i){
EXPECT_EQ(expectVal[i], out0[i]);
EXPECT_EQ(expectIdx[i], out1[i]);
}
}


TEST(kernel, TopKCpu3) {
// build routine
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2});
auto outputTensor0 = Tensor::share(DataType::F32, Shape{1, 3, 2, 2});
auto outputTensor1 = Tensor::share(DataType::U32, Shape{1, 3, 2, 2});

auto kernel = TopKCpu::build(TopKInfo(1,0, *inputTensor));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> ins(inputTensor->elementsSize());
std::vector<float> out0(outputTensor0->elementsSize());
std::vector<uint32_t> out1(outputTensor1->elementsSize());

std::iota(ins.begin(), ins.end(), 0);
// inference
void const *inputs[]{ins.data()};
void *outputs[]{out0.data(), out1.data()};
routine(res, nullptr, inputs, outputs);
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";});

// check
std::vector<float> expectVal = {12, 13, 14, 15, 16, 17, 18, 19, 20,21, 22,23};
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1};


for(size_t i=0;i< expectVal.size(); ++i){
EXPECT_EQ(expectVal[i], out0[i]);
EXPECT_EQ(expectIdx[i], out1[i]);
}
}

TEST(kernel, TopKCpu4) {
// build routine
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2});
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 2, 2, 2});
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 2, 2, 2});

auto kernel = TopKCpu::build(TopKInfo(2,1, *inputTensor));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> ins(inputTensor->elementsSize());
std::vector<float> out0(outputTensor0->elementsSize());
std::vector<uint32_t> out1(outputTensor1->elementsSize());

std::iota(ins.begin(), ins.end(), 0);
// inference
void const *inputs[]{ins.data()};
void *outputs[]{out0.data(), out1.data()};
routine(res, nullptr, inputs, outputs);
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";});

// check
std::vector<float> expectVal = {8, 9, 10, 11,4,5,6,7,20,21,22,23,16,17,18,19};
std::vector<uint32_t> expectIdx = {2,2,2,2,1,1,1,1,2,2,2,2,1,1,1,1};


for(size_t i=0;i< expectVal.size(); ++i){
EXPECT_EQ(expectVal[i], out0[i]);
EXPECT_EQ(expectIdx[i], out1[i]);
}
}


TEST(kernel, TopKCpu5) {
// build routine
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2});
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 1, 2});
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 1, 2});

auto kernel = TopKCpu::build(TopKInfo(1,2, *inputTensor));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> ins(inputTensor->elementsSize());
std::vector<float> out0(outputTensor0->elementsSize());
std::vector<uint32_t> out1(outputTensor1->elementsSize());

std::iota(ins.begin(), ins.end(), 0);
// inference
void const *inputs[]{ins.data()};
void *outputs[]{out0.data(), out1.data()};
routine(res, nullptr, inputs, outputs);
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";});

// check
std::vector<float> expectVal = {2,3,6,7,10,11,14,15,18,19,22,23};
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1};


for(size_t i=0;i< expectVal.size(); ++i){
EXPECT_EQ(expectVal[i], out0[i]);
EXPECT_EQ(expectIdx[i], out1[i]);
}
}

TEST(kernel, TopKCpu6) {
// build routine
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2});
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 2, 1});
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 2, 1});

auto kernel = TopKCpu::build(TopKInfo(1,3, *inputTensor));
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> ins(inputTensor->elementsSize());
std::vector<float> out0(outputTensor0->elementsSize());
std::vector<uint32_t> out1(outputTensor1->elementsSize());

std::iota(ins.begin(), ins.end(), 0);
// inference
void const *inputs[]{ins.data()};
void *outputs[]{out0.data(), out1.data()};
routine(res, nullptr, inputs, outputs);
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";});

// check
std::vector<float> expectVal = {1,3,5,7,9,11,13,15,17,19,21,23};
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1};


for(size_t i=0;i< expectVal.size(); ++i){
EXPECT_EQ(expectVal[i], out0[i]);
EXPECT_EQ(expectIdx[i], out1[i]);
}
}
20 changes: 20 additions & 0 deletions src/05computation/include/computation/operators/topk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef COMPUTATION_TOPK_H
#define COMPUTATION_TOPK_H

#include "../operator.h"

namespace refactor::computation {

struct TopK final : public Operator {
uint32_t topk,axis;
constexpr TopK(uint32_t topk, uint32_t axis) noexcept : topk(topk), axis(axis){}

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_SPLIT_H
17 changes: 17 additions & 0 deletions src/05computation/src/operators/topk.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "computation/operators/topk.h"
#include "kernel/collectors/topk.h"

namespace refactor::computation {

size_t TopK::typeId() noexcept {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
size_t TopK::opTypeId() const noexcept { return typeId(); }
std::string_view TopK::name() const noexcept { return "TopK"; }
auto TopK::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
using Collector_ = kernel::TopKCollector;
return std::make_unique<Collector_>(target, topk, axis);
}

}// namespace refactor::computation
2 changes: 2 additions & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@
#include "operators/transpose.hh"
#include "operators/unsqueeze.hh"
#include "operators/where.hh"
#include "operators/topk.hh"

namespace refactor::onnx {

@@ -131,6 +132,7 @@ namespace refactor::onnx {
REGISTER(Where , Where );
REGISTER(HardSigmoid , HardSigmoid );
REGISTER(Pad , Pad );
REGISTER(TopK , TopK );
// clang-format on
#undef REGISTER
}
55 changes: 55 additions & 0 deletions src/07onnx/src/operators/topk.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "common.h"
#include "topk.hh"
#include "computation/operators/topk.h"
#include <execution>

namespace refactor::onnx {
using Op = TopK;

Op::TopK(Int topk, Int axis):topk(topk), axis(axis){}

auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox {
auto axis = attributes["axis"].int_();
auto topk = attributes["topk"].int_();
return OpBox(std::make_unique<Op>(topk, axis));
}

auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "TopK"; }

auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {
if (inputs.empty() || inputs.size() >= 2) {
return Err(InferError(ERROR_MSG("Input size error")));
}
auto const &input = inputs[0];
auto rank = input.rank();
auto axis_ = axis < 0 ? axis + rank : axis;
if (rank <= axis_) {
return Err(InferError(ERROR_MSG("axis error")));
}
if (topk < 0 || topk > input.shape[axis_].value()){
return Err(InferError(ERROR_MSG("topk error")));
}

Tensors ans(2, nullptr);
auto dependencies = extractDependency(inputs);
ans[0] = Tensor::share(input.dataType, input.shape, dependencies);
ans[0]->shape[axis_] = DimExpr(topk);
ans[1] = Tensor::share(input.dataType, input.shape, dependencies);
ans[1]->shape[axis_] = DimExpr(topk);
return Ok(Tensors{std::move(ans)});
}

auto Op::lower(TensorRefs inputs) const -> computation::OpBox {
using Op_ = computation::TopK;
auto rank = inputs[0].rank();
auto axis_ = axis < 0 ? axis + rank : axis;
return std::make_unique<Op_>(topk, axis_);
}

}// namespace refactor::onnx
23 changes: 23 additions & 0 deletions src/07onnx/src/operators/topk.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef ONNX_TOPK_HH
#define ONNX_TOPK_HH

#include "frontend/operator.h"

namespace refactor::onnx {
using namespace frontend;

struct TopK final : public Operator {
Int topk, axis;
TopK(Int topk, Int axis);

static size_t typeId();
static OpBox build(ModelContext const &, std::string_view, Attributes);
size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
computation::OpBox lower(TensorRefs) const final;
};

}// namespace refactor::onnx

#endif// ONNX_WHERE_HH

0 comments on commit de18a1b

Please sign in to comment.