Skip to content

Commit

Permalink
feat(kernel): 实现 concat cpu kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
YdrMaster committed Nov 11, 2023
1 parent 8ec8306 commit 203cc42
Showing 3 changed files with 16 additions and 16 deletions.
24 changes: 12 additions & 12 deletions src/04kernel/src/kernels/concat/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -25,18 +25,18 @@ namespace refactor::kernel {
Routine K::lower(Resources &) const noexcept {
using namespace runtime;
return [info = this->info](Resources &, void const **inputs, void **outputs) {
// auto data = reinterpret_cast<uint8_t const *>(inputs[0]);
// std::for_each_n(std::execution::par_unseq,
// natural_t(0), info.blockCount,
// [=, &info](auto i) {
// auto offset = i * info.sum;
// for (auto j : range0_(info.segments.size())) {
// auto len = info.segments[j];
// auto out = reinterpret_cast<uint8_t *>(outputs[j]);
// std::memcpy(out + i * len, data + offset, len);
// offset += len;
// }
// });
auto dst = reinterpret_cast<uint8_t *>(outputs[0]);
std::for_each_n(std::execution::par_unseq,
natural_t(0), info.blockCount,
[=, &info](auto i) {
auto offset = i * info.sum;
for (auto j : range0_(info.segments.size())) {
auto len = info.segments[j];
auto src = reinterpret_cast<uint8_t const *>(inputs[j]);
std::memcpy(dst + offset, src + i * len, len);
offset += len;
}
});
};
}

6 changes: 3 additions & 3 deletions src/04kernel/src/kernels/split/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -25,15 +25,15 @@ namespace refactor::kernel {
Routine K::lower(Resources &) const noexcept {
using namespace runtime;
return [info = this->info](Resources &, void const **inputs, void **outputs) {
auto data = reinterpret_cast<uint8_t const *>(inputs[0]);
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
std::for_each_n(std::execution::par_unseq,
natural_t(0), info.blockCount,
[=, &info](auto i) {
auto offset = i * info.sum;
for (auto j : range0_(info.segments.size())) {
auto len = info.segments[j];
auto out = reinterpret_cast<uint8_t *>(outputs[j]);
std::memcpy(out + i * len, data + offset, len);
auto dst = reinterpret_cast<uint8_t *>(outputs[j]);
std::memcpy(dst + i * len, src + offset, len);
offset += len;
}
});
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/split/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ namespace refactor::kernel {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing split operation using CUDA";
return "Performing concat operation using CUDA";
}

}// namespace refactor::kernel

0 comments on commit 203cc42

Please sign in to comment.