diff --git a/python/cirron/cirron.py b/python/cirron/cirron.py index 86f9242..56efd5b 100644 --- a/python/cirron/cirron.py +++ b/python/cirron/cirron.py @@ -44,6 +44,9 @@ def __repr__(self): ) +overhead = {} + + class Collector: cirron_lib = CDLL(lib_path) cirron_lib.start.argtypes = None @@ -55,6 +58,23 @@ def __init__(self): self._fd = None self.counters = Counter() + # We try to estimate what the overhead of the collector is, taking the minimum + # of 10 runs. + global overhead + if not overhead: + collector = Collector() + for _ in range(10): + with Collector() as collector: + pass + + for field, _ in Counter._fields_: + if field not in overhead: + overhead[field] = getattr(collector.counters, field) + else: + overhead[field] = min( + overhead[field], getattr(collector.counters, field) + ) + def __enter__(self): ret_val = Collector.cirron_lib.start() if ret_val == -1: @@ -80,21 +100,3 @@ def __exit__(self, exc_type, exc_value, traceback): ) else: setattr(self.counters, field, 0) - - -# We try to estimate what the overhead of the collector is, taking the minimum -# of 10 runs. -overhead = {} -collector = Collector() -o = {} -for _ in range(10): - with Collector() as collector: - pass - - for field, _ in Counter._fields_: - if field not in overhead: - o[field] = getattr(collector.counters, field) - else: - o[field] = min(overhead[field], getattr(collector.counters, field)) -overhead = o -del collector diff --git a/python/tests/tests.py b/python/tests/tests.py index 8445e12..da24896 100644 --- a/python/tests/tests.py +++ b/python/tests/tests.py @@ -1,7 +1,9 @@ import unittest import os + from cirron import Tracer, Collector + class Test(unittest.TestCase): def test_tracer(self): with Tracer() as t: @@ -9,7 +11,10 @@ def test_tracer(self): self.assertEqual(len(t.trace), 3) - @unittest.skipIf("GITHUB_ACTIONS" in os.environ, "As of 02/07/2024, GitHub Actions does not support perf_event_open.") + @unittest.skipIf( + "GITHUB_ACTIONS" in os.environ, + "As of 02/07/2024, GitHub Actions does not support perf_event_open.", + ) def test_collector(self): with Collector() as c: print(0)