diff --git a/Detectors/TPC/workflow/readers/include/TPCReaderWorkflow/TPCSectorCompletionPolicy.h b/Detectors/TPC/workflow/readers/include/TPCReaderWorkflow/TPCSectorCompletionPolicy.h index 33abded35624a..9f2a8b31dbdba 100644 --- a/Detectors/TPC/workflow/readers/include/TPCReaderWorkflow/TPCSectorCompletionPolicy.h +++ b/Detectors/TPC/workflow/readers/include/TPCReaderWorkflow/TPCSectorCompletionPolicy.h @@ -91,7 +91,7 @@ class TPCSectorCompletionPolicy return std::regex_match(device.name.begin(), device.name.end(), std::regex(expression.c_str())); }; - auto callback = [bRequireAll = mRequireAll, inputMatchers = mInputMatchers, externalInputMatchers = mExternalInputMatchers, pTpcSectorMask = mTpcSectorMask, orderCheck = mOrderCheck](framework::InputSpan const& inputs) -> framework::CompletionPolicy::CompletionOp { + auto callback = [bRequireAll = mRequireAll, inputMatchers = mInputMatchers, externalInputMatchers = mExternalInputMatchers, pTpcSectorMask = mTpcSectorMask, orderCheck = mOrderCheck](framework::InputSpan const& inputs, auto const&, auto&) -> framework::CompletionPolicy::CompletionOp { unsigned long tpcSectorMask = pTpcSectorMask ? *pTpcSectorMask : 0xFFFFFFFFF; std::bitset validSectors = 0; bool haveMatchedInput = false; diff --git a/Framework/Core/include/Framework/CompletionPolicy.h b/Framework/Core/include/Framework/CompletionPolicy.h index eda45bd315471..55d3014166956 100644 --- a/Framework/Core/include/Framework/CompletionPolicy.h +++ b/Framework/Core/include/Framework/CompletionPolicy.h @@ -64,26 +64,21 @@ struct CompletionPolicy { using Matcher = std::function; using InputSetElement = DataRef; - using Callback = std::function; using CallbackFull = std::function const&, ServiceRegistryRef&)>; using CallbackConfigureRelayer = std::function; /// Constructor CompletionPolicy() - : name{}, matcher{}, callback{} {} + : name{}, matcher{}, callbackFull{} {} /// Constructor for emplace_back - CompletionPolicy(std::string _name, Matcher _matcher, Callback _callback, bool _balanceChannels = true) - : name(std::move(_name)), matcher(std::move(_matcher)), callback(std::move(_callback)), callbackFull{nullptr}, balanceChannels{_balanceChannels} {} CompletionPolicy(std::string _name, Matcher _matcher, CallbackFull _callback, bool _balanceChannels = true) - : name(std::move(_name)), matcher(std::move(_matcher)), callback(nullptr), callbackFull{std::move(_callback)}, balanceChannels{_balanceChannels} {} + : name(std::move(_name)), matcher(std::move(_matcher)), callbackFull{std::move(_callback)}, balanceChannels{_balanceChannels} {} /// Name of the policy itself. std::string name = ""; /// Callback to be used to understand if the policy should apply /// to the given device. Matcher matcher = nullptr; - /// Actual policy which decides what to do with a partial InputRecord. - Callback callback = nullptr; /// Actual policy which decides what to do with a partial InputRecord, extended version CallbackFull callbackFull = nullptr; /// A callback which allows you to configure the behavior of the data relayer associated diff --git a/Framework/Core/src/CompletionPolicyHelpers.cxx b/Framework/Core/src/CompletionPolicyHelpers.cxx index 34fc002428e12..b3e0621bf984e 100644 --- a/Framework/Core/src/CompletionPolicyHelpers.cxx +++ b/Framework/Core/src/CompletionPolicyHelpers.cxx @@ -21,9 +21,6 @@ #include #include -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wpedantic" - namespace o2::framework { @@ -35,7 +32,7 @@ CompletionPolicy CompletionPolicyHelpers::defineByNameOrigin(std::string const& auto originReceived = std::make_shared>(); - auto callback = [originReceived, origin, op](InputSpan const& inputRefs) -> CompletionPolicy::CompletionOp { + auto callback = [originReceived, origin, op](InputSpan const& inputRefs, std::vector const&, ServiceRegistryRef&) -> CompletionPolicy::CompletionOp { // update list of the start times of inputs with origin @origin for (auto& ref : inputRefs) { if (ref.header != nullptr) { @@ -77,7 +74,7 @@ CompletionPolicy CompletionPolicyHelpers::defineByName(std::string const& name, auto matcher = [name](DeviceSpec const& device) -> bool { return std::regex_match(device.name.begin(), device.name.end(), std::regex(name)); }; - auto callback = [op](InputSpan const&) -> CompletionPolicy::CompletionOp { + auto callback = [op](InputSpan const&, std::vector const& specs, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp { return op; }; switch (op) { @@ -108,7 +105,8 @@ CompletionPolicy CompletionPolicyHelpers::defineByName(std::string const& name, CompletionPolicy CompletionPolicyHelpers::consumeWhenAll(const char* name, CompletionPolicy::Matcher matcher) { - auto callback = [](InputSpan const& inputs) -> CompletionPolicy::CompletionOp { + auto callback = [](InputSpan const& inputs, std::vector const& specs, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp { + assert(inputs.size() == specs.size()); for (auto& input : inputs) { if (input.header == nullptr) { return CompletionPolicy::CompletionOp::Wait; @@ -123,7 +121,7 @@ CompletionPolicy CompletionPolicyHelpers::consumeWhenAllOrdered(const char* name { auto callbackFull = [](InputSpan const& inputs, std::vector const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp { auto& decongestionService = ref.get(); - decongestionService.orderedCompletionPolicyActive = 1; + decongestionService.orderedCompletionPolicyActive = true; for (auto& input : inputs) { if (input.header == nullptr) { return CompletionPolicy::CompletionOp::Wait; @@ -199,7 +197,7 @@ CompletionPolicy CompletionPolicyHelpers::consumeExistingWhenAny(const char* nam CompletionPolicy CompletionPolicyHelpers::consumeWhenAny(const char* name, CompletionPolicy::Matcher matcher) { - auto callback = [](InputSpan const& inputs) -> CompletionPolicy::CompletionOp { + auto callback = [](InputSpan const& inputs, std::vector const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp { for (auto& input : inputs) { if (input.header != nullptr) { return CompletionPolicy::CompletionOp::Consume; @@ -289,7 +287,7 @@ CompletionPolicy CompletionPolicyHelpers::consumeWhenAnyWithAllConditions(std::s CompletionPolicy CompletionPolicyHelpers::processWhenAny(const char* name, CompletionPolicy::Matcher matcher) { - auto callback = [](InputSpan const& inputs) -> CompletionPolicy::CompletionOp { + auto callback = [](InputSpan const& inputs, std::vector const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp { size_t present = 0; for (auto& input : inputs) { if (input.header != nullptr) { @@ -307,4 +305,3 @@ CompletionPolicy CompletionPolicyHelpers::processWhenAny(const char* name, Compl } } // namespace o2::framework -#pragma GCC diagnostic pop diff --git a/Framework/Core/src/DataRelayer.cxx b/Framework/Core/src/DataRelayer.cxx index e9591a9cc76ea..b3d20e19852fa 100644 --- a/Framework/Core/src/DataRelayer.cxx +++ b/Framework/Core/src/DataRelayer.cxx @@ -675,6 +675,9 @@ void DataRelayer::getReadyToProcess(std::vector& comp notDirty++; continue; } + if (!mCompletionPolicy.callbackFull) { + throw runtime_error_f("Completion police %s has no callback set", mCompletionPolicy.name.c_str()); + } auto partial = getPartialRecord(li); // TODO: get the data ref from message model auto getter = [&partial](size_t idx, size_t part) { @@ -692,14 +695,8 @@ void DataRelayer::getReadyToProcess(std::vector& comp return partial[idx].size(); }; InputSpan span{getter, nPartsGetter, static_cast(partial.size())}; - CompletionPolicy::CompletionOp action; - if (mCompletionPolicy.callback) { - action = mCompletionPolicy.callback(span); - } else if (mCompletionPolicy.callbackFull) { - action = mCompletionPolicy.callbackFull(span, mInputs, mContext); - } else { - throw runtime_error_f("Completion police %s has no callback set", mCompletionPolicy.name.c_str()); - } + CompletionPolicy::CompletionOp action = mCompletionPolicy.callbackFull(span, mInputs, mContext); + auto& variables = mTimesliceIndex.getVariablesForSlot(slot); auto timeslice = std::get_if(&variables.get(0)); switch (action) { diff --git a/Framework/Core/test/test_CompletionPolicy.cxx b/Framework/Core/test/test_CompletionPolicy.cxx index 3bb2dc12caff0..059f20b352b3d 100644 --- a/Framework/Core/test/test_CompletionPolicy.cxx +++ b/Framework/Core/test/test_CompletionPolicy.cxx @@ -12,6 +12,7 @@ #include #include "Framework/CompletionPolicy.h" #include "Framework/CompletionPolicyHelpers.h" +#include "Framework/ServiceRegistry.h" #include "Headers/DataHeader.h" #include "Headers/NameHeader.h" #include "Framework/CompletionPolicy.h" @@ -39,7 +40,9 @@ TEST_CASE("TestCompletionPolicy_callback") return true; }; - auto callback = [&stack](InputSpan const& inputRefs) { + ServiceRegistry services; + + auto callback = [&stack](InputSpan const& inputRefs, std::vector const&, ServiceRegistryRef&) { for (auto const& ref : inputRefs) { auto const* header = CompletionPolicyHelpers::getHeader(ref); REQUIRE(header == reinterpret_cast(stack.data())); @@ -53,7 +56,9 @@ TEST_CASE("TestCompletionPolicy_callback") {"test", matcher, callback}}; CompletionPolicy::InputSetElement ref{nullptr, reinterpret_cast(stack.data()), nullptr}; InputSpan const& inputs{[&ref](size_t) { return ref; }, 1}; + std::vector specs; + ServiceRegistryRef servicesRef{services}; for (auto& policy : policies) { - policy.callback(inputs); + policy.callbackFull(inputs, specs, servicesRef); } } diff --git a/Framework/Core/test/test_StaggeringWorkflow.cxx b/Framework/Core/test/test_StaggeringWorkflow.cxx index edc07d598465b..590ce83cef467 100644 --- a/Framework/Core/test/test_StaggeringWorkflow.cxx +++ b/Framework/Core/test/test_StaggeringWorkflow.cxx @@ -53,7 +53,7 @@ void customize(std::vector& policies) // search for spec names starting with "processor" return spec.name.find("processor") == 0; }, - [](auto const&) { return o2::framework::CompletionPolicy::CompletionOp::Consume; }}); + [](auto const&, auto const&, auto &) { return o2::framework::CompletionPolicy::CompletionOp::Consume; }}); } #include "Framework/runDataProcessing.h"