From e4490ce08f0dbd3db96e6e43c3e6d5ba717595f6 Mon Sep 17 00:00:00 2001 From: unex <63149623+UNEXENU@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:21:54 +0100 Subject: [PATCH] initial RuntimeInformation support for service queries --- src/engine/Operation.cpp | 20 ++-- src/engine/Operation.h | 5 + src/engine/RuntimeInformation.cpp | 81 ++++++++++++++++ src/engine/RuntimeInformation.h | 5 + src/engine/Service.cpp | 33 ++++++- src/engine/Service.h | 44 +++++++-- src/util/LazyJsonParser.h | 2 +- src/util/http/HttpClient.cpp | 55 ++++++++++- src/util/http/HttpClient.h | 20 +++- src/util/http/beast.h | 1 + test/ConcurrentCacheTest.cpp | 9 ++ test/RuntimeInformationTest.cpp | 48 +++++++++- test/ServiceTest.cpp | 148 ++++++++++++++++++------------ 13 files changed, 386 insertions(+), 85 deletions(-) diff --git a/src/engine/Operation.cpp b/src/engine/Operation.cpp index 0f94ff7886..f422f910ed 100644 --- a/src/engine/Operation.cpp +++ b/src/engine/Operation.cpp @@ -413,10 +413,8 @@ void Operation::updateRuntimeInformationOnSuccess( // Therefore, for each child of this operation the correct runtime is // available. _runtimeInfo->children_.clear(); - for (auto* child : getChildren()) { - AD_CONTRACT_CHECK(child); - _runtimeInfo->children_.push_back( - child->getRootOperation()->getRuntimeInfoPointer()); + for (auto child : getRuntimeInfoChildren()) { + _runtimeInfo->children_.push_back(child); } } signalQueryUpdate(); @@ -470,8 +468,8 @@ void Operation::updateRuntimeInformationWhenOptimizedOut( // _______________________________________________________________________ void Operation::updateRuntimeInformationOnFailure(Milliseconds duration) { _runtimeInfo->children_.clear(); - for (auto child : getChildren()) { - _runtimeInfo->children_.push_back(child->getRootOperation()->_runtimeInfo); + for (auto child : getRuntimeInfoChildren()) { + _runtimeInfo->children_.push_back(child); } _runtimeInfo->totalTime_ = duration; @@ -598,7 +596,6 @@ const vector& Operation::getResultSortedOn() const { } // _____________________________________________________________________________ - void Operation::signalQueryUpdate() const { if (_executionContext && _executionContext->areWebsocketUpdatesEnabled()) { _executionContext->signalQueryUpdate(*_rootRuntimeInfo); @@ -625,3 +622,12 @@ uint64_t Operation::getSizeEstimate() { return getSizeEstimateBeforeLimit(); } } + +// _____________________________________________________________________________ +cppcoro::generator> +Operation::getRuntimeInfoChildren() { + for (auto child : getChildren()) { + AD_CONTRACT_CHECK(child); + co_yield child->getRootOperation()->getRuntimeInfoPointer(); + } +} diff --git a/src/engine/Operation.h b/src/engine/Operation.h index 1a3f68b83d..b7123a3f23 100644 --- a/src/engine/Operation.h +++ b/src/engine/Operation.h @@ -121,6 +121,11 @@ class Operation { return {interm.begin(), interm.end()}; } + // Get access to the children's RuntimeInfo. Required for the `Service`, as + // it's children can't be accessed using `getChildren()` above. + virtual cppcoro::generator> + getRuntimeInfoChildren(); + // recursively collect all Warnings generated by all descendants vector collectWarnings() const; diff --git a/src/engine/RuntimeInformation.cpp b/src/engine/RuntimeInformation.cpp index f9f2d851c8..59288e1ae3 100644 --- a/src/engine/RuntimeInformation.cpp +++ b/src/engine/RuntimeInformation.cpp @@ -185,6 +185,29 @@ std::string_view RuntimeInformation::toString(Status status) { AD_FAIL(); } +// __________________________________________________________________________ +RuntimeInformation::Status RuntimeInformation::fromString( + std::string_view str) { + if (str == "fully materialized") { + return fullyMaterialized; + } else if (str == "lazily materialized") { + return lazilyMaterialized; + } else if (str == "in progress") { + return inProgress; + } else if (str == "not started") { + return notStarted; + } else if (str == "optimized out") { + return optimizedOut; + } else if (str == "failed") { + return failed; + } else if (str == "failed because child failed") { + return failedBecauseChildFailed; + } else if (str == "cancelled") { + return cancelled; + } + AD_FAIL(); +} + // ________________________________________________________________________________________________________________ void to_json(nlohmann::ordered_json& j, const std::shared_ptr& rti) { @@ -220,6 +243,64 @@ void to_json(nlohmann::ordered_json& j, {"time_query_planning", rti.timeQueryPlanning.count()}}; } +// __________________________________________________________________________ +void from_json(const nlohmann::json& j, RuntimeInformation& rti) { + // Helper lambdas to ignore missing key or invalid value. + auto tryGet = [&j](T& dst, std::string_view key) { + try { + j.at(key).get_to(dst); + } catch (const nlohmann::json::exception& e) { + } + }; + using namespace std::chrono; + auto tryGetTime = [&j](microseconds& dst, std::string_view key) { + try { + dst = + duration_cast(milliseconds(j.at(key).get())); + } catch (const nlohmann::json::exception& e) { + } + }; + + auto cacheStatusFromString = [](std::string_view str) { + using ad_utility::CacheStatus; + if (str == "cached_not_pinned") { + return CacheStatus::cachedNotPinned; + } else if (str == "cached_pinned") { + return CacheStatus::cachedPinned; + } else if (str == "computed") { + return CacheStatus::computed; + } else if (str == "not_in_cache_not_computed") { + return CacheStatus::notInCacheAndNotComputed; + } else { + throw std::runtime_error( + "Unknown string value was encountered in `fromString(CacheStatus)`"); + } + }; + + tryGet(rti.descriptor_, "description"); + tryGet(rti.numRows_, "result_rows"); + tryGet(rti.numCols_, "result_cols"); + tryGet(rti.columnNames_, "column_names"); + tryGetTime(rti.totalTime_, "total_time"); + tryGetTime(rti.originalTotalTime_, "original_total_time"); + tryGetTime(rti.originalOperationTime_, "original_operation_time"); + if (auto it = j.find("cache_status"); it != j.end()) { + rti.cacheStatus_ = cacheStatusFromString(it->get()); + } + tryGet(rti.details_, "details"); + tryGet(rti.costEstimate_, "estimated_total_cost"); + tryGet(rti.multiplicityEstimates_, "estimated_column_multiplicities"); + tryGet(rti.sizeEstimate_, "estimated_size"); + if (auto it = j.find("status"); it != j.end()) { + rti.status_ = RuntimeInformation::fromString(it->get()); + } + if (auto it = j.find("children"); it != j.end()) { + for (const auto& child : *it) { + rti.children_.push_back(std::make_shared(child)); + } + } +} + // __________________________________________________________________________ void RuntimeInformation::addLimitOffsetRow(const LimitOffsetClause& l, bool fullResultIsNotCached) { diff --git a/src/engine/RuntimeInformation.h b/src/engine/RuntimeInformation.h index ff4cdc1488..89452ccc5f 100644 --- a/src/engine/RuntimeInformation.h +++ b/src/engine/RuntimeInformation.h @@ -94,6 +94,9 @@ class RuntimeInformation { /// library to allow for implicit conversion. friend void to_json(nlohmann::ordered_json& j, const RuntimeInformation& rti); + // Import from json. Missing keys or invalid values are ignored. + friend void from_json(const nlohmann::json& j, RuntimeInformation& rti); + /// Set `columnNames_` from a `VariableToColumnMap`. The former is a vector /// (convenient for this class), the latter is a hash map (appropriate for /// the rest of the code). @@ -138,6 +141,8 @@ class RuntimeInformation { static std::string_view toString(Status status); + static Status fromString(std::string_view str); + // A helper function for printing the details as a string. static void formatDetailValue(std::ostream& out, std::string_view key, const nlohmann::json& value); diff --git a/src/engine/Service.cpp b/src/engine/Service.cpp index a2791055c0..bc57f76de0 100644 --- a/src/engine/Service.cpp +++ b/src/engine/Service.cpp @@ -24,10 +24,10 @@ // ____________________________________________________________________________ Service::Service(QueryExecutionContext* qec, parsedQuery::Service parsedServiceClause, - GetResultFunction getResultFunction) + NetworkFunctions networkFunctions) : Operation{qec}, parsedServiceClause_{std::move(parsedServiceClause)}, - getResultFunction_{std::move(getResultFunction)} {} + networkFunctions_{std::move(networkFunctions)} {} // ____________________________________________________________________________ std::string Service::getCacheKeyImpl() const { @@ -125,6 +125,31 @@ ProtoResult Service::computeResultImpl([[maybe_unused]] bool requestLaziness) { ad_utility::httpUtils::Url serviceUrl{ asStringViewUnsafe(parsedServiceClause_.serviceIri_.getContent())}; + // Receive updates about the RuntimeInformation from the service endpoint. + const std::string queryId = ad_utility::UuidGenerator()(); + auto updateRuntimeInformation = [&]() { + try { + const std::string target = absl::StrCat("/watch/", queryId); + for (const auto& msg : + networkFunctions_.getRuntimeInfoFunction_(serviceUrl, target)) { + childRuntimeInformation_ = + std::make_shared(nlohmann::json::parse(msg)); + } + } catch (const boost::beast::system_error& se) { + // If the endpoint closes the connection we have received all messages + // -> ignore the error. + if (se.code() != boost::beast::websocket::error::closed) { + LOG(ERROR) << "SERVICE Websocket client: " << se.what() << '\n'; + } + } catch (std::exception& e) { + LOG(ERROR) << "SERVICE Websocket client: " << e.what() << '\n'; + } + }; + if (!runtimeInfoThread_) { + runtimeInfoThread_ = + std::make_unique(updateRuntimeInformation); + } + // Construct the query to be sent to the SPARQL endpoint. std::string variablesForSelectClause = absl::StrJoin( parsedServiceClause_.visibleVariables_, " ", Variable::AbslFormatter); @@ -138,10 +163,10 @@ ProtoResult Service::computeResultImpl([[maybe_unused]] bool requestLaziness) { << ", target: " << serviceUrl.target() << ")" << std::endl << serviceQuery << std::endl; - HttpOrHttpsResponse response = getResultFunction_( + HttpOrHttpsResponse response = networkFunctions_.getResultFunction_( serviceUrl, cancellationHandle_, boost::beast::http::verb::post, serviceQuery, "application/sparql-query", - "application/sparql-results+json"); + "application/sparql-results+json", {{"Query-Id"sv, queryId}}); auto throwErrorWithContext = [this, &response](std::string_view sv) { std::string ctx; diff --git a/src/engine/Service.h b/src/engine/Service.h index 8fef6f5d0e..b0a34c01fb 100644 --- a/src/engine/Service.h +++ b/src/engine/Service.h @@ -32,10 +32,14 @@ class Service : public Operation { public: // The type of the function used to obtain the results, see below. using GetResultFunction = std::function; + std::string_view, + const std::unordered_map&)>; + + // The type of the function used to obtain the RuntimeInformation. + using GetRuntimeInfoFunction = std::function( + const ad_utility::httpUtils::Url&, std::string_view)>; // Information on a Sibling operation. struct SiblingInfo { @@ -44,16 +48,29 @@ class Service : public Operation { std::string cacheKey_; }; + struct NetworkFunctions { + GetResultFunction getResultFunction_; + GetRuntimeInfoFunction getRuntimeInfoFunction_; + }; + private: // The parsed SERVICE clause. parsedQuery::Service parsedServiceClause_; - // The function used to obtain the result from the remote endpoint. - GetResultFunction getResultFunction_; + // The functions used to obtain the result and runtime information from the + // remote endpoint. + NetworkFunctions networkFunctions_; // Optional sibling information to be used in `getSiblingValuesClause`. std::optional siblingInfo_; + // RuntimeInformation of the service-query computation on the endpoint. + std::shared_ptr childRuntimeInformation_; + + // Thread for the websocket-client retrieving `childRuntimeInformation_` from + // the endpoint. + std::unique_ptr runtimeInfoThread_; + public: // Construct from parsed Service clause. // @@ -62,7 +79,15 @@ class Service : public Operation { // but in our tests (`ServiceTest`) we use a mock function that does not // require a running `HttpServer`. Service(QueryExecutionContext* qec, parsedQuery::Service parsedServiceClause, - GetResultFunction getResultFunction = sendHttpOrHttpsRequest); + NetworkFunctions networkFunctions = { + .getResultFunction_ = sendHttpOrHttpsRequest, + .getRuntimeInfoFunction_ = readHttpOrHttpsWebsocketStream}); + + ~Service() { + if (runtimeInfoThread_) { + runtimeInfoThread_->join(); + } + } // Methods inherited from base class `Operation`. std::string getDescriptor() const override; @@ -83,6 +108,13 @@ class Service : public Operation { // A SERVICE clause has no children. vector getChildren() override { return {}; } + cppcoro::generator> + getRuntimeInfoChildren() final { + if (childRuntimeInformation_) { + co_yield childRuntimeInformation_; + } + } + // Convert the given binding to TripleComponent. TripleComponent bindingToTripleComponent( const nlohmann::json& binding, diff --git a/src/util/LazyJsonParser.h b/src/util/LazyJsonParser.h index 23e99fa2d8..78cb60a448 100644 --- a/src/util/LazyJsonParser.h +++ b/src/util/LazyJsonParser.h @@ -66,7 +66,7 @@ class LazyJsonParser { // Context for the 3 parsing sections. struct BeforeArrayPath { - // Indices of the latest parsed literal, used to add keys to the curPath_. + // Indices of the latest parsed literal, used to add keys to the `curPath_`. struct LiteralView { size_t start_{0}; size_t length_{0}; diff --git a/src/util/http/HttpClient.cpp b/src/util/http/HttpClient.cpp index e09808c3c3..2e54455052 100644 --- a/src/util/http/HttpClient.cpp +++ b/src/util/http/HttpClient.cpp @@ -62,6 +62,14 @@ HttpClientImpl::HttpClientImpl(std::string_view host, // ____________________________________________________________________________ template HttpClientImpl::~HttpClientImpl() { + // If the stream was upgraded to a websocket connection, try to close it. + if (ws_) { + if (ws_->is_open()) { + ws_->close(beast::websocket::close_code::normal); + } + return; + } + // We are closing the HTTP connection and destroying the client. So it is // neither required nor possible in a safe way to report errors from a // destructor and we can simply ignore the error codes. @@ -87,7 +95,9 @@ HttpOrHttpsResponse HttpClientImpl::sendRequest( const boost::beast::http::verb& method, std::string_view host, std::string_view target, ad_utility::SharedCancellationHandle handle, std::string_view requestBody, std::string_view contentTypeHeader, - std::string_view acceptHeader) { + std::string_view acceptHeader, + const std::unordered_map& + customHeaders) { // Check that the client pointer is valid. AD_CORRECTNESS_CHECK(client); // Check that we have a stream (created in the constructor). @@ -102,6 +112,10 @@ HttpOrHttpsResponse HttpClientImpl::sendRequest( request.set(http::field::accept, acceptHeader); request.set(http::field::content_type, contentTypeHeader); request.set(http::field::content_length, std::to_string(requestBody.size())); + for (const auto& h : customHeaders) { + request.set(h.first, h.second); + } + request.body() = requestBody; auto wait = [&client, &handle]( @@ -197,12 +211,14 @@ HttpOrHttpsResponse sendHttpOrHttpsRequest( const ad_utility::httpUtils::Url& url, ad_utility::SharedCancellationHandle handle, const boost::beast::http::verb& method, std::string_view requestData, - std::string_view contentTypeHeader, std::string_view acceptHeader) { + std::string_view contentTypeHeader, std::string_view acceptHeader, + const std::unordered_map& + customHeaders) { auto sendRequest = [&]() -> HttpOrHttpsResponse { auto client = std::make_unique(url.host(), url.port()); return Client::sendRequest(std::move(client), method, url.host(), url.target(), std::move(handle), requestData, - contentTypeHeader, acceptHeader); + contentTypeHeader, acceptHeader, customHeaders); }; if (url.protocol() == Url::Protocol::HTTP) { return sendRequest.operator()(); @@ -211,3 +227,36 @@ HttpOrHttpsResponse sendHttpOrHttpsRequest( return sendRequest.operator()(); } } + +// ____________________________________________________________________________ +template +cppcoro::generator HttpClientImpl::readWebSocketStream( + std::unique_ptr> client, std::string_view host, + std::string_view target) { + AD_CORRECTNESS_CHECK(client->stream_); + client->ws_ = std::make_unique>( + std::move(*(client->stream_))); + client->ws_->handshake(host, target); + + beast::flat_buffer buffer; + for (;;) { + client->ws_->read(buffer); + co_yield beast::buffers_to_string(buffer.data()); + buffer.consume(buffer.size()); + } +} + +// ____________________________________________________________________________ +cppcoro::generator readHttpOrHttpsWebsocketStream( + const ad_utility::httpUtils::Url& url, std::string_view target) { + auto listen = [&]() -> cppcoro::generator { + auto client = std::make_unique(url.host(), url.port()); + return Client::readWebSocketStream(std::move(client), url.host(), target); + }; + if (url.protocol() == Url::Protocol::HTTP) { + return listen.operator()(); + } else { + AD_CORRECTNESS_CHECK(url.protocol() == Url::Protocol::HTTPS); + return listen.operator()(); + } +} diff --git a/src/util/http/HttpClient.h b/src/util/http/HttpClient.h index c4838bf651..24134fff6c 100644 --- a/src/util/http/HttpClient.h +++ b/src/util/http/HttpClient.h @@ -58,13 +58,21 @@ class HttpClientImpl { std::string_view target, ad_utility::SharedCancellationHandle handle, std::string_view requestBody = "", std::string_view contentTypeHeader = "text/plain", - std::string_view acceptHeader = "text/plain"); + std::string_view acceptHeader = "text/plain", + const std::unordered_map& + customHeaders = {}); // Simple way to establish a websocket connection boost::beast::http::response sendWebSocketHandshake(const boost::beast::http::verb& method, std::string_view host, std::string_view target); + // Upgrade the client to a websocket connection and return a generator of + // messages received from the server. + static cppcoro::generator readWebSocketStream( + std::unique_ptr client, std::string_view host, + std::string_view target); + private: // The connection stream and associated objects. See the implementation of // `openStream` for why we need all of them, and not just `stream_`. @@ -75,6 +83,7 @@ class HttpClientImpl { workGuard_ = boost::asio::make_work_guard(ioContext_); std::unique_ptr ssl_context_; std::unique_ptr stream_; + std::unique_ptr> ws_; }; // Instantiation for HTTP. @@ -94,4 +103,11 @@ HttpOrHttpsResponse sendHttpOrHttpsRequest( const boost::beast::http::verb& method = boost::beast::http::verb::get, std::string_view postData = "", std::string_view contentTypeHeader = "text/plain", - std::string_view acceptHeader = "text/plain"); + std::string_view acceptHeader = "text/plain", + const std::unordered_map& + customHeaders = {}); + +// Global convenience function to create a websocket connection to the given URL +// and return a generator of messages received from the server. +cppcoro::generator readHttpOrHttpsWebsocketStream( + const ad_utility::httpUtils::Url& url, std::string_view target); diff --git a/src/util/http/beast.h b/src/util/http/beast.h index 10a0e86ce2..30c931bd25 100644 --- a/src/util/http/beast.h +++ b/src/util/http/beast.h @@ -39,6 +39,7 @@ #include #include #include +#include // For boost versions prior to 1.81 this should be no-op #if defined BOOST_BEAST_VERSION && BOOST_BEAST_VERSION < 345 diff --git a/test/ConcurrentCacheTest.cpp b/test/ConcurrentCacheTest.cpp index 9dbfbde509..30cdf4916a 100644 --- a/test/ConcurrentCacheTest.cpp +++ b/test/ConcurrentCacheTest.cpp @@ -112,6 +112,7 @@ TEST(ConcurrentCache, sequentialComputation) { ASSERT_TRUE(a.getStorage().wlock()->_inProgress.empty()); } +// _____________________________________________________________________________ TEST(ConcurrentCache, sequentialPinnedComputation) { SimpleConcurrentLruCache a{3ul}; ad_utility::Timer t{ad_utility::Timer::Started}; @@ -143,6 +144,7 @@ TEST(ConcurrentCache, sequentialPinnedComputation) { ASSERT_TRUE(a.getStorage().wlock()->_inProgress.empty()); } +// _____________________________________________________________________________ TEST(ConcurrentCache, sequentialPinnedUpgradeComputation) { SimpleConcurrentLruCache a{3ul}; ad_utility::Timer t{ad_utility::Timer::Started}; @@ -175,6 +177,7 @@ TEST(ConcurrentCache, sequentialPinnedUpgradeComputation) { ASSERT_TRUE(a.getStorage().wlock()->_inProgress.empty()); } +// _____________________________________________________________________________ TEST(ConcurrentCache, concurrentComputation) { auto a = SimpleConcurrentLruCache(3ul); StartStopSignal signal; @@ -204,6 +207,7 @@ TEST(ConcurrentCache, concurrentComputation) { ASSERT_EQ(result2._cacheStatus, ad_utility::CacheStatus::computed); } +// _____________________________________________________________________________ TEST(ConcurrentCache, concurrentPinnedComputation) { auto a = SimpleConcurrentLruCache(3ul); StartStopSignal signal; @@ -235,6 +239,7 @@ TEST(ConcurrentCache, concurrentPinnedComputation) { ASSERT_EQ(result2._cacheStatus, ad_utility::CacheStatus::computed); } +// _____________________________________________________________________________ TEST(ConcurrentCache, concurrentPinnedUpgradeComputation) { auto a = SimpleConcurrentLruCache(3ul); StartStopSignal signal; @@ -267,6 +272,7 @@ TEST(ConcurrentCache, concurrentPinnedUpgradeComputation) { ASSERT_EQ(result._cacheStatus, ad_utility::CacheStatus::computed); } +// _____________________________________________________________________________ TEST(ConcurrentCache, abort) { auto a = SimpleConcurrentLruCache(3ul); StartStopSignal signal; @@ -293,6 +299,7 @@ TEST(ConcurrentCache, abort) { ASSERT_THROW(fut.get(), std::runtime_error); } +// _____________________________________________________________________________ TEST(ConcurrentCache, abortPinned) { auto a = SimpleConcurrentLruCache(3ul); StartStopSignal signal; @@ -318,6 +325,7 @@ TEST(ConcurrentCache, abortPinned) { ASSERT_THROW(fut.get(), std::runtime_error); } +// _____________________________________________________________________________ TEST(ConcurrentCache, cacheStatusToString) { using enum ad_utility::CacheStatus; EXPECT_EQ(toString(cachedNotPinned), "cached_not_pinned"); @@ -531,6 +539,7 @@ TEST(ConcurrentCache, testTryInsertIfNotPresentDoesWorkCorrectly) { expectContainsSingleElementAtKey0(true, "jkl"); } +// _____________________________________________________________________________ TEST(ConcurrentCache, computeButDontStore) { SimpleConcurrentLruCache cache{}; diff --git a/test/RuntimeInformationTest.cpp b/test/RuntimeInformationTest.cpp index 4a32c6f62f..0c16590d2e 100644 --- a/test/RuntimeInformationTest.cpp +++ b/test/RuntimeInformationTest.cpp @@ -102,6 +102,20 @@ TEST(RuntimeInformation, statusToString) { EXPECT_ANY_THROW(R::toString(static_cast(72))); } +// ________________________________________________________________ +TEST(RuntimeInformation, statusFromString) { + using enum RuntimeInformation::Status; + using R = RuntimeInformation; + EXPECT_EQ(R::fromString("fully materialized"), fullyMaterialized); + EXPECT_EQ(R::fromString("lazily materialized"), lazilyMaterialized); + EXPECT_EQ(R::fromString("not started"), notStarted); + EXPECT_EQ(R::fromString("optimized out"), optimizedOut); + EXPECT_EQ(R::fromString("failed"), failed); + EXPECT_EQ(R::fromString("failed because child failed"), + failedBecauseChildFailed); + EXPECT_ANY_THROW(R::fromString("")); +} + // ________________________________________________________________ TEST(RuntimeInformation, formatDetailValue) { std::ostringstream s; @@ -133,7 +147,7 @@ TEST(RuntimeInformation, formatDetailValue) { } // ________________________________________________________________ -TEST(RuntimeInformation, toStringAndJson) { +TEST(RuntimeInformation, stringAndJsonConversion) { RuntimeInformation child; child.descriptor_ = "child"; child.numCols_ = 2; @@ -227,4 +241,36 @@ TEST(RuntimeInformation, toStringAndJson) { } )"; ASSERT_EQ(j, nlohmann::ordered_json::parse(expectedJson)); + + // Check conversion from JSON to `RuntimeInformation`. + auto rtiFieldsCheck = [](const RuntimeInformation& a, + const RuntimeInformation& b) { + ASSERT_EQ(a.descriptor_, b.descriptor_); + ASSERT_EQ(a.numCols_, b.numCols_); + ASSERT_EQ(a.numRows_, b.numRows_); + ASSERT_EQ(a.columnNames_, b.columnNames_); + ASSERT_EQ(a.totalTime_, b.totalTime_); + ASSERT_EQ(a.details_, b.details_); + ASSERT_EQ(a.cacheStatus_, b.cacheStatus_); + ASSERT_EQ(a.status_, b.status_); + ASSERT_EQ(a.children_.size(), b.children_.size()); + }; + + auto rtiEqual = [rtiFieldsCheck](const RuntimeInformation& a, + const RuntimeInformation& b) { + rtiFieldsCheck(a, b); + for (size_t i = 0; i < a.children_.size(); ++i) { + rtiFieldsCheck(*a.children_[i], *b.children_[i]); + } + }; + + // Check 1: Normal RuntimeInformation. + RuntimeInformation rti; + from_json(j, rti); + rtiEqual(rti, parent); + + // Check 2: Missing keys or values with wrong type -> ignore and use defaults. + RuntimeInformation rti2; + from_json({"description", 42}, rti2); + ASSERT_NO_THROW(rtiEqual(rti2, RuntimeInformation())); } diff --git a/test/ServiceTest.cpp b/test/ServiceTest.cpp index 7f53202dbb..112191e584 100644 --- a/test/ServiceTest.cpp +++ b/test/ServiceTest.cpp @@ -42,7 +42,7 @@ class ServiceTest : public ::testing::Test { // // 1. It tests that the request method is POST, the content-type header is // `application/sparql-query`, and the accept header is - // `text/tab-separated-values` (our `Service` always does this). + // `application/sparql-results+json` (our `Service` always does this). // // 2. It tests that the host and port are as expected. // @@ -52,61 +52,74 @@ class ServiceTest : public ::testing::Test { // // NOTE: In a previous version of this test, we set up an actual test server. // The code can be found in the history of this PR. - static auto constexpr getResultFunctionFactory = + static auto constexpr getNetworkFunctionsFactory = [](std::string_view expectedUrl, std::string_view expectedSparqlQuery, std::string predefinedResult, boost::beast::http::status status = boost::beast::http::status::ok, std::string contentType = "application/sparql-results+json", std::exception_ptr mockException = - nullptr) -> Service::GetResultFunction { - return [=](const ad_utility::httpUtils::Url& url, - ad_utility::SharedCancellationHandle, - const boost::beast::http::verb& method, - std::string_view postData, std::string_view contentTypeHeader, - std::string_view acceptHeader) { - // Check that the request parameters are as expected. - // - // NOTE: The first three are hard-coded in `Service::computeResult`, but - // the host and port of the endpoint are derived from the IRI, so the last - // two checks are non-trivial. - EXPECT_EQ(method, boost::beast::http::verb::post); - EXPECT_EQ(contentTypeHeader, "application/sparql-query"); - EXPECT_EQ(acceptHeader, "application/sparql-results+json"); - EXPECT_EQ(url.asString(), expectedUrl); - - // Check that the whitespace-normalized POST data is the expected query. - // - // NOTE: a SERVICE clause specifies only the body of a SPARQL query, from - // which `Service::computeResult` has to construct a full SPARQL query by - // adding `SELECT ... WHERE`, so this checks something non-trivial. - std::string whitespaceNormalizedPostData = - std::regex_replace(std::string{postData}, std::regex{"\\s+"}, " "); - EXPECT_EQ(whitespaceNormalizedPostData, expectedSparqlQuery); - - if (mockException) { - std::rethrow_exception(mockException); - } + nullptr) -> Service::NetworkFunctions { + return { + .getResultFunction_ = + [=](const ad_utility::httpUtils::Url& url, + ad_utility::SharedCancellationHandle, + const boost::beast::http::verb& method, + std::string_view postData, std::string_view contentTypeHeader, + std::string_view acceptHeader, + const std::unordered_map& + customHeaders) { + // Check that the request parameters are as expected. + // + // NOTE: The first three are hard-coded in + // `Service::computeResult`, but the host and port of the endpoint + // are derived from the IRI, so the last two checks are + // non-trivial. + EXPECT_EQ(method, boost::beast::http::verb::post); + EXPECT_EQ(contentTypeHeader, "application/sparql-query"); + EXPECT_EQ(acceptHeader, "application/sparql-results+json"); + EXPECT_EQ(url.asString(), expectedUrl); + + // Check that the whitespace-normalized POST data is the expected + // query. + // + // NOTE: a SERVICE clause specifies only the body of a SPARQL + // query, from which `Service::computeResult` has to construct a + // full SPARQL query by adding `SELECT ... WHERE`, so this checks + // something non-trivial. + std::string whitespaceNormalizedPostData = std::regex_replace( + std::string{postData}, std::regex{"\\s+"}, " "); + EXPECT_EQ(whitespaceNormalizedPostData, expectedSparqlQuery); + + if (mockException) { + std::rethrow_exception(mockException); + } - auto body = - [](std::string result) -> cppcoro::generator> { - // Randomly slice the string to make tests more robust. - std::mt19937 rng{std::random_device{}()}; - - const std::string resultStr = result; - std::uniform_int_distribution distribution{ - 0, resultStr.length() / 2}; - - for (size_t start = 0; start < resultStr.length();) { - size_t size = distribution(rng); - std::string resultCopy{resultStr.substr(start, size)}; - co_yield std::as_writable_bytes(std::span{resultCopy}); - start += size; - } - }; - return (HttpOrHttpsResponse){.status_ = status, - .contentType_ = contentType, - .body_ = body(predefinedResult)}; - }; + auto body = [](std::string result) + -> cppcoro::generator> { + // Randomly slice the string to make tests more robust. + std::mt19937 rng{std::random_device{}()}; + + const std::string resultStr = result; + std::uniform_int_distribution distribution{ + 0, resultStr.length() / 2}; + + for (size_t start = 0; start < resultStr.length();) { + size_t size = distribution(rng); + std::string resultCopy{resultStr.substr(start, size)}; + co_yield std::as_writable_bytes(std::span{resultCopy}); + start += size; + } + }; + return (HttpOrHttpsResponse){.status_ = status, + .contentType_ = contentType, + .body_ = body(predefinedResult)}; + }, + .getRuntimeInfoFunction_ = + [=](const ad_utility::httpUtils::Url& url, + std::string_view target) -> cppcoro::generator { + EXPECT_EQ(url.asString(), expectedUrl); + co_yield "{}"; + }}; }; // The following method generates a JSON result from variables and rows for @@ -207,8 +220,8 @@ TEST_F(ServiceTest, computeResult) { bool silent = false) -> Result { Service s{testQec, silent ? parsedServiceClauseSilent : parsedServiceClause, - getResultFunctionFactory(expectedUrl, expectedSparqlQuery, - result, status, contentType)}; + getNetworkFunctionsFactory(expectedUrl, expectedSparqlQuery, + result, status, contentType)}; return s.computeResultOnlyForTesting(); }; @@ -328,7 +341,7 @@ TEST_F(ServiceTest, computeResult) { // CHECK 1b: Even if the SILENT-keyword is set, throw local errors. Service serviceSilent{ testQec, parsedServiceClauseSilent, - getResultFunctionFactory( + getNetworkFunctionsFactory( expectedUrl, expectedSparqlQuery, "{}", boost::beast::http::status::ok, "application/sparql-results+json", std::make_exception_ptr( @@ -341,7 +354,7 @@ TEST_F(ServiceTest, computeResult) { Service serviceSilent2{ testQec, parsedServiceClauseSilent, - getResultFunctionFactory( + getNetworkFunctionsFactory( expectedUrl, expectedSparqlQuery, "{}", boost::beast::http::status::ok, "application/sparql-results+json", std::make_exception_ptr( @@ -442,7 +455,7 @@ TEST_F(ServiceTest, computeResult) { Service serviceOperation5{ testQec, parsedServiceClause5, - getResultFunctionFactory( + getNetworkFunctionsFactory( expectedUrl, expectedSparqlQuery5, genJsonResult({"x", "y", "z2"}, {{"x", "y", "y"}, {"bla", "bli", "y"}, @@ -455,7 +468,7 @@ TEST_F(ServiceTest, computeResult) { // Check 7: Lazy computation Service lazyService{ testQec, parsedServiceClause, - getResultFunctionFactory( + getNetworkFunctionsFactory( expectedUrl, expectedSparqlQuery, genJsonResult({"x", "y"}, {{"bla", "bli"}, {"blu", "bla"}, {"bli", "blu"}}), @@ -467,7 +480,7 @@ TEST_F(ServiceTest, computeResult) { // Check 8: LazyJsonParser Error Service service8{ testQec, parsedServiceClause, - getResultFunctionFactory( + getNetworkFunctionsFactory( expectedUrl, expectedSparqlQuery, std::string(1'000'000, '0'), boost::beast::http::status::ok, "application/sparql-results+json")}; AD_EXPECT_THROW_WITH_MESSAGE( @@ -479,7 +492,7 @@ TEST_F(ServiceTest, computeResult) { Service service8b{ testQec, parsedServiceClause, - getResultFunctionFactory( + getNetworkFunctionsFactory( expectedUrl, expectedSparqlQuery, R"({"head": {"vars": ["a"]}, "results": {"bindings": [{"a": break}]}})", boost::beast::http::status::ok, "application/sparql-results+json")}; @@ -503,7 +516,7 @@ TEST_F(ServiceTest, getCacheKey) { Service service( testQec, parsedServiceClause, - getResultFunctionFactory( + getNetworkFunctionsFactory( "http://localhorst:80/api", "PREFIX doof: SELECT ?x ?y WHERE { }", genJsonResult( @@ -656,13 +669,26 @@ TEST_F(ServiceTest, precomputeSiblingResult) { "PREFIX doof: ", "{ }", true}, - getResultFunctionFactory( + getNetworkFunctionsFactory( "http://localhorst:80/api", "PREFIX doof: SELECT ?x ?y WHERE { }", genJsonResult({"x", "y"}, {{"a", "b"}}), boost::beast::http::status::ok, "application/sparql-results+json")); - auto service2 = std::make_shared(*service); + // auto service2 = std::make_shared(*service); + auto service2 = std::make_shared( + testQec, + parsedQuery::Service{ + {Variable{"?x"}, Variable{"?y"}}, + TripleComponent::Iri::fromIriref(""), + "PREFIX doof: ", + "{ }", + true}, + getNetworkFunctionsFactory( + "http://localhorst:80/api", + "PREFIX doof: SELECT ?x ?y WHERE { }", + genJsonResult({"x", "y"}, {{"a", "b"}}), + boost::beast::http::status::ok, "application/sparql-results+json")); // Adaptation of the Values class, allowing to compute lazy Results. class MockValues : public Values {