Skip to content

Commit

Permalink
add an only_predicted_instances option to include user-defined instan…
Browse files Browse the repository at this point in the history
…ces in the tracking
  • Loading branch information
getzze committed Dec 6, 2024
1 parent a520942 commit b9471f7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
1 change: 1 addition & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def make_predict_cli_call(
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
"tracking.only_predicted_instances",
)

for key in bool_items_as_ints:
Expand Down
19 changes: 19 additions & 0 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,25 @@ def from_numpy(
)


def convert_to_predicted_instance(
inst: Union[Instance, PredictedInstance],
*,
score: float = 1.0,
tracking_score: float = 0.0,
) -> PredictedInstance:
"""Convert an Instance to a PredictedInstance, if it's not one already.
Score is by default 1.0, like a user-defined instance.
"""
if isinstance(inst, PredictedInstance):
return inst

kwargs = attr.asdict(inst)
kwargs["score"] = score
kwargs["tracking_score"] = tracking_score
return PredictedInstance(**kwargs)


def make_instance_cattr() -> cattr.Converter:
"""Create a cattr converter for Lists of Instances/PredictedInstances.
Expand Down
17 changes: 13 additions & 4 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple

from sleap import Track, LabeledFrame, Skeleton
from sleap.instance import convert_to_predicted_instance

from sleap.nn.tracker.components import (
factory_object_keypoint_similarity,
Expand All @@ -22,7 +23,6 @@
cull_frame_instances,
connect_single_track_breaks,
InstanceType,
PredictedInstance,
FrameMatches,
Match,
)
Expand Down Expand Up @@ -581,6 +581,7 @@ class Tracker(BaseTracker):
robust_best_instance: float = 1.0

min_new_track_points: int = 0
only_predicted_instances: bool = True

track_matching_queue: Deque[MatchedFrameInstances] = attr.ib()

Expand Down Expand Up @@ -843,6 +844,7 @@ def make_tracker_by_name(
robust: float = 1.0,
min_new_track_points: int = 0,
min_match_points: int = 0,
only_predicted_instances: bool = True,
# Optical flow options
img_scale: float = 1.0,
of_window_size: int = 21,
Expand Down Expand Up @@ -941,6 +943,7 @@ def pre_cull_function(inst_list):
max_tracks=max_tracks,
target_instance_count=target_instance_count,
post_connect_single_breaks=post_connect_single_breaks,
only_predicted_instances=only_predicted_instances,
)

if target_instance_count and kf_init_frame_count:
Expand Down Expand Up @@ -1057,6 +1060,11 @@ def get_by_name_factory_options(cls):
option["help"] = "Minimum points for match candidates"
options.append(option)

option = dict(name="only_predicted_instances", default=1)
option["type"] = int
option["help"] = "Track only predicted instances, not user-defined instances."
options.append(option)

option = dict(name="img_scale", default=1.0)
option["type"] = float
option["help"] = "For optical-flow: Image scale"
Expand Down Expand Up @@ -1518,9 +1526,10 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele
# Run tracking on every frame
for lf in frames:
# Use only the predicted instances
instances = [
inst for inst in lf.instances if isinstance(inst, PredictedInstance)
]
if tracker.only_predicted_instances:
instances = lf.predicted_instances
else:
instances = [convert_to_predicted_instance(inst) for inst in lf.instances]

# Clear the tracks
for inst in instances:
Expand Down

0 comments on commit b9471f7

Please sign in to comment.