Skip to content

Commit

Permalink
feat: add HardSigmoid cpu/cuda kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Jan 17, 2024
1 parent 54c2f7e commit 0b314ff
Show file tree
Hide file tree
Showing 15 changed files with 457 additions and 1 deletion.
19 changes: 19 additions & 0 deletions src/04kernel/include/kernel/collectors/hard_sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef KERNEL_HARD_SIGMOIG_H
#define KERNEL_HARD_SIGMOIG_H

#include "../collector.h"

namespace refactor::kernel {

struct HardSigmoidCollector final : public InfoCollector {
float alpha, beta;

constexpr HardSigmoidCollector(decltype(_target) target, float alpha_, float beta_) noexcept
: InfoCollector(target), alpha(alpha_), beta(beta_) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};
}// namespace refactor::kernel

#endif// KERNEL_HARD_SIGMOIG_H
29 changes: 29 additions & 0 deletions src/04kernel/src/collectors/hard_sigmoid.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "kernel/collectors/hard_sigmoid.h"
#include "../kernels/hard_sigmoid/cpu_kernel.hh"
#include "../kernels/hard_sigmoid/cuda_kernel.hh"

namespace refactor::kernel {

std::vector<KernelBox>
HardSigmoidCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &a = inputs[0];

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = HardSigmoidCpu::build(alpha, beta, a); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = HardSigmoidCuda::build(alpha, beta, a); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
53 changes: 53 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "cpu_kernel.hh"
#include <execution>

namespace refactor::kernel {
using K = HardSigmoidCpu;
using DT = DataType;

K::HardSigmoidCpu(float alpha_, float beta_, DT dataType_, size_t size_) noexcept
: Kernel(), alpha(alpha_), beta(beta_), dataType(dataType_), size(size_) {}

auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox {
if (!a.dataType.isCpuNumberic()) {
return nullptr;
}
return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize());
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing HardSigmoid using CPU";
}

template<class T>
static Routine lowerTyped(float alpha_, float beta_, size_t size) {
using namespace runtime;

return [=](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto x = reinterpret_cast<T const *>(inputs[0]);
auto y = reinterpret_cast<T *>(outputs[0]);
std::for_each_n(std::execution::par_unseq,
natural_t(0), size,
[&](auto i) {
y[i] = std::max(static_cast<T>(0), std::min(static_cast<T>(1), alpha_ * x[i] + beta_));
});
};
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
switch (dataType) {
case DT::F32:
return lowerTyped<float>(alpha, beta, size);
case DT::F64:
return lowerTyped<double>(alpha, beta, size);
default:
UNREACHABLE();
}
}
}// namespace refactor::kernel
26 changes: 26 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef KERNEL_HARD_SIGMOID_CPU_KERNEL_HH
#define KERNEL_HARD_SIGMOID_CPU_KERNEL_HH

#include "kernel/collectors/hard_sigmoid.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct HardSigmoidCpu final : public Kernel {
float alpha, beta;
DataType dataType;
size_t size;

explicit HardSigmoidCpu(float, float, DataType, size_t) noexcept;

static KernelBox build(float, float, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_HARD_SIGMOID_CPU_KERNEL_HH
87 changes: 87 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include "cuda_kernel.hh"

#ifdef USE_CUDA
#include "../../generator/nvrtc_repo.h"
#include "kernel/cuda/threads_distributer.cuh"
#include <cuda_runtime.h>
#endif

namespace refactor::kernel {
using K = HardSigmoidCuda;
using DT = DataType;

K::HardSigmoidCuda(float alpha_, float beta_, DT dt_, size_t size_) noexcept
: Kernel(), alpha(alpha_), beta(beta_), dataType(dt_), size(size_) {}

auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif
return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize());
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing hardsigmoid operation on Nvidia GPU";
}

#ifdef USE_CUDA
constexpr static const char *TEMPLATE = R"~(
__device__ __forceinline__ static {0:} fn({0:} x) {{
return {1:};
}}
extern "C" __global__ void kernel(
{0:} *__restrict__ y,
{0:} const *__restrict__ x,
size_t n
) {{
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step)
y[tid] = fn(x[tid]);
}}
)~";
auto K::lower(Resources &res) const -> RoutineWorkspace {
using namespace runtime;

std::string op = "";
switch (dataType) {
case DT::F32:
op = fmt::format("fmaxf(0.f, fminf(1.f, fmaf({}, x, {})))", alpha, beta);
break;
case DT::F64:
op = fmt::format("fmax(0.0, fmin(1.0, fma({}, x, {})))",
static_cast<double>(alpha), static_cast<double>(beta));
break;
case DT::FP16:
op = fmt::format("__hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, (__float2half({}) * x + __float2half({}))))",
alpha, beta);
break;
default:
UNREACHABLE();
}
auto name = fmt::format("hardsigmoid_{}_{}_{}", dataType.name(), alpha, beta);
auto code = fmt::format(TEMPLATE, nvrtc::dataType(dataType), op);
return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
params = cuda::ThreadsDistributer()(size)](
Resources &, void *, void const *const *inputs, void *const *outputs) {
auto y = outputs[0];
auto x = inputs[0];
auto n = params.n;
void *args[]{&y, &x, &n};
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}
#endif

}// namespace refactor::kernel
28 changes: 28 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH
#define KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH

