From 58fd284f05bc91700aa74e885065b9f633280de1 Mon Sep 17 00:00:00 2001 From: Teddy Reed Date: Sun, 14 Aug 2016 15:41:53 -0700 Subject: [PATCH] Improve dispatcher tests (#2358) This improves dispatcher tests by allowing units to act like component tests and use embedded std::thread-based osquery APIs. A unit may force a 'service' to run by joining the Dispatcher before deconstructing. --- include/osquery/dispatcher.h | 24 +++- osquery/dispatcher/dispatcher.cpp | 32 ++++-- osquery/dispatcher/tests/dispatcher_tests.cpp | 108 +++++++++++++++++- osquery/logger/plugins/buffered.cpp | 33 +++--- osquery/logger/plugins/buffered.h | 4 + .../logger/plugins/tests/buffered_tests.cpp | 42 +++---- 6 files changed, 182 insertions(+), 61 deletions(-) diff --git a/include/osquery/dispatcher.h b/include/osquery/dispatcher.h index d515b9f2e67..9a97ed4b73b 100644 --- a/include/osquery/dispatcher.h +++ b/include/osquery/dispatcher.h @@ -79,6 +79,10 @@ class InterruptableRunnable { /// Put the runnable into an interruptible sleep. virtual void pauseMilli(std::chrono::milliseconds milli); + private: + /// Testing only, the interruptible will bypass initial interruption check. + void mustRun() { bypass_check_ = true; } + private: /** * @brief Protect interruption checking and resource tear down. @@ -94,6 +98,19 @@ class InterruptableRunnable { /// Use an interruption point to exit a pause if the thread was interrupted. RunnerInterruptPoint point_; + + private: + /// Testing only, track the interruptible check for interruption. + bool checked_{false}; + + /// Testing only, require that the interruptible bypass the first check. + std::atomic bypass_check_{false}; + + private: + FRIEND_TEST(DispatcherTests, test_run); + FRIEND_TEST(DispatcherTests, test_independent_run); + FRIEND_TEST(DispatcherTests, test_interruption); + FRIEND_TEST(BufferedLogForwarderTests, test_async); }; class InternalRunnable : private boost::noncopyable, @@ -181,6 +198,10 @@ class Dispatcher : private boost::noncopyable { /// When a service ends, it will remove itself from the dispatcher. static void removeService(const InternalRunnable* service); + private: + /// For testing only, reset the stopping status for unittests. + void resetStopping() { stopping_ = false; } + private: /// The set of shared osquery service threads. std::vector service_threads_; @@ -209,6 +230,7 @@ class Dispatcher : private boost::noncopyable { private: friend class InternalRunnable; - friend class ExtensionsTest; + friend class ExtensionsTests; + friend class DispatcherTests; }; } diff --git a/osquery/dispatcher/dispatcher.cpp b/osquery/dispatcher/dispatcher.cpp index 32ea6c9c3fb..73a83d3060b 100644 --- a/osquery/dispatcher/dispatcher.cpp +++ b/osquery/dispatcher/dispatcher.cpp @@ -51,6 +51,11 @@ void InterruptableRunnable::interrupt() { bool InterruptableRunnable::interrupted() { WriteLock lock(stopping_); + // A small conditional to force-skip an interruption check, used in testing. + if (bypass_check_ && !checked_) { + checked_ = true; + return false; + } return interrupted_; } @@ -106,14 +111,26 @@ void Dispatcher::removeService(const InternalRunnable* service) { self.services_.end()); } +inline static void assureRun(const InternalRunnableRef& service) { + while (true) { + // Wait for each thread's entry point (start) meaning the thread context + // was allocated and (run) was called by std::thread started. + if (service->hasRun()) { + break; + } + // We only need to check if std::terminate is called very quickly after + // the std::thread is created. + sleepFor(20); + } +} + void Dispatcher::joinServices() { auto& self = instance(); DLOG(INFO) << "Thread: " << std::this_thread::get_id() << " requesting a join"; WriteLock join_lock(self.join_mutex_); + for (auto& thread : self.service_threads_) { - // Boost threads would have been interrupted, and joined using the - // provided thread instance. thread->join(); DLOG(INFO) << "Service thread: " << thread.get() << " has joined"; } @@ -133,16 +150,7 @@ void Dispatcher::stopServices() { DLOG(INFO) << "Thread: " << std::this_thread::get_id() << " requesting a stop"; for (const auto& service : self.services_) { - while (true) { - // Wait for each thread's entry point (start) meaning the thread context - // was allocated and (run) was called by std::thread started. - if (service->hasRun()) { - break; - } - // We only need to check if std::terminate is called very quickly after - // the std::thread is created. - sleepFor(20); - } + assureRun(service); service->interrupt(); DLOG(INFO) << "Service: " << service.get() << " has been interrupted"; } diff --git a/osquery/dispatcher/tests/dispatcher_tests.cpp b/osquery/dispatcher/tests/dispatcher_tests.cpp index 4f1e2c765a4..4bc576f5dae 100644 --- a/osquery/dispatcher/tests/dispatcher_tests.cpp +++ b/osquery/dispatcher/tests/dispatcher_tests.cpp @@ -15,7 +15,7 @@ namespace osquery { class DispatcherTests : public testing::Test { - void TearDown() override {} + void TearDown() override { Dispatcher::instance().resetStopping(); } }; TEST_F(DispatcherTests, test_singleton) { @@ -26,8 +26,108 @@ TEST_F(DispatcherTests, test_singleton) { class TestRunnable : public InternalRunnable { public: - int* i; - explicit TestRunnable(int* i) : i(i) {} - virtual void start() { ++*i; } + explicit TestRunnable() {} + + virtual void start() override { + WriteLock lock(mutex_); + ++i; + } + + void reset() { + WriteLock lock(mutex_); + i = 0; + } + + size_t count() { + WriteLock lock(mutex_); + return i; + } + + private: + static size_t i; + + private: + Mutex mutex_; }; + +size_t TestRunnable::i{0}; + +TEST_F(DispatcherTests, test_service_count) { + auto runnable = std::make_shared(); + + auto service_count = Dispatcher::instance().serviceCount(); + // The service exits after incrementing. + auto s = Dispatcher::addService(runnable); + EXPECT_TRUE(s); + + // Wait for the service to stop. + Dispatcher::joinServices(); + + // Make sure the service is removed. + EXPECT_EQ(service_count, Dispatcher::instance().serviceCount()); +} + +TEST_F(DispatcherTests, test_run) { + auto runnable = std::make_shared(); + runnable->mustRun(); + runnable->reset(); + + // The service exits after incrementing. + Dispatcher::addService(runnable); + Dispatcher::joinServices(); + EXPECT_EQ(1U, runnable->count()); + EXPECT_TRUE(runnable->hasRun()); + + // This runnable cannot be executed again. + auto s = Dispatcher::addService(runnable); + EXPECT_FALSE(s); + + Dispatcher::joinServices(); + EXPECT_EQ(1U, runnable->count()); +} + +TEST_F(DispatcherTests, test_independent_run) { + // Nothing stops two instances of the same service from running. + auto r1 = std::make_shared(); + auto r2 = std::make_shared(); + r1->mustRun(); + r2->mustRun(); + r1->reset(); + + Dispatcher::addService(r1); + Dispatcher::addService(r2); + Dispatcher::joinServices(); + + EXPECT_EQ(2U, r1->count()); +} + +class BlockingTestRunnable : public InternalRunnable { + public: + explicit BlockingTestRunnable() {} + + virtual void start() override { + // Wow that's a long sleep! + pauseMilli(100 * 1000); + } +}; + +TEST_F(DispatcherTests, test_interruption) { + auto r1 = std::make_shared(); + r1->mustRun(); + Dispatcher::addService(r1); + + // This service would normally wait for 100 seconds. + r1->interrupt(); + + Dispatcher::joinServices(); + EXPECT_TRUE(r1->hasRun()); +} + +TEST_F(DispatcherTests, test_stop_dispatcher) { + Dispatcher::stopServices(); + + auto r1 = std::make_shared(); + auto s = Dispatcher::addService(r1); + EXPECT_FALSE(s); +} } diff --git a/osquery/logger/plugins/buffered.cpp b/osquery/logger/plugins/buffered.cpp index 50278469e91..3cb04a74dd0 100644 --- a/osquery/logger/plugins/buffered.cpp +++ b/osquery/logger/plugins/buffered.cpp @@ -57,12 +57,12 @@ void BufferedLogForwarder::check() { // For each index, accumulate the log line into the result or status set. std::vector results, statuses; iterate(indexes, ([&results, &statuses, this](std::string& index) { - std::string value; - auto& target = isResultIndex(index) ? results : statuses; - if (getDatabaseValue(kLogs, index, value)) { - target.push_back(std::move(value)); - } - })); + std::string value; + auto& target = isResultIndex(index) ? results : statuses; + if (getDatabaseValue(kLogs, index, value)) { + target.push_back(std::move(value)); + } + })); // If any results/statuses were found in the flushed buffer, send. if (results.size() > 0) { @@ -72,11 +72,11 @@ void BufferedLogForwarder::check() { } else { // Clear the results logs once they were sent. iterate(indexes, ([this](std::string& index) { - if (!isResultIndex(index)) { - return; - } - deleteValueWithCount(kLogs, index); - })); + if (!isResultIndex(index)) { + return; + } + deleteValueWithCount(kLogs, index); + })); } } @@ -87,11 +87,11 @@ void BufferedLogForwarder::check() { } else { // Clear the status logs once they were sent. iterate(indexes, ([this](std::string& index) { - if (!isStatusIndex(index)) { - return; - } - deleteValueWithCount(kLogs, index); - })); + if (!isStatusIndex(index)) { + return; + } + deleteValueWithCount(kLogs, index); + })); } } @@ -156,7 +156,6 @@ void BufferedLogForwarder::purge() { LOG(ERROR) << "Error deleting value during buffered log purge"; } }); - } void BufferedLogForwarder::start() { diff --git a/osquery/logger/plugins/buffered.h b/osquery/logger/plugins/buffered.h index 7ffea996a13..d31ab98ee39 100644 --- a/osquery/logger/plugins/buffered.h +++ b/osquery/logger/plugins/buffered.h @@ -146,6 +146,7 @@ class BufferedLogForwarder : public InternalRunnable { protected: /// Return whether the string is a result index bool isResultIndex(const std::string& index); + /// Return whether the string is a status index bool isStatusIndex(const std::string& index); @@ -156,11 +157,13 @@ class BufferedLogForwarder : public InternalRunnable { protected: /// Generate a result index string to use with the backing store std::string genResultIndex(size_t time = 0); + /// Generate a status index string to use with the backing store std::string genStatusIndex(size_t time = 0); private: std::string genIndexPrefix(bool results); + std::string genIndex(bool results, size_t time = 0); /** @@ -170,6 +173,7 @@ class BufferedLogForwarder : public InternalRunnable { Status addValueWithCount(const std::string& domain, const std::string& key, const std::string& value); + /** * @brief Delete a database value while maintaining count * diff --git a/osquery/logger/plugins/tests/buffered_tests.cpp b/osquery/logger/plugins/tests/buffered_tests.cpp index 0e449627799..316b77e01d3 100644 --- a/osquery/logger/plugins/tests/buffered_tests.cpp +++ b/osquery/logger/plugins/tests/buffered_tests.cpp @@ -114,10 +114,9 @@ TEST_F(BufferedLogForwarderTests, test_basic) { runner.logString("baz"); EXPECT_CALL(runner, send(ElementsAre("bar", "baz"), "result")) .WillOnce(Return(Status(0))); - EXPECT_CALL( - runner, - send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status")) - .WillOnce(Return(Status(0))); + EXPECT_CALL(runner, + send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), + "status")).WillOnce(Return(Status(0))); runner.check(); // This call should not result in sending again runner.check(); @@ -143,16 +142,14 @@ TEST_F(BufferedLogForwarderTests, test_retry) { runner.logString("bar"); EXPECT_CALL(runner, send(ElementsAre("foo", "bar"), "result")) .WillOnce(Return(Status(0))); - EXPECT_CALL( - runner, - send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status")) - .WillOnce(Return(Status(1, "fail"))); + EXPECT_CALL(runner, + send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), + "status")).WillOnce(Return(Status(1, "fail"))); runner.check(); - EXPECT_CALL( - runner, - send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status")) - .WillOnce(Return(Status(0))); + EXPECT_CALL(runner, + send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), + "status")).WillOnce(Return(Status(0))); runner.check(); // This call should not send again because the previous was successful @@ -215,22 +212,14 @@ TEST_F(BufferedLogForwarderTests, test_multiple) { TEST_F(BufferedLogForwarderTests, test_async) { auto runner = std::make_shared>( "mock", kLogPeriod); - Dispatcher::addService(runner); + runner->mustRun(); EXPECT_CALL(*runner, send(ElementsAre("foo"), "result")) .WillOnce(Return(Status(0))); runner->logString("foo"); - std::this_thread::sleep_for(5 * kLogPeriod); - EXPECT_CALL(*runner, send(ElementsAre("bar"), "result")) - .Times(3) - .WillOnce(Return(Status(1, "fail"))) - .WillOnce(Return(Status(1, "fail again"))) - .WillOnce(Return(Status(0))); - runner->logString("bar"); - std::this_thread::sleep_for(15 * kLogPeriod); - - Dispatcher::stopServices(); + Dispatcher::addService(runner); + runner->interrupt(); Dispatcher::joinServices(); } @@ -325,10 +314,9 @@ TEST_F(BufferedLogForwarderTests, test_purge_max) { EXPECT_CALL(runner, send(ElementsAre("foo", "bar", "baz"), "result")) .WillOnce(Return(Status(1, "fail"))); - EXPECT_CALL( - runner, - send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status")) - .WillOnce(Return(Status(1, "fail"))); + EXPECT_CALL(runner, + send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), + "status")).WillOnce(Return(Status(1, "fail"))); runner.check(); EXPECT_CALL(runner, send(ElementsAre("baz"), "result"))