From 6208c2800e52d3f60b7a70adbc0ce0cc70433301 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sun, 30 Jan 2022 15:17:44 -0800 Subject: [PATCH] torch/monitor: merge Interval and FixedCount stats (#72009) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72009 This simplifies the Stats interface by merging IntervalStat and FixedCountStat into a single Stat w/ a specific window size duration and an optional max samples per window. This allows for the original intention of having comparably sized windows (for statistical purposes) while also having a consistent output bandwidth. Test Plan: ``` buck test //caffe2/test:monitor //caffe2/test/cpp/monitor:monitor ``` Reviewed By: kiukchung Differential Revision: D33822956 fbshipit-source-id: a74782492421be613a1a8b14341b6fb2e8eeb8b4 (cherry picked from commit 293b94e0b4646521ffe047e5222c4bba7e688464) --- docs/source/monitor.rst | 7 -- test/cpp/monitor/test_counters.cpp | 128 ++++++++++++--------------- test/test_monitor.py | 18 ++-- torch/_C/_monitor.pyi | 17 +--- torch/csrc/monitor/counters.h | 133 ++++++++++++++--------------- torch/csrc/monitor/python_init.cpp | 72 +++++++--------- 6 files changed, 158 insertions(+), 217 deletions(-) diff --git a/docs/source/monitor.rst b/docs/source/monitor.rst index 8c4b8216b1d8d8..7952586da9c12e 100644 --- a/docs/source/monitor.rst +++ b/docs/source/monitor.rst @@ -30,13 +30,6 @@ API Reference .. autoclass:: torch.monitor.Stat :members: - -.. autoclass:: torch.monitor.IntervalStat - :members: +add, count, name - :special-members: __init__ - -.. autoclass:: torch.monitor.FixedCountStat - :members: +add, count, name :special-members: __init__ .. autoclass:: torch.monitor.data_value_t diff --git a/test/cpp/monitor/test_counters.cpp b/test/cpp/monitor/test_counters.cpp index 45f7b240801dce..9104e9bb251bad 100644 --- a/test/cpp/monitor/test_counters.cpp +++ b/test/cpp/monitor/test_counters.cpp @@ -8,9 +8,10 @@ using namespace torch::monitor; TEST(MonitorTest, CounterDouble) { - FixedCountStat a{ + Stat a{ "a", {Aggregation::MEAN, Aggregation::COUNT}, + std::chrono::milliseconds(100000), 2, }; a.add(5.0); @@ -27,9 +28,10 @@ TEST(MonitorTest, CounterDouble) { } TEST(MonitorTest, CounterInt64Sum) { - FixedCountStat a{ + Stat a{ "a", {Aggregation::SUM}, + std::chrono::milliseconds(100000), 2, }; a.add(5); @@ -42,9 +44,10 @@ TEST(MonitorTest, CounterInt64Sum) { } TEST(MonitorTest, CounterInt64Value) { - FixedCountStat a{ + Stat a{ "a", {Aggregation::VALUE}, + std::chrono::milliseconds(100000), 2, }; a.add(5); @@ -57,9 +60,10 @@ TEST(MonitorTest, CounterInt64Value) { } TEST(MonitorTest, CounterInt64Mean) { - FixedCountStat a{ + Stat a{ "a", {Aggregation::MEAN}, + std::chrono::milliseconds(100000), 2, }; { @@ -84,9 +88,10 @@ TEST(MonitorTest, CounterInt64Mean) { } TEST(MonitorTest, CounterInt64Count) { - FixedCountStat a{ + Stat a{ "a", {Aggregation::COUNT}, + std::chrono::milliseconds(100000), 2, }; ASSERT_EQ(a.count(), 0); @@ -103,9 +108,10 @@ TEST(MonitorTest, CounterInt64Count) { } TEST(MonitorTest, CounterInt64MinMax) { - FixedCountStat a{ + Stat a{ "a", {Aggregation::MIN, Aggregation::MAX}, + std::chrono::milliseconds(100000), 6, }; { @@ -134,9 +140,10 @@ TEST(MonitorTest, CounterInt64MinMax) { } TEST(MonitorTest, CounterInt64WindowSize) { - FixedCountStat a{ + Stat a{ "a", {Aggregation::COUNT, Aggregation::SUM}, + std::chrono::milliseconds(100000), /*windowSize=*/3, }; a.add(1); @@ -145,8 +152,34 @@ TEST(MonitorTest, CounterInt64WindowSize) { a.add(3); ASSERT_EQ(a.count(), 0); + // after logging max for window, should be zero a.add(4); - ASSERT_EQ(a.count(), 1); + ASSERT_EQ(a.count(), 0); + + auto stats = a.get(); + std::unordered_map want = { + {Aggregation::COUNT, 3}, + {Aggregation::SUM, 6}, + }; + ASSERT_EQ(stats, want); +} + +TEST(MonitorTest, CounterInt64WindowSizeHuge) { + Stat a{ + "a", + {Aggregation::COUNT, Aggregation::SUM}, + std::chrono::hours(24 * 365 * 10), // 10 years + /*windowSize=*/3, + }; + a.add(1); + a.add(2); + ASSERT_EQ(a.count(), 2); + a.add(3); + ASSERT_EQ(a.count(), 0); + + // after logging max for window, should be zero + a.add(4); + ASSERT_EQ(a.count(), 0); auto stats = a.get(); std::unordered_map want = { @@ -157,14 +190,15 @@ TEST(MonitorTest, CounterInt64WindowSize) { } template -struct TestIntervalStat : public IntervalStat { - uint64_t mockWindowId{0}; +struct TestStat : public Stat { + uint64_t mockWindowId{1}; - TestIntervalStat( + TestStat( std::string name, std::initializer_list aggregations, - std::chrono::milliseconds windowSize) - : IntervalStat(name, aggregations, windowSize) {} + std::chrono::milliseconds windowSize, + int64_t maxSamples = std::numeric_limits::max()) + : Stat(name, aggregations, windowSize, maxSamples) {} uint64_t currentWindowId() const override { return mockWindowId; @@ -192,10 +226,10 @@ struct HandlerGuard { } }; -TEST(MonitorTest, IntervalStat) { +TEST(MonitorTest, Stat) { HandlerGuard guard; - IntervalStat a{ + Stat a{ "a", {Aggregation::COUNT, Aggregation::SUM}, std::chrono::milliseconds(1), @@ -213,10 +247,10 @@ TEST(MonitorTest, IntervalStat) { ASSERT_LE(guard.handler->events.size(), 2); } -TEST(MonitorTest, IntervalStatEvent) { +TEST(MonitorTest, StatEvent) { HandlerGuard guard; - TestIntervalStat a{ + TestStat a{ "a", {Aggregation::COUNT, Aggregation::SUM}, std::chrono::milliseconds(1), @@ -245,11 +279,11 @@ TEST(MonitorTest, IntervalStatEvent) { ASSERT_EQ(e.data, data); } -TEST(MonitorTest, IntervalStatEventDestruction) { +TEST(MonitorTest, StatEventDestruction) { HandlerGuard guard; { - TestIntervalStat a{ + TestStat a{ "a", {Aggregation::COUNT, Aggregation::SUM}, std::chrono::hours(10), @@ -269,59 +303,3 @@ TEST(MonitorTest, IntervalStatEventDestruction) { }; ASSERT_EQ(e.data, data); } - -TEST(MonitorTest, FixedCountStatEvent) { - HandlerGuard guard; - - FixedCountStat a{ - "a", - {Aggregation::COUNT, Aggregation::SUM}, - 3, - }; - ASSERT_EQ(guard.handler->events.size(), 0); - - a.add(1); - ASSERT_EQ(a.count(), 1); - a.add(2); - ASSERT_EQ(a.count(), 2); - ASSERT_EQ(guard.handler->events.size(), 0); - - a.add(1); - ASSERT_EQ(a.count(), 0); - ASSERT_EQ(guard.handler->events.size(), 1); - - Event e = guard.handler->events.at(0); - ASSERT_EQ(e.name, "torch.monitor.Stat"); - ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{}); - std::unordered_map data{ - {"a.sum", 4L}, - {"a.count", 3L}, - }; - ASSERT_EQ(e.data, data); -} - -TEST(MonitorTest, FixedCountStatEventDestruction) { - HandlerGuard guard; - - { - FixedCountStat a{ - "a", - {Aggregation::COUNT, Aggregation::SUM}, - 3, - }; - ASSERT_EQ(guard.handler->events.size(), 0); - a.add(1); - ASSERT_EQ(a.count(), 1); - ASSERT_EQ(guard.handler->events.size(), 0); - } - ASSERT_EQ(guard.handler->events.size(), 1); - - Event e = guard.handler->events.at(0); - ASSERT_EQ(e.name, "torch.monitor.Stat"); - ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{}); - std::unordered_map data{ - {"a.sum", 1L}, - {"a.count", 1L}, - }; - ASSERT_EQ(e.data, data); -} diff --git a/test/test_monitor.py b/test/test_monitor.py index 0c3bda47f93635..8a9aabfcc7ad03 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -10,8 +10,6 @@ from torch.monitor import ( Aggregation, - FixedCountStat, - IntervalStat, Event, log_event, register_event_handler, @@ -28,12 +26,11 @@ def handler(event): events.append(event) handle = register_event_handler(handler) - s = IntervalStat( + s = Stat( "asdf", (Aggregation.SUM, Aggregation.COUNT), timedelta(milliseconds=1), ) - self.assertIsInstance(s, Stat) self.assertEqual(s.name, "asdf") s.add(2) @@ -48,12 +45,12 @@ def handler(event): unregister_event_handler(handle) def test_fixed_count_stat(self) -> None: - s = FixedCountStat( + s = Stat( "asdf", (Aggregation.SUM, Aggregation.COUNT), + timedelta(hours=100), 3, ) - self.assertIsInstance(s, Stat) s.add(1) s.add(2) name = s.name @@ -126,10 +123,11 @@ def test_event_handler(self): with self.create_summary_writer() as w: handle = register_event_handler(TensorboardEventHandler(w)) - s = FixedCountStat( + s = Stat( "asdf", (Aggregation.SUM, Aggregation.COUNT), - 2, + timedelta(hours=1), + 5, ) for i in range(10): s.add(i) @@ -150,8 +148,8 @@ def test_event_handler(self): tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items() } self.assertEqual(scalars, { - "asdf.sum": [1, 5, 9, 13, 17], - "asdf.count": [2, 2, 2, 2, 2], + "asdf.sum": [10], + "asdf.count": [5], }) diff --git a/torch/_C/_monitor.pyi b/torch/_C/_monitor.pyi index cac98e034f3390..47771f180ac6d3 100644 --- a/torch/_C/_monitor.pyi +++ b/torch/_C/_monitor.pyi @@ -15,21 +15,12 @@ class Aggregation(Enum): class Stat: name: str count: int - def add(self, v: float) -> None: ... - def get(self) -> Dict[Aggregation, float]: ... - -class IntervalStat(Stat): - def __init__( - self, - name: str, - aggregations: List[Aggregation], - window_size: datetime.timedelta, - ) -> None: ... - -class FixedCountStat(Stat): def __init__( - self, name: str, aggregations: List[Aggregation], window_size: int + self, name: str, aggregations: List[Aggregation], window_size: int, + max_samples: int = -1, ) -> None: ... + def add(self, v: float) -> None: ... + def get(self) -> Dict[Aggregation, float]: ... class Event: name: str diff --git a/torch/csrc/monitor/counters.h b/torch/csrc/monitor/counters.h index e2a6d60861eac9..5ef83270a2a4b3 100644 --- a/torch/csrc/monitor/counters.h +++ b/torch/csrc/monitor/counters.h @@ -69,10 +69,19 @@ void TORCH_API unregisterStat(Stat* stat); void TORCH_API unregisterStat(Stat* stat); } // namespace detail -// Stat is a base class for stats. These stats are used to compute summary -// statistics in a performant way over repeating intervals. When the window -// closes the stats are logged via the event handlers as a `torch.monitor.Stat` -// event. +// Stat is used to compute summary statistics in a performant way over fixed +// intervals. Stat logs the statistics as an Event once every `windowSize` +// duration. When the window closes the stats are logged via the event handlers +// as a `torch.monitor.Stat` event. +// +// `windowSize` should be set to something relatively high to avoid a huge +// number of events being logged. Ex: 60s. Stat uses millisecond precision. +// +// If maxSamples is set, the stat will cap the number of samples per window by +// discarding `add` calls once `maxSamples` adds have occurred. If it's not set, +// all `add` calls during the window will be included. +// This is an optional field to make aggregations more directly comparable +// across windows when the number of samples might vary. // // Stats support double and int64_t data types depending on what needs to be // logged and needs to be templatized with one of them. @@ -91,8 +100,27 @@ class Stat { }; public: - Stat(std::string name, std::vector aggregations) - : name_(std::move(name)), aggregations_(merge(aggregations)) { + Stat( + std::string name, + std::initializer_list aggregations, + std::chrono::milliseconds windowSize, + int64_t maxSamples = std::numeric_limits::max()) + : name_(std::move(name)), + aggregations_(merge(aggregations)), + windowSize_(windowSize), + maxSamples_(maxSamples) { + detail::registerStat(this); + } + + Stat( + std::string name, + std::vector aggregations, + std::chrono::milliseconds windowSize, + int64_t maxSamples = std::numeric_limits::max()) + : name_(std::move(name)), + aggregations_(merge(aggregations)), + windowSize_(windowSize), + maxSamples_(maxSamples) { detail::registerStat(this); } @@ -110,6 +138,10 @@ class Stat { std::lock_guard guard(mu_); maybeLogLocked(); + if (alreadyLogged()) { + return; + } + if (aggregations_.test(static_cast(Aggregation::VALUE))) { current_.value = v; } @@ -150,7 +182,29 @@ class Stat { } protected: - virtual void maybeLogLocked() = 0; + virtual uint64_t currentWindowId() const { + std::chrono::milliseconds now = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()); + + // always returns a currentWindowId of at least 1 to avoid 0 window issues + return (now / windowSize_) + 1; + } + + private: + bool alreadyLogged() { + return lastLoggedWindowId_ == currentWindowId(); + } + + void maybeLogLocked() { + auto windowId = currentWindowId(); + bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_; + if (shouldLog && !alreadyLogged()) { + logLocked(); + lastLoggedWindowId_ = windowId_; + windowId_ = windowId; + } + } void logLocked() { prev_ = current_; @@ -215,72 +269,11 @@ class Stat { std::mutex mu_; Values current_; Values prev_; -}; - -// IntervalStat is a Stat that logs the stat once every `windowSize` duration. -// This should be set to something relatively high to avoid a huge number of -// events being logged. Ex: 60s. -template -class IntervalStat : public Stat { - public: - IntervalStat( - std::string name, - std::initializer_list aggregations, - std::chrono::milliseconds windowSize) - : Stat(std::move(name), aggregations), windowSize_(windowSize) {} - - IntervalStat( - std::string name, - std::vector aggregations, - std::chrono::milliseconds windowSize) - : Stat(std::move(name), aggregations), windowSize_(windowSize) {} - - protected: - virtual uint64_t currentWindowId() const { - auto now = std::chrono::steady_clock::now().time_since_epoch(); - return now / windowSize_; - } - - private: - void maybeLogLocked() override { - auto windowId = currentWindowId(); - if (windowId_ != windowId) { - Stat::logLocked(); - windowId_ = windowId; - } - } uint64_t windowId_{0}; + uint64_t lastLoggedWindowId_{0}; const std::chrono::milliseconds windowSize_; -}; - -// FixedCountStat is a Stat that logs the stat every `windowSize` number of add -// calls. For high performance stats this window size should be fairly large to -// ensure that the event logging frequency is in the range of 1s to 60s under -// normal usage. Core stats should error on the side of less frequent. -template -class FixedCountStat : public Stat { - public: - FixedCountStat( - std::string name, - std::initializer_list aggregations, - int64_t windowSize) - : Stat(std::move(name), aggregations), windowSize_(windowSize) {} - - FixedCountStat( - std::string name, - std::vector aggregations, - int64_t windowSize) - : Stat(std::move(name), aggregations), windowSize_(windowSize) {} - - private: - void maybeLogLocked() override { - if (Stat::current_.count >= windowSize_) { - Stat::logLocked(); - } - } - - const int64_t windowSize_; + const int64_t maxSamples_; }; } // namespace monitor } // namespace torch diff --git a/torch/csrc/monitor/python_init.cpp b/torch/csrc/monitor/python_init.cpp index 1d4efadb43e883..c27816321900d3 100644 --- a/torch/csrc/monitor/python_init.cpp +++ b/torch/csrc/monitor/python_init.cpp @@ -133,8 +133,37 @@ void initMonitorBindings(PyObject* module) { m, "Stat", R"DOC( - Parent class for all aggregating stat implementations. + Stat is used to compute summary statistics in a performant way over + fixed intervals. Stat logs the statistics as an Event once every + ``window_size`` duration. When the window closes the stats are logged + via the event handlers as a ``torch.monitor.Stat`` event. + + ``window_size`` should be set to something relatively high to avoid a + huge number of events being logged. Ex: 60s. Stat uses millisecond + precision. + + If ``max_samples`` is set, the stat will cap the number of samples per + window by discarding `add` calls once ``max_samples`` adds have + occurred. If it's not set, all ``add`` calls during the window will be + included. This is an optional field to make aggregations more directly + comparable across windows when the number of samples might vary. + + When the Stat is destructed it will log any remaining data even if the + window hasn't elapsed. )DOC") + .def( + py::init< + std::string, + std::vector, + std::chrono::milliseconds, + int64_t>(), + py::arg("name"), + py::arg("aggregations"), + py::arg("window_size"), + py::arg("max_samples") = std::numeric_limits::max(), + R"DOC( + Constructs the ``Stat``. + )DOC") .def( "add", &Stat::add, @@ -165,47 +194,6 @@ void initMonitorBindings(PyObject* module) { once the event has been logged. )DOC"); - py::class_, Stat>( - m, - "IntervalStat", - R"DOC( - IntervalStat is a Stat that logs once every ``window_size`` duration. This - should be set to something relatively high to avoid a huge number of - events being logged. Ex: 60s. - The stat will be logged as an event on the next ``add`` call after the - window ends. - )DOC") - .def( - py::init< - std::string, - std::vector, - std::chrono::milliseconds>(), - py::arg("name"), - py::arg("aggregations"), - py::arg("window_size"), - R"DOC( - Constructs the ``IntervalStat``. - )DOC"); - - py::class_, Stat>( - m, - "FixedCountStat", - R"DOC( - FixedCountStat is a Stat that logs every ``window_size`` number of - ``add`` calls. For high performance stats this window size should be - fairly large to ensure that the event logging frequency is in the range - of 1s to 60s under normal usage. Core stats should error on the side of - logging less frequently. - )DOC") - .def( - py::init, int64_t>(), - py::arg("name"), - py::arg("aggregations"), - py::arg("window_size"), - R"DOC( - Constructs the ``FixedCountStat``. - )DOC"); - py::class_( m, "Event",