Skip to content

Commit

Permalink
fix(kernel): 双目运算不能交换
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 18, 2024
1 parent ec39fd7 commit 20788a3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
41 changes: 30 additions & 11 deletions src/04kernel/src/kernels/simple_binary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,28 @@ extern "C" __global__ void kernel(
}}
)~";

constexpr static const char *SCALAR = R"~(
constexpr static const char *SCALAR_A = R"~(
__device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{
return {1:};
}}
extern "C" __global__ void kernel(
{0:} *__restrict__ y,
{0:} const *__restrict__ s,
{0:} const *__restrict__ v,
size_t n
) {{
auto num = *s;
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {{
y[tid] = fn(num, v[tid]);
}}
}}
)~";

constexpr static const char *SCALAR_B = R"~(
__device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{
return {1:};
}}
Expand Down Expand Up @@ -209,19 +230,17 @@ extern "C" __global__ void kernel(

} else if (auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); rank == 1) {
static const std::vector<dim_t> S0{0, 1, 1}, S1{1, 0, 1};
auto name = fmt::format("binaryScalar{}", postfix);
auto code = fmt::format(SCALAR, dt_, op_);
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
// clang-format off
scalar = broadcaster.strides == S0 ? 0
: broadcaster.strides == S1 ? 1
: UNREACHABLEX(int, "Unreachable")]// clang-format on
auto scalar_a = broadcaster.strides == S0;
auto name = fmt::format("binaryScalar{}{}", postfix, scalar_a ? "A" : "B");
auto code = scalar_a ? fmt::format(SCALAR_A, dt_, op_)
: fmt::format(SCALAR_B, dt_, op_);
return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]//
(Resources &, void *, void const *const *inputs, void *const *outputs) {
auto c = outputs[0];
auto s = inputs[scalar],
v = inputs[1 - scalar];
auto a = inputs[0],
b = inputs[1];
auto n = params.n;
void *args[]{&c, &v, &s, &n};
void *args[]{&c, &a, &b, &n};
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
Expand Down
9 changes: 8 additions & 1 deletion src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ TEST(kernel, BinaryCudaFmodF32) {
}

TEST(kernel, BinaryCudaBroadcast) {
testBinaryCuda<DataType::I8>(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6});
testBinaryCuda<DataType::F32>(SimpleBinaryType::Sub,
Shape{1, 2, 3, 4, 5, 6},
Shape{},
Shape{1, 2, 3, 4, 5, 6});
testBinaryCuda<DataType::F32>(SimpleBinaryType::Div,
Shape{},
Shape{1, 2, 3, 4, 5, 6},
Shape{1, 2, 3, 4, 5, 6});
}

#endif

0 comments on commit 20788a3

Please sign in to comment.