Skip to content

Commit

Permalink
fix(kernel): 为 Gather 支持负的 indices 值
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 30, 2024
1 parent 9495516 commit 6630866
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/04kernel/cuda/src/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ namespace refactor::kernel::cuda {
tid += step) {
auto i = tid / batch,
j = tid % batch;
auto index = __ldg(indices + i % midSizeO);
auto k = __ldg(indices + i % midSizeO);
auto quot = k >= 0 ? i / midSizeO : i / midSizeO + 1;
optimizedMemcpy(unit * tid + output,
unit * (batch * (i / midSizeO * midSizeI + index) + j) + data,
unit * (batch * (quot * midSizeI + k) + j) + data,
unit);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/04kernel/src/kernels/gather/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ namespace refactor::kernel {
int64_t k = info.idxType == DataType::I64
? reinterpret_cast<int64_t const *>(inputs[1])[d.rem]
: reinterpret_cast<int32_t const *>(inputs[1])[d.rem];
auto quot = k >= 0 ? d.quot : d.quot + 1;
std::memcpy(info.postfix * i + output,
info.postfix * (d.quot * info.midSizeI + k) + data,
info.postfix * (quot * info.midSizeI + k) + data,
info.postfix);
});
};
Expand Down

0 comments on commit 6630866

Please sign in to comment.