diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index 1bb930e58..ed05b91f8 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -424,10 +424,6 @@ inference: This tracker "shifts" instances from previous frames using optical flow before matching instances in each frame to the <i>shifted</i> instances from prior frames.' - # - name: tracking.max_tracking - # label: Limit max number of tracks - # type: bool - default: false - name: tracking.max_tracks label: Max number of tracks type: optional_int @@ -459,10 +455,12 @@ inference: none_label: Use max (non-robust) range: 0,1 default: 0.95 - # - name: tracking.save_shifted_instances - # label: Save shifted instances - # type: bool - # default: false + - name: tracking.save_shifted_instances + label: Save shifted instances + help: 'Save the flow-shifted instances between elapsed frames. It improves + instance matching at the cost of using a bit more of memory.' + type: bool + default: false - type: text text: '<b>Kalman filter-based tracking</b>:<br /> Uses the above tracking options to track instances for an initial diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 2c2617036..7a9c94358 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -733,7 +733,7 @@ def run(self): # count < 0 means there was an error and we didn't get any results. if new_counts is not None and new_counts >= 0: total_count = items_for_inference.total_frame_count - no_result_count = total_count - new_counts + no_result_count = max(0, total_count - new_counts) message = ( f"Inference ran on {total_count} frames." diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 949703020..4c2370e09 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -816,6 +816,8 @@ def __init__(self, state=None, player=None, *args, **kwargs): self.click_mode = "" self.in_zoom = False + self._down_pos = None + self.zoomFactor = 1 anchor_mode = QGraphicsView.AnchorUnderMouse self.setTransformationAnchor(anchor_mode) @@ -1039,7 +1041,7 @@ def mouseReleaseEvent(self, event): scenePos = self.mapToScene(event.pos()) # check if mouse moved during click - has_moved = event.pos() != self._down_pos + has_moved = self._down_pos is not None and event.pos() != self._down_pos if event.button() == Qt.LeftButton: if self.in_zoom: diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 14e0d5c6f..0923b6979 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -21,60 +21,67 @@ function which provides a simplified interface for creating `Predictor`s. """ -import attr import argparse +import atexit +import json import logging -import warnings import os -import sys -import tempfile import platform import shutil -import atexit import subprocess -import rich.progress -import pandas as pd -from rich.pretty import pprint +import sys +import tempfile +import warnings +from abc import ABC, abstractmethod from collections import deque -import json -from time import time from datetime import datetime from pathlib import Path -import tensorflow_hub as hub -from abc import ABC, abstractmethod -from typing import Text, Optional, List, Dict, Union, Iterator, Tuple -from threading import Thread from queue import Queue +from threading import Thread +from time import time +from typing import Dict, Iterator, List, Optional, Text, Tuple, Union -import tensorflow as tf +if sys.version_info >= (3, 8): + from functools import cached_property + +else: # cached_property is defined only for python >=3.8 + cached_property = property + +import attr import numpy as np +import pandas as pd +import rich.progress +import tensorflow as tf +import tensorflow_hub as hub +from rich.pretty import pprint +from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2, +) import sleap - -from sleap.nn.config import TrainingJobConfig, DataConfig -from sleap.nn.data.resizing import SizeMatcher -from sleap.nn.model import Model -from sleap.nn.tracking import Tracker, run_tracker -from sleap.nn.paf_grouping import PAFScorer +from sleap.instance import LabeledFrame, PredictedInstance +from sleap.io.dataset import Labels +from sleap.nn.config import DataConfig, TrainingJobConfig from sleap.nn.data.pipelines import ( - Provider, - Pipeline, + Batcher, + InstanceCentroidFinder, + KerasModelPredictor, LabelsReader, - VideoReader, Normalizer, - Resizer, + Pipeline, Prefetcher, - InstanceCentroidFinder, - KerasModelPredictor, + Provider, + Resizer, + VideoReader, ) +from sleap.nn.data.resizing import SizeMatcher +from sleap.nn.model import Model +from sleap.nn.paf_grouping import PAFScorer +from sleap.nn.tracking import Tracker from sleap.nn.utils import reset_input_layer -from sleap.io.dataset import Labels -from sleap.util import frame_list, make_scoped_dictionary -from sleap.instance import PredictedInstance, LabeledFrame +from sleap.util import RateColumn, frame_list, make_scoped_dictionary -from tensorflow.python.framework.convert_to_constants import ( - convert_variables_to_constants_v2, -) +logger = logging.getLogger(__name__) MOVENET_MODELS = { "lightning": { @@ -126,8 +133,6 @@ ], ) -logger = logging.getLogger(__name__) - def get_keras_model_path(path: Text) -> str: """Utility method for finding the path to a saved Keras model. @@ -144,17 +149,6 @@ def get_keras_model_path(path: Text) -> str: return os.path.join(path, "best_model.h5") -class RateColumn(rich.progress.ProgressColumn): - """Renders the progress rate.""" - - def render(self, task: "Task") -> rich.progress.Text: - """Show progress rate.""" - speed = task.speed - if speed is None: - return rich.progress.Text("?", style="progress.data.speed") - return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") - - @attr.s(auto_attribs=True) class Predictor(ABC): """Base interface class for predictors.""" @@ -167,9 +161,12 @@ class Predictor(ABC): report_rate: float = attr.ib(default=2.0, kw_only=True) model_paths: List[str] = attr.ib(factory=list, kw_only=True) - @property + @cached_property def report_period(self) -> float: """Time between progress reports in seconds.""" + if self.report_rate <= 0: + logger.warning("report_rate must be positive, fallback to 1") + return 1.0 return 1.0 / self.report_rate @classmethod @@ -360,7 +357,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: ensure_rgb=(not self.is_grayscale), ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -374,6 +371,122 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: def _initialize_inference_model(self): pass + def _process_batch(self, ex: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """Run prediction model on batch. + + This method handles running inference on a batch and postprocessing. + + Args: + ex: a dictionary holding the input for inference. + + Returns: + The input dictionary updated with the predictions. + """ + # Skip inference if model is not loaded + if self.inference_model is None: + return ex + + # Run inference on current batch. + preds = self.inference_model.predict_on_batch(ex, numpy=True) + + # Add model outputs to the input data example. + ex.update(preds) + + # Convert to numpy arrays if not already. + if isinstance(ex["video_ind"], tf.Tensor): + ex["video_ind"] = ex["video_ind"].numpy().flatten() + if isinstance(ex["frame_ind"], tf.Tensor): + ex["frame_ind"] = ex["frame_ind"].numpy().flatten() + + # Adjust for potential SizeMatcher scaling. + offset_x = ex.get("offset_x", 0) + offset_y = ex.get("offset_y", 0) + ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) + ex["instance_peaks"] /= np.expand_dims( + np.expand_dims(ex["scale"], axis=1), axis=1 + ) + + return ex + + def _run_batch_json( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + max_length: int = 30, + ) -> Iterator[Dict[str, np.ndarray]]: + n_processed = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + for ex in examples: + # Process batch of examples. + ex = self._process_batch(ex) + + # Track timing and progress. + elapsed_batch = time() - t0_batch + t0_batch = time() + n_batch = len(ex["frame_ind"]) + n_processed += n_batch + elapsed_all = time() - t0_all + + # Compute recent rate. + n_recent.append(n_batch) + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + # Report. + if time() > last_report + self.report_period: + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + # Return results. + yield ex + + def _run_batch_rich( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + ) -> Iterator[Dict[str, np.ndarray]]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Predicting...", total=n_total) + last_report = time() + for ex in examples: + ex = self._process_batch(ex) + + progress.update(task, advance=len(ex["frame_ind"])) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + # Return results. + yield ex + def _predict_generator( self, data_provider: Provider ) -> Iterator[Dict[str, np.ndarray]]: @@ -395,103 +508,22 @@ def _predict_generator( if self.inference_model is None: self._initialize_inference_model() - def process_batch(ex): - # Run inference on current batch. - preds = self.inference_model.predict_on_batch(ex, numpy=True) - - # Add model outputs to the input data example. - ex.update(preds) - - # Convert to numpy arrays if not already. - if isinstance(ex["video_ind"], tf.Tensor): - ex["video_ind"] = ex["video_ind"].numpy().flatten() - if isinstance(ex["frame_ind"], tf.Tensor): - ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - - # Adjust for potential SizeMatcher scaling. - offset_x = ex.get("offset_x", 0) - offset_y = ex.get("offset_y", 0) - ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) - ex["instance_peaks"] /= np.expand_dims( - np.expand_dims(ex["scale"], axis=1), axis=1 - ) - - return ex + # Compile loop examples before starting time to improve ETA + n_total = len(data_provider) + examples = self.pipeline.make_dataset() # Loop over data batches with optional progress reporting. if self.verbosity == "rich": - with rich.progress.Progress( - "{task.description}", - rich.progress.BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - "ETA:", - rich.progress.TimeRemainingColumn(), - RateColumn(), - auto_refresh=False, - refresh_per_second=self.report_rate, - speed_estimate_period=5, - ) as progress: - task = progress.add_task("Predicting...", total=len(data_provider)) - last_report = time() - for ex in self.pipeline.make_dataset(): - ex = process_batch(ex) - progress.update(task, advance=len(ex["frame_ind"])) - - # Handle refreshing manually to support notebooks. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - progress.refresh() - - # Return results. - yield ex + for ex in self._run_batch_rich(examples, n_total=n_total): + yield ex elif self.verbosity == "json": - n_processed = 0 - n_total = len(data_provider) - n_recent = deque(maxlen=30) - elapsed_recent = deque(maxlen=30) - last_report = time() - t0_all = time() - t0_batch = time() - for ex in self.pipeline.make_dataset(): - # Process batch of examples. - ex = process_batch(ex) - - # Track timing and progress. - elapsed_batch = time() - t0_batch - t0_batch = time() - n_batch = len(ex["frame_ind"]) - n_processed += n_batch - elapsed_all = time() - t0_all - - # Compute recent rate. - n_recent.append(n_batch) - elapsed_recent.append(elapsed_batch) - rate = sum(n_recent) / sum(elapsed_recent) - eta = (n_total - n_processed) / rate - - # Report. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - print( - json.dumps( - { - "n_processed": n_processed, - "n_total": n_total, - "elapsed": elapsed_all, - "rate": rate, - "eta": eta, - } - ), - flush=True, - ) - last_report = time() - - # Return results. + for ex in self._run_batch_json(examples, n_total=n_total): yield ex + else: - for ex in self.pipeline.make_dataset(): - yield process_batch(ex) + for ex in examples: + yield self._process_batch(ex) def predict( self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True @@ -582,7 +614,7 @@ def export_model( ) + (keras_model_shape[3],) tracing_batch = np.zeros((1,) + sample_shape, dtype="uint8") - outputs = self.inference_model.predict(tracing_batch) + _ = self.inference_model.predict(tracing_batch) self.inference_model.export_model( save_path, signatures, save_traces, model_name, tensors, unrag_outputs @@ -2535,7 +2567,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: skeletons=self.confmap_config.data.labels.skeletons, ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -4387,13 +4419,13 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: if self.centroid_model is None: anchor_part = self.confmap_config.data.instance_cropping.center_on_part - pipeline += sleap.nn.data.pipelines.InstanceCentroidFinder( + pipeline += InstanceCentroidFinder( center_on_anchor_part=anchor_part is not None, anchor_part_names=anchor_part, skeletons=self.confmap_config.data.labels.skeletons, ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -4615,7 +4647,7 @@ def __init__(self, model_name="lightning"): ) def call(self, ex): - if type(ex) == dict: + if isinstance(ex, dict): img = ex["image"] else: @@ -5390,7 +5422,9 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: ) ) else: - provider_list.append(LabelsReader(labels)) + provider_list.append( + LabelsReader(labels, example_indices=frame_list(args.frames)) + ) data_path_list.append(file_path) @@ -5459,7 +5493,7 @@ def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: max_instances=args.max_instances, ) - if type(predictor) == BottomUpPredictor: + if isinstance(predictor, BottomUpPredictor): predictor.inference_model.bottomup_layer.paf_scorer.max_edge_length_ratio = ( args.max_edge_length_ratio ) @@ -5484,7 +5518,10 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]: """ policy_args = make_scoped_dictionary(vars(args), exclude_nones=True) if "tracking" in policy_args: - tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) + tracker = Tracker.make_tracker_by_name( + progress_reporting=args.verbosity, + **policy_args["tracking"], + ) return tracker return None @@ -5568,7 +5605,6 @@ def main(args: Optional[list] = None): # Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run) if args.models is not None: - # Run inference on all files inputed for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)): # Setup models. @@ -5651,14 +5687,16 @@ def main(args: Optional[list] = None): data_path = data_path_list[0] # Load predictions - data_path = args.data_path print("Loading predictions...") - labels_pr = sleap.load_file(data_path) + labels_pr = sleap.load_file(data_path.as_posix()) frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + if provider.example_indices is not None: + # Convert indices to a set to search in O(1), otherwise it is much slower + index_set = set(provider.example_indices) + frames = list(filter(lambda lf: lf.frame_idx in index_set, frames)) print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + frames = tracker.run_tracker(frames=frames) labels_pr = Labels(labeled_frames=frames) @@ -5679,7 +5717,7 @@ def main(args: Optional[list] = None): labels_pr.provenance["sleap_version"] = sleap.__version__ labels_pr.provenance["platform"] = platform.platform() labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["data_path"] = os.fspath(data_path) labels_pr.provenance["output_path"] = output_path labels_pr.provenance["total_elapsed"] = total_elapsed labels_pr.provenance["start_timestamp"] = start_timestamp diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..55170ea36 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1,12 +1,18 @@ """Tracking tools for linking grouped instances over time.""" -from collections import deque, defaultdict import abc +import functools +import json +import logging +import sys +from collections import deque +from time import time +from typing import Callable, Deque, Dict, Iterable, Iterator, List, Optional, Tuple + import attr -import numpy as np import cv2 -import functools -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple +import numpy as np +import rich.progress from sleap import Track, LabeledFrame, Skeleton @@ -26,8 +32,16 @@ Match, ) from sleap.nn.tracker.kalman import BareKalmanTracker - from sleap.nn.data.normalization import ensure_int +from sleap.util import RateColumn + +if sys.version_info >= (3, 8): + from functools import cached_property + +else: # cached_property is defined only for python >=3.8 + cached_property = property + +logger = logging.getLogger(__name__) @attr.s(eq=False, slots=True, auto_attribs=True) @@ -66,7 +80,6 @@ def from_instance( shift_score: float = 0.0, with_skeleton: bool = False, ): - points_array = new_points_array if points_array is None: points_array = ref_instance.points_array @@ -511,14 +524,162 @@ def get_candidates( class BaseTracker(abc.ABC): """Abstract base class for tracker.""" + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + kw_only=True, + ) + report_rate: float = attr.ib(default=2.0, kw_only=True) + @property def is_valid(self): return False + @cached_property + def report_period(self) -> float: + """Time between progress reports in seconds.""" + if self.report_rate <= 0: + logger.warning("report_rate must be positive, fallback to 1") + return 1.0 + return 1.0 / self.report_rate + + def run_step(self, lf: LabeledFrame) -> LabeledFrame: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, t=lf.frame_idx) + if self.uses_image: + track_args["img"] = lf.video[lf.frame_idx] + else: + track_args["img"] = None + track_args["img_hw"] = lf.image.shape[-3:-1] + + return LabeledFrame( + frame_idx=lf.frame_idx, + video=lf.video, + instances=self.track(**track_args), + ) + + def _run_tracker_json( + self, + frames: List[LabeledFrame], + max_length: int = 30, + ) -> Iterator[LabeledFrame]: + n_total = len(frames) + n_processed = 0 + n_batch = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + + for lf in frames: + new_lf = self.run_step(lf) + + # Track timing and progress + elapsed_all = time() - t0_all + n_processed += 1 + n_batch += 1 + + # Report + if time() > last_report + self.report_period: + elapsed_batch = time() - t0_batch + t0_batch = time() + + # Compute recent rate + n_recent.append(n_batch) + n_batch = 0 + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + yield new_lf + + def _run_tracker_rich(self, frames: List[LabeledFrame]) -> Iterator[LabeledFrame]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Tracking...", total=len(frames)) + last_report = time() + for lf in frames: + new_lf = self.run_step(lf) + + progress.update(task, advance=1) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + yield new_lf + + def run_tracker( + self, + frames: List[LabeledFrame], + *, + verbosity: Optional[str] = None, + final_pass: bool = True, + ) -> List[LabeledFrame]: + """Run the tracker on a set of labeled frames. + + Args: + frames: A list of labeled frames with instances. + + Returns: + The input frames with the new tracks assigned. If the frames already had tracks, + they will be cleared if the tracker has been re-initialized. + """ + # Return original frames if we aren't retracking + if not self.is_valid: + return frames + + verbosity = verbosity or self.verbosity + + # Run tracking on every frame + if verbosity == "rich": + new_lfs = list(self._run_tracker_rich(frames)) + + elif verbosity == "json": + new_lfs = list(self._run_tracker_json(frames)) + + else: + new_lfs = list(self.run_step(lf) for lf in frames) + + # Run final_pass + if final_pass: + self.final_pass(new_lfs) + + return new_lfs + @abc.abstractmethod def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ): @@ -564,6 +725,12 @@ class Tracker(BaseTracker): use the max similarity (non-robust). For selecting a robust score, 0.95 is a good value. max_tracking: Max tracking is incorporated when this is set to true. + verbosity: Mode of inference progress reporting. If `"rich"` (the + default), an updating progress bar is displayed in the console or notebook. + If `"json"`, a JSON-serialized message is printed out which can be captured + for programmatic progress monitoring. If `"none"`, nothing is displayed + during tracking -- this is recommended when running on clusters or headless + machines where the output is captured to a log file. """ max_tracks: int = None @@ -670,7 +837,6 @@ def track( if t is None: if self.has_max_tracking: if len(self.track_matching_queue_dict) > 0: - # Default to last timestep + 1 if available. # Here we find the track that has the most instances. track_with_max_instances = max( @@ -686,7 +852,6 @@ def track( t = 0 else: if len(self.track_matching_queue) > 0: - # Default to last timestep + 1 if available. t = self.track_matching_queue[-1].t + 1 @@ -701,7 +866,6 @@ def track( # Process untracked instances. if untracked_instances: - if self.pre_cull_function: self.pre_cull_function(untracked_instances) @@ -791,7 +955,6 @@ def spawn_for_untracked_instances( ) -> List[InstanceType]: results = [] for inst in unmatched_instances: - # Skip if this instance is too small to spawn a new track with. if inst.n_visible_points < self.min_new_track_points: continue @@ -868,6 +1031,8 @@ def make_tracker_by_name( oks_errors: Optional[list] = None, oks_score_weighting: bool = False, oks_normalization: str = "all", + progress_reporting: str = "rich", + report_rate: float = 2.0, **kwargs, ) -> BaseTracker: # Parse max_tracking arguments, only True if max_tracks is not None and > 0 @@ -942,6 +1107,8 @@ def pre_cull_function(inst_list): max_tracks=max_tracks, target_instance_count=target_instance_count, post_connect_single_breaks=post_connect_single_breaks, + verbosity=progress_reporting, + report_rate=report_rate, ) if target_instance_count and kf_init_frame_count: @@ -961,7 +1128,6 @@ def pre_cull_function(inst_list): @classmethod def get_by_name_factory_options(cls): - options = [] option = dict(name="tracker", default="None") @@ -1080,7 +1246,8 @@ def get_by_name_factory_options(cls): option["type"] = int option["help"] = ( "If non-zero and tracking.tracker is set to flow, save the shifted " - "instances between elapsed frames" + "instances between elapsed frames. It uses a bit more of memory but gives " + "better instance matches." ) options.append(option) @@ -1094,9 +1261,10 @@ def int_list_func(s): option = dict(name="kf_init_frame_count", default="0") option["type"] = int - option[ - "help" - ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." + option["help"] = ( + "For Kalman filter: Number of frames to track with other tracker. " + "0 means no Kalman filters will be used." + ) options.append(option) def float_list_func(s): @@ -1117,9 +1285,10 @@ def float_list_func(s): option = dict(name="oks_score_weighting", default="0") option["type"] = int option["help"] = ( - "For Object Keypoint similarity: if 0 (default), only the distance between the reference " - "and query keypoint is used to compute the similarity. If 1, each distance is weighted " - "by the prediction scores of the reference and query keypoint." + "For Object Keypoint similarity: if 0 (default), only the distance " + "between the reference and query keypoint is used to compute the " + "similarity. If 1, each distance is weighted by the prediction scores " + "of the reference and query keypoint." ) options.append(option) @@ -1127,10 +1296,10 @@ def float_list_func(s): option["type"] = str option["options"] = ["all", "ref", "union"] option["help"] = ( - "For Object Keypoint similarity: Determine how to normalize similarity score. " + "Object Keypoint similarity: Determine how to normalize similarity score. " "If 'all', similarity score is normalized by number of reference points. " - "If 'ref', similarity score is normalized by number of visible reference points. " - "If 'union', similarity score is normalized by number of points both visible " + "If 'ref', score is normalized by number of visible reference points. " + "If 'union', score is normalized by number of points both visible " "in query and reference instance." ) options.append(option) @@ -1150,11 +1319,21 @@ def add_cli_parser_args(cls, parser, arg_scope: str = ""): else: arg_name = arg["name"] - parser.add_argument( - f"--{arg_name}", - type=arg["type"], - help=help_string, - ) + if arg["name"] == "tracker": + # If default is defined for "tracking.tracker", we cannot detect + # mal-formed command line. + parser.add_argument( + f"--{arg_name}", + type=arg["type"], + help=help_string, + ) + else: + parser.add_argument( + f"--{arg_name}", + type=arg["type"], + help=help_string, + default=arg["default"], + ) @attr.s(auto_attribs=True) @@ -1230,7 +1409,6 @@ def add_frame_instances( # "usuable" instances—i.e., instances with the nodes that we'll track # using Kalman filters. elif frame_match.has_only_first_choice_matches: - good_instances = [ inst for inst in instances if self.is_usable_instance(inst) ] @@ -1384,6 +1562,7 @@ def cull_function(inst_list): def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ) -> List[InstanceType]: @@ -1444,7 +1623,6 @@ def track( # Check whether we've been getting good results from the Kalman filters. # First, has it been a while since the filters were initialized? if self.init_done and (t - self.last_init_t) > self.re_init_cooldown: - # If it's been a while, then see if it's also been a while since # the filters successfully matched tracks to the instances. if self.kalman_tracker.last_frame_with_tracks < t - self.re_init_after: @@ -1501,47 +1679,6 @@ def run(self, frames: List[LabeledFrame]): connect_single_track_breaks(frames, self.instance_count) -def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[LabeledFrame]: - """Run a tracker on a set of labeled frames. - - Args: - frames: A list of labeled frames with instances. - tracker: An initialized Tracker. - - Returns: - The input frames with the new tracks assigned. If the frames already had tracks, - they will be cleared if the tracker has been re-initialized. - """ - # Return original frames if we aren't retracking - if not tracker.is_valid: - return frames - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if tracker.uses_image: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - track_args["img_hw"] = lf.image.shape[-3:-1] - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args), - ) - new_lfs.append(new_lf) - - return new_lfs - - def retrack(): import argparse import operator @@ -1579,8 +1716,7 @@ def retrack(): print(f"Done loading predictions in {time.time() - t0} seconds.") print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + frames = tracker.run_tracker(frames=frames) new_labels = Labels(labeled_frames=frames) diff --git a/sleap/util.py b/sleap/util.py index bc3389b7d..e4d0c1eb7 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -1,32 +1,49 @@ -"""A miscellaneous set of utility functions. +"""A miscellaneous set of utility functions. Try not to put things in here unless they really have no other place. """ +from __future__ import annotations + import json import os import re import shutil from collections import defaultdict from pathlib import Path -from typing import Any, Dict, Hashable, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse from urllib.request import url2pathname +try: + from importlib.resources import files # New in 3.9+ +except ImportError: + from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. + import attr import h5py as h5 import numpy as np import psutil import rapidjson +import rich.progress import yaml -try: - from importlib.resources import files # New in 3.9+ -except ImportError: - from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. - import sleap.version as sleap_version +if TYPE_CHECKING: + from rich.progress import Task + + +class RateColumn(rich.progress.ProgressColumn): + """Renders the progress rate.""" + + def render(self, task: Task) -> rich.progress.Text: + """Show progress rate.""" + speed = task.speed + if speed is None: + return rich.progress.Text("?", style="progress.data.speed") + return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") + def json_loads(json_str: str) -> Dict: """A simple wrapper around the JSON decoder we are using. diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 0c7ba2b0a..fa0cc5f51 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -9,29 +9,26 @@ FrameMatches, greedy_matching, ) -from sleap.io.dataset import Labels from sleap.instance import PredictedInstance from sleap.skeleton import Skeleton -def tracker_by_name(frames=None, **kwargs): - t = Tracker.make_tracker_by_name(**kwargs) - print(kwargs) - print(t.candidate_maker) - if frames is None: - t.track([]) - t.final_pass([]) - return +def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs): + # Create tracker + t = Tracker.make_tracker_by_name(verbosity="none", **kwargs) + # Update img_scale + if img_scale: + if hasattr(t, "candidate_maker") and hasattr(t.candidate_maker, "img_scale"): + t.candidate_maker.img_scale = img_scale + else: + # Do not even run tracking as it can be slow + pytest.skip("img_scale is not defined for this tracker") + return - for lf in frames: - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) - t.track(**track_args, img_hw=(1, 1)) - t.final_pass(frames) + # Run tracking + new_frames = t.run_tracker(frames or []) + assert len(new_frames) == len(frames) @pytest.mark.parametrize( @@ -42,22 +39,25 @@ def tracker_by_name(frames=None, **kwargs): ["instance", "normalized_instance", "iou", "centroid", "object_keypoint"], ) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) +@pytest.mark.parametrize("img_scale", [0, 1, 0.25]) @pytest.mark.parametrize("count", [0, 2]) def test_tracker_by_name( centered_pair_predictions_sorted, tracker, similarity, match, + img_scale, count, ): # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity=similarity, match=match, + img_scale=img_scale, max_tracks=count, ) @@ -76,7 +76,7 @@ def test_oks_tracker_by_name( # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity="object_keypoint", diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..c479462f8 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -1,12 +1,12 @@ -import inspect import operator import os import time +from pathlib import Path import sleap from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components -from sleap.io.dataset import Labels, LabeledFrame +from sleap.io.dataset import Labels def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): @@ -19,7 +19,7 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): inference_cli(cli.split(" ")) labels = sleap.load_file(f"{tmpdir}/simpletracks.slp") - assert len(labels.tracks) == 27 + assert len(labels.tracks) == 8 def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): @@ -37,18 +37,19 @@ def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): # TODO: Refactor the below things into a real test suite. +# running an equivalent to `make_ground_truth` is done as a test in tests/nn/test_tracker_components.py def make_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{len(tracker.spawned_tracks)}\t{time.time()-t0}") Labels.save_file(new_labels, gt_filename) def compare_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{time.time() - t0}") does_match = check_tracks(new_labels, gt_filename) @@ -78,43 +79,6 @@ def check_tracks(labels, gt_filename, limit=None): return True -def run_tracker(frames, tracker): - sig = inspect.signature(tracker.track) - takes_img = "img" in sig.parameters - - # t0 = time.time() - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if takes_img: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args, img_hw=lf.image.shape[-3:-1]), - ) - new_lfs.append(new_lf) - - # if lf.frame_idx % 100 == 0: print(lf.frame_idx, time.time()-t0) - - # print(time.time() - t0) - - new_labels = Labels() - new_labels.extend(new_lfs) - return new_labels - - def main(f, dir): filename = "tests/data/json_format_v2/centered_pair_predictions.json" @@ -166,7 +130,9 @@ def make_tracker( return tracker def make_filename(tracker_name, matcher_name, sim_name, scale=0): - return f"{dir}{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5" + return Path(dir).joinpath( + f"{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5" + ) def make_tracker_and_filename(*args, **kwargs): tracker = make_tracker(*args, **kwargs) @@ -180,7 +146,6 @@ def make_tracker_and_filename(*args, **kwargs): for tracker_name in trackers.keys(): for matcher_name in matchers.keys(): for sim_name in similarities.keys(): - if tracker_name == "flow": # If this tracker supports scale, try multiple scales for scale in scales: