Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small improvements to tracking #1894

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
refactor tracking progress
getzze committed Oct 25, 2024
commit 1d4b92a25bb43e6b8bf3bb1a8195b19bf756683f
152 changes: 80 additions & 72 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import sys
from collections import deque
from time import time
from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple
from typing import Callable, Deque, Dict, Iterable, Iterator, List, Optional, Tuple

import attr
import cv2
@@ -549,6 +549,82 @@ def run_step(self, lf: LabeledFrame) -> LabeledFrame:
instances=self.track(**track_args),
)

def _run_tracker_json(
self,
frames: List[LabeledFrame],
max_length: int = 30,
) -> Iterator[LabeledFrame]:
n_total = len(frames)
n_processed = 0
n_batch = 0
n_recent = deque(maxlen=max_length)
elapsed_recent = deque(maxlen=max_length)
last_report = time()
t0_all = time()
t0_batch = time()

for lf in frames:
new_lf = self.run_step(lf)

# Track timing and progress
elapsed_all = time() - t0_all
n_processed += 1
n_batch += 1

# Report
if time() > last_report + self.report_period:
elapsed_batch = time() - t0_batch
t0_batch = time()

# Compute recent rate
n_recent.append(n_batch)
n_batch = 0
elapsed_recent.append(elapsed_batch)
rate = sum(n_recent) / sum(elapsed_recent)
eta = (n_total - n_processed) / rate

print(
json.dumps(
{
"n_processed": n_processed,
"n_total": n_total,
"elapsed": elapsed_all,
"rate": rate,
"eta": eta,
}
),
flush=True,
)
last_report = time()

yield new_lf

def _run_tracker_rich(self, frames: List[LabeledFrame]) -> Iterator[LabeledFrame]:
with rich.progress.Progress(
"{task.description}",
rich.progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
"ETA:",
rich.progress.TimeRemainingColumn(),
RateColumn(),
auto_refresh=False,
refresh_per_second=self.report_rate,
speed_estimate_period=5,
) as progress:
task = progress.add_task("Tracking...", total=len(frames))
last_report = time()
for lf in frames:
new_lf = self.run_step(lf)

progress.update(task, advance=1)

# Handle refreshing manually to support notebooks.
if time() > last_report + self.report_period:
progress.refresh()
last_report = time()

yield new_lf

def run_tracker(
self,
frames: List[LabeledFrame],
@@ -570,84 +646,16 @@ def run_tracker(
return frames

verbosity = verbosity or self.verbosity
new_lfs = []

# Run tracking on every frame
if verbosity == "rich":
with rich.progress.Progress(
"{task.description}",
rich.progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
"ETA:",
rich.progress.TimeRemainingColumn(),
RateColumn(),
auto_refresh=False,
refresh_per_second=self.report_rate,
speed_estimate_period=5,
) as progress:
task = progress.add_task("Tracking...", total=len(frames))
last_report = time()
for lf in frames:
new_lf = self.run_step(lf)
new_lfs.append(new_lf)

progress.update(task, advance=1)

# Handle refreshing manually to support notebooks.
elapsed_since_last_report = time() - last_report
if elapsed_since_last_report > self.report_period:
progress.refresh()
new_lfs = list(self._run_tracker_rich(frames))

elif verbosity == "json":
n_total = len(frames)
n_processed = 0
n_batch = 0
elapsed_all = 0
n_recent = deque(maxlen=30)
elapsed_recent = deque(maxlen=30)
last_report = time()
t0_all = time()
t0_batch = time()
for lf in frames:
new_lf = self.run_step(lf)
new_lfs.append(new_lf)

# Track timing and progress.
elapsed_all = time() - t0_all
n_processed += 1
n_batch += 1

# Report.
elapsed_since_last_report = time() - last_report
if elapsed_since_last_report > self.report_period:
elapsed_batch = time() - t0_batch
t0_batch = time()

# Compute recent rate.
n_recent.append(n_batch)
n_batch = 0
elapsed_recent.append(elapsed_batch)
rate = sum(n_recent) / sum(elapsed_recent)
eta = (n_total - n_processed) / rate

print(
json.dumps(
{
"n_processed": n_processed,
"n_total": n_total,
"elapsed": elapsed_all,
"rate": rate,
"eta": eta,
}
),
flush=True,
)
last_report = time()
new_lfs = list(self._run_tracker_json(frames))

else:
for lf in frames:
new_lf = self.run_step(lf)
new_lfs.append(new_lf)
new_lfs = list(self.run_step(lf) for lf in frames)

# Run final_pass
if final_pass: