Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

国产硬件适配 #63

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Prev Previous commit
Next Next commit
feat: 添加寒武纪平台erf/mod/cast/clip/gather/scatternd算子
Chamberlain0w0 authored and YdrMaster committed Jan 31, 2024
commit beda02971cd50cca84626b657db803982e9b40f6
6 changes: 6 additions & 0 deletions src/04kernel/src/collectors/cast.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "kernel/collectors/cast.h"
#include "../kernels/cast/cpu_kernel.hh"
#include "../kernels/cast/cuda_kernel.hh"
#include "../kernels/cast/cnnl_kernel.hh"

namespace refactor::kernel {

@@ -24,6 +25,11 @@ namespace refactor::kernel {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Mlu:
if (auto ptr = CastCnnl::build(from, to); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
6 changes: 6 additions & 0 deletions src/04kernel/src/collectors/clip.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "kernel/collectors/clip.h"
#include "../kernels/clip/cpu_kernel.hh"
#include "../kernels/clip/cuda_kernel.hh"
#include "../kernels/clip/cnnl_kernel.hh"

namespace refactor::kernel {

@@ -24,6 +25,11 @@ namespace refactor::kernel {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Mlu:
if (auto ptr = ClipCnnl::build(data, hasMax); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
6 changes: 6 additions & 0 deletions src/04kernel/src/collectors/gather.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/collectors/gather.h"
#include "../kernels/gather/cnnl_kernel.hh"
#include "../kernels/gather/cpu_kernel.hh"
#include "../kernels/gather/cuda_kernel.hh"

@@ -20,6 +21,11 @@ namespace refactor::kernel {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Mlu:
if (auto ptr = GatherCnnl::build(axis, inputs[0].get(), inputs[1].get(), outputs[0].get()); ptr != nullptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
6 changes: 6 additions & 0 deletions src/04kernel/src/collectors/scatter_nd.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "kernel/collectors/scatter_nd.h"
#include "../kernels/scatter_nd/cpu_kernel.hh"
#include "../kernels/scatter_nd/cuda_kernel.hh"
#include "../kernels/scatter_nd/cnnl_kernel.hh"

namespace refactor::kernel {

@@ -23,6 +24,11 @@ namespace refactor::kernel {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Mlu:
if (auto ptr = ScatterNDCnnl::build(inputs, outputs); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
226 changes: 226 additions & 0 deletions src/04kernel/src/kernels/cast/cnnl_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
#include "cnnl_kernel.hh"

#ifdef USE_BANG
#include "../../utilities/bang/cnnl_context.hh"
#include "../../utilities/bang/cnnl_functions.h"
#endif


namespace refactor::kernel {
using K = CastCnnl;
using DT = DataType;

K::CastCnnl(decltype(from) from_,
decltype(to) to_,
decltype(shape) shape_) noexcept
: from(from_), to(to_), shape(shape_) {}

auto K::build(Tensor const &from, Tensor const &to) noexcept -> KernelBox {
#ifndef USE_BANG
return nullptr;
#endif

return std::make_unique<K>(from.dataType, to.dataType,
std::vector<int>(from.shape.begin(), from.shape.end()));
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing cast operation using CNNL";
}

#ifdef USE_BANG

static cnnlCastDataType_t castType(DataType from, DataType to);

auto K::lower(Resources &res) const -> RoutineWorkspace {
using namespace cnnl;
using namespace runtime;

struct Descriptors {
cnnlTensorDescriptor_t inDesc, outDesc;
cnnlCastDataType_t cast;

Descriptors() : inDesc(nullptr), outDesc(nullptr) {
CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc));
CNNL_ASSERT(cnnlCreateTensorDescriptor(&outDesc));
}
~Descriptors() noexcept(false) {
CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc));
CNNL_ASSERT(cnnlDestroyTensorDescriptor(outDesc));
}
};
auto d = std::make_shared<Descriptors>();
d->cast = castType(from, to);
setCnnlTensor(d->inDesc, from, slice(shape.data(), shape.size()));
setCnnlTensor(d->outDesc, to, slice(shape.data(), shape.size()));

res.fetchOrStore<CnnlContext>();
return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
CNNL_ASSERT(cnnlCastDataType(res.fetchOrStore<CnnlContext>()->handle,
d->inDesc, inputs[0], d->cast, d->outDesc, outputs[0]));
// BANG_ASSERT(cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue));
};
}

static cnnlCastDataType_t castType(DataType from, DataType to) {
switch (from) {
case DT::F32:
switch (to) {
case DT::F64:
return CNNL_CAST_FLOAT_TO_DOUBLE;
case DT::FP16:
return CNNL_CAST_FLOAT_TO_HALF;
case DT::I64:
return CNNL_CAST_FLOAT_TO_INT64;
case DT::I32:
return CNNL_CAST_FLOAT_TO_INT32;
case DT::I16:
return CNNL_CAST_FLOAT_TO_INT16;
case DT::I8:
return CNNL_CAST_FLOAT_TO_INT8;
case DT::U8:
return CNNL_CAST_FLOAT_TO_UINT8;
// case DT::BF16:
// return CNNL_CAST_FLOAT_TO_BFLOAT16;
case DT::Bool:
return CNNL_CAST_FLOAT_TO_BOOL;
default:
UNREACHABLE();
}
case DT::FP16:
switch (to) {
case DT::F32:
return CNNL_CAST_HALF_TO_FLOAT;
case DT::I64:
return CNNL_CAST_HALF_TO_INT64;
case DT::I32:
return CNNL_CAST_HALF_TO_INT32;
case DT::I16:
return CNNL_CAST_HALF_TO_INT16;
case DT::I8:
return CNNL_CAST_HALF_TO_INT8;
case DT::U8:
return CNNL_CAST_HALF_TO_UINT8;
case DT::Bool:
return CNNL_CAST_HALF_TO_BOOL;
default:
UNREACHABLE();
}
case DT::I32:
switch (to) {
case DT::F32:
return CNNL_CAST_INT32_TO_FLOAT;
case DT::FP16:
return CNNL_CAST_INT32_TO_HALF;
case DT::I64:
return CNNL_CAST_INT32_TO_INT64;
case DT::I16:
return CNNL_CAST_INT32_TO_INT16;
case DT::I8:
return CNNL_CAST_INT32_TO_INT8;
case DT::Bool:
return CNNL_CAST_INT32_TO_BOOL;
default:
UNREACHABLE();
}
case DT::I16:
switch (to) {
case DT::F32:
return CNNL_CAST_INT16_TO_FLOAT;
case DT::FP16:
return CNNL_CAST_INT16_TO_HALF;
case DT::I32:
return CNNL_CAST_INT16_TO_INT32;
// case DT::I8:
// return CNNL_CAST_INT16_TO_INT8;
default:
UNREACHABLE();
}
case DT::I8:
switch (to) {
case DT::F32:
return CNNL_CAST_INT8_TO_FLOAT;
case DT::FP16:
return CNNL_CAST_INT8_TO_HALF;
case DT::I32:
return CNNL_CAST_INT8_TO_INT32;
case DT::I16:
return CNNL_CAST_INT8_TO_INT16;
default:
UNREACHABLE();
}
case DT::U8:
switch (to) {
case DT::F32:
return CNNL_CAST_UINT8_TO_FLOAT;
case DT::FP16:
return CNNL_CAST_UINT8_TO_HALF;
case DT::I64:
return CNNL_CAST_UINT8_TO_INT64;
case DT::I32:
return CNNL_CAST_UINT8_TO_INT32;
default:
UNREACHABLE();
}
case DT::Bool:
switch (to) {
case DT::F32:
return CNNL_CAST_BOOL_TO_FLOAT;
case DT::FP16:
return CNNL_CAST_BOOL_TO_HALF;
case DT::I32:
return CNNL_CAST_BOOL_TO_INT32;
default:
UNREACHABLE();
}
case DT::I64:
switch (to) {
case DT::F32:
return CNNL_CAST_INT64_TO_FLOAT;
case DT::FP16:
return CNNL_CAST_INT64_TO_HALF;
case DT::I32:
return CNNL_CAST_INT64_TO_INT32;
case DT::U32:
return CNNL_CAST_INT64_TO_UINT32;
default:
UNREACHABLE();
}
case DT::U32:
switch (to) {
case DT::I64:
return CNNL_CAST_UINT32_TO_INT64;
case DT::U64:
return CNNL_CAST_UINT32_TO_UINT64;
default:
UNREACHABLE();
}
case DT::F64:
switch (to) {
case DT::F32:
return CNNL_CAST_DOUBLE_TO_FLOAT;
default:
UNREACHABLE();
}
case DT::BF16:
switch (to) {
// case DT::F32:
// return CNNL_CAST_BF16_TO_FLOAT;
default:
UNREACHABLE();
}
default:
UNREACHABLE();
}
}

#endif

}// namespace refactor::kernel
27 changes: 27 additions & 0 deletions src/04kernel/src/kernels/cast/cnnl_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef KERNEL_CAST_CNNL_KERNEL_HH
#define KERNEL_CAST_CNNL_KERNEL_HH

#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct CastCnnl final : public Kernel {
DataType from, to;
std::vector<int> shape;

CastCnnl(decltype(from), decltype(to), decltype(shape)) noexcept;

static KernelBox build(Tensor const &, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_BANG
RoutineWorkspace lower(Resources &) const final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_CAST_CNNL_KERNEL_HH
66 changes: 66 additions & 0 deletions src/04kernel/src/kernels/clip/cnnl_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "cnnl_kernel.hh"

#ifdef USE_BANG
#include "../../utilities/bang/cnnl_context.hh"
#include "../../utilities/bang/cnnl_functions.h"
#endif

namespace refactor::kernel {
using K = ClipCnnl;

K::ClipCnnl(decltype(dataType) dt,
decltype(shape) shape_,
decltype(hasMax) hasMax_) noexcept
: dataType(dt), shape(shape_), hasMax(hasMax_) {
}

auto K::build(Tensor const &data, bool hasMax) noexcept -> KernelBox {
return data.dataType.isCpuNumberic()
? std::make_unique<K>(data.dataType,
std::vector<int>(data.shape.begin(), data.shape.end()),
hasMax)
: nullptr;
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing clip operation using CNNL";
}

#ifdef USE_BANG
auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
using namespace cnnl;
using namespace runtime;

struct Descriptors {
cnnlTensorDescriptor_t t;

Descriptors() : t(nullptr) {
CNNL_ASSERT(cnnlCreateTensorDescriptor(&t));
}
~Descriptors() noexcept(false) {
CNNL_ASSERT(cnnlDestroyTensorDescriptor(t));
}
};
auto d = std::make_shared<Descriptors>();
setCnnlTensor(d->t, dataType, slice(shape.data(), shape.size()));

res.fetchOrStore<CnnlContext>();
return [d = std::move(d), hasMax = this->hasMax](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
CNNL_ASSERT(cnnlClip_v2(res.fetchOrStore<CnnlContext>()->handle,
CNNL_POINTER_MODE_DEVICE, d->t,
inputs[0], inputs[1], hasMax ? inputs[2] : nullptr,
d->t, outputs[0]));
BANG_ASSERT(cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue));
};
}

#endif

}// namespace refactor::kernel
Loading