Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into talmo/check-for-tracks-before-training-id…
Browse files Browse the repository at this point in the history
…-models
talmo authored Dec 16, 2024
2 parents d233bea + 0042cc2 commit e154749
Showing 15 changed files with 191 additions and 16 deletions.
Binary file added docs/_static/bonsai-connection.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-filecapture.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-predictcentroids.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-predictposeidentities.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-predictposes.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bonsai-workflow.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 75 additions & 0 deletions docs/guides/bonsai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
(bonsai)=

# Using Bonsai with SLEAP

Bonsai is a visual language for reactive programming and currently supports SLEAP models.

:::{note}
Currently Bonsai supports only single instance, top-down and top-down-id SLEAP models.
:::

### Exporting a SLEAP trained model

Before we can import a trained model into Bonsai, we need to use the {code}`sleap-export` command to convert the model to a format supported by Bonsai. For example, to export a top-down-id model, the command is as follows:

```bash
sleap-export -m centroid/model/folder/path -m top_down_id/model/folder/path -e exported/model/path
```

Please refer to the {ref}`sleap-export` docs for more details on using the command.

This will generate the necessary `.pb` file and other information files required by Bonsai. In this example, these files were saved to the specified `exported/model/path` folder.

The `exported/model/path` folder will have a structure like the following:

```plaintext
exported/model/path
├── centroid_config.json
├── confmap_config.json
├── frozen_graph.pb
└── info.json
```

### Installing Bonsai and necessary packages

