diff --git a/src/07onnx/src/operators/gather.cc b/src/07onnx/src/operators/gather.cc index 3aa08890..6373bda2 100644 --- a/src/07onnx/src/operators/gather.cc +++ b/src/07onnx/src/operators/gather.cc @@ -1,6 +1,8 @@ #include "computation/operators/gather.h" #include "common.h" #include "gather.hh" +#include "kernel/collectors/gather.h" +#include "runtime/resource.h" #include namespace refactor::onnx { @@ -42,41 +44,34 @@ namespace refactor::onnx { if (!options.shouldCalculate(inputs, {*ans})) { return Ok(Tensors{std::move(ans)}); } + { + using Shape = kernel::Shape; + using Tensor = kernel::Tensor; + using LayoutType = kernel::LayoutType; - std::for_each_n(std::execution::unseq, natural_t(0), ans->elementsSize(), - [&data, &indices, &output, - axis_, - q = indices.shape.size(), - ssz = output.size(), - src = data.data->get(), - dst = reinterpret_cast(ans->malloc()), - eleSize = data.dataType.size()](auto const i) { - auto indices_ = locateN(output, i); - int64_t k; - { - size_t ii = 0, mul = 1; - for (auto j : range0_(q).rev()) { - ii += indices_[j] * mul; - mul *= indices.shape[j].value(); - } - k = indices.dataType == DataType::I64 - ? indices.data->get()[ii] - : indices.data->get()[ii]; - } - { - size_t ii = 0, mul = 1; - for (auto j : range(static_cast(axis_) + q, ssz).rev()) { - ii += indices_[j] * mul; - mul *= data.shape[j - q + 1].value(); - } - ii += k * mul; - for (auto j : range0_(axis_).rev()) { - ii += indices_[j] * mul; - mul *= data.shape[j].value(); - } - std::memcpy(dst + i * eleSize, src + ii * eleSize, eleSize); - } - }); + Shape t1Shape(data.shape.size(), 1); + Shape t2Shape(indices.shape.size(), 1); + Shape oShape(ans->shape.size(), 1); + std::transform(std::execution::unseq, + data.shape.begin(), data.shape.end(), t1Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + std::transform(std::execution::unseq, + indices.shape.begin(), indices.shape.end(), t2Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto t1 = Tensor::share(data.dataType, t1Shape, LayoutType::Others, data.data); + auto t2 = Tensor::share(indices.dataType, t2Shape, LayoutType::Others, indices.data); + std::transform(std::execution::unseq, + ans->shape.begin(), ans->shape.end(), oShape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto o = Tensor::share(data.dataType, oShape, LayoutType::Others); + runtime::Resources res; + const auto collector = kernel::GatherCollector(computation::Target::Cpu, axis_); + auto routine = std::move(collector.filter({*t1, *t2}, {*o}).at(0))->lower(res).routine; + void const *inputsCpu[]{*t1->data, *t2->data}; + void *outputsCpu[]{o->malloc()}; + routine(res, nullptr, inputsCpu, outputsCpu); + ans->data = o->data; + } return Ok(Tensors{std::move(ans)}); } diff --git a/src/07onnx/src/operators/simple_binary.cc b/src/07onnx/src/operators/simple_binary.cc index 2db99bdd..bccc99ad 100644 --- a/src/07onnx/src/operators/simple_binary.cc +++ b/src/07onnx/src/operators/simple_binary.cc @@ -1,6 +1,9 @@ #include "simple_binary.hh" #include "common.h" #include "computation/operators/simple_binary.h" +#include "kernel/collectors/simple_binary.h" +#include "runtime/resource.h" +#include namespace refactor::onnx { using Op = SimpleBinary; @@ -10,7 +13,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - auto fmod = attributes.getOrInsert( "fmod", {0}).int_(); + auto fmod = attributes.getOrInsert("fmod", {0}).int_(); // clang-format off auto type = opType == "onnx::Add" ? Ty::Add : @@ -93,30 +96,6 @@ namespace refactor::onnx { // clang-format on } - template - void calculate(Ty ty, void *dst, void const *a, void const *b) { - using T_ = typename primitive::type; - auto a_ = *reinterpret_cast(a); - auto b_ = *reinterpret_cast(b); - auto dst_ = reinterpret_cast(dst); - switch (ty) { - case Ty::Add: - *dst_ = a_ + b_; - break; - case Ty::Sub: - *dst_ = a_ - b_; - break; - case Ty::Mul: - *dst_ = a_ * b_; - break; - case Ty::Div: - *dst_ = a_ / b_; - break; - default: - UNREACHABLE(); - } - } - auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { EXPECT_SIZE(2) @@ -139,35 +118,36 @@ namespace refactor::onnx { return Ok(Tensors{std::move(ans)}); } - auto eleSize = dataType.size(); - auto dst = reinterpret_cast(ans->malloc()); - for (auto i : range0_(ans->elementsSize())) { - auto indices = locateN(ans->shape, i); - auto a_ = locate1(a, indices), - b_ = locate1(b, indices); - auto dst_ = dst + i * eleSize; - //------------------------------------- -#define CASE(T) \ - case DataType::T: \ - calculate(type, dst_, a_, b_); \ - break - //------------------------------------- - switch (dataType.internal) { - CASE(F32); - CASE(F64); - CASE(I32); - CASE(I64); - CASE(I8); - CASE(I16); - CASE(U8); - CASE(U16); - CASE(U32); - CASE(U64); - default: - ans->free(); - break; - } + { + using Shape = kernel::Shape; + using Tensor = kernel::Tensor; + using LayoutType = kernel::LayoutType; + + Shape t1Shape(a.shape.size(), 1); + Shape t2Shape(b.shape.size(), 1); + Shape oShape(ans->shape.size(), 1); + std::transform(std::execution::unseq, + a.shape.begin(), a.shape.end(), t1Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + std::transform(std::execution::unseq, + b.shape.begin(), b.shape.end(), t2Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto t1 = Tensor::share(a.dataType, t1Shape, LayoutType::Others, a.data); + auto t2 = Tensor::share(b.dataType, t2Shape, LayoutType::Others, b.data); + std::transform(std::execution::unseq, + ans->shape.begin(), ans->shape.end(), oShape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto o = Tensor::share(a.dataType, oShape, LayoutType::Others); + runtime::Resources res; + auto type_ = static_cast(type); + const auto collector = kernel::SimpleBinaryCollector(computation::Target::Cpu, type_); + auto routine = std::move(collector.filter({*t1, *t2}, {*o}).at(0))->lower(res).routine; + void const *inputsCpu[]{*t1->data, *t2->data}; + void *outputsCpu[]{o->malloc()}; + routine(res, nullptr, inputsCpu, outputsCpu); + ans->data = o->data; } + return Ok(Tensors{std::move(ans)}); }