-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add HardSigmoid cpu/cuda kernel
- Loading branch information
Showing
15 changed files
with
464 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef KERNEL_HARD_SIGMOIG_H | ||
#define KERNEL_HARD_SIGMOIG_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct HardSigmoidCollector final : public InfoCollector { | ||
float alpha, beta; | ||
|
||
constexpr HardSigmoidCollector(decltype(_target) target, float alpha_, float beta_) noexcept | ||
: InfoCollector(target), alpha(alpha_), beta(beta_) {} | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_HARD_SIGMOIG_H | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#include "kernel/collectors/hard_sigmoid.h" | ||
#include "../kernels/hard_sigmoid/cpu_kernel.hh" | ||
#include "../kernels/hard_sigmoid/cuda_kernel.hh" | ||
|
||
namespace refactor::kernel { | ||
|
||
std::vector<KernelBox> | ||
HardSigmoidCollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
auto const &a = inputs[0]; | ||
|
||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
if (auto ptr = HardSigmoidCpu::build(alpha, beta, a); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
case decltype(_target)::Nvidia: | ||
if (auto ptr = HardSigmoidCuda::build(alpha, beta, a); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#include "cpu_kernel.hh" | ||
#include <execution> | ||
|
||
namespace refactor::kernel { | ||
using K = HardSigmoidCpu; | ||
using DT = DataType; | ||
|
||
K::HardSigmoidCpu(float alpha_, float beta_, DT dataType_, size_t size_) noexcept | ||
: Kernel(), alpha(alpha_), beta(beta_), dataType(dataType_), size(size_) {} | ||
|
||
auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox { | ||
if (!a.dataType.isCpuNumberic()) { | ||
return nullptr; | ||
} | ||
return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize()); | ||
} | ||
|
||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing HardSigmoid using CPU"; | ||
} | ||
|
||
template<class T> | ||
static Routine lowerTyped(float alpha_, float beta_, size_t size) { | ||
using namespace runtime; | ||
|
||
return [=](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { | ||
auto x = reinterpret_cast<T const *>(inputs[0]); | ||
auto y = reinterpret_cast<T *>(outputs[0]); | ||
std::for_each_n(std::execution::par_unseq, | ||
natural_t(0), size, | ||
[&](auto i) { | ||
y[i] = std::clamp(alpha_ * x[i] + beta_, static_cast<T>(0), static_cast<T>(1)); | ||
}); | ||
}; | ||
} | ||
|
||
auto K::lower(Resources &) const noexcept -> RoutineWorkspace { | ||
switch (dataType) { | ||
case DT::F32: | ||
return lowerTyped<float>(alpha, beta, size); | ||
case DT::F64: | ||
return lowerTyped<double>(alpha, beta, size); | ||
default: | ||
UNREACHABLE(); | ||
} | ||
} | ||
}// namespace refactor::kernel | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#ifndef KERNEL_HARD_SIGMOID_CPU_KERNEL_HH | ||
#define KERNEL_HARD_SIGMOID_CPU_KERNEL_HH | ||
|
||
#include "kernel/collectors/hard_sigmoid.h" | ||
#include "kernel/tensor.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct HardSigmoidCpu final : public Kernel { | ||
float alpha, beta; | ||
DataType dataType; | ||
size_t size; | ||
|
||
explicit HardSigmoidCpu(float, float, DataType, size_t) noexcept; | ||
|
||
static KernelBox build(float, float, Tensor const &) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
RoutineWorkspace lower(Resources &) const noexcept final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_HARD_SIGMOID_CPU_KERNEL_HH | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#include "cuda_kernel.hh" | ||
|
||
#ifdef USE_CUDA | ||
#include "../../generator/nvrtc_repo.h" | ||
#include "kernel/cuda/threads_distributer.cuh" | ||
#include <cuda_runtime.h> | ||
#endif | ||
|
||
namespace refactor::kernel { | ||
using K = HardSigmoidCuda; | ||
using DT = DataType; | ||
|
||
K::HardSigmoidCuda(float alpha_, float beta_, DT dt_, size_t size_) noexcept | ||
: Kernel(), alpha(alpha_), beta(beta_), dataType(dt_), size(size_) {} | ||
|
||
auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox { | ||
#ifndef USE_CUDA | ||
return nullptr; | ||
#endif | ||
return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize()); | ||
} | ||
|
||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
auto K::kernelTypeId() const noexcept -> size_t { | ||
return typeId(); | ||
} | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing hardsigmoid operation on Nvidia GPU"; | ||
} | ||
|
||
#ifdef USE_CUDA | ||
constexpr static const char *TEMPLATE = R"~( | ||
__device__ __forceinline__ static {0:} fn({0:} x) {{ | ||
return {1:}; | ||
}} | ||
extern "C" __global__ void kernel( | ||
{0:} *__restrict__ y, | ||
{0:} const *__restrict__ x, | ||
size_t n | ||
) {{ | ||
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, | ||
step = blockDim.x * gridDim.x; | ||
tid < n; | ||
tid += step) | ||
y[tid] = fn(x[tid]); | ||
}} | ||
)~"; | ||
auto K::lower(Resources &res) const -> RoutineWorkspace { | ||
using namespace runtime; | ||
|
||
std::string op = ""; | ||
switch (dataType) { | ||
case DT::F32: | ||
op = fmt::format("fmaxf(0.f, fminf(1.f, fmaf({}, x, {})))", alpha, beta); | ||
break; | ||
case DT::F64: | ||
op = fmt::format("fmax(0.0, fmin(1.0, fma({}, x, {})))", | ||
static_cast<double>(alpha), static_cast<double>(beta)); | ||
break; | ||
case DT::FP16: | ||
op = fmt::format("__hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, (__float2half({}) * x + __float2half({}))))", | ||
alpha, beta); | ||
break; | ||
default: | ||
UNREACHABLE(); | ||
} | ||
auto name = fmt::format("hardsigmoid_{}_{}_{}", dataType.name(), alpha, beta); | ||
auto code = fmt::format(TEMPLATE, nvrtc::dataType(dataType), op); | ||
return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), | ||
params = cuda::ThreadsDistributer()(size)]( | ||
Resources &, void *, void const *const *inputs, void *const *outputs) { | ||
auto y = outputs[0]; | ||
auto x = inputs[0]; | ||
auto n = params.n; | ||
void *args[]{&y, &x, &n}; | ||
h->launch(params.gridSize, 1, 1, | ||
params.blockSize, 1, 1, | ||
0, args); | ||
}; | ||
} | ||
#endif | ||
|
||
}// namespace refactor::kernel | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#ifndef KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH | ||
#define KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH | ||
|
||
#include "kernel/collectors/hard_sigmoid.h" | ||
#include "kernel/tensor.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct HardSigmoidCuda final : public Kernel { | ||
float alpha, beta; | ||
DataType dataType; | ||
size_t size; | ||
|
||
explicit HardSigmoidCuda(float, float, DataType, size_t) noexcept; | ||
|
||
static KernelBox build(float, float, Tensor const &) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
#ifdef USE_CUDA | ||
RoutineWorkspace lower(Resources &) const final; | ||
#endif | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh" | ||
#include <gtest/gtest.h> | ||
|
||
using namespace refactor; | ||
using namespace kernel; | ||
|
||
TEST(kernel, HardSigmoidCpu) { | ||
// build routine | ||
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); | ||
float alpha = 0.2f, beta = 0.5f; | ||
auto kernel = HardSigmoidCpu::build(alpha, beta, *dataTensor); | ||
ASSERT_TRUE(kernel); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine; | ||
// put input data | ||
std::vector<float> result(dataTensor->elementsSize()); | ||
for (auto i : range0_(result.size())) { result[i] = i; } | ||
// inference | ||
{ | ||
void const *inputs[]{result.data()}; | ||
void *outputs[]{result.data()}; | ||
routine(res, nullptr, inputs, outputs); | ||
} | ||
std::vector<float> output = {0.5, 0.7, 0.9, 1., 1., 1., 1., 1., 1., | ||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., | ||
1., 1., 1., 1., 1., 1.}; | ||
// check | ||
for (auto i : range0_(result.size())) { | ||
EXPECT_FLOAT_EQ(output[i], result[i]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#ifdef USE_CUDA | ||
|
||
#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh" | ||
#include "../../../src/kernels/hard_sigmoid/cuda_kernel.hh" | ||
#include "hardware/device_manager.h" | ||
#include <gtest/gtest.h> | ||
|
||
using namespace refactor; | ||
using namespace kernel; | ||
using namespace hardware; | ||
|
||
TEST(kernel, HardSigmoidCuda) { | ||
// build routine | ||
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); | ||
float alpha = 0.2f, beta = 0.5f; | ||
auto kernel = HardSigmoidCuda::build(alpha, beta, *dataTensor); | ||
auto kCpu = HardSigmoidCpu::build(alpha, beta, *dataTensor); | ||
ASSERT_TRUE(kernel && kCpu); | ||
auto res = runtime::Resources(); | ||
auto routine = kernel->lower(res).routine, | ||
rCpu = kCpu->lower(res).routine; | ||
// malloc | ||
auto &dev = *device::init(Device::Type::Nvidia, 0, ""); | ||
auto gpuMem = dev.malloc(dataTensor->bytesSize()); | ||
// put input data | ||
std::vector<float> data(dataTensor->elementsSize()); | ||
for (auto i : range0_(data.size())) { data[i] = i; } | ||
gpuMem->copyFromHost(data.data(), dataTensor->bytesSize()); | ||
// inference | ||
{ | ||
void const *inputs[]{*gpuMem}; | ||
void *outputs[]{*gpuMem}; | ||
routine(res, nullptr, inputs, outputs); | ||
} | ||
{ | ||
void const *inputs[]{data.data()}; | ||
void *outputs[]{data.data()}; | ||
rCpu(res, nullptr, inputs, outputs); | ||
} | ||
// take output data | ||
std::vector<float> result(dataTensor->elementsSize()); | ||
gpuMem->copyToHost(result.data(), dataTensor->bytesSize()); | ||
// check | ||
for (auto i : range0_(data.size())) { | ||
EXPECT_FLOAT_EQ(data[i], result[i]); | ||
} | ||
} | ||
|
||
#endif |
23 changes: 23 additions & 0 deletions
23
src/05computation/include/computation/operators/hard_sigmoid.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef COMPUTATION_HARD_SIGMOID_H | ||
#define COMPUTATION_HARD_SIGMOID_H | ||
|
||
#include "../operator.h" | ||
|
||
namespace refactor::computation { | ||
|
||
struct HardSigmoid final : public Operator { | ||
float alpha, beta; | ||
|
||
constexpr HardSigmoid(float alpha_, float beta_) noexcept | ||
: Operator(), alpha(alpha_), beta(beta_){}; | ||
|
||
static size_t typeId() noexcept; | ||
size_t opTypeId() const noexcept final; | ||
std::string_view name() const noexcept final; | ||
kernel::CollectorBox candidateKernels(Target) const noexcept final; | ||
std::string serialize() const noexcept final; | ||
}; | ||
|
||
}// namespace refactor::computation | ||
|
||
#endif// COMPUTATION_HARD_SIGMOID_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include "computation/operators/hard_sigmoid.h" | ||
#include "kernel/collectors/hard_sigmoid.h" | ||
|
||
namespace refactor::computation { | ||
using Op = HardSigmoid; | ||
|
||
auto Op::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
auto Op::opTypeId() const noexcept -> size_t { return typeId(); } | ||
auto Op::name() const noexcept -> std::string_view { return "HardSigmoid"; } | ||
|
||
auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { | ||
using Collector_ = kernel::HardSigmoidCollector; | ||
return std::make_unique<Collector_>(target, alpha, beta); | ||
} | ||
auto Op::serialize() const noexcept -> std::string { | ||
return fmt::format("{}()", name()); | ||
} | ||
|
||
}// namespace refactor::computation | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.