Skip to content

Commit

Permalink
feat(kernel): 实现 concat cpu kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Nov 11, 2023
1 parent 8ec8306 commit 4bc0daf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
24 changes: 12 additions & 12 deletions src/04kernel/src/kernels/concat/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ namespace refactor::kernel {
Routine K::lower(Resources &) const noexcept {
using namespace runtime;
return [info = this->info](Resources &, void const **inputs, void **outputs) {
// auto data = reinterpret_cast<uint8_t const *>(inputs[0]);
// std::for_each_n(std::execution::par_unseq,
// natural_t(0), info.blockCount,
// [=, &info](auto i) {
// auto offset = i * info.sum;
// for (auto j : range0_(info.segments.size())) {
// auto len = info.segments[j];
// auto out = reinterpret_cast<uint8_t *>(outputs[j]);
// std::memcpy(out + i * len, data + offset, len);
// offset += len;
// }
// });
auto dst = reinterpret_cast<uint8_t *>(outputs[0]);
std::for_each_n(std::execution::par_unseq,
natural_t(0), info.blockCount,
[=, &info](auto i) {
auto offset = i * info.sum;
for (auto j : range0_(info.segments.size())) {
auto len = info.segments[j];
auto src = reinterpret_cast<uint8_t const *>(inputs[j]);
std::memcpy(dst + offset, src + i * len, len);
offset += len;
}
});
};
}

Expand Down
6 changes: 3 additions & 3 deletions src/04kernel/src/kernels/split/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ namespace refactor::kernel {
Routine K::lower(Resources &) const noexcept {
using namespace runtime;
return [info = this->info](Resources &, void const **inputs, void **outputs) {
auto data = reinterpret_cast<uint8_t const *>(inputs[0]);
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
std::for_each_n(std::execution::par_unseq,
natural_t(0), info.blockCount,
[=, &info](auto i) {
auto offset = i * info.sum;
for (auto j : range0_(info.segments.size())) {
auto len = info.segments[j];
auto out = reinterpret_cast<uint8_t *>(outputs[j]);
std::memcpy(out + i * len, data + offset, len);
auto dst = reinterpret_cast<uint8_t *>(outputs[j]);
std::memcpy(dst + i * len, src + offset, len);
offset += len;
}
});
Expand Down

0 comments on commit 4bc0daf

Please sign in to comment.