#include "kernel/collectors/hard_sigmoid.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct HardSigmoidCuda final : public Kernel {
float alpha, beta;
DataType dataType;
size_t size;

explicit HardSigmoidCuda(float, float, DataType, size_t) noexcept;

static KernelBox build(float, float, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH
31 changes: 31 additions & 0 deletions src/04kernel/test/kernels/hard_sigmoid/test_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;

TEST(kernel, HardSigmoidCpu) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5});
float alpha = 0.2f, beta = 0.5f;
auto kernel = HardSigmoidCpu::build(alpha, beta, *dataTensor);
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> result(dataTensor->elementsSize());
for (auto i : range0_(result.size())) { result[i] = i; }
// inference
{
void const *inputs[]{result.data()};
void *outputs[]{result.data()};
routine(res, nullptr, inputs, outputs);
}
std::vector<float> output = {0.5, 0.7, 0.9, 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1.};
// check
for (auto i : range0_(result.size())) {
EXPECT_FLOAT_EQ(output[i], result[i]);
}
}
49 changes: 49 additions & 0 deletions src/04kernel/test/kernels/hard_sigmoid/test_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifdef USE_CUDA

#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh"
#include "../../../src/kernels/hard_sigmoid/cuda_kernel.hh"
#include "hardware/device_manager.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;
using namespace hardware;

TEST(kernel, HardSigmoidCuda) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5});
float alpha = 0.2f, beta = 0.5f;
auto kernel = HardSigmoidCuda::build(alpha, beta, *dataTensor);
auto kCpu = HardSigmoidCpu::build(alpha, beta, *dataTensor);
ASSERT_TRUE(kernel && kCpu);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine,
rCpu = kCpu->lower(res).routine;
// malloc
auto &dev = *device::init(Device::Type::Nvidia, 0, "");
auto gpuMem = dev.malloc(dataTensor->bytesSize());
// put input data
std::vector<float> data(dataTensor->elementsSize());
for (auto i : range0_(data.size())) { data[i] = i; }
gpuMem->copyFromHost(data.data(), dataTensor->bytesSize());
// inference
{
void const *inputs[]{*gpuMem};
void *outputs[]{*gpuMem};
routine(res, nullptr, inputs, outputs);
}
{
void const *inputs[]{data.data()};
void *outputs[]{data.data()};
rCpu(res, nullptr, inputs, outputs);
}
// take output data
std::vector<float> result(dataTensor->elementsSize());
gpuMem->copyToHost(result.data(), dataTensor->bytesSize());
// check
for (auto i : range0_(data.size())) {
EXPECT_FLOAT_EQ(data[i], result[i]);
}
}

#endif
23 changes: 23 additions & 0 deletions src/05computation/include/computation/operators/hard_sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef COMPUTATION_HARD_SIGMOID_H
#define COMPUTATION_HARD_SIGMOID_H

#include "../operator.h"

namespace refactor::computation {

struct HardSigmoid final : public Operator {
float alpha, beta;

constexpr HardSigmoid(float alpha_, float beta_) noexcept
: Operator(), alpha(alpha_), beta(beta_){};

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const noexcept final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_HARD_SIGMOID_H
22 changes: 22 additions & 0 deletions src/05computation/src/operators/hard_sigmoid.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "computation/operators/hard_sigmoid.h"
#include "kernel/collectors/hard_sigmoid.h"

namespace refactor::computation {
using Op = HardSigmoid;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "HardSigmoid"; }

auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
using Collector_ = kernel::HardSigmoidCollector;
return std::make_unique<Collector_>(target, alpha, beta);
}
auto Op::serialize() const noexcept -> std::string {
return fmt::format("{}()", name());
}

}// namespace refactor::computation
2 changes: 2 additions & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "operators/gather_elements.hh"
#include "operators/gemm.hh"
#include "operators/global_pool.hh"
#include "operators/hard_sigmoid.hh"
#include "operators/mat_mul.hh"
#include "operators/mat_mul_integer.hh"
#include "operators/pool.hh"
Expand Down Expand Up @@ -124,6 +125,7 @@ namespace refactor::onnx {
REGISTER(Transpose , Transpose );
REGISTER(Unsqueeze , Unsqueeze );
REGISTER(Where , Where );
REGISTER(HardSigmoid , HardSigmoid );
#undef REGISTER
// clang-format on
}
Expand Down
Loading

0 comments on commit 0b314ff

Please sign in to comment.