Skip to content

Commit

Permalink
feat: add hardswish cpu/cuda kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Jan 17, 2024
1 parent 8b6333d commit 3998833
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/collectors/simple_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace refactor::kernel {
Erf,
Neg,
Not,
HardSwish,
};

std::string_view unaryName(SimpleUnaryType type);
Expand Down
1 change: 1 addition & 0 deletions src/04kernel/src/collectors/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace refactor::kernel {
CASE(Erf);
CASE(Neg);
CASE(Not);
CASE(HardSwish);
default:
UNREACHABLE();
}
Expand Down
11 changes: 11 additions & 0 deletions src/04kernel/src/kernels/simple_unary/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace refactor::kernel {
Op::Tanh,
Op::Neg,
Op::Erf,
Op::HardSwish,
};
return supportedOp.contains(op) && a.dataType.isCpuNumberic()
? std::make_unique<K>(op, a.dataType, a.elementsSize())
Expand Down Expand Up @@ -49,6 +50,9 @@ namespace refactor::kernel {
using M = std::conditional_t<sizeof(T) <= 4, float, double>;
return static_cast<T>(std::tanh(static_cast<M>(x)));
}
template<class T> auto hardswishFun(T x) noexcept -> T {
return x * (std::max(0., std::min(1., 1.f / 6 * x + 0.5)));
}
auto copyForUnsigned(size_t n) noexcept -> Routine {
return [n](runtime::Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
std::memcpy(outputs[0], inputs[0], n);
Expand Down Expand Up @@ -171,6 +175,13 @@ namespace refactor::kernel {
default:
UNREACHABLE();
}
case Op::HardSwish:
switch (dataType) {
CASE(hardswishFun, F32);
CASE(hardswishFun, F64);
default:
UNREACHABLE();
}
default:
UNREACHABLE();
}
Expand Down
7 changes: 6 additions & 1 deletion src/04kernel/src/kernels/simple_unary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace refactor::kernel {
static const std::unordered_set<Op>
supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
Op::Sigmoid, Op::Tanh, Op::Neg,
Op::Erf};
Op::Erf, Op::HardSwish};
#ifndef USE_CUDA
return nullptr;
#endif
Expand Down Expand Up @@ -154,6 +154,11 @@ extern "C" __global__ void kernel(
{__(Op::Erf, DT::I64 ), "erf(static_cast<double>(x))"},
{__(Op::Erf, DT::FP16), "__float2half(erff(__half2float(x)))"},
{__(Op::Erf, DT::BF16), "__float2bfloat16(erff(__bfloat162float(x)))"},

{__(Op::HardSwish, DT::F32 ), "x * fmaxf(0.f, fminf(1.f, fmaf(1.f/6.f, x, 0.5f)))"},
{__(Op::HardSwish, DT::FP16), "x * __hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, hrcp(__float2half(6.f)) * x + hrcp(__float2half(2.f))))"},
{__(Op::HardSwish, DT::F64 ), "x * fmax(0.0, fmin(1.0, fma(1.0/6.0, x, 0.5)))"},

};
// clang-format on

Expand Down
29 changes: 28 additions & 1 deletion src/04kernel/test/kernels/simple_unary/test_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using namespace refactor;
using namespace kernel;

using VecFloat = std::vector<float>;

static void testOp(SimpleUnaryType opType, float check(float)) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{20, 30, 50});
Expand All @@ -12,7 +14,7 @@ static void testOp(SimpleUnaryType opType, float check(float)) {
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> data(dataTensor->elementsSize());
VecFloat data(dataTensor->elementsSize());
for (auto i : range0_(data.size())) { data[i] = i * 1e-4f; }
auto result = data;
// inference
Expand All @@ -27,9 +29,34 @@ static void testOp(SimpleUnaryType opType, float check(float)) {
}
}

static void testOpWithData(SimpleUnaryType opType, const VecFloat &data) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3});
auto kernel = SimpleUnaryCpu::build(opType, *dataTensor);
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
VecFloat inputdata(dataTensor->elementsSize());
for (auto i : range0_(inputdata.size())) { inputdata[i] = i; }
auto result = inputdata;
// inference
{
void const *inputs[]{result.data()};
void *outputs[]{result.data()};
routine(res, nullptr, inputs, outputs);
}
// check
for (auto i : range0_(inputdata.size())) {
EXPECT_NEAR(data[i], result[i], 1e-5);
}
}

