From 0bc8305d49f9848655631c4326fd7e0f5f3be858 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 8 Jan 2024 18:35:43 +0800 Subject: [PATCH 1/8] =?UTF-8?q?refactor:=20=E5=B0=9D=E8=AF=95=E7=9B=B4?= =?UTF-8?q?=E6=8E=A5=E8=B0=83=E7=94=A8=20cuda=20runtime=20api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/02hardware/CMakeLists.txt | 8 +++--- src/02hardware/include/hardware/device.h | 2 +- .../include/hardware/devices/nvidia.h | 2 +- src/02hardware/src/device.cc | 2 +- src/02hardware/src/devices/nvidia/device.cc | 25 +++++++++++++------ .../src/devices/nvidia/functions.cu | 19 -------------- .../src/devices/nvidia/functions.cuh | 24 ------------------ .../devices/nvidia/{memory.cu => memory.cc} | 15 +++++++++-- .../devices/nvidia/{memory.cuh => memory.hh} | 0 9 files changed, 38 insertions(+), 59 deletions(-) delete mode 100644 src/02hardware/src/devices/nvidia/functions.cu delete mode 100644 src/02hardware/src/devices/nvidia/functions.cuh rename src/02hardware/src/devices/nvidia/{memory.cu => memory.cc} (69%) rename src/02hardware/src/devices/nvidia/{memory.cuh => memory.hh} (100%) diff --git a/src/02hardware/CMakeLists.txt b/src/02hardware/CMakeLists.txt index ece758395..a6bd999f0 100644 --- a/src/02hardware/CMakeLists.txt +++ b/src/02hardware/CMakeLists.txt @@ -2,15 +2,15 @@ cmake_minimum_required(VERSION 3.12 FATAL_ERROR) project(hardware VERSION 0.0.0 LANGUAGES CXX) message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) -if(USE_CUDA) - file(GLOB_RECURSE HARDWARE_CUDA_SRC src/*.cu) -endif() - file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp) add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC}) target_link_libraries(hardware PUBLIC common) target_include_directories(hardware PUBLIC include) +if(USE_CUDA) + target_include_directories(hardware PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +endif() + file(GLOB_RECURSE HARDWARE_TEST test/*.cpp) if(HARDWARE_TEST) add_executable(hardware_test ${HARDWARE_TEST}) diff --git a/src/02hardware/include/hardware/device.h b/src/02hardware/include/hardware/device.h index 5809fcf94..335c759bf 100644 --- a/src/02hardware/include/hardware/device.h +++ b/src/02hardware/include/hardware/device.h @@ -50,7 +50,7 @@ namespace refactor::hardware { virtual ~Device() = default; virtual Type type() const noexcept = 0; - virtual void setContext() const noexcept; + virtual void setContext() const; Arc malloc(size_t); Arc absorb(Arc &&); diff --git a/src/02hardware/include/hardware/devices/nvidia.h b/src/02hardware/include/hardware/devices/nvidia.h index 1facba2c3..d19dd3152 100644 --- a/src/02hardware/include/hardware/devices/nvidia.h +++ b/src/02hardware/include/hardware/devices/nvidia.h @@ -8,7 +8,7 @@ namespace refactor::hardware { class Nvidia final : public Device { public: explicit Nvidia(int32_t card); - void setContext() const noexcept final; + void setContext() const final; Type type() const noexcept final { return Type::Nvidia; } diff --git a/src/02hardware/src/device.cc b/src/02hardware/src/device.cc index 08c094994..29ac122e0 100644 --- a/src/02hardware/src/device.cc +++ b/src/02hardware/src/device.cc @@ -56,7 +56,7 @@ namespace refactor::hardware { Device::Device(decltype(_card) card, decltype(_mem) mem) : _card(card), _mem(std::move(mem)) {} - void Device::setContext() const noexcept {} + void Device::setContext() const {} auto Device::malloc(size_t size) -> Arc { return Arc(new Blob(this, size)); } diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc index 1ae5b2244..67b255807 100644 --- a/src/02hardware/src/devices/nvidia/device.cc +++ b/src/02hardware/src/devices/nvidia/device.cc @@ -1,17 +1,28 @@ #include "hardware/devices/nvidia.h" #include "hardware/mem_pool.h" + #ifdef USE_CUDA -#include "functions.cuh" -#include "memory.cuh" +#include "memory.hh" +#include + +#define CUDA_ASSERT(STATUS) \ + if (auto status = (STATUS); status != cudaSuccess) { \ + RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ + cudaGetErrorString(status), (int) status)); \ + } #endif namespace refactor::hardware { static Arc cudaMemory(int32_t card) { #ifdef USE_CUDA - ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card); - setDevice(card); - auto [free, total] = getMemInfo(); + int deviceCount; + CUDA_ASSERT(cudaGetDeviceCount(&deviceCount)); + ASSERT(0 <= card && card < deviceCount, "Invalid card id: {}", card); + CUDA_ASSERT(cudaSetDevice(card)); + + size_t free, total; + CUDA_ASSERT(cudaMemGetInfo(&free, &total)); auto size = std::min(free, std::max(5ul << 30, total * 4 / 5)); fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}", card, free, total, size); @@ -26,9 +37,9 @@ namespace refactor::hardware { Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {} - void Nvidia::setContext() const noexcept { + void Nvidia::setContext() const { #ifdef USE_CUDA - setDevice(_card); + CUDA_ASSERT(cudaSetDevice(_card)); #endif } diff --git a/src/02hardware/src/devices/nvidia/functions.cu b/src/02hardware/src/devices/nvidia/functions.cu deleted file mode 100644 index 844ef388c..000000000 --- a/src/02hardware/src/devices/nvidia/functions.cu +++ /dev/null @@ -1,19 +0,0 @@ -#include "functions.cuh" - -namespace refactor::hardware { - - int getDeviceCount() { - int deviceCount; - CUDA_ASSERT(cudaGetDeviceCount(&deviceCount)); - return deviceCount; - } - void setDevice(int device) { - CUDA_ASSERT(cudaSetDevice(device)); - } - MemInfo getMemInfo() { - MemInfo memInfo; - CUDA_ASSERT(cudaMemGetInfo(&memInfo.free, &memInfo.total)); - return memInfo; - } - -}// namespace refactor::hardware diff --git a/src/02hardware/src/devices/nvidia/functions.cuh b/src/02hardware/src/devices/nvidia/functions.cuh deleted file mode 100644 index 0a47d4492..000000000 --- a/src/02hardware/src/devices/nvidia/functions.cuh +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef HARDWARE_DEVICES_NVIDIA_FUNCTIONS_CUH -#define HARDWARE_DEVICES_NVIDIA_FUNCTIONS_CUH - -#include "common.h" - -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } - -namespace refactor::hardware { - - struct MemInfo { - size_t free, total; - }; - - int getDeviceCount(); - void setDevice(int device); - MemInfo getMemInfo(); - -}// namespace refactor::hardware - -#endif// HARDWARE_DEVICES_NVIDIA_FUNCTIONS_CUH diff --git a/src/02hardware/src/devices/nvidia/memory.cu b/src/02hardware/src/devices/nvidia/memory.cc similarity index 69% rename from src/02hardware/src/devices/nvidia/memory.cu rename to src/02hardware/src/devices/nvidia/memory.cc index b3c5fe3d3..42310196c 100644 --- a/src/02hardware/src/devices/nvidia/memory.cu +++ b/src/02hardware/src/devices/nvidia/memory.cc @@ -1,5 +1,14 @@ -#include "functions.cuh" -#include "memory.cuh" +#ifdef USE_CUDA + +#include "memory.hh" +#include "common.h" +#include + +#define CUDA_ASSERT(STATUS) \ + if (auto status = (STATUS); status != cudaSuccess) { \ + RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ + cudaGetErrorString(status), (int) status)); \ + } namespace refactor::hardware { using M = NvidiaMemory; @@ -29,3 +38,5 @@ namespace refactor::hardware { } }// namespace refactor::hardware + +#endif diff --git a/src/02hardware/src/devices/nvidia/memory.cuh b/src/02hardware/src/devices/nvidia/memory.hh similarity index 100% rename from src/02hardware/src/devices/nvidia/memory.cuh rename to src/02hardware/src/devices/nvidia/memory.hh From 322835d8220146fc5a542cdd0893d1293d044bc0 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 9 Jan 2024 10:45:10 +0800 Subject: [PATCH 2/8] =?UTF-8?q?feat(device.cc)=EF=BC=9A=E4=BB=8E=20cudaDev?= =?UTF-8?q?iceProp::textureAlignment=20=E8=8E=B7=E5=8F=96=E8=AE=BE?= =?UTF-8?q?=E5=A4=87=E5=AF=B9=E9=BD=90=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/02hardware/src/devices/nvidia/device.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc index 67b255807..4c5e3a479 100644 --- a/src/02hardware/src/devices/nvidia/device.cc +++ b/src/02hardware/src/devices/nvidia/device.cc @@ -24,12 +24,15 @@ namespace refactor::hardware { size_t free, total; CUDA_ASSERT(cudaMemGetInfo(&free, &total)); auto size = std::min(free, std::max(5ul << 30, total * 4 / 5)); - fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}", - card, free, total, size); + cudaDeviceProp prop; + CUDA_ASSERT(cudaGetDeviceProperties(&prop, 0)); + size_t alignment = prop.textureAlignment; + fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}, alignment {}", + card, free, total, size, alignment); return std::make_shared( std::make_shared(), size, - 256ul); + alignment); #else RUNTIME_ERROR("CUDA is not enabled"); #endif @@ -37,7 +40,7 @@ namespace refactor::hardware { Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {} - void Nvidia::setContext() const { + void Nvidia::setContext() const { #ifdef USE_CUDA CUDA_ASSERT(cudaSetDevice(_card)); #endif From 5046a53bf16d4d8f787d9f75046be707eb237aef Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 9 Jan 2024 12:37:15 +0800 Subject: [PATCH 3/8] =?UTF-8?q?fix(frontend):=20=E7=AE=80=E5=8C=96=20lower?= =?UTF-8?q?=20to=20computation=20=E9=80=BB=E8=BE=91=EF=BC=8C=E4=B8=8D?= =?UTF-8?q?=E8=A6=81=E8=B7=B3=E8=BF=87=E4=BB=BB=E4=BD=95=E8=BE=B9=E7=9A=84?= =?UTF-8?q?=E6=98=A0=E5=B0=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/06frontend/src/graph.cc | 50 ++++++++++++++----------------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/src/06frontend/src/graph.cc b/src/06frontend/src/graph.cc index 713624ac5..083822e87 100644 --- a/src/06frontend/src/graph.cc +++ b/src/06frontend/src/graph.cc @@ -175,42 +175,30 @@ namespace refactor::frontend { std::vector nodes(_internal.nodes.size()); std::vector edges(_internal.edges.size()); + + auto fn = [&edges, this](auto i) { + if (edges[i].tensor) { + return; + } + auto const &[tensor, name] = _internal.edges[i]; + computation::Shape shape(tensor->shape.size()); + std::transform(std::execution::unseq, + tensor->shape.begin(), tensor->shape.end(), shape.begin(), + [](auto const &dim) { return dim.value(); }); + auto layout = shape.size() == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others; + edges[i].tensor = computation::Tensor::share(tensor->dataType, std::move(shape), layout, tensor->data); + edges[i].name = name; + }; + std::transform(_internal.topology.begin(), _internal.topology.end(), nodes.begin(), - [&edges, this](auto const &nodeRef) { + [&fn, this](auto const &nodeRef) { auto const &[op, name] = _internal.nodes[nodeRef.idx]; + std::for_each(nodeRef.inputs.begin(), nodeRef.inputs.end(), fn); + std::for_each(nodeRef.outputs.begin(), nodeRef.outputs.end(), fn); auto constant = std::all_of(std::execution::unseq, nodeRef.outputs.begin(), nodeRef.outputs.end(), [this](auto i) { return _internal.edges[i].tensor->data; }); - if (constant) { - return computation::Node{nullptr, name}; - } - auto fn = [&edges, &nodeRef, this](auto i) { - if (edges[i].tensor) { - return; - } - auto const &[tensor, name] = _internal.edges[i]; - computation::Shape shape(tensor->shape.size()); - std::transform(std::execution::unseq, - tensor->shape.begin(), tensor->shape.end(), shape.begin(), - [](auto const &dim) { return dim.value(); }); - auto layout = shape.size() == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others; - edges[i].tensor = computation::Tensor::share(tensor->dataType, std::move(shape), layout, tensor->data); - edges[i].name = name; - }; - auto op_ = op->lower(TensorRefs(_internal.edges, nodeRef.inputs)); - auto valueDependentInputs = op->valueDependentInputs(); - auto it = valueDependentInputs.begin(); - for (auto i : range0_(nodeRef.inputs.size())) { - auto input = nodeRef.inputs[i]; - if (it != valueDependentInputs.end() && i == *it) { - edges[input].name = _internal.edges[input].name; - ++it; - continue; - } - fn(input); - } - std::for_each(std::execution::unseq, nodeRef.outputs.begin(), nodeRef.outputs.end(), fn); - return computation::Node{std::move(op_), name}; + return computation::Node{constant ? nullptr : op->lower(TensorRefs(_internal.edges, nodeRef.inputs)), name}; }); auto const endTime = high_resolution_clock::now(); From cc8a86dc3f11fc248fdc7fc9addbb337465a1c4e Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Tue, 9 Jan 2024 16:09:44 +0800 Subject: [PATCH 4/8] feat: add Erf cpu/cuda kernel --- .../src/kernels/simple_unary/cpu_kernel.cc | 16 ++++++++++++ .../src/kernels/simple_unary/cuda_kernel.cc | 16 +++++++++++- .../test/kernels/simple_unary/test_cpu.cpp | 1 + .../test/kernels/simple_unary/test_cuda.cpp | 1 + src/07onnx/test/test_simple_unary.cpp | 25 +++++++++++++++++++ 5 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 src/07onnx/test/test_simple_unary.cpp diff --git a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc index d34528569..5e83938df 100644 --- a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc @@ -18,6 +18,7 @@ namespace refactor::kernel { Op::Sigmoid, Op::Tanh, Op::Neg, + Op::Erf, }; return supportedOp.contains(op) && a.dataType.isCpuNumberic() ? std::make_unique(op, a.dataType, a.elementsSize()) @@ -155,6 +156,21 @@ namespace refactor::kernel { default: UNREACHABLE(); } + case Op::Erf: + switch (dataType) { + CASE(std::erf, F32); + CASE(std::erf, F64); + CASE(std::erf, I8); + CASE(std::erf, I16); + CASE(std::erf, I32); + CASE(std::erf, I64); + CASE(std::erf, U8); + CASE(std::erf, U16); + CASE(std::erf, U32); + CASE(std::erf, U64); + default: + UNREACHABLE(); + } default: UNREACHABLE(); } diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc index e3c260dbc..7403b0dde 100644 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc @@ -18,7 +18,8 @@ namespace refactor::kernel { auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { static const std::unordered_set supportedOp{Op::Abs, Op::Relu, Op::Sqrt, - Op::Sigmoid, Op::Tanh, Op::Neg}; + Op::Sigmoid, Op::Tanh, Op::Neg, + Op::Erf}; #ifndef USE_CUDA return nullptr; #endif @@ -140,6 +141,19 @@ extern "C" __global__ void kernel( {__(Op::Neg, DT::BF16), "-x"}, {__(Op::Neg, DT::F32 ), "-x"}, {__(Op::Neg, DT::F64 ), "-x"}, + + {__(Op::Erf, DT::F32 ), "erff(x)"}, + {__(Op::Erf, DT::F64 ), "erf(x)"}, + {__(Op::Erf, DT::U8 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::I8 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::U16 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::I16 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::U32 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::I32 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::U64 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::I64 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::FP16), "__float2half(erff(__half2float(x)))"}, + {__(Op::Erf, DT::BF16), "__float2bfloat16(erff(__bfloat162float(x)))"}, }; // clang-format on diff --git a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp index da1cb6f83..e24d2091f 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp @@ -31,4 +31,5 @@ TEST(kernel, SimpleUnaryCpu) { testOp(SimpleUnaryType::Abs, std::abs); testOp(SimpleUnaryType::Sqrt, std::sqrt); testOp(SimpleUnaryType::Tanh, std::tanh); + testOp(SimpleUnaryType::Erf, std::erf); } diff --git a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp index 6ff5d798b..ce8d66f8c 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp @@ -51,6 +51,7 @@ TEST(kernel, SimpleUnaryCuda) { testOp(SimpleUnaryType::Sqrt); testOp(SimpleUnaryType::Sigmoid); testOp(SimpleUnaryType::Tanh); + testOp(SimpleUnaryType::Erf); } #endif diff --git a/src/07onnx/test/test_simple_unary.cpp b/src/07onnx/test/test_simple_unary.cpp new file mode 100644 index 000000000..12529a1a5 --- /dev/null +++ b/src/07onnx/test/test_simple_unary.cpp @@ -0,0 +1,25 @@ +#include "../src/operators/simple_unary.hh" +#include "onnx/operators.h" +#include + +using namespace refactor; +using namespace onnx; + +TEST(infer, SimpleUnary) { + onnx::register_(); + + { + // Erf Test + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, + }; + count_t inputs[]{0}; + auto infered = SimpleUnary(SimpleUnaryType::Erf).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); + } +} From 0f9fde58b255c6c9a2685191a9107740899cda22 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 10 Jan 2024 11:00:15 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat(python=5Fffi):=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=BB=8E=20executor=20=E5=AD=98=E5=8F=96=20Device::Blob?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/03runtime/include/runtime/stream.h | 6 ++++-- src/03runtime/src/stream.cc | 12 ++++++++--- src/09python_ffi/src/executor.cc | 28 ++++++++++++++++++-------- src/09python_ffi/src/executor.h | 2 ++ src/09python_ffi/src/main.cpp | 3 +++ 5 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/03runtime/include/runtime/stream.h b/src/03runtime/include/runtime/stream.h index 24d9fe73e..8a4b4f51d 100644 --- a/src/03runtime/include/runtime/stream.h +++ b/src/03runtime/include/runtime/stream.h @@ -42,9 +42,11 @@ namespace refactor::runtime { decltype(_device)); decltype(_graph) const &graph() const noexcept { return _graph; } - void setData(count_t, void const *, size_t); + auto setData(count_t, size_t) -> Arc; void setData(count_t, Arc); - bool getData(count_t, void *, size_t) const; + auto getData(count_t) -> Arc const; + void setData(count_t, void const *, size_t); + bool copyData(count_t, void *, size_t) const; void run(); auto bench(void (*sync)()) -> std::vector; void trace(std::function); diff --git a/src/03runtime/src/stream.cc b/src/03runtime/src/stream.cc index 569769c1d..a02234a5f 100644 --- a/src/03runtime/src/stream.cc +++ b/src/03runtime/src/stream.cc @@ -18,15 +18,21 @@ namespace refactor::runtime { std::move(edges), } {} + auto Stream::setData(count_t i, size_t size) -> Arc { + return _graph.edges[i].blob = _device->malloc(size); + } + void Stream::setData(count_t i, Arc blob) { + _graph.edges[i].blob = std::move(blob); + } void Stream::setData(count_t i, void const *data, size_t size) { auto blob = _device->malloc(size); blob->copyFromHost(data, size); _graph.edges[i].blob = std::move(blob); } - void Stream::setData(count_t i, Arc blob) { - _graph.edges[i].blob = std::move(blob); + auto Stream::getData(count_t i) -> Arc const { + return _graph.edges[i].blob; } - bool Stream::getData(count_t i, void *data, size_t size) const { + bool Stream::copyData(count_t i, void *data, size_t size) const { if (!_graph.edges[i].blob) { return false; } _graph.edges[i].blob->copyToHost(data, size); return true; diff --git a/src/09python_ffi/src/executor.cc b/src/09python_ffi/src/executor.cc index 92b83087b..9b8a6d6da 100644 --- a/src/09python_ffi/src/executor.cc +++ b/src/09python_ffi/src/executor.cc @@ -26,7 +26,7 @@ namespace refactor::python_ffi { for (auto i : graph.topology.globalInputs()) { auto size = graph.edges[i].tensor->bytesSize(); buffer.resize(size); - if (stream.getData(i, buffer.data(), size)) { + if (stream.copyData(i, buffer.data(), size)) { _stream.setData(i, buffer.data(), size); } } @@ -35,9 +35,7 @@ namespace refactor::python_ffi { void Executor::setInput(count_t i, pybind11::array data) { i = _stream.graph().topology.globalInputs().at(i); - auto const &name = _stream.graph().edges[i].name; - auto const &edges = _graph.internal().contiguous().edges; - auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor; + auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; ASSERT(tensor.bytesSize() == static_cast(data.nbytes()), "input size mismatch"); _stream.setData(i, data.data(), data.nbytes()); } @@ -45,14 +43,28 @@ namespace refactor::python_ffi { auto Executor::getOutput(count_t i) -> pybind11::array { i = _stream.graph().topology.globalOutputs().at(i); - auto const &name = _stream.graph().edges[i].name; - auto const &edges = _graph.internal().contiguous().edges; - auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor; + auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; auto ans = pybind11::array(buildNumpyDType(tensor.dataType), std::move(tensor.shape)); - _stream.getData(i, ans.mutable_data(), ans.nbytes()); + _stream.copyData(i, ans.mutable_data(), ans.nbytes()); return ans; } + auto Executor::pin(count_t i) -> Arc { + i = _stream.graph().topology.globalInputs().at(i); + + if (auto pinned = _stream.getData(i); pinned) { + return pinned; + } else { + auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; + return _stream.setData(i, tensor.bytesSize()); + } + } + void Executor::setPinned(count_t i, Arc pinned) { + i = _stream.graph().topology.globalInputs().at(i); + + _stream.setData(i, std::move(pinned)); + } + void Executor::run() { _stream.run(); } diff --git a/src/09python_ffi/src/executor.h b/src/09python_ffi/src/executor.h index 5174cc744..ff70bafcc 100644 --- a/src/09python_ffi/src/executor.h +++ b/src/09python_ffi/src/executor.h @@ -16,6 +16,8 @@ namespace refactor::python_ffi { void dispatch(Arc, std::string allocator); void setInput(count_t, pybind11::array); auto getOutput(count_t) -> pybind11::array; + auto pin(count_t) -> Arc; + void setPinned(count_t, Arc); void run(); void bench(bool sync); void trace(std::string path, std::string format); diff --git a/src/09python_ffi/src/main.cpp b/src/09python_ffi/src/main.cpp index 54b92ee38..3093f7275 100644 --- a/src/09python_ffi/src/main.cpp +++ b/src/09python_ffi/src/main.cpp @@ -21,6 +21,7 @@ namespace refactor::python_ffi { py::class_ >(m, "Tensor" ); py::class_ >(m, "Operator" ); py::class_ >(m, "Device" ); + py::class_>(m, "Pinned" ); m .def("config_log" , &configLog , return_::automatic ) .def("find_device" , &findDevice , return_::move ) @@ -44,6 +45,8 @@ namespace refactor::python_ffi { .def("dispatch" , &Executor::dispatch , return_::automatic ) .def("set_input" , &Executor::setInput , return_::automatic ) .def("get_output" , &Executor::getOutput , return_::move ) + .def("pin" , &Executor::pin , return_::move ) + .def("set_pinned" , &Executor::setPinned , return_::automatic ) .def("run" , &Executor::run , return_::automatic ) .def("bench" , &Executor::bench , return_::automatic ) .def("trace" , &Executor::trace , return_::automatic ) From dc21249bf221916c4d7b26b400fa9260eb537e31 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Thu, 11 Jan 2024 15:42:56 +0800 Subject: [PATCH 6/8] feat: add mod cpu/cuda kernel and test --- .../include/kernel/collectors/simple_binary.h | 2 + src/04kernel/src/collectors/simple_binary.cc | 2 + .../src/kernels/simple_binary/cpu_kernel.cc | 35 +++++++++- .../src/kernels/simple_binary/cuda_kernel.cc | 36 +++++++++- .../kernels/simple_binary/test_binary_cpu.cpp | 50 ++++++++++++++ .../simple_binary/test_binary_cuda.cpp | 68 ++++++++++++------- .../src/operators/simple_binary.cc | 6 ++ src/07onnx/src/operators/simple_binary.cc | 21 +++++- src/07onnx/src/operators/simple_binary.hh | 2 + 9 files changed, 195 insertions(+), 27 deletions(-) diff --git a/src/04kernel/include/kernel/collectors/simple_binary.h b/src/04kernel/include/kernel/collectors/simple_binary.h index 87d1b5f6e..7423ee703 100644 --- a/src/04kernel/include/kernel/collectors/simple_binary.h +++ b/src/04kernel/include/kernel/collectors/simple_binary.h @@ -14,6 +14,8 @@ namespace refactor::kernel { And, Or, Xor, + Mod, + Fmod, }; std::string_view opName(SimpleBinaryType type); diff --git a/src/04kernel/src/collectors/simple_binary.cc b/src/04kernel/src/collectors/simple_binary.cc index e2c001ff7..53ae6723c 100644 --- a/src/04kernel/src/collectors/simple_binary.cc +++ b/src/04kernel/src/collectors/simple_binary.cc @@ -19,6 +19,8 @@ namespace refactor::kernel { CASE(And); CASE(Or); CASE(Xor); + CASE(Mod); + CASE(Fmod); default: UNREACHABLE(); } diff --git a/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc b/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc index 6a737f3ae..ed8941d87 100644 --- a/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc +++ b/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc @@ -1,4 +1,5 @@ #include "cpu_kernel.hh" +#include #include namespace refactor::kernel { @@ -118,8 +119,38 @@ namespace refactor::kernel { UNREACHABLE(); } } - default: - UNREACHABLE(); + case Op::Mod: { + switch (dataType.internal) { + CASE_DT(a % b, U8); + CASE_DT(a % b, I8); + CASE_DT(a % b, U16); + CASE_DT(a % b, I16); + CASE_DT(a % b, I32); + CASE_DT(a % b, I64); + CASE_DT(a % b, U32); + CASE_DT(a % b, U64); + default: + UNREACHABLE(); + } + } + case Op::Fmod: { + switch (dataType.internal) { + CASE_DT(std::fmod(a, b), F32); + CASE_DT(a % b, U8); + CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I8); + CASE_DT(a % b, U16); + CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I16); + CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I32); + CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I64); + CASE_DT(std::fmod(a, b), F64); + CASE_DT(a % b, U32); + CASE_DT(a % b, U64); + default: + UNREACHABLE(); + } + default: + UNREACHABLE(); + } } } diff --git a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc index 58d5f677e..97c9975ca 100644 --- a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc @@ -135,12 +135,46 @@ extern "C" __global__ void kernel( case DataType::F32: return "powf(a, b)"; case DataType::FP16: - return "__float2half(__powf(__half2float(a), __half2float(b)))"; + return "__float2half(powf(__half2float(a), __half2float(b)))"; case DataType::BF16: return "__float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b)))"; default: return "pow(a, b)"; } + case SimpleBinaryType::Mod: + switch (dt) { + case DataType::U8: + case DataType::I8: + case DataType::U16: + case DataType::I16: + case DataType::I32: + case DataType::I64: + case DataType::U32: + case DataType::U64: + return "a % b"; + default: + UNREACHABLE(); + } + case SimpleBinaryType::Fmod: + switch (dt) { + case DataType::U8: + case DataType::I8: + case DataType::U16: + case DataType::I16: + case DataType::I32: + case DataType::I64: + case DataType::U32: + case DataType::U64: + return "a % b < 0 ? (a % b + b) : (a % b)"; + case DataType::F32: + return "fmodf(a, b)"; + case DataType::FP16: + return "__float2half(fmodf(__half2float(a), __half2float(b)))"; + case DataType::BF16: + return "__float2bfloat16(fmodf(__bfloat162float(a), __bfloat162float(b)))"; + default: + UNREACHABLE(); + } default: UNREACHABLE(); } diff --git a/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp b/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp index 0247a7f39..d97ac1e66 100644 --- a/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp +++ b/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp @@ -1,4 +1,5 @@ #include "../src/kernels/simple_binary/cpu_kernel.hh" +#include #include using namespace refactor; @@ -27,11 +28,60 @@ void testBinaryCPU(SimpleBinaryType binaryOPT, std::function operation) { + // Create Tensor and build kernels + auto aTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto bTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor); + ASSERT_TRUE(cpuKernel); + auto res = runtime::Resources(); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector a(aTensor->elementsSize(), -3); + std::vector b(bTensor->elementsSize(), 2); + std::vector c(cTensor->elementsSize()); + // Compute + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + // Compare + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], operation(a[i], b[i])); + } +} + +void testFmodWithI32CPU(SimpleBinaryType binaryOPT, std::function operation) { + // Create Tensor and build kernels + auto aTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto bTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor); + ASSERT_TRUE(cpuKernel); + auto res = runtime::Resources(); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector a(aTensor->elementsSize(), -3); + std::vector b(bTensor->elementsSize(), 2); + std::vector c(cTensor->elementsSize()); + // Compute + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + // Compare + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], operation(a[i], b[i])); + } +} + TEST(kernel, BinaryCpu) { testBinaryCPU(SimpleBinaryType::Add, [](float a, float b) { return a + b; }); testBinaryCPU(SimpleBinaryType::Sub, [](float a, float b) { return a - b; }); testBinaryCPU(SimpleBinaryType::Mul, [](float a, float b) { return a * b; }); testBinaryCPU(SimpleBinaryType::Div, [](float a, float b) { return a / b; }); + testModCPU(SimpleBinaryType::Mod, [](int a, int b) { return a % b; }); + testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return a % b < 0 ? (a % b + b) : (a % b); }); + testBinaryCPU(SimpleBinaryType::Fmod, [](float a, float b) { return std::fmod(a, b); }); } TEST(kernel, BinaryCpuBroadcast) { diff --git a/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp b/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp index 901af265b..ed48426bf 100644 --- a/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp +++ b/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp @@ -9,12 +9,13 @@ using namespace refactor; using namespace kernel; using namespace hardware; +template void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape dimC) { // Create Tensor and build kernels - using T_ = primitive::type; - auto aTensor = Tensor::share(DataType::I8, dimA, LayoutType::NCHW); - auto bTensor = Tensor::share(DataType::I8, dimB, LayoutType::NCHW); - auto cTensor = Tensor::share(DataType::I8, dimC, LayoutType::NCHW); + using T_ = primitive::type; + auto aTensor = Tensor::share(T, dimA, LayoutType::NCHW); + auto bTensor = Tensor::share(T, dimB, LayoutType::NCHW); + auto cTensor = Tensor::share(T, dimC, LayoutType::NCHW); auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor), cudaKernel = BinaryCuda::build(binaryOPT, *aTensor, *bTensor); @@ -24,8 +25,8 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di auto cudaRoutine = cudaKernel->lower(res).routine; // Init inputs and outputs - std::vector a(aTensor->elementsSize(), 3.0f); - std::vector b(bTensor->elementsSize(), 2.0f); + std::vector a(aTensor->elementsSize(), 3); + std::vector b(bTensor->elementsSize(), 2); std::vector c(cTensor->elementsSize()); auto &dev = *device::init(Device::Type::Nvidia, 0, ""); auto aGPU = dev.malloc(aTensor->bytesSize()), @@ -53,35 +54,56 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di } TEST(kernel, BinaryCudaAdd) { - testBinaryCuda(SimpleBinaryType::Add, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Add, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaMul) { - testBinaryCuda(SimpleBinaryType::Mul, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Mul, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaSub) { - testBinaryCuda(SimpleBinaryType::Sub, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Sub, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaDiv) { - testBinaryCuda(SimpleBinaryType::Div, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Div, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCudaMod) { + testBinaryCuda(SimpleBinaryType::Mod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCudaFmodI8) { + testBinaryCuda(SimpleBinaryType::Fmod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCudaFmodF32) { + testBinaryCuda(SimpleBinaryType::Fmod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaBroadcast) { - testBinaryCuda(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6}); + testBinaryCuda(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6}); } #endif diff --git a/src/05computation/src/operators/simple_binary.cc b/src/05computation/src/operators/simple_binary.cc index 31831e7c4..a9bcde0b9 100644 --- a/src/05computation/src/operators/simple_binary.cc +++ b/src/05computation/src/operators/simple_binary.cc @@ -39,6 +39,10 @@ namespace refactor::computation { static uint8_t ID = 8; return reinterpret_cast(&ID); } + case Ty::Mod: { + static uint8_t ID = 9; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -64,6 +68,8 @@ namespace refactor::computation { return "Or"; case Ty::Xor: return "Xor"; + case Ty::Mod: + return "Mod"; default: UNREACHABLE(); } diff --git a/src/07onnx/src/operators/simple_binary.cc b/src/07onnx/src/operators/simple_binary.cc index a1bd5b24d..fed2a979c 100644 --- a/src/07onnx/src/operators/simple_binary.cc +++ b/src/07onnx/src/operators/simple_binary.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Simple binary operator should not have attributes"); + auto fmod = defaultOr(attributes, "fmod", {0}).int_(); // clang-format off auto type = opType == "onnx::Add" ? Ty::Add : @@ -21,6 +21,7 @@ namespace refactor::onnx { opType == "onnx::And" ? Ty::And : opType == "onnx::Or" ? Ty::Or : opType == "onnx::Xor" ? Ty::Xor : + opType == "onnx::Mod" ? (fmod == 0 ? Ty::Mod : Ty::Fmod) : UNREACHABLEX(Ty, "Unsupported binary operator: {}", opType); // clang-format on return OpBox(std::make_unique(type)); @@ -48,6 +49,22 @@ namespace refactor::onnx { static uint8_t ID = 5; return reinterpret_cast(&ID); } + case Ty::And: { + static uint8_t ID = 6; + return reinterpret_cast(&ID); + } + case Ty::Or: { + static uint8_t ID = 7; + return reinterpret_cast(&ID); + } + case Ty::Xor: { + static uint8_t ID = 8; + return reinterpret_cast(&ID); + } + case Ty::Mod: { + static uint8_t ID = 9; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -65,6 +82,7 @@ namespace refactor::onnx { case Ty::And: return "onnx::And"; case Ty::Or : return "onnx::Or" ; case Ty::Xor: return "onnx::Xor"; + case Ty::Mod: return "onnx::Mod"; default: UNREACHABLE(); } // clang-format on @@ -162,6 +180,7 @@ namespace refactor::onnx { case Ty::And : type_ = Ty_::And; break; case Ty::Or : type_ = Ty_::Or ; break; case Ty::Xor : type_ = Ty_::Xor; break; + case Ty::Mod : type_ = Ty_::Mod; break; default: UNREACHABLE(); } // clang-format on diff --git a/src/07onnx/src/operators/simple_binary.hh b/src/07onnx/src/operators/simple_binary.hh index dfcacc17d..4c948f5fc 100644 --- a/src/07onnx/src/operators/simple_binary.hh +++ b/src/07onnx/src/operators/simple_binary.hh @@ -15,6 +15,8 @@ namespace refactor::onnx { And, Or, Xor, + Mod, + Fmod, }; struct SimpleBinary final : public Operator { From 23ce52281c61891a7f7b32d580159de68d414587 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Thu, 11 Jan 2024 19:08:25 +0800 Subject: [PATCH 7/8] fix: add fmod --- src/05computation/src/operators/simple_binary.cc | 6 ++++++ src/07onnx/src/operators/simple_binary.cc | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/05computation/src/operators/simple_binary.cc b/src/05computation/src/operators/simple_binary.cc index a9bcde0b9..90f7ac028 100644 --- a/src/05computation/src/operators/simple_binary.cc +++ b/src/05computation/src/operators/simple_binary.cc @@ -43,6 +43,10 @@ namespace refactor::computation { static uint8_t ID = 9; return reinterpret_cast(&ID); } + case Ty::Fmod: { + static uint8_t ID = 10; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -70,6 +74,8 @@ namespace refactor::computation { return "Xor"; case Ty::Mod: return "Mod"; + case Ty::Fmod: + return "Fmod"; default: UNREACHABLE(); } diff --git a/src/07onnx/src/operators/simple_binary.cc b/src/07onnx/src/operators/simple_binary.cc index fed2a979c..ffa42cae3 100644 --- a/src/07onnx/src/operators/simple_binary.cc +++ b/src/07onnx/src/operators/simple_binary.cc @@ -65,6 +65,10 @@ namespace refactor::onnx { static uint8_t ID = 9; return reinterpret_cast(&ID); } + case Ty::Fmod: { + static uint8_t ID = 10; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -83,6 +87,7 @@ namespace refactor::onnx { case Ty::Or : return "onnx::Or" ; case Ty::Xor: return "onnx::Xor"; case Ty::Mod: return "onnx::Mod"; + case Ty::Fmod: return "onnx::Mod"; default: UNREACHABLE(); } // clang-format on @@ -181,6 +186,7 @@ namespace refactor::onnx { case Ty::Or : type_ = Ty_::Or ; break; case Ty::Xor : type_ = Ty_::Xor; break; case Ty::Mod : type_ = Ty_::Mod; break; + case Ty::Fmod : type_ = Ty_::Fmod; break; default: UNREACHABLE(); } // clang-format on From 54c2f7e292a0304324b8f27e7b902526881e1aaa Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 12 Jan 2024 13:03:52 +0800 Subject: [PATCH 8/8] =?UTF-8?q?refactor(python=5Fffi):=20=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=BB=8E=20exector=20=E5=AD=98=E5=8F=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=9D=97=E7=9A=84=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/03runtime/include/runtime/stream.h | 2 +- src/03runtime/src/stream.cc | 2 +- src/09python_ffi/src/executor.cc | 26 ++++++++++++-------------- src/09python_ffi/src/executor.h | 6 +++--- src/09python_ffi/src/main.cpp | 4 ++-- 5 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/03runtime/include/runtime/stream.h b/src/03runtime/include/runtime/stream.h index 8a4b4f51d..4f9e34f6c 100644 --- a/src/03runtime/include/runtime/stream.h +++ b/src/03runtime/include/runtime/stream.h @@ -44,7 +44,7 @@ namespace refactor::runtime { decltype(_graph) const &graph() const noexcept { return _graph; } auto setData(count_t, size_t) -> Arc; void setData(count_t, Arc); - auto getData(count_t) -> Arc const; + auto getData(count_t) const -> Arc; void setData(count_t, void const *, size_t); bool copyData(count_t, void *, size_t) const; void run(); diff --git a/src/03runtime/src/stream.cc b/src/03runtime/src/stream.cc index a02234a5f..570563bc0 100644 --- a/src/03runtime/src/stream.cc +++ b/src/03runtime/src/stream.cc @@ -29,7 +29,7 @@ namespace refactor::runtime { blob->copyFromHost(data, size); _graph.edges[i].blob = std::move(blob); } - auto Stream::getData(count_t i) -> Arc const { + auto Stream::getData(count_t i) const -> Arc { return _graph.edges[i].blob; } bool Stream::copyData(count_t i, void *data, size_t size) const { diff --git a/src/09python_ffi/src/executor.cc b/src/09python_ffi/src/executor.cc index 9b8a6d6da..d402de9f0 100644 --- a/src/09python_ffi/src/executor.cc +++ b/src/09python_ffi/src/executor.cc @@ -40,7 +40,15 @@ namespace refactor::python_ffi { _stream.setData(i, data.data(), data.nbytes()); } - auto Executor::getOutput(count_t i) -> pybind11::array { + void Executor::setInputBlob(count_t i, Arc blob) { + i = _stream.graph().topology.globalInputs().at(i); + + auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; + ASSERT(tensor.bytesSize() == blob->size(), "input size mismatch"); + _stream.setData(i, std::move(blob)); + } + + auto Executor::getOutput(count_t i) const -> pybind11::array { i = _stream.graph().topology.globalOutputs().at(i); auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; @@ -49,20 +57,10 @@ namespace refactor::python_ffi { return ans; } - auto Executor::pin(count_t i) -> Arc { - i = _stream.graph().topology.globalInputs().at(i); - - if (auto pinned = _stream.getData(i); pinned) { - return pinned; - } else { - auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; - return _stream.setData(i, tensor.bytesSize()); - } - } - void Executor::setPinned(count_t i, Arc pinned) { - i = _stream.graph().topology.globalInputs().at(i); + auto Executor::getOutputBlob(count_t i) const -> Arc { + i = _stream.graph().topology.globalOutputs().at(i); - _stream.setData(i, std::move(pinned)); + return _stream.getData(i); } void Executor::run() { diff --git a/src/09python_ffi/src/executor.h b/src/09python_ffi/src/executor.h index ff70bafcc..004b9e63d 100644 --- a/src/09python_ffi/src/executor.h +++ b/src/09python_ffi/src/executor.h @@ -15,9 +15,9 @@ namespace refactor::python_ffi { Executor(computation::Graph, runtime::Stream); void dispatch(Arc, std::string allocator); void setInput(count_t, pybind11::array); - auto getOutput(count_t) -> pybind11::array; - auto pin(count_t) -> Arc; - void setPinned(count_t, Arc); + void setInputBlob(count_t, Arc); + auto getOutput(count_t) const -> pybind11::array; + auto getOutputBlob(count_t) const -> Arc; void run(); void bench(bool sync); void trace(std::string path, std::string format); diff --git a/src/09python_ffi/src/main.cpp b/src/09python_ffi/src/main.cpp index 3093f7275..8a95aa8d1 100644 --- a/src/09python_ffi/src/main.cpp +++ b/src/09python_ffi/src/main.cpp @@ -44,9 +44,9 @@ namespace refactor::python_ffi { py::class_>(m, "Executor" ) .def("dispatch" , &Executor::dispatch , return_::automatic ) .def("set_input" , &Executor::setInput , return_::automatic ) + .def("set_input_blob" , &Executor::setInputBlob , return_::automatic ) .def("get_output" , &Executor::getOutput , return_::move ) - .def("pin" , &Executor::pin , return_::move ) - .def("set_pinned" , &Executor::setPinned , return_::automatic ) + .def("get_output_blob" , &Executor::getOutputBlob , return_::move ) .def("run" , &Executor::run , return_::automatic ) .def("bench" , &Executor::bench , return_::automatic ) .def("trace" , &Executor::trace , return_::automatic )