-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1 parent
a813939
commit de18a1b
Showing
12 changed files
with
528 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef KERNEL_TOPK_INFO_H | ||
#define KERNEL_TOPK_INFO_H | ||
|
||
#include "../tensor.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct TopKInfo { | ||
|
||
uint8_t topk; | ||
uint8_t axis; | ||
size_t in_stride, in_stride_pre_axis, out_stride_pre_axis; | ||
size_t elem_size, axis_elem_size; | ||
|
||
TopKInfo(uint8_t topk, uint8_t axis, Tensor const &input); | ||
size_t getElementSize() const {return elem_size;} | ||
size_t getAxisElementSize()const { return axis_elem_size;} | ||
size_t getInStride()const{return in_stride;} | ||
size_t getInStridePreAxis()const{return in_stride_pre_axis;} | ||
size_t getOutStridePreAxis()const {return out_stride_pre_axis;} | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_SPLIT_INFO_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef KERNEL_TOPK_H | ||
#define KERNEL_TOPK_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct TopKCollector final : public InfoCollector { | ||
uint32_t topk, axis; | ||
|
||
constexpr TopKCollector(decltype(_target) target, uint32_t topk, uint32_t axis_) noexcept | ||
: InfoCollector(target), topk(topk), axis(axis_) {} | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_SPLIT_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#include "kernel/attributes/topk_info.h" | ||
#include <numeric> | ||
|
||
namespace refactor::kernel { | ||
|
||
TopKInfo::TopKInfo(uint8_t topk, uint8_t axis, Tensor const &input):topk(topk), | ||
axis(axis), | ||
in_stride(input.strides()[axis]), | ||
in_stride_pre_axis(axis == 0 ? 0 : input.strides()[axis - 1]), | ||
out_stride_pre_axis(in_stride_pre_axis/input.shape[axis]*topk), | ||
elem_size(input.elementsSize()), | ||
axis_elem_size(input.shape[axis]){} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include "kernel/collectors/topk.h" | ||
#include "../kernels/topk/cpu_kernel.hh" | ||
#include "kernel/attributes/topk_info.h" | ||
//#include "../kernels/topk/cuda_kernel.hh" | ||
|
||
namespace refactor::kernel { | ||
|
||
std::vector<KernelBox> | ||
TopKCollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
TopKInfo info(topk, axis, inputs[0]); | ||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
if (auto ptr = TopKCpu::build(info); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
//todo :暂时用cpu的实现 | ||
case decltype(_target)::Nvidia: | ||
if (auto ptr = TopKCpu::build(info); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include "cpu_kernel.hh" | ||
#include <execution> | ||
#include <list> | ||
|
||
namespace refactor::kernel { | ||
using K = TopKCpu; | ||
|
||
K::TopKCpu(TopKInfo info) noexcept | ||
: Kernel(), info(std::move(info)) {} | ||
|
||
auto K::build(TopKInfo info) noexcept -> KernelBox { | ||
return std::make_unique<K>(std::move(info)); | ||
} | ||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { | ||
return typeId(); | ||
} | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing topk operation on generic cpu"; | ||
} | ||
|
||
auto K::lower(Resources &) const noexcept -> RoutineWorkspace { | ||
using namespace runtime; | ||
return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { | ||
auto src = reinterpret_cast<float const *>(inputs[0]); | ||
|
||
auto dstVal = reinterpret_cast<float*>(outputs[0]);//T | ||
auto dstIndex = reinterpret_cast<uint32_t*>(outputs[1]); | ||
|
||
|
||
size_t M = info.getElementSize() / info.getAxisElementSize(); | ||
size_t N = info.getAxisElementSize(); | ||
auto inStride1 = info.getInStridePreAxis(); | ||
auto inStride2 = info.getInStride(); | ||
auto outStride1 = info.getOutStridePreAxis(); | ||
auto outStride2 = inStride2; | ||
|
||
for(size_t m = 0; m < M; m ++){ | ||
using PairType = std::pair<float, uint8_t>; | ||
std::list<PairType> list; | ||
for(size_t n = 0; n < N; n++){ | ||
auto srcIdx = m /inStride2 * inStride1 + m % inStride2 + n * inStride2; | ||
list.push_back({src[srcIdx],n}); | ||
} | ||
list.sort([](const PairType &a, const PairType &b)->bool{return a.first > b.first;}); | ||
|
||
size_t offset = m /inStride2 * outStride1 + m % inStride2; | ||
std::for_each_n(list.begin(), (uint32_t)info.topk, | ||
[&](auto &elem) { | ||
dstVal[offset] = elem.first; | ||
dstIndex[offset] = elem.second; | ||
offset += outStride2; | ||
}); | ||
} | ||
}; | ||
} | ||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef KERNEL_TOPK_CPU_KERNEL_HH | ||
#define KERNEL_TOPK_CPU_KERNEL_HH | ||
|
||
#include "kernel/attributes/topk_info.h" | ||
#include "kernel/kernel.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct TopKCpu final : public Kernel { | ||
TopKInfo info; | ||
explicit TopKCpu(TopKInfo info) noexcept; | ||
|
||
static KernelBox build(TopKInfo info) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
RoutineWorkspace lower(Resources &) const noexcept final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_SPLIT_CPU_KERNEL_HH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
#include "../../../src/kernels/topk/cpu_kernel.hh" | ||
#include <gtest/gtest.h> | ||
#include <numeric> | ||
|
||
using namespace refactor; | ||
using namespace kernel; | ||
|
||
TEST(kernel, TopKCpu) { | ||
// build routine | ||
auto inputTensor = Tensor::share(DataType::F32, Shape{3, 4}); | ||
auto outputTensor0 = Tensor::share(DataType::F32, Shape{3, 3}); | ||
auto outputTensor1 = Tensor::share(DataType::U32, Shape{3, 3}); | ||
|
||
auto kernel = TopKCpu::build(TopKInfo(3,1, *inputTensor)); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> ins(inputTensor->elementsSize()); | ||
std::vector<float> out0(outputTensor0->elementsSize()); | ||
std::vector<uint32_t> out1(outputTensor1->elementsSize()); | ||
|
||
std::iota(ins.begin(), ins.end(), 0); | ||
// inference | ||
void const *inputs[]{ins.data()}; | ||
void *outputs[]{out0.data(), out1.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
|
||
// check | ||
std::vector<float> expectVal = {3,2,1,7,6,5,11,10,9}; | ||
std::vector<uint32_t> expectIdx = {3,2,1,3,2,1,3,2,1}; | ||
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";}); | ||
|
||
for(size_t i=0;i< expectVal.size(); ++i){ | ||
EXPECT_EQ(expectVal[i], out0[i]); | ||
EXPECT_EQ(expectIdx[i], out1[i]); | ||
} | ||
} | ||
|
||
TEST(kernel, TopKCpu1) { | ||
// build routine | ||
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 4, 2}); | ||
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 2}); | ||
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 2}); | ||
|
||
auto kernel = TopKCpu::build(TopKInfo(3,1, *inputTensor)); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> ins(inputTensor->elementsSize()); | ||
std::vector<float> out0(outputTensor0->elementsSize()); | ||
std::vector<uint32_t> out1(outputTensor1->elementsSize()); | ||
|
||
std::iota(ins.begin(), ins.end(), 0); | ||
// inference | ||
void const *inputs[]{ins.data()}; | ||
void *outputs[]{out0.data(), out1.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";}); | ||
|
||
// check | ||
std::vector<float> expectVal = {6,7,4,5,2,3,14,15,12,13,10,11}; | ||
std::vector<uint32_t> expectIdx = {3,3,2,2,1,1,3,3,2,2,1,1}; | ||
|
||
|
||
for(size_t i=0;i< expectVal.size(); ++i){ | ||
EXPECT_EQ(expectVal[i], out0[i]); | ||
EXPECT_EQ(expectIdx[i], out1[i]); | ||
} | ||
} | ||
|
||
TEST(kernel, TopKCpu2) { | ||
// build routine | ||
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 4, 2}); | ||
auto outputTensor0 = Tensor::share(DataType::F32, Shape{1, 4, 2}); | ||
auto outputTensor1 = Tensor::share(DataType::U32, Shape{1, 4, 2}); | ||
|
||
auto kernel = TopKCpu::build(TopKInfo(1,0, *inputTensor)); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> ins(inputTensor->elementsSize()); | ||
std::vector<float> out0(outputTensor0->elementsSize()); | ||
std::vector<uint32_t> out1(outputTensor1->elementsSize()); | ||
|
||
std::iota(ins.begin(), ins.end(), 0); | ||
// inference | ||
void const *inputs[]{ins.data()}; | ||
void *outputs[]{out0.data(), out1.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";}); | ||
|
||
// check | ||
std::vector<float> expectVal = {8,9,10,11,12,13,14,15}; | ||
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1}; | ||
|
||
|
||
for(size_t i=0;i< expectVal.size(); ++i){ | ||
EXPECT_EQ(expectVal[i], out0[i]); | ||
EXPECT_EQ(expectIdx[i], out1[i]); | ||
} | ||
} | ||
|
||
|
||
TEST(kernel, TopKCpu3) { | ||
// build routine | ||
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); | ||
auto outputTensor0 = Tensor::share(DataType::F32, Shape{1, 3, 2, 2}); | ||
auto outputTensor1 = Tensor::share(DataType::U32, Shape{1, 3, 2, 2}); | ||
|
||
auto kernel = TopKCpu::build(TopKInfo(1,0, *inputTensor)); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> ins(inputTensor->elementsSize()); | ||
std::vector<float> out0(outputTensor0->elementsSize()); | ||
std::vector<uint32_t> out1(outputTensor1->elementsSize()); | ||
|
||
std::iota(ins.begin(), ins.end(), 0); | ||
// inference | ||
void const *inputs[]{ins.data()}; | ||
void *outputs[]{out0.data(), out1.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";}); | ||
|
||
// check | ||
std::vector<float> expectVal = {12, 13, 14, 15, 16, 17, 18, 19, 20,21, 22,23}; | ||
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1}; | ||
|
||
|
||
for(size_t i=0;i< expectVal.size(); ++i){ | ||
EXPECT_EQ(expectVal[i], out0[i]); | ||
EXPECT_EQ(expectIdx[i], out1[i]); | ||
} | ||
} | ||
|
||
TEST(kernel, TopKCpu4) { | ||
// build routine | ||
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); | ||
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 2, 2, 2}); | ||
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 2, 2, 2}); | ||
|
||
auto kernel = TopKCpu::build(TopKInfo(2,1, *inputTensor)); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> ins(inputTensor->elementsSize()); | ||
std::vector<float> out0(outputTensor0->elementsSize()); | ||
std::vector<uint32_t> out1(outputTensor1->elementsSize()); | ||
|
||
std::iota(ins.begin(), ins.end(), 0); | ||
// inference | ||
void const *inputs[]{ins.data()}; | ||
void *outputs[]{out0.data(), out1.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";}); | ||
|
||
// check | ||
std::vector<float> expectVal = {8, 9, 10, 11,4,5,6,7,20,21,22,23,16,17,18,19}; | ||
std::vector<uint32_t> expectIdx = {2,2,2,2,1,1,1,1,2,2,2,2,1,1,1,1}; | ||
|
||
|
||
for(size_t i=0;i< expectVal.size(); ++i){ | ||
EXPECT_EQ(expectVal[i], out0[i]); | ||
EXPECT_EQ(expectIdx[i], out1[i]); | ||
} | ||
} | ||
|
||
|
||
TEST(kernel, TopKCpu5) { | ||
// build routine | ||
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); | ||
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 1, 2}); | ||
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 1, 2}); | ||
|
||
auto kernel = TopKCpu::build(TopKInfo(1,2, *inputTensor)); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> ins(inputTensor->elementsSize()); | ||
std::vector<float> out0(outputTensor0->elementsSize()); | ||
std::vector<uint32_t> out1(outputTensor1->elementsSize()); | ||
|
||
std::iota(ins.begin(), ins.end(), 0); | ||
// inference | ||
void const *inputs[]{ins.data()}; | ||
void *outputs[]{out0.data(), out1.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";}); | ||
|
||
// check | ||
std::vector<float> expectVal = {2,3,6,7,10,11,14,15,18,19,22,23}; | ||
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1}; | ||
|
||
|
||
for(size_t i=0;i< expectVal.size(); ++i){ | ||
EXPECT_EQ(expectVal[i], out0[i]); | ||
EXPECT_EQ(expectIdx[i], out1[i]); | ||
} | ||
} | ||
|
||
TEST(kernel, TopKCpu6) { | ||
// build routine | ||
auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); | ||
auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 2, 1}); | ||
auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 2, 1}); | ||
|
||
auto kernel = TopKCpu::build(TopKInfo(1,3, *inputTensor)); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> ins(inputTensor->elementsSize()); | ||
std::vector<float> out0(outputTensor0->elementsSize()); | ||
std::vector<uint32_t> out1(outputTensor1->elementsSize()); | ||
|
||
std::iota(ins.begin(), ins.end(), 0); | ||
// inference | ||
void const *inputs[]{ins.data()}; | ||
void *outputs[]{out0.data(), out1.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<<val<<" ";}); | ||
|
||
// check | ||
std::vector<float> expectVal = {1,3,5,7,9,11,13,15,17,19,21,23}; | ||
std::vector<uint32_t> expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1}; | ||
|
||
|
||
for(size_t i=0;i< expectVal.size(); ++i){ | ||
EXPECT_EQ(expectVal[i], out0[i]); | ||
EXPECT_EQ(expectIdx[i], out1[i]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef COMPUTATION_TOPK_H | ||
#define COMPUTATION_TOPK_H | ||
|
||
#include "../operator.h" | ||
|
||
namespace refactor::computation { | ||
|
||
struct TopK final : public Operator { | ||
uint32_t topk,axis; | ||
constexpr TopK(uint32_t topk, uint32_t axis) noexcept : topk(topk), axis(axis){} | ||
|
||
static size_t typeId() noexcept; | ||
size_t opTypeId() const noexcept final; | ||
std::string_view name() const noexcept final; | ||
kernel::CollectorBox candidateKernels(Target) const noexcept final; | ||
}; | ||
|
||
}// namespace refactor::computation | ||
|
||
#endif// COMPUTATION_SPLIT_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#include "computation/operators/topk.h" | ||
#include "kernel/collectors/topk.h" | ||
|
||
namespace refactor::computation { | ||
|
||
size_t TopK::typeId() noexcept { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
size_t TopK::opTypeId() const noexcept { return typeId(); } | ||
std::string_view TopK::name() const noexcept { return "TopK"; } | ||
auto TopK::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { | ||
using Collector_ = kernel::TopKCollector; | ||
return std::make_unique<Collector_>(target, topk, axis); | ||
} | ||
|
||
}// namespace refactor::computation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#include "common.h" | ||
#include "topk.hh" | ||
#include "computation/operators/topk.h" | ||
#include <execution> | ||
|
||
namespace refactor::onnx { | ||
using Op = TopK; | ||
|
||
Op::TopK(Int topk, Int axis):topk(topk), axis(axis){} | ||
|
||
auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { | ||
auto axis = attributes["axis"].int_(); | ||
auto topk = attributes["topk"].int_(); | ||
return OpBox(std::make_unique<Op>(topk, axis)); | ||
} | ||
|
||
auto Op::typeId() -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto Op::opTypeId() const -> size_t { return typeId(); } | ||
auto Op::opTypeName() const -> std::string_view { return "TopK"; } | ||
|
||
auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { | ||
if (inputs.empty() || inputs.size() >= 2) { | ||
return Err(InferError(ERROR_MSG("Input size error"))); | ||
} | ||
auto const &input = inputs[0]; | ||
auto rank = input.rank(); | ||
auto axis_ = axis < 0 ? axis + rank : axis; | ||
if (rank <= axis_) { | ||
return Err(InferError(ERROR_MSG("axis error"))); | ||
} | ||
if (topk < 0 || topk > input.shape[axis_].value()){ | ||
return Err(InferError(ERROR_MSG("topk error"))); | ||
} | ||
|
||
Tensors ans(2, nullptr); | ||
auto dependencies = extractDependency(inputs); | ||
ans[0] = Tensor::share(input.dataType, input.shape, dependencies); | ||
ans[0]->shape[axis_] = DimExpr(topk); | ||
ans[1] = Tensor::share(input.dataType, input.shape, dependencies); | ||
ans[1]->shape[axis_] = DimExpr(topk); | ||
return Ok(Tensors{std::move(ans)}); | ||
} | ||
|
||
auto Op::lower(TensorRefs inputs) const -> computation::OpBox { | ||
using Op_ = computation::TopK; | ||
auto rank = inputs[0].rank(); | ||
auto axis_ = axis < 0 ? axis + rank : axis; | ||
return std::make_unique<Op_>(topk, axis_); | ||
} | ||
|
||
}// namespace refactor::onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef ONNX_TOPK_HH | ||
#define ONNX_TOPK_HH | ||
|
||
#include "frontend/operator.h" | ||
|
||
namespace refactor::onnx { | ||
using namespace frontend; | ||
|
||
struct TopK final : public Operator { | ||
Int topk, axis; | ||
TopK(Int topk, Int axis); | ||
|
||
static size_t typeId(); | ||
static OpBox build(ModelContext const &, std::string_view, Attributes); | ||
size_t opTypeId() const final; | ||
std::string_view opTypeName() const final; | ||
InferResult infer(TensorRefs, InferOptions const &) const final; | ||
computation::OpBox lower(TensorRefs) const final; | ||
}; | ||
|
||
}// namespace refactor::onnx | ||
|
||
#endif// ONNX_WHERE_HH |