Skip to content

Commit

Permalink
fix: fix mod kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
kilinchange authored and bitzyz committed Jan 17, 2024
1 parent 3998833 commit 5978296
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/04kernel/src/kernels/simple_binary/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ namespace refactor::kernel {
switch (dataType.internal) {
CASE_DT(std::fmod(a, b), F32);
CASE_DT(a % b, U8);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I8);
CASE_DT(static_cast<int8_t>(std::fmod(a, b)), I8);
CASE_DT(a % b, U16);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I16);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I32);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I64);
CASE_DT(static_cast<int16_t>(std::fmod(a, b)), I16);
CASE_DT(static_cast<int32_t>(std::fmod(a, b)), I32);
CASE_DT(static_cast<int64_t>(std::fmod(a, b)), I64);
CASE_DT(std::fmod(a, b), F64);
CASE_DT(a % b, U32);
CASE_DT(a % b, U64);
Expand Down
12 changes: 8 additions & 4 deletions src/04kernel/src/kernels/simple_binary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,18 @@ extern "C" __global__ void kernel(
case SimpleBinaryType::Fmod:
switch (dt) {
case DataType::U8:
case DataType::I8:
case DataType::U16:
case DataType::U32:
case DataType::U64:
return "a % b";
case DataType::I8:
return "static_cast<char>(fmodf(a, b))";
case DataType::I16:
return "static_cast<short>(fmodf(a, b))";
case DataType::I32:
return "static_cast<int>(fmodf(a, b))";
case DataType::I64:
case DataType::U32:
case DataType::U64:
return "a % b < 0 ? (a % b + b) : (a % b)";
return "static_cast<long long>(fmodf(a, b))";
case DataType::F32:
return "fmodf(a, b)";
case DataType::FP16:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST(kernel, BinaryCpu) {
testBinaryCPU(SimpleBinaryType::Mul, [](float a, float b) { return a * b; });
testBinaryCPU(SimpleBinaryType::Div, [](float a, float b) { return a / b; });
testModCPU(SimpleBinaryType::Mod, [](int a, int b) { return a % b; });
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return a % b < 0 ? (a % b + b) : (a % b); });
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return static_cast<int32_t>(std::fmod(a, b)); });
testBinaryCPU(SimpleBinaryType::Fmod, [](float a, float b) { return std::fmod(a, b); });
}

Expand Down
1 change: 1 addition & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ namespace refactor::onnx {
REGISTER(And , SimpleBinary );
REGISTER(Or , SimpleBinary );
REGISTER(Xor , SimpleBinary );
REGISTER(Mod , SimpleBinary );
REGISTER(Abs , SimpleUnary );
REGISTER(Acos , SimpleUnary );
REGISTER(Acosh , SimpleUnary );
Expand Down

0 comments on commit 5978296

Please sign in to comment.