TEST(kernel, SimpleUnaryCpu) {
testOp(SimpleUnaryType::Abs, std::abs);
testOp(SimpleUnaryType::Sqrt, std::sqrt);
testOp(SimpleUnaryType::Tanh, std::tanh);
testOp(SimpleUnaryType::Erf, std::erf);
testOpWithData(SimpleUnaryType::HardSwish,
VecFloat{0.000000, 0.666667, 1.666667, 3.000000, 4.000000, 5.000000});
}
1 change: 1 addition & 0 deletions src/04kernel/test/kernels/simple_unary/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TEST(kernel, SimpleUnaryCuda) {
testOp(SimpleUnaryType::Sigmoid);
testOp(SimpleUnaryType::Tanh);
testOp(SimpleUnaryType::Erf);
testOp(SimpleUnaryType::HardSwish);
}

#endif
6 changes: 6 additions & 0 deletions src/05computation/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ namespace refactor::computation {
static uint8_t ID = 19;
return reinterpret_cast<size_t>(&ID);
}
case SimpleUnaryType::HardSwish: {
static uint8_t ID = 20;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -128,6 +132,8 @@ namespace refactor::computation {
return "Neg";
case SimpleUnaryType::Not:
return "Not";
case SimpleUnaryType::HardSwish:
return "HardSwish";
default:
UNREACHABLE();
}
Expand Down
1 change: 1 addition & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ namespace refactor::onnx {
REGISTER(Not , SimpleUnary );
REGISTER(Neg , SimpleUnary );
REGISTER(Identity , SimpleUnary );
REGISTER(HardSwish , SimpleUnary );
REGISTER(Slice , Slice );
REGISTER(Softmax , Softmax );
REGISTER(Split , Split );
Expand Down
9 changes: 8 additions & 1 deletion src/07onnx/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace refactor::onnx {
opType == "onnx::Not" ? Ty::Not :
opType == "onnx::Neg" ? Ty::Neg :
opType == "onnx::Identity"? Ty::Identity:
opType == "onnx::HardSwish" ? Ty::HardSwish :
UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType);
// clang-format on

Expand Down Expand Up @@ -129,6 +130,10 @@ namespace refactor::onnx {
static uint8_t ID = 21;
return reinterpret_cast<size_t>(&ID);
}
case Ty::HardSwish: {
static uint8_t ID = 22;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -159,6 +164,7 @@ namespace refactor::onnx {
case Ty::Not : return "onnx::Not";
case Ty::Neg : return "onnx::Neg";
case Ty::Identity : return "onnx::Identity";
case Ty::HardSwish : return "onnx::HardSwish";
default: UNREACHABLE();
}
// clang-format on
Expand Down Expand Up @@ -187,7 +193,7 @@ namespace refactor::onnx {
Ty::Atan, Ty::Atanh,
Ty::Cos, Ty::Cosh,
Ty::Sin, Ty::Sinh,
Ty::Tan},
Ty::Tan, Ty::HardSwish},
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log},
{Ty::Neg},
{Ty::Identity}};
Expand Down Expand Up @@ -287,6 +293,7 @@ namespace refactor::onnx {
case Ty::Not : type_ = Ty_::Not ; break;
case Ty::Neg : type_ = Ty_::Neg ; break;
case Ty::Identity : return std::make_unique<computation::Identity>();
case Ty::HardSwish : type_ = Ty_::HardSwish ; break;
default: UNREACHABLE();
}
// clang-format on
Expand Down
17 changes: 9 additions & 8 deletions src/07onnx/src/operators/simple_unary.hh
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ namespace refactor::onnx {
Atanh,
Cos,
Cosh,
Sin,
Sinh,
Tan,
Tanh,
Relu,
Sqrt,
Sigmoid,
Erf,
HardSwish,
Identity,
Log,
Not,
Neg,
Identity,
Relu,
Sin,
Sinh,
Sqrt,
Sigmoid,
Tan,
Tanh,
};

struct SimpleUnary final : public Operator {
Expand Down
14 changes: 14 additions & 0 deletions src/07onnx/test/test_simple_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,18 @@ TEST(infer, SimpleUnary) {
ASSERT_EQ(y->dataType, DataType::F32);
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
}
{
// HardSwish Test
auto edges = Edges{
{Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""},
};
count_t inputs[]{0};
auto infered = SimpleUnary(SimpleUnaryType::HardSwish).infer(TensorRefs(edges, inputs), {true});
ASSERT_TRUE(infered.isOk());
auto outputs = std::move(infered.unwrap());
ASSERT_EQ(outputs.size(), 1);
auto y = std::move(outputs[0]);
ASSERT_EQ(y->dataType, DataType::F32);
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
}
}

0 comments on commit 3998833

Please sign in to comment.