Skip to content

Commit

Permalink
Merge pull request #84 from InfiniTensor/add_max_min_kernel
Browse files Browse the repository at this point in the history
add max/min kernel
  • Loading branch information
YdrMaster authored Jan 31, 2024
2 parents 6630866 + e237349 commit 35dc6c8
Show file tree
Hide file tree
Showing 13 changed files with 571 additions and 35 deletions.
26 changes: 26 additions & 0 deletions src/04kernel/include/kernel/collectors/select.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef KERNEL_SELECT_H
#define KERNEL_SELECT_H

#include "../collector.h"

namespace refactor::kernel {

enum class SelectType {
Max,
Min,
};

std::string_view opName(SelectType type);

struct SelectCollector final : public InfoCollector {
SelectType selectType;

SelectCollector(decltype(_target), SelectType) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_SELECT_H
16 changes: 0 additions & 16 deletions src/04kernel/include/kernel/selector.h

This file was deleted.

18 changes: 11 additions & 7 deletions src/04kernel/src/attributes/broadcaster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,17 @@ namespace refactor::kernel {
}()) {}

void Broadcaster::locate(dim_t k, dim_t ans[]) const noexcept {
long rem = k;
std::fill_n(ans, inputsCount, 0);
for (auto i : range0_(strides.size() / (inputsCount + 1))) {
auto dim = strides.data() + (inputsCount + 1) * i;
auto div = std::div(rem, dim[inputsCount]);
for (auto j : range0_(inputsCount)) { ans[j] += dim[j] * div.quot; }
rem = div.rem;
if (!needBroadcast()) {
std::fill_n(ans, inputsCount, k);
} else {
long rem = k;
std::fill_n(ans, inputsCount, 0);
for (auto i : range0_(strides.size() / (inputsCount + 1))) {
auto dim = strides.data() + (inputsCount + 1) * i;
auto div = std::div(rem, dim[inputsCount]);
for (auto j : range0_(inputsCount)) { ans[j] += dim[j] * div.quot; }
rem = div.rem;
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions src/04kernel/src/collectors/select.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "kernel/collectors/select.h"
#include "../kernels/select/cpu_kernel.hh"
#include "../kernels/select/cuda_kernel.hh"

namespace refactor::kernel {

#define REGISTER(T) \
if (auto ptr = T::build(selectType, inputs); ptr) { \
ans.emplace_back(std::move(ptr)); \
}

#define CASE(OP) \
case SelectType::OP: \
return #OP

std::string_view opName(SelectType type) {
switch (type) {
CASE(Max);
CASE(Min);
default:
UNREACHABLE();
}
}

SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept
: InfoCollector(target), selectType(type) {}

std::vector<KernelBox>
SelectCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
REGISTER(SelectCpu)
break;
case decltype(_target)::Nvidia:
REGISTER(SelectCuda)
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
7 changes: 0 additions & 7 deletions src/04kernel/src/kernels/concat/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,6 @@ extern "C" __global__ void kernel(
}
auto segments = ss.str();

ss.str("");
for (auto i : range0_(inputCount)) {
ss << std::endl
<< " reinterpret_cast<char const *>(inputs[" << i << "]), ";
}
auto castInputs = ss.str();

ss.str("");
ss << "Concat_" << info.blockCount << ',' << unit;
for (auto seg : info.segments) {
Expand Down
91 changes: 91 additions & 0 deletions src/04kernel/src/kernels/select/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "cpu_kernel.hh"
#include <execution>

namespace refactor::kernel {
using K = SelectCpu;
using DT = DataType;

K::SelectCpu(
decltype(dataType) dataType_,
decltype(selectType) selectType_,
decltype(broadcaster) broadcaster_,
decltype(inputsNum) inputsNum_) noexcept
: dataType(dataType_),
selectType(selectType_),
broadcaster(broadcaster_),
inputsNum(inputsNum_) {}

auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox {
auto const &x = inputs_[0].get();
return x.dataType.isCpuNumberic()
? std::make_unique<K>(x.dataType, selectType_, Broadcaster(inputs_), inputs_.size())
: nullptr;
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing select operation on generic cpu";
}

template<class T>
auto lowerTyped(SelectType selectType, Broadcaster broadcaster, size_t inputsNum) noexcept -> RoutineWorkspace {
using namespace runtime;

T(*op)
(T const a, T const b);
switch (selectType) {
case SelectType::Max:
op = [](T const a, T const b) { return std::max(a, b); };
break;
case SelectType::Min:
op = [](T const a, T const b) { return std::min(a, b); };
break;
default:
UNREACHABLE();
}

return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto output = reinterpret_cast<T *>(outputs[0]);
for (auto i : range0_(broadcaster.outputsCount)) {
std::vector<dim_t> ans(broadcaster.inputsCount);
broadcaster.locate(i, ans.data());
for (auto inputIdx : range0_(inputsNum)) {
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
if (inputIdx == 0) {
output[i] = input[ans[inputIdx]];
} else {
output[i] = op(output[i], input[ans[inputIdx]]);
}
}
}
};
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
#define CASE(DT) \
case DataType::DT: \
return lowerTyped<primitive<DataType::DT>::type>(selectType, broadcaster, inputsNum)

switch (dataType) {
CASE(F32);
CASE(U8);
CASE(I8);
CASE(U16);
CASE(I16);
CASE(I32);
CASE(I64);
CASE(F64);
CASE(U32);
CASE(U64);
default:
UNREACHABLE();
}
}

}// namespace refactor::kernel
29 changes: 29 additions & 0 deletions src/04kernel/src/kernels/select/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef KERNEL_SELECT_CPU_KERNEL_HH
#define KERNEL_SELECT_CPU_KERNEL_HH

#include "kernel/attributes/broadcaster.h"
#include "kernel/collectors/select.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct SelectCpu final : public Kernel {
DataType dataType;
SelectType selectType;
Broadcaster broadcaster;
size_t inputsNum;

SelectCpu(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept;

static KernelBox build(SelectType, TensorRefs) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_Select_CPU_KERNEL_HH
Loading

0 comments on commit 35dc6c8

Please sign in to comment.