From bb9c4d013bf3b1fb42139504cdf9e69f2ed91e09 Mon Sep 17 00:00:00 2001 From: 4kangjc <1158035520@qq.com> Date: Fri, 14 Apr 2023 23:55:58 +0800 Subject: [PATCH] scheduler: Support moved tast captures / arguments --- CMakeLists.txt | 1 + include/marl/move_only_function.h | 188 ++++++++++++++++++++++++++++++ include/marl/task.h | 23 +--- src/move_only_function_test.cpp | 148 +++++++++++++++++++++++ src/scheduler_test.cpp | 30 +++++ 5 files changed, 369 insertions(+), 21 deletions(-) create mode 100644 include/marl/move_only_function.h create mode 100644 src/move_only_function_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 78ee4a3..d31af89 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -340,6 +340,7 @@ if(MARL_BUILD_TESTS) ${MARL_SRC_DIR}/marl_test.cpp ${MARL_SRC_DIR}/marl_test.h ${MARL_SRC_DIR}/memory_test.cpp + ${MARL_SRC_DIR}/move_only_function_test.cpp ${MARL_SRC_DIR}/osfiber_test.cpp ${MARL_SRC_DIR}/parallelize_test.cpp ${MARL_SRC_DIR}/pool_test.cpp diff --git a/include/marl/move_only_function.h b/include/marl/move_only_function.h new file mode 100644 index 0000000..a6e671c --- /dev/null +++ b/include/marl/move_only_function.h @@ -0,0 +1,188 @@ +// Copyright 2023 The Marl Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef marl_move_only_function_h +#define marl_move_only_function_h + +#include "memory.h" + +#include +#include +#include +#include +#include + +namespace marl { + +#if __cplusplus > 201402L + +using std::invoke; + +#else +// std::invoke for C++11 and C++14 +template +typename std::enable_if< + std::is_member_pointer::type>::value, + typename std::result_of::type>::type +invoke(Functor&& f, Args&&... args) { + return std::mem_fn(f)(std::forward(args)...); +} + +template +typename std::enable_if< + !std::is_member_pointer::type>::value, + typename std::result_of::type>::type +invoke(Functor&& f, Args&&... args) { + return std::forward(f)(std::forward(args)...); +} + +#endif // __cplusplus > 201402L + +template +class move_only_function; + +#if __cplusplus > 201402L && (defined(__GNUC__) || defined(__clang__)) +template +class move_only_function { +#else +template +class move_only_function { + static const bool kNoexcept = false; // private +#endif + + public: + constexpr move_only_function() = default; + move_only_function(std::nullptr_t) {} + + template ::type, move_only_function>::value>::type> + move_only_function(F&& function) { + if (sizeof(typename std::decay::type) <= kMaximumOptimizableSize) { + ops_ = EraseCopySmall(&object_, std::forward(function)); + } else { + ops_ = EraseCopyLarge(&object_, std::forward(function)); + } + } + + move_only_function(move_only_function&& function) noexcept { + ops_ = function.ops_; + function.ops_ = nullptr; + if (ops_) { + ops_->manager(&object_, &function.object_); + } + } + ~move_only_function() { + if (ops_) { + ops_->manager(&object_, nullptr); + } + } + + template ::type, move_only_function>::value>::type> + move_only_function& operator=(F&& function) { + this->~move_only_function(); + new (this) move_only_function(std::forward(function)); + return *this; + } + + move_only_function& operator=(move_only_function&& function) noexcept { + if (&function != this) { + this->~move_only_function(); + new (this) move_only_function(std::move(function)); + } + return *this; + } + + move_only_function& operator=(std::nullptr_t) { + ops_->manager(&object_, nullptr); + ops_ = nullptr; + return *this; + } + + // The behavior is undefined if `*this == nullptr` holds. + R operator()(Args... args) const noexcept(kNoexcept) { + return ops_->invoker(&object_, std::forward(args)...); + } + + constexpr explicit operator bool() const { return ops_; } + + private: + static constexpr std::size_t kMaximumOptimizableSize = 3 * sizeof(void*); + + struct TypeOps { + using Invoker = R (*)(void* object, Args&&... args); + using Manager = void (*)(void* dest, void* src); + + Invoker invoker; + Manager manager; + }; + + template + static R Invoke(F&& f, Args&&... args) { + return invoke(std::forward(f), std::forward(args)...); + } + + template + const TypeOps* EraseCopySmall(void* buffer, T&& obejct) { + using Decayed = typename std::decay::type; + + static const TypeOps ops = { + // Invoker + [](void* object, Args&&... args) -> R { + return Invoke(*static_cast(object), + std::forward(args)...); + }, + // Manager + [](void* dest, void* src) { + if (src) { + new (dest) Decayed(std::move(*static_cast(src))); + static_cast(src)->~Decayed(); + } else { + static_cast(dest)->~Decayed(); + } + }}; + + new (buffer) Decayed(std::forward(obejct)); + return &ops; + } + + template + const TypeOps* EraseCopyLarge(void* buffer, T&& object) { + using Decayed = typename std::decay::type; + using Stored = Decayed*; + + static const TypeOps ops = { + /* invoker */ + [](void* object, Args&&... args) -> R { + return Invoke(**static_cast(object), + std::forward(args)...); + }, + /* Manager */ + [](void* dest, void* src) { + if (src) { + new (dest) Stored(*static_cast(src)); + } else { + delete *static_cast(dest); + } + }, + }; + new (buffer) Stored(new Decayed(std::forward(object))); + return &ops; + } + + mutable marl::aligned_storage::type object_; + const TypeOps* ops_ = nullptr; +}; + +} // namespace marl + +#endif diff --git a/include/marl/task.h b/include/marl/task.h index 1e7d3f4..0242ebf 100644 --- a/include/marl/task.h +++ b/include/marl/task.h @@ -16,15 +16,14 @@ #define marl_task_h #include "export.h" - -#include +#include "move_only_function.h" namespace marl { // Task is a unit of work for the scheduler. class Task { public: - using Function = std::function; + using Function = move_only_function; enum class Flags { None = 0, @@ -37,14 +36,9 @@ class Task { }; MARL_NO_EXPORT inline Task(); - MARL_NO_EXPORT inline Task(const Task&); MARL_NO_EXPORT inline Task(Task&&); - MARL_NO_EXPORT inline Task(const Function& function, - Flags flags = Flags::None); MARL_NO_EXPORT inline Task(Function&& function, Flags flags = Flags::None); - MARL_NO_EXPORT inline Task& operator=(const Task&); MARL_NO_EXPORT inline Task& operator=(Task&&); - MARL_NO_EXPORT inline Task& operator=(const Function&); MARL_NO_EXPORT inline Task& operator=(Function&&); // operator bool() returns true if the Task has a valid function. @@ -62,28 +56,15 @@ class Task { }; Task::Task() {} -Task::Task(const Task& o) : function(o.function), flags(o.flags) {} Task::Task(Task&& o) : function(std::move(o.function)), flags(o.flags) {} -Task::Task(const Function& function_, Flags flags_ /* = Flags::None */) - : function(function_), flags(flags_) {} Task::Task(Function&& function_, Flags flags_ /* = Flags::None */) : function(std::move(function_)), flags(flags_) {} -Task& Task::operator=(const Task& o) { - function = o.function; - flags = o.flags; - return *this; -} Task& Task::operator=(Task&& o) { function = std::move(o.function); flags = o.flags; return *this; } -Task& Task::operator=(const Function& f) { - function = f; - flags = Flags::None; - return *this; -} Task& Task::operator=(Function&& f) { function = std::move(f); flags = Flags::None; diff --git a/src/move_only_function_test.cpp b/src/move_only_function_test.cpp new file mode 100644 index 0000000..bea5ef7 --- /dev/null +++ b/src/move_only_function_test.cpp @@ -0,0 +1,148 @@ +// Copyright 2023 The Marl Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "marl/move_only_function.h" + +#include + +#include "marl_test.h" + +class MoveOnlyFunctionTest : public testing::Test { +}; + +int minus(int a, int b) { return a - b; } + +int product(int a, int b) { return a * b; } + +struct divides { + int operator()(int a, int b) const { return a / b; } +}; + +// noncopyable plus +struct plus { + int operator()(int a, int b) const { return a + b; } + plus(const plus&) = delete; + plus() = default; + plus(plus&&) = default; +}; + +template +void multiplication(T& a, const T& b) { + a *= b; +} + +TEST_F(MoveOnlyFunctionTest, Empty) { + marl::move_only_function f; + EXPECT_FALSE(f); +} + +TEST_F(MoveOnlyFunctionTest, Lambda) { + marl::move_only_function f1 = [](int a, int b) { return a + b; }; + EXPECT_EQ(f1(1, 2), 3); + + int a = 0; + marl::move_only_function f2 = [&a]() { ++a; }; + f2(); + EXPECT_EQ(a, 1); + + plus p; + marl::move_only_function f3([](plus&& p) { return p(2, 5); }); + EXPECT_EQ(f3(std::move(p)), 7); + + divides d; + marl::move_only_function f4([d](int a, int b) { return d(a, b); }); + EXPECT_EQ(f4(20, 5), 4); + +#if __cplusplus >= 201402L + std::unique_ptr uq(new int(3)); + marl::move_only_function f5 = [uni_ptr = std::move(uq)]() { return *uni_ptr; }; + EXPECT_EQ(f5(), 3); +#endif + + std::array payload; + payload.back() = 5; + marl::move_only_function f6 = [payload] { return payload.back(); }; + EXPECT_EQ(f6(), 5); +} + +TEST_F(MoveOnlyFunctionTest, MemberMethod) { + struct Pii { + int a, b; + int hash_func() { return std::hash()(a) ^ std::hash()(b); } + int sum(int c = 0) const { return a + b + c; } + }; + + plus p; + marl::move_only_function f1(std::move(p)); + EXPECT_EQ(f1(2, 5), 7); + + divides div; + marl::move_only_function f2(div); + EXPECT_EQ(f2(30, 5), 6); + + Pii pii{4, 5}; + marl::move_only_function f3(&Pii::hash_func); + EXPECT_EQ(f3(&pii), pii.hash_func()); + + marl::move_only_function f4(&Pii::sum); + EXPECT_EQ(f4(&pii, 1), 10); +} + +TEST_F(MoveOnlyFunctionTest, Fucntion) { + marl::move_only_function f1(minus); + EXPECT_EQ(f1(10, 5), 5); + + marl::move_only_function f2 = &product; + EXPECT_EQ(f2(2, 3), 6); + + marl::move_only_function f3(&multiplication); + double a = 4.0f; + f3(a, 5); + EXPECT_EQ(a, 20); +} + +TEST_F(MoveOnlyFunctionTest, Move) { + struct OnlyCopyable { + OnlyCopyable() : v(new std::vector()) {} + OnlyCopyable(const OnlyCopyable& oc) : v(new std::vector(*oc.v)) {} + ~OnlyCopyable() { delete v; } + std::vector* v; + }; + marl::move_only_function f, f2; + OnlyCopyable payload; + + payload.v->resize(100, 12); + + // BE SURE THAT THE LAMBDA IS NOT LARGER THAN kMaximumOptimizableSize. + f = [payload] { return payload.v->back(); }; + f2 = std::move(f); + EXPECT_EQ(12, f2()); +} + +TEST_F(MoveOnlyFunctionTest, LargeFunctor) { + marl::move_only_function f, f2; + std::array, 100> payload; + + payload.back().resize(10, 12); + f = [payload] { return payload.back().back(); }; + f2 = std::move(f); + EXPECT_EQ(12, f2()); +} + +TEST_F(MoveOnlyFunctionTest, Clear) { + marl::move_only_function f([]{}); + EXPECT_TRUE(f); + f = nullptr; + EXPECT_FALSE(f); +} diff --git a/src/scheduler_test.cpp b/src/scheduler_test.cpp index 64cf995..6cf064d 100644 --- a/src/scheduler_test.cpp +++ b/src/scheduler_test.cpp @@ -108,6 +108,36 @@ TEST_P(WithBoundScheduler, ScheduleWithArgs) { ASSERT_EQ(got, "s: 'a string', i: 42, b: true"); } +TEST_P(WithBoundScheduler, ScheduleWithMovedCapture) { +#if __cplusplus >= 201402L // C++14 or greater + std::unique_ptr move_me(new std::string("move me")); + std::string got; + marl::WaitGroup wg(1); + marl::schedule([moved = std::move(move_me), wg, &got]() { + got = *moved; + wg.done(); + }); + wg.wait(); + ASSERT_EQ(got, "move me"); +#else + GTEST_SKIP() << "Test requires c++14 or greater"; +#endif +} + +TEST_P(WithBoundScheduler, ScheduleWithMovedArg) { + std::unique_ptr move_me(new std::string("move me")); + std::string got; + marl::WaitGroup wg(1); + marl::schedule( + [wg, &got](std::unique_ptr& str) { + got = *str; + wg.done(); + }, + std::move(move_me)); + wg.wait(); + ASSERT_EQ(got, "move me"); +} + TEST_P(WithBoundScheduler, FibersResumeOnSameThread) { marl::WaitGroup fence(1); marl::WaitGroup wg(1000);