1. Install Bonsai. See the [Bonsai installation instructions](https://bonsai-rx.org/docs/articles/installation.html).

2. Download and add the necessary packages for Bonsai to run with SLEAP. See the official [Bonsai SLEAP documentation](https://github.com/bonsai-rx/sleap?tab=readme-ov-file#bonsai---sleap) for more information.

### Using Bonsai SLEAP modules

Once you have Bonsai installed with the required packages, you should be able to open the Bonsai application. The workflow must have a source module `FileCapture` which can be found in the toolbox search in the workflow editor. Provide the path to the video that was used to train the SLEAP model in the `FileName` field of the module.

![Bonsai FileCapture module](../_static/bonsai-filecapture.jpg)

#### Top-down model
The top-down model requires both the `PredictCentroids` and the `PredictPoses` modules.

The `PredictCentroids` module will predict the centroids of detections. There are two fields inside the `PredictCentroids` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centroid model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

![Bonsai PredictCentroids module](../_static/bonsai-predictcentroids.jpg)

The `PredictPoses` module will predict the instances of detections. Similar to the `PredictCentroid` module, there are two fields inside the `PredictPoses` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centered instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

![Bonsai PredictPoses module](../_static/bonsai-predictposes.jpg)

#### Top-Down-ID model
The `PredictPoseIdentities` module will predict the instances with identities. This module has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the top-down-id model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

![Bonsai PredictPoseIdentities module](../_static/bonsai-predictposeidentities.jpg)

#### Single instance model
The `PredictSinglePose` module will predict the poses for single instance models. This module also has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the single instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder.

### Connecting the modules
Right-click on the `FileCapture` module and select **Create Connection**. Now click on the required SLEAP module to complete the connection.

![Bonsai module connection ](../_static/bonsai-connection.jpg)

Once it is done, the workflow in Bonsai will look something like the following:

![Bonsai.SLEAP workflow](../_static/bonsai-workflow.jpg)

Now you can click the green start button to run the workflow and you can add more modules to analyze and visualize the results in Bonsai.

For more documentation on various modules and workflows, please refer to the [official Bonsai docs](https://bonsai-rx.org/docs/articles/editor.html).
5 changes: 5 additions & 0 deletions docs/guides/index.md
Original file line number Diff line number Diff line change
@@ -30,6 +30,10 @@

{ref}`remote-inference` when you trained models and you want to run inference on a different machine using a **command-line interface**.

## SLEAP with Bonsai

{ref}`bonsai` when you want to analyze the trained SLEAP model to visualize the poses, centroids and identities for further visual analysis.

```{toctree}
:hidden: true
:maxdepth: 2
@@ -44,4 +48,5 @@ proofreading
colab
custom-training
remote
bonsai
```
6 changes: 5 additions & 1 deletion sleap/gui/app.py
Original file line number Diff line number Diff line change
@@ -873,6 +873,8 @@ def new_instance_menu_action():
"Point Displacement (max)",
"Primary Point Displacement (sum)",
"Primary Point Displacement (max)",
"Tracking Score (mean)",
"Tracking Score (min)",
"Instance Score (sum)",
"Instance Score (min)",
"Point Score (sum)",
@@ -1406,6 +1408,8 @@ def _set_seekbar_header(self, graph_name: str):
"Point Displacement (max)": data_obj.get_point_displacement_series,
"Primary Point Displacement (sum)": data_obj.get_primary_point_displacement_series,
"Primary Point Displacement (max)": data_obj.get_primary_point_displacement_series,
"Tracking Score (mean)": data_obj.get_tracking_score_series,
"Tracking Score (min)": data_obj.get_tracking_score_series,
"Instance Score (sum)": data_obj.get_instance_score_series,
"Instance Score (min)": data_obj.get_instance_score_series,
"Point Score (sum)": data_obj.get_point_score_series,
@@ -1419,7 +1423,7 @@ def _set_seekbar_header(self, graph_name: str):
else:
if graph_name in header_functions:
kwargs = dict(video=self.state["video"])
reduction_name = re.search("\\((sum|max|min)\\)", graph_name)
reduction_name = re.search("\\((sum|max|min|mean)\\)", graph_name)
if reduction_name is not None:
kwargs["reduction"] = reduction_name.group(1)
series = header_functions[graph_name](**kwargs)
44 changes: 38 additions & 6 deletions sleap/info/summary.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ class StatisticSeries:
are frame index and value are some numerical value for the frame.
Args:
labels: The :class:`Labels` for which to calculate series.
labels: The `Labels` for which to calculate series.
"""

labels: Labels
@@ -41,7 +41,7 @@ def get_point_score_series(
"""Get series with statistic of point scores in each frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* sum
* min
@@ -67,7 +67,7 @@ def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]:
"""Get series with statistic of instance scores in each frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* sum
* min
@@ -93,7 +93,7 @@ def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, flo
same track) from the closest earlier labeled frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to point scores:
* sum
* mean
@@ -121,7 +121,7 @@ def get_primary_point_displacement_series(
Get sum of displacement for single node of each instance per frame.
Args:
video: The :class:`Video` for which to calculate statistic.
video: The `Video` for which to calculate statistic.
reduction: name of function applied to point scores:
* sum
* mean
@@ -226,7 +226,7 @@ def _calculate_frame_velocity(
Calculate total point displacement between two given frames.
Args:
lf: The :class:`LabeledFrame` for which we want velocity
lf: The `LabeledFrame` for which we want velocity
last_lf: The frame from which to calculate displacement.
reduce_function: Numpy function (e.g., np.sum, np.nanmean)
is applied to *point* displacement, and then those
@@ -246,3 +246,35 @@ def _calculate_frame_velocity(
inst_dist = reduce_function(point_dist)
val += inst_dist if not np.isnan(inst_dist) else 0
return val

def get_tracking_score_series(
self, video: Video, reduction: str = "min"
) -> Dict[int, float]:
"""Get series with statistic of tracking scores in each frame.
Args:
video: The `Video` for which to calculate statistic.
reduction: name of function applied to scores:
* mean
* min
Returns:
The series dictionary (see class docs for details)
"""
reduce_fn = {
"min": np.nanmin,
"mean": np.nanmean,
}[reduction]

series = dict()

for lf in self.labels.find(video):
vals = [
inst.tracking_score for inst in lf if hasattr(inst, "tracking_score")
]
if vals:
val = reduce_fn(vals)
if not np.isnan(val):
series[lf.frame_idx] = val

return series
27 changes: 24 additions & 3 deletions sleap/instance.py
Original file line number Diff line number Diff line change
@@ -1049,7 +1049,9 @@ def scores(self) -> np.ndarray:
return self.points_and_scores_array[:, 2]

@classmethod
def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance":
def from_instance(
cls, instance: Instance, score: float, tracking_score: float = 0.0
) -> "PredictedInstance":
"""Create a `PredictedInstance` from an `Instance`.
The fields are copied in a shallow manner with the exception of points. For each
@@ -1059,6 +1061,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance":
Args:
instance: The `Instance` object to shallow copy data from.
score: The score for this instance.
tracking_score: The tracking score for this instance.
Returns:
A `PredictedInstance` for the given `Instance`.
@@ -1070,6 +1073,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance":
)
kw_args["points"] = PredictedPointArray.from_array(instance._points)
kw_args["score"] = score
kw_args["tracking_score"] = tracking_score
return cls(**kw_args)

@classmethod
@@ -1080,6 +1084,7 @@ def from_arrays(
instance_score: float,
skeleton: Skeleton,
track: Optional[Track] = None,
tracking_score: float = 0.0,
) -> "PredictedInstance":
"""Create a predicted instance from data arrays.
@@ -1094,6 +1099,7 @@ def from_arrays(
skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the
predicted instance.
track: Optional `sleap.Track` to associate with the instance.
tracking_score: Optional float representing the track matching score.
Returns:
A new `PredictedInstance`.
@@ -1114,6 +1120,7 @@ def from_arrays(
skeleton=skeleton,
score=instance_score,
track=track,
tracking_score=tracking_score,
)

@classmethod
@@ -1124,6 +1131,7 @@ def from_pointsarray(
instance_score: float,
skeleton: Skeleton,
track: Optional[Track] = None,
tracking_score: float = 0.0,
) -> "PredictedInstance":
"""Create a predicted instance from data arrays.
@@ -1138,12 +1146,18 @@ def from_pointsarray(
skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the
predicted instance.
track: Optional `sleap.Track` to associate with the instance.
tracking_score: Optional float representing the track matching score.
Returns:
A new `PredictedInstance`.
"""
return cls.from_arrays(
points, point_confidences, instance_score, skeleton, track=track
points,
point_confidences,
instance_score,
skeleton,
track=track,
tracking_score=tracking_score,
)

@classmethod
@@ -1154,6 +1168,7 @@ def from_numpy(
instance_score: float,
skeleton: Skeleton,
track: Optional[Track] = None,
tracking_score: float = 0.0,
) -> "PredictedInstance":
"""Create a predicted instance from data arrays.
@@ -1168,12 +1183,18 @@ def from_numpy(
skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the
predicted instance.
track: Optional `sleap.Track` to associate with the instance.
tracking_score: Optional float representing the track matching score.
Returns:
A new `PredictedInstance`.
"""
return cls.from_arrays(
points, point_confidences, instance_score, skeleton, track=track
points,
point_confidences,
instance_score,
skeleton,
track=track,
tracking_score=tracking_score,
)


21 changes: 16 additions & 5 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
@@ -3778,9 +3778,10 @@ def _object_builder():
PredictedInstance.from_numpy(
points=pts,
point_confidences=confs,
instance_score=np.nanmean(score),
instance_score=np.nanmean(confs),
skeleton=skeleton,
track=track,
tracking_score=np.nanmean(score),
)
)

@@ -4452,18 +4453,27 @@ def _object_builder():
break

# Loop over frames.
for image, video_ind, frame_ind, points, confidences, scores in zip(
for (
image,
video_ind,
frame_ind,
centroid_vals,
points,
confidences,
scores,
) in zip(
ex["image"],
ex["video_ind"],
ex["frame_ind"],
ex["centroid_vals"],
ex["instance_peaks"],
ex["instance_peak_vals"],
ex["instance_scores"],
):
# Loop over instances.
predicted_instances = []
for i, (pts, confs, score) in enumerate(
zip(points, confidences, scores)
for i, (pts, centroid_val, confs, score) in enumerate(
zip(points, centroid_vals, confidences, scores)
):
if np.isnan(pts).all():
continue
@@ -4474,9 +4484,10 @@ def _object_builder():
PredictedInstance.from_numpy(
points=pts,
point_confidences=confs,
instance_score=np.nanmean(score),
instance_score=centroid_val,
skeleton=skeleton,
track=track,
tracking_score=score,
)
)

Binary file added tests/data/tracks/clip.predictions.slp
Binary file not shown.
14 changes: 14 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,20 @@ def min_tracks_2node_labels():
)


@pytest.fixture
def min_tracks_2node_predictions():
"""
Generated with:
```
sleap-track -m "tests/data/models/min_tracks_2node.UNet.bottomup_multiclass" "tests/data/tracks/clip.mp4"
```
"""
return Labels.load_file(
"tests/data/tracks/clip.predictions.slp",
video_search=["tests/data/tracks/clip.mp4"],
)


@pytest.fixture
def min_tracks_13node_labels():
return Labels.load_file(
15 changes: 14 additions & 1 deletion tests/info/test_summary.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,19 @@ def test_frame_statistics(simple_predictions):

x = stats.get_point_displacement_series(video, "max")
assert len(x) == 2
assert len(x) == 2
assert x[0] == 0
assert x[1] == 18.0


def test_get_tracking_score_series(min_tracks_2node_predictions):

stats = StatisticSeries(min_tracks_2node_predictions)
x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "min")
assert len(x) == 1500
assert x[0] == 0.9999966621398926
assert x[1000] == 0.9998022317886353

x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "mean")
assert len(x) == 1500
assert x[0] == 0.9999983310699463
assert x[1000] == 0.9999011158943176

0 comments on commit e154749

Please sign in to comment.