diff --git a/docs/_static/bonsai-connection.jpg b/docs/_static/bonsai-connection.jpg new file mode 100644 index 000000000..32b725416 Binary files /dev/null and b/docs/_static/bonsai-connection.jpg differ diff --git a/docs/_static/bonsai-filecapture.jpg b/docs/_static/bonsai-filecapture.jpg new file mode 100644 index 000000000..7a809d67a Binary files /dev/null and b/docs/_static/bonsai-filecapture.jpg differ diff --git a/docs/_static/bonsai-predictcentroids.jpg b/docs/_static/bonsai-predictcentroids.jpg new file mode 100644 index 000000000..e284f2338 Binary files /dev/null and b/docs/_static/bonsai-predictcentroids.jpg differ diff --git a/docs/_static/bonsai-predictposeidentities.jpg b/docs/_static/bonsai-predictposeidentities.jpg new file mode 100644 index 000000000..8582fd707 Binary files /dev/null and b/docs/_static/bonsai-predictposeidentities.jpg differ diff --git a/docs/_static/bonsai-predictposes.jpg b/docs/_static/bonsai-predictposes.jpg new file mode 100644 index 000000000..2e4f04a22 Binary files /dev/null and b/docs/_static/bonsai-predictposes.jpg differ diff --git a/docs/_static/bonsai-workflow.jpg b/docs/_static/bonsai-workflow.jpg new file mode 100644 index 000000000..0481c3dcf Binary files /dev/null and b/docs/_static/bonsai-workflow.jpg differ diff --git a/docs/guides/bonsai.md b/docs/guides/bonsai.md new file mode 100644 index 000000000..d262873b6 --- /dev/null +++ b/docs/guides/bonsai.md @@ -0,0 +1,75 @@ +(bonsai)= + +# Using Bonsai with SLEAP + +Bonsai is a visual language for reactive programming and currently supports SLEAP models. + +:::{note} +Currently Bonsai supports only single instance, top-down and top-down-id SLEAP models. +::: + +### Exporting a SLEAP trained model + +Before we can import a trained model into Bonsai, we need to use the {code}`sleap-export` command to convert the model to a format supported by Bonsai. For example, to export a top-down-id model, the command is as follows: + +```bash +sleap-export -m centroid/model/folder/path -m top_down_id/model/folder/path -e exported/model/path +``` + +Please refer to the {ref}`sleap-export` docs for more details on using the command. + +This will generate the necessary `.pb` file and other information files required by Bonsai. In this example, these files were saved to the specified `exported/model/path` folder. + +The `exported/model/path` folder will have a structure like the following: + +```plaintext +exported/model/path +├── centroid_config.json +├── confmap_config.json +├── frozen_graph.pb +└── info.json +``` + +### Installing Bonsai and necessary packages + +1. Install Bonsai. See the [Bonsai installation instructions](https://bonsai-rx.org/docs/articles/installation.html). + +2. Download and add the necessary packages for Bonsai to run with SLEAP. See the official [Bonsai SLEAP documentation](https://github.com/bonsai-rx/sleap?tab=readme-ov-file#bonsai---sleap) for more information. + +### Using Bonsai SLEAP modules + +Once you have Bonsai installed with the required packages, you should be able to open the Bonsai application. The workflow must have a source module `FileCapture` which can be found in the toolbox search in the workflow editor. Provide the path to the video that was used to train the SLEAP model in the `FileName` field of the module. + +![Bonsai FileCapture module](../_static/bonsai-filecapture.jpg) + +#### Top-down model +The top-down model requires both the `PredictCentroids` and the `PredictPoses` modules. + +The `PredictCentroids` module will predict the centroids of detections. There are two fields inside the `PredictCentroids` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centroid model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictCentroids module](../_static/bonsai-predictcentroids.jpg) + +The `PredictPoses` module will predict the instances of detections. Similar to the `PredictCentroid` module, there are two fields inside the `PredictPoses` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centered instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictPoses module](../_static/bonsai-predictposes.jpg) + +#### Top-Down-ID model +The `PredictPoseIdentities` module will predict the instances with identities. This module has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the top-down-id model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +![Bonsai PredictPoseIdentities module](../_static/bonsai-predictposeidentities.jpg) + +#### Single instance model +The `PredictSinglePose` module will predict the poses for single instance models. This module also has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the single instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +### Connecting the modules +Right-click on the `FileCapture` module and select **Create Connection**. Now click on the required SLEAP module to complete the connection. + +![Bonsai module connection ](../_static/bonsai-connection.jpg) + +Once it is done, the workflow in Bonsai will look something like the following: + +![Bonsai.SLEAP workflow](../_static/bonsai-workflow.jpg) + +Now you can click the green start button to run the workflow and you can add more modules to analyze and visualize the results in Bonsai. + +For more documentation on various modules and workflows, please refer to the [official Bonsai docs](https://bonsai-rx.org/docs/articles/editor.html). diff --git a/docs/guides/index.md b/docs/guides/index.md index 7eb55b2b2..6d773d9de 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -30,6 +30,10 @@ {ref}`remote-inference` when you trained models and you want to run inference on a different machine using a **command-line interface**. +## SLEAP with Bonsai + +{ref}`bonsai` when you want to analyze the trained SLEAP model to visualize the poses, centroids and identities for further visual analysis. + ```{toctree} :hidden: true :maxdepth: 2 @@ -44,4 +48,5 @@ proofreading colab custom-training remote +bonsai ``` diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 2dbceb3b7..0ead06b1e 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -873,6 +873,8 @@ def new_instance_menu_action(): "Point Displacement (max)", "Primary Point Displacement (sum)", "Primary Point Displacement (max)", + "Tracking Score (mean)", + "Tracking Score (min)", "Instance Score (sum)", "Instance Score (min)", "Point Score (sum)", @@ -1406,6 +1408,8 @@ def _set_seekbar_header(self, graph_name: str): "Point Displacement (max)": data_obj.get_point_displacement_series, "Primary Point Displacement (sum)": data_obj.get_primary_point_displacement_series, "Primary Point Displacement (max)": data_obj.get_primary_point_displacement_series, + "Tracking Score (mean)": data_obj.get_tracking_score_series, + "Tracking Score (min)": data_obj.get_tracking_score_series, "Instance Score (sum)": data_obj.get_instance_score_series, "Instance Score (min)": data_obj.get_instance_score_series, "Point Score (sum)": data_obj.get_point_score_series, @@ -1419,7 +1423,7 @@ def _set_seekbar_header(self, graph_name: str): else: if graph_name in header_functions: kwargs = dict(video=self.state["video"]) - reduction_name = re.search("\\((sum|max|min)\\)", graph_name) + reduction_name = re.search("\\((sum|max|min|mean)\\)", graph_name) if reduction_name is not None: kwargs["reduction"] = reduction_name.group(1) series = header_functions[graph_name](**kwargs) diff --git a/sleap/info/summary.py b/sleap/info/summary.py index c6a6af60e..0cad1617e 100644 --- a/sleap/info/summary.py +++ b/sleap/info/summary.py @@ -21,7 +21,7 @@ class StatisticSeries: are frame index and value are some numerical value for the frame. Args: - labels: The :class:`Labels` for which to calculate series. + labels: The `Labels` for which to calculate series. """ labels: Labels @@ -41,7 +41,7 @@ def get_point_score_series( """Get series with statistic of point scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -67,7 +67,7 @@ def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]: """Get series with statistic of instance scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -93,7 +93,7 @@ def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, flo same track) from the closest earlier labeled frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -121,7 +121,7 @@ def get_primary_point_displacement_series( Get sum of displacement for single node of each instance per frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -226,7 +226,7 @@ def _calculate_frame_velocity( Calculate total point displacement between two given frames. Args: - lf: The :class:`LabeledFrame` for which we want velocity + lf: The `LabeledFrame` for which we want velocity last_lf: The frame from which to calculate displacement. reduce_function: Numpy function (e.g., np.sum, np.nanmean) is applied to *point* displacement, and then those @@ -246,3 +246,35 @@ def _calculate_frame_velocity( inst_dist = reduce_function(point_dist) val += inst_dist if not np.isnan(inst_dist) else 0 return val + + def get_tracking_score_series( + self, video: Video, reduction: str = "min" + ) -> Dict[int, float]: + """Get series with statistic of tracking scores in each frame. + + Args: + video: The `Video` for which to calculate statistic. + reduction: name of function applied to scores: + * mean + * min + + Returns: + The series dictionary (see class docs for details) + """ + reduce_fn = { + "min": np.nanmin, + "mean": np.nanmean, + }[reduction] + + series = dict() + + for lf in self.labels.find(video): + vals = [ + inst.tracking_score for inst in lf if hasattr(inst, "tracking_score") + ] + if vals: + val = reduce_fn(vals) + if not np.isnan(val): + series[lf.frame_idx] = val + + return series diff --git a/sleap/instance.py b/sleap/instance.py index 08a5c6ae6..382ececf2 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1049,7 +1049,9 @@ def scores(self) -> np.ndarray: return self.points_and_scores_array[:, 2] @classmethod - def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": + def from_instance( + cls, instance: Instance, score: float, tracking_score: float = 0.0 + ) -> "PredictedInstance": """Create a `PredictedInstance` from an `Instance`. The fields are copied in a shallow manner with the exception of points. For each @@ -1059,6 +1061,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": Args: instance: The `Instance` object to shallow copy data from. score: The score for this instance. + tracking_score: The tracking score for this instance. Returns: A `PredictedInstance` for the given `Instance`. @@ -1070,6 +1073,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": ) kw_args["points"] = PredictedPointArray.from_array(instance._points) kw_args["score"] = score + kw_args["tracking_score"] = tracking_score return cls(**kw_args) @classmethod @@ -1080,6 +1084,7 @@ def from_arrays( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1094,6 +1099,7 @@ def from_arrays( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. @@ -1114,6 +1120,7 @@ def from_arrays( skeleton=skeleton, score=instance_score, track=track, + tracking_score=tracking_score, ) @classmethod @@ -1124,6 +1131,7 @@ def from_pointsarray( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1138,12 +1146,18 @@ def from_pointsarray( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) @classmethod @@ -1154,6 +1168,7 @@ def from_numpy( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1168,12 +1183,18 @@ def from_numpy( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 14e0d5c6f..3f01a1c3c 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -3778,9 +3778,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=np.nanmean(confs), skeleton=skeleton, track=track, + tracking_score=np.nanmean(score), ) ) @@ -4452,18 +4453,27 @@ def _object_builder(): break # Loop over frames. - for image, video_ind, frame_ind, points, confidences, scores in zip( + for ( + image, + video_ind, + frame_ind, + centroid_vals, + points, + confidences, + scores, + ) in zip( ex["image"], ex["video_ind"], ex["frame_ind"], + ex["centroid_vals"], ex["instance_peaks"], ex["instance_peak_vals"], ex["instance_scores"], ): # Loop over instances. predicted_instances = [] - for i, (pts, confs, score) in enumerate( - zip(points, confidences, scores) + for i, (pts, centroid_val, confs, score) in enumerate( + zip(points, centroid_vals, confidences, scores) ): if np.isnan(pts).all(): continue @@ -4474,9 +4484,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=centroid_val, skeleton=skeleton, track=track, + tracking_score=score, ) ) diff --git a/tests/data/tracks/clip.predictions.slp b/tests/data/tracks/clip.predictions.slp new file mode 100644 index 000000000..652e21302 Binary files /dev/null and b/tests/data/tracks/clip.predictions.slp differ diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index ec5dfbc29..c6507caec 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -97,6 +97,20 @@ def min_tracks_2node_labels(): ) +@pytest.fixture +def min_tracks_2node_predictions(): + """ + Generated with: + ``` + sleap-track -m "tests/data/models/min_tracks_2node.UNet.bottomup_multiclass" "tests/data/tracks/clip.mp4" + ``` + """ + return Labels.load_file( + "tests/data/tracks/clip.predictions.slp", + video_search=["tests/data/tracks/clip.mp4"], + ) + + @pytest.fixture def min_tracks_13node_labels(): return Labels.load_file( diff --git a/tests/info/test_summary.py b/tests/info/test_summary.py index 2cf76c166..672d97e63 100644 --- a/tests/info/test_summary.py +++ b/tests/info/test_summary.py @@ -37,6 +37,19 @@ def test_frame_statistics(simple_predictions): x = stats.get_point_displacement_series(video, "max") assert len(x) == 2 - assert len(x) == 2 assert x[0] == 0 assert x[1] == 18.0 + + +def test_get_tracking_score_series(min_tracks_2node_predictions): + + stats = StatisticSeries(min_tracks_2node_predictions) + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "min") + assert len(x) == 1500 + assert x[0] == 0.9999966621398926 + assert x[1000] == 0.9998022317886353 + + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "mean") + assert len(x) == 1500 + assert x[0] == 0.9999983310699463 + assert x[1000] == 0.9999011158943176