Skip to content

Commit

Permalink
torch/monitor: merge Interval and FixedCount stats (pytorch#72009)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#72009

This simplifies the Stats interface by merging IntervalStat and FixedCountStat into a single Stat w/ a specific window size duration and an optional max samples per window. This allows for the original intention of having comparably sized windows (for statistical purposes) while also having a consistent output bandwidth.

Test Plan:
```
buck test //caffe2/test:monitor //caffe2/test/cpp/monitor:monitor
```

Reviewed By: kiukchung

Differential Revision: D33822956

fbshipit-source-id: a74782492421be613a1a8b14341b6fb2e8eeb8b4
(cherry picked from commit 293b94e)
  • Loading branch information
d4l3k authored and pytorchmergebot committed Jan 30, 2022
1 parent a18cfb7 commit 6208c28
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 217 deletions.
7 changes: 0 additions & 7 deletions docs/source/monitor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@ API Reference

.. autoclass:: torch.monitor.Stat
:members:

.. autoclass:: torch.monitor.IntervalStat
:members: +add, count, name
:special-members: __init__

.. autoclass:: torch.monitor.FixedCountStat
:members: +add, count, name
:special-members: __init__

.. autoclass:: torch.monitor.data_value_t
Expand Down
128 changes: 53 additions & 75 deletions test/cpp/monitor/test_counters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
using namespace torch::monitor;

TEST(MonitorTest, CounterDouble) {
FixedCountStat<double> a{
Stat<double> a{
"a",
{Aggregation::MEAN, Aggregation::COUNT},
std::chrono::milliseconds(100000),
2,
};
a.add(5.0);
Expand All @@ -27,9 +28,10 @@ TEST(MonitorTest, CounterDouble) {
}

TEST(MonitorTest, CounterInt64Sum) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::SUM},
std::chrono::milliseconds(100000),
2,
};
a.add(5);
Expand All @@ -42,9 +44,10 @@ TEST(MonitorTest, CounterInt64Sum) {
}

TEST(MonitorTest, CounterInt64Value) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::VALUE},
std::chrono::milliseconds(100000),
2,
};
a.add(5);
Expand All @@ -57,9 +60,10 @@ TEST(MonitorTest, CounterInt64Value) {
}

TEST(MonitorTest, CounterInt64Mean) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::MEAN},
std::chrono::milliseconds(100000),
2,
};
{
Expand All @@ -84,9 +88,10 @@ TEST(MonitorTest, CounterInt64Mean) {
}

TEST(MonitorTest, CounterInt64Count) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::COUNT},
std::chrono::milliseconds(100000),
2,
};
ASSERT_EQ(a.count(), 0);
Expand All @@ -103,9 +108,10 @@ TEST(MonitorTest, CounterInt64Count) {
}

TEST(MonitorTest, CounterInt64MinMax) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::MIN, Aggregation::MAX},
std::chrono::milliseconds(100000),
6,
};
{
Expand Down Expand Up @@ -134,9 +140,10 @@ TEST(MonitorTest, CounterInt64MinMax) {
}

TEST(MonitorTest, CounterInt64WindowSize) {
FixedCountStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(100000),
/*windowSize=*/3,
};
a.add(1);
Expand All @@ -145,8 +152,34 @@ TEST(MonitorTest, CounterInt64WindowSize) {
a.add(3);
ASSERT_EQ(a.count(), 0);

// after logging max for window, should be zero
a.add(4);
ASSERT_EQ(a.count(), 1);
ASSERT_EQ(a.count(), 0);

auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::COUNT, 3},
{Aggregation::SUM, 6},
};
ASSERT_EQ(stats, want);
}

TEST(MonitorTest, CounterInt64WindowSizeHuge) {
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::hours(24 * 365 * 10), // 10 years
/*windowSize=*/3,
};
a.add(1);
a.add(2);
ASSERT_EQ(a.count(), 2);
a.add(3);
ASSERT_EQ(a.count(), 0);

// after logging max for window, should be zero
a.add(4);
ASSERT_EQ(a.count(), 0);

auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
Expand All @@ -157,14 +190,15 @@ TEST(MonitorTest, CounterInt64WindowSize) {
}

template <typename T>
struct TestIntervalStat : public IntervalStat<T> {
uint64_t mockWindowId{0};
struct TestStat : public Stat<T> {
uint64_t mockWindowId{1};

TestIntervalStat(
TestStat(
std::string name,
std::initializer_list<Aggregation> aggregations,
std::chrono::milliseconds windowSize)
: IntervalStat<T>(name, aggregations, windowSize) {}
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: Stat<T>(name, aggregations, windowSize, maxSamples) {}

uint64_t currentWindowId() const override {
return mockWindowId;
Expand Down Expand Up @@ -192,10 +226,10 @@ struct HandlerGuard {
}
};

TEST(MonitorTest, IntervalStat) {
TEST(MonitorTest, Stat) {
HandlerGuard<AggregatingEventHandler> guard;

IntervalStat<int64_t> a{
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
Expand All @@ -213,10 +247,10 @@ TEST(MonitorTest, IntervalStat) {
ASSERT_LE(guard.handler->events.size(), 2);
}

TEST(MonitorTest, IntervalStatEvent) {
TEST(MonitorTest, StatEvent) {
HandlerGuard<AggregatingEventHandler> guard;

TestIntervalStat<int64_t> a{
TestStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
Expand Down Expand Up @@ -245,11 +279,11 @@ TEST(MonitorTest, IntervalStatEvent) {
ASSERT_EQ(e.data, data);
}

TEST(MonitorTest, IntervalStatEventDestruction) {
TEST(MonitorTest, StatEventDestruction) {
HandlerGuard<AggregatingEventHandler> guard;

{
TestIntervalStat<int64_t> a{
TestStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::hours(10),
Expand All @@ -269,59 +303,3 @@ TEST(MonitorTest, IntervalStatEventDestruction) {
};
ASSERT_EQ(e.data, data);
}

TEST(MonitorTest, FixedCountStatEvent) {
HandlerGuard<AggregatingEventHandler> guard;

FixedCountStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
3,
};
ASSERT_EQ(guard.handler->events.size(), 0);

a.add(1);
ASSERT_EQ(a.count(), 1);
a.add(2);
ASSERT_EQ(a.count(), 2);
ASSERT_EQ(guard.handler->events.size(), 0);

a.add(1);
ASSERT_EQ(a.count(), 0);
ASSERT_EQ(guard.handler->events.size(), 1);

Event e = guard.handler->events.at(0);
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 4L},
{"a.count", 3L},
};
ASSERT_EQ(e.data, data);
}

TEST(MonitorTest, FixedCountStatEventDestruction) {
HandlerGuard<AggregatingEventHandler> guard;

{
FixedCountStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
3,
};
ASSERT_EQ(guard.handler->events.size(), 0);
a.add(1);
ASSERT_EQ(a.count(), 1);
ASSERT_EQ(guard.handler->events.size(), 0);
}
ASSERT_EQ(guard.handler->events.size(), 1);

Event e = guard.handler->events.at(0);
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 1L},
{"a.count", 1L},
};
ASSERT_EQ(e.data, data);
}
18 changes: 8 additions & 10 deletions test/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from torch.monitor import (
Aggregation,
FixedCountStat,
IntervalStat,
Event,
log_event,
register_event_handler,
Expand All @@ -28,12 +26,11 @@ def handler(event):
events.append(event)

handle = register_event_handler(handler)
s = IntervalStat(
s = Stat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
timedelta(milliseconds=1),
)
self.assertIsInstance(s, Stat)
self.assertEqual(s.name, "asdf")

s.add(2)
Expand All @@ -48,12 +45,12 @@ def handler(event):
unregister_event_handler(handle)

def test_fixed_count_stat(self) -> None:
s = FixedCountStat(
s = Stat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
timedelta(hours=100),
3,
)
self.assertIsInstance(s, Stat)
s.add(1)
s.add(2)
name = s.name
Expand Down Expand Up @@ -126,10 +123,11 @@ def test_event_handler(self):
with self.create_summary_writer() as w:
handle = register_event_handler(TensorboardEventHandler(w))

s = FixedCountStat(
s = Stat(
"asdf",
(Aggregation.SUM, Aggregation.COUNT),
2,
timedelta(hours=1),
5,
)
for i in range(10):
s.add(i)
Expand All @@ -150,8 +148,8 @@ def test_event_handler(self):
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
}
self.assertEqual(scalars, {
"asdf.sum": [1, 5, 9, 13, 17],
"asdf.count": [2, 2, 2, 2, 2],
"asdf.sum": [10],
"asdf.count": [5],
})


Expand Down
17 changes: 4 additions & 13 deletions torch/_C/_monitor.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,12 @@ class Aggregation(Enum):
class Stat:
name: str
count: int
def add(self, v: float) -> None: ...
def get(self) -> Dict[Aggregation, float]: ...

class IntervalStat(Stat):
def __init__(
self,
name: str,
aggregations: List[Aggregation],
window_size: datetime.timedelta,
) -> None: ...

class FixedCountStat(Stat):
def __init__(
self, name: str, aggregations: List[Aggregation], window_size: int
self, name: str, aggregations: List[Aggregation], window_size: int,
max_samples: int = -1,
) -> None: ...
def add(self, v: float) -> None: ...
def get(self) -> Dict[Aggregation, float]: ...

class Event:
name: str
Expand Down
Loading

0 comments on commit 6208c28

Please sign in to comment.