Skip to content

Commit

Permalink
fix(fronted): 为适配torchvision模型,修改部分bug
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Jan 30, 2024
1 parent 6630866 commit 4039358
Showing 1 changed file with 29 additions and 34 deletions.
63 changes: 29 additions & 34 deletions src/07onnx/src/operators/gather.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "computation/operators/gather.h"
#include "common.h"
#include "gather.hh"
#include "kernel/collectors/gather.h"
#include "runtime/resource.h"
#include <execution>

namespace refactor::onnx {
Expand Down Expand Up @@ -42,41 +44,34 @@ namespace refactor::onnx {
if (!options.shouldCalculate(inputs, {*ans})) {
return Ok(Tensors{std::move(ans)});
}
{
using Shape = kernel::Shape;
using Tensor = kernel::Tensor;
using LayoutType = kernel::LayoutType;

std::for_each_n(std::execution::unseq, natural_t(0), ans->elementsSize(),
[&data, &indices, &output,
axis_,
q = indices.shape.size(),
ssz = output.size(),
src = data.data->get<uint8_t>(),
dst = reinterpret_cast<uint8_t *>(ans->malloc()),
eleSize = data.dataType.size()](auto const i) {
auto indices_ = locateN(output, i);
int64_t k;
{
size_t ii = 0, mul = 1;
for (auto j : range0_(q).rev()) {
ii += indices_[j] * mul;
mul *= indices.shape[j].value();
}
k = indices.dataType == DataType::I64
? indices.data->get<int64_t>()[ii]
: indices.data->get<int32_t>()[ii];
}
{
size_t ii = 0, mul = 1;
for (auto j : range(static_cast<decltype(q)>(axis_) + q, ssz).rev()) {
ii += indices_[j] * mul;
mul *= data.shape[j - q + 1].value();
}
ii += k * mul;
for (auto j : range0_(axis_).rev()) {
ii += indices_[j] * mul;
mul *= data.shape[j].value();
}
std::memcpy(dst + i * eleSize, src + ii * eleSize, eleSize);
}
});
Shape t1Shape(data.shape.size(), 1);
Shape t2Shape(indices.shape.size(), 1);
Shape oShape(ans->shape.size(), 1);
std::transform(std::execution::unseq,
data.shape.begin(), data.shape.end(), t1Shape.begin(),
[](auto const &i) { return static_cast<dim_t>(i.value()); });
std::transform(std::execution::unseq,
indices.shape.begin(), indices.shape.end(), t2Shape.begin(),
[](auto const &i) { return static_cast<dim_t>(i.value()); });
auto t1 = Tensor::share(data.dataType, t1Shape, LayoutType::Others, data.data);
auto t2 = Tensor::share(indices.dataType, t2Shape, LayoutType::Others, indices.data);
std::transform(std::execution::unseq,
ans->shape.begin(), ans->shape.end(), oShape.begin(),
[](auto const &i) { return static_cast<dim_t>(i.value()); });
auto o = Tensor::share(data.dataType, oShape, LayoutType::Others);
runtime::Resources res;
static const auto collector = kernel::GatherCollector(computation::Target::Cpu, axis_);
auto routine = std::move(collector.filter({*t1, *t2}, {*o}).at(0))->lower(res).routine;
void const *inputsCpu[]{*t1->data, *t2->data};
void *outputsCpu[]{o->malloc()};
routine(res, nullptr, inputsCpu, outputsCpu);
ans->data = o->data;
}

return Ok(Tensors{std::move(ans)});
}
Expand Down

0 comments on commit 4039358

Please sign in to comment.