diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 108c2a65e..6a92c2e3b 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -25,15 +25,15 @@ Tell us a little about the system you're using. Please include information about how you installed. --> -- OS: +- OS: <!-- [e.g. ubuntu 20.04, macOS 11.0] --> -- Version(s): -<!-- e.g. [SLEAP v1.4.1a2, python 3.8] ---> -- SLEAP installation method (listed [here](https://sleap.ai/installation.html#)): - - [ ] [Conda from package](https://sleap.ai/installation.html#conda-package) - - [ ] [Conda from source](https://sleap.ai/installation.html#conda-from-source) - - [ ] [pip package](https://sleap.ai/installation.html#pip-package) - - [ ] [Apple Silicon Macs](https://sleap.ai/installation.html#apple-silicon-macs) +- Version(s): +<!-- e.g. [SLEAP v1.4.1, python 3.8] ---> +- SLEAP installation method (listed [here](https://sleap.ai/installation.html#)): + - [ ] [Conda from package](https://sleap.ai/installation.html#conda-package) + - [ ] [Conda from source](https://sleap.ai/installation.html#conda-from-source) + - [ ] [pip package](https://sleap.ai/installation.html#pip-package) + - [ ] [Apple Silicon Macs](https://sleap.ai/installation.html#apple-silicon-macs) <details><summary>Environment packages</summary> <!-- For reproduction, it's useful to have the full environment. For example, the output of `pip freeze` or `conda list` ---> diff --git a/.github/workflows/build_conda_ci.yml b/.github/workflows/build_conda_ci.yml index 0d5980730..3fd3d2b92 100644 --- a/.github/workflows/build_conda_ci.yml +++ b/.github/workflows/build_conda_ci.yml @@ -11,7 +11,7 @@ on: - "requirements.txt" - "dev_requirements.txt" - "environment_build.yml" - - ".github/workflows/build_conda_ci.yml" + - ".github/workflows/build_conda_ci.yml" # Run! # If RUN_BUILD_JOB is set to true, then RUN_ID will be overwritten to the current run id env: diff --git a/.github/workflows/build_pypi_ci.yml b/.github/workflows/build_pypi_ci.yml index c22cc5a69..68142b288 100644 --- a/.github/workflows/build_pypi_ci.yml +++ b/.github/workflows/build_pypi_ci.yml @@ -11,7 +11,7 @@ on: - "jupyter_requirements.txt" - "pypi_requirements.txt" - "environment_build.yml" - - ".github/workflows/build_pypi_ci.yml" + - ".github/workflows/build_pypi_ci.yml" # Run! jobs: build: diff --git a/.github/workflows/website.yml b/.github/workflows/website.yml index ede9eef9f..36c1d6ad7 100644 --- a/.github/workflows/website.yml +++ b/.github/workflows/website.yml @@ -7,8 +7,8 @@ on: branches: # 'main' triggers updates to 'sleap.ai', all others to 'sleap.ai/develop' - main - - develop - - liezl/update-intallation-docs-1.4.1 # again! + - develop # Run + - liezl/bump-to-1.4.1 paths: - "docs/**" - "README.rst" diff --git a/README.rst b/README.rst index 7cc9b27a3..f7a5acd6c 100644 --- a/README.rst +++ b/README.rst @@ -69,7 +69,7 @@ Quick install .. code-block:: bash - conda create -y -n sleap -c conda-forge -c nvidia -c sleap -c anaconda sleap + conda create -y -n sleap -c conda-forge -c nvidia -c sleap/label/dev -c sleap -c anaconda sleap `pip` **(any OS except Apple silicon)**: diff --git a/docs/_static/bonsai-connection.jpg b/docs/_static/bonsai-connection.jpg new file mode 100644 index 000000000..32b725416 Binary files /dev/null and b/docs/_static/bonsai-connection.jpg differ diff --git a/docs/_static/bonsai-filecapture.jpg b/docs/_static/bonsai-filecapture.jpg new file mode 100644 index 000000000..7a809d67a Binary files /dev/null and b/docs/_static/bonsai-filecapture.jpg differ diff --git a/docs/_static/bonsai-predictcentroids.jpg b/docs/_static/bonsai-predictcentroids.jpg new file mode 100644 index 000000000..e284f2338 Binary files /dev/null and b/docs/_static/bonsai-predictcentroids.jpg differ diff --git a/docs/_static/bonsai-predictposeidentities.jpg b/docs/_static/bonsai-predictposeidentities.jpg new file mode 100644 index 000000000..8582fd707 Binary files /dev/null and b/docs/_static/bonsai-predictposeidentities.jpg differ diff --git a/docs/_static/bonsai-predictposes.jpg b/docs/_static/bonsai-predictposes.jpg new file mode 100644 index 000000000..2e4f04a22 Binary files /dev/null and b/docs/_static/bonsai-predictposes.jpg differ diff --git a/docs/_static/bonsai-workflow.jpg b/docs/_static/bonsai-workflow.jpg new file mode 100644 index 000000000..0481c3dcf Binary files /dev/null and b/docs/_static/bonsai-workflow.jpg differ diff --git a/docs/conf.py b/docs/conf.py index 796497f6b..074869903 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,7 +28,7 @@ copyright = f"2019–{date.today().year}, Talmo Lab" # The short X.Y version -version = "1.4.1a2" +version = "1.4.1" # Get the sleap version # with open("../sleap/version.py") as f: @@ -36,7 +36,7 @@ # version = re.search("\d.+(?=['\"])", version_file).group(0) # Release should be the full branch name -release = "v1.4.1a2" +release = "v1.4.1" html_title = f"SLEAP ({release})" html_short_title = "SLEAP" diff --git a/docs/guides/bonsai.md b/docs/guides/bonsai.md new file mode 100644 index 000000000..d262873b6 --- /dev/null +++ b/docs/guides/bonsai.md @@ -0,0 +1,75 @@ +(bonsai)= + +# Using Bonsai with SLEAP + +Bonsai is a visual language for reactive programming and currently supports SLEAP models. + +:::{note} +Currently Bonsai supports only single instance, top-down and top-down-id SLEAP models. +::: + +### Exporting a SLEAP trained model + +Before we can import a trained model into Bonsai, we need to use the {code}`sleap-export` command to convert the model to a format supported by Bonsai. For example, to export a top-down-id model, the command is as follows: + +```bash +sleap-export -m centroid/model/folder/path -m top_down_id/model/folder/path -e exported/model/path +``` + +Please refer to the {ref}`sleap-export` docs for more details on using the command. + +This will generate the necessary `.pb` file and other information files required by Bonsai. In this example, these files were saved to the specified `exported/model/path` folder. + +The `exported/model/path` folder will have a structure like the following: + +```plaintext +exported/model/path +├── centroid_config.json +├── confmap_config.json +├── frozen_graph.pb +└── info.json +``` + +### Installing Bonsai and necessary packages + +1. Install Bonsai. See the [Bonsai installation instructions](https://bonsai-rx.org/docs/articles/installation.html). + +2. Download and add the necessary packages for Bonsai to run with SLEAP. See the official [Bonsai SLEAP documentation](https://github.com/bonsai-rx/sleap?tab=readme-ov-file#bonsai---sleap) for more information. + +### Using Bonsai SLEAP modules + +Once you have Bonsai installed with the required packages, you should be able to open the Bonsai application. The workflow must have a source module `FileCapture` which can be found in the toolbox search in the workflow editor. Provide the path to the video that was used to train the SLEAP model in the `FileName` field of the module. + + + +#### Top-down model +The top-down model requires both the `PredictCentroids` and the `PredictPoses` modules. + +The `PredictCentroids` module will predict the centroids of detections. There are two fields inside the `PredictCentroids` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centroid model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + + + +The `PredictPoses` module will predict the instances of detections. Similar to the `PredictCentroid` module, there are two fields inside the `PredictPoses` module: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the centered instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + + + +#### Top-Down-ID model +The `PredictPoseIdentities` module will predict the instances with identities. This module has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the top-down-id model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + + + +#### Single instance model +The `PredictSinglePose` module will predict the poses for single instance models. This module also has two fields: the `ModelFileName` field and the `TrainingConfig` field. The `TrainingConfig` field expects the path to the training config JSON file for the single instance model. The `ModelFileName` field expects the path to the `frozen_graph.pb` file in the `exported/model/path` folder. + +### Connecting the modules +Right-click on the `FileCapture` module and select **Create Connection**. Now click on the required SLEAP module to complete the connection. + + + +Once it is done, the workflow in Bonsai will look something like the following: + + + +Now you can click the green start button to run the workflow and you can add more modules to analyze and visualize the results in Bonsai. + +For more documentation on various modules and workflows, please refer to the [official Bonsai docs](https://bonsai-rx.org/docs/articles/editor.html). diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 134461c60..339c5405b 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -230,7 +230,7 @@ optional arguments: --tracking.kf_node_indices TRACKING.KF_NODE_INDICES For Kalman filter: Indices of nodes to track. (default: ) --tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT - For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) + For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) Kalman filters require TRACKING.KF_NODE_INDICES, TRACKING.MAX_TRACKING and TRACKING.MAX_TRACKS or TRACKING.TARGET_INSTANCE_COUNT, TRACKING.TRACKER to be simple or simplemaxtracks, and TRACKING.SIMILARITY to not be normalized_instance. ``` #### Examples: @@ -285,6 +285,12 @@ sleap-track --gpu 1 ... sleap-track -m "models/my_model" --frames 1000-2000 "input_video.mp4" ``` +**9. Use Kalman tracker (not recommended since flow is preferred):** + +```none +sleap-track -m "models/my_model" --tracking.similarity instance --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 --tracking.kf_init_frame_count 10 --tracking.kf_node_indices 0,1 -o "output_predictions.slp" "input_video.mp4" +``` + ## Dataset files (sleap-convert)= diff --git a/docs/guides/index.md b/docs/guides/index.md index 7eb55b2b2..6d773d9de 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -30,6 +30,10 @@ {ref}`remote-inference` when you trained models and you want to run inference on a different machine using a **command-line interface**. +## SLEAP with Bonsai + +{ref}`bonsai` when you want to analyze the trained SLEAP model to visualize the poses, centroids and identities for further visual analysis. + ```{toctree} :hidden: true :maxdepth: 2 @@ -44,4 +48,5 @@ proofreading colab custom-training remote +bonsai ``` diff --git a/docs/installation.md b/docs/installation.md index 4799a0893..2c1ef41be 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -5,12 +5,12 @@ SLEAP can be installed as a Python package on Windows, Linux, and Mac OS. For qu ````{tabs} ```{group-tab} Windows and Linux ```bash - conda create -y -n sleap -c conda-forge -c nvidia -c sleap -c anaconda sleap=1.4.1a2 + conda create -y -n sleap -c conda-forge -c nvidia -c sleap/label/dev -c sleap -c anaconda sleap=1.4.1 ``` ``` ```{group-tab} Mac OS ```bash - conda create -y -n sleap -c conda-forge -c anaconda -c sleap sleap=1.4.1a2 + conda create -y -n sleap -c conda-forge -c anaconda -c sleap sleap=1.4.1 ``` ``` ```` @@ -27,7 +27,7 @@ local: Installation requires entering commands in a terminal. To open one: ````{tabs} ```{tab} Windows - Open the *Start menu* and search for the *Anaconda Prompt* (if using Miniconda) or the *Command Prompt* if not. + Open the *Start menu* and search for the *Anaconda Prompt* (if using Miniconda) or the *Command Prompt* if not. ```{note} On Windows, our personal preference is to use alternative terminal apps like [Cmder](https://cmder.net) or [Windows Terminal](https://aka.ms/terminal). ``` @@ -66,7 +66,6 @@ If you don't have a `conda` package manager installation, here are some quick in Miniforge is a minimal installer for conda that includes the `conda` package manager and is maintained by the [conda-forge](https://conda-forge.org) community. The only difference between Miniforge and Miniconda is that Miniforge uses the `conda-forge` channel by default, which provides a much wider selection of community-maintained packages. - ````{tabs} ```{group-tab} Windows Open a new PowerShell terminal (does not need to be admin) and enter: @@ -135,20 +134,20 @@ This is a minimal installer for conda that includes the `conda` package manager See the [Miniconda website](https://docs.anaconda.com/free/miniconda/) for up-to-date installation instructions if the above instructions don't work for your system. - (installation-methods)= + ## Installation methods SLEAP can be installed three different ways: via {ref}`conda package<condapackage>`, {ref}`conda from source<condasource>`, or {ref}`pip package<pippackage>`. Select one of the methods below to install SLEAP. We recommend {ref}`conda package<condapackage>`. -````{tabs} +`````{tabs} ```{tab} conda package **This is the recommended installation method**. ````{tabs} ```{group-tab} Windows and Linux ```bash - conda create -y -n sleap -c conda-forge -c nvidia -c sleap -c anaconda sleap=1.4.1a2 - ``` + conda create -y -n sleap -c conda-forge -c nvidia -c sleap/label/dev -c sleap -c anaconda sleap=1.4.1 + ``` ```{note} - This comes with CUDA to enable GPU support. All you need is to have an NVIDIA GPU and [updated drivers](https://nvidia.com/drivers). - If you already have CUDA installed on your system, this will not conflict with it. @@ -157,7 +156,7 @@ SLEAP can be installed three different ways: via {ref}`conda package<condapackag ``` ```{group-tab} Mac OS ```bash - conda create -y -n sleap -c conda-forge -c anaconda -c sleap sleap=1.4.1a2 + conda create -y -n sleap -c conda-forge -c anaconda -c sleap sleap=1.4.1 ``` ```{note} This will also work in CPU mode if you don't have a GPU on your machine. @@ -222,7 +221,7 @@ SLEAP can be installed three different ways: via {ref}`conda package<condapackag ````{tabs} ```{group-tab} NVIDIA GPU ```bash - conda create --name sleap pip python=3.7.12 cudatoolkit=11.3 cudnn=8.2 + conda create --name sleap pip python=3.7.12 cudatoolkit=11.3 cudnn=8.2 -c conda-forge -c nvidia ``` ``` ```{group-tab} CPU or other GPU @@ -240,7 +239,7 @@ SLEAP can be installed three different ways: via {ref}`conda package<condapackag ``` 3. Finally, we can perform the `pip install`: ```bash - pip install sleap[pypi]==1.4.1a2 + pip install sleap[pypi]==1.4.1 ``` ```{note} The pypi distributed package of SLEAP ships with the following extras: @@ -256,7 +255,7 @@ SLEAP can be installed three different ways: via {ref}`conda package<condapackag ``` ```` ``` -```` +````` ## Testing that things are working diff --git a/sleap/config/frame_range_form.yaml b/sleap/config/frame_range_form.yaml new file mode 100644 index 000000000..3f01eade4 --- /dev/null +++ b/sleap/config/frame_range_form.yaml @@ -0,0 +1,13 @@ +main: + + - name: min_frame_idx + label: Minimum frame index + type: int + range: 1,1000000 + default: 1 + + - name: max_frame_idx + label: Maximum frame index + type: int + range: 1,1000000 + default: 1000 \ No newline at end of file diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 2dbceb3b7..8b711c806 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -656,17 +656,18 @@ def prev_vid(): key="edge style", ) + # XXX add_submenu_choices( menu=viewMenu, title="Node Marker Size", - options=(1, 2, 4, 6, 8, 12), + options=prefs["node marker sizes"], key="marker size", ) add_submenu_choices( menu=viewMenu, title="Node Label Size", - options=(6, 12, 18, 24, 36), + options=prefs["node label sizes"], key="node label size", ) @@ -804,6 +805,12 @@ def new_instance_menu_action(): "Delete Predictions beyond Max Instances...", self.commands.deleteInstanceLimitPredictions, ) + add_menu_item( + labelMenu, + "delete frame limit predictions", + "Delete Predictions beyond Frame Limit...", + self.commands.deleteFrameLimitPredictions, + ) ### Tracks Menu ### @@ -873,6 +880,8 @@ def new_instance_menu_action(): "Point Displacement (max)", "Primary Point Displacement (sum)", "Primary Point Displacement (max)", + "Tracking Score (mean)", + "Tracking Score (min)", "Instance Score (sum)", "Instance Score (min)", "Point Score (sum)", @@ -1331,7 +1340,7 @@ def updateStatusMessage(self, message: Optional[str] = None): message += f" [Hidden] Press '{hide_key}' to toggle." self.statusBar().setStyleSheet("color: red") else: - self.statusBar().setStyleSheet("color: black") + self.statusBar().setStyleSheet("") self.statusBar().showMessage(message) @@ -1406,6 +1415,8 @@ def _set_seekbar_header(self, graph_name: str): "Point Displacement (max)": data_obj.get_point_displacement_series, "Primary Point Displacement (sum)": data_obj.get_primary_point_displacement_series, "Primary Point Displacement (max)": data_obj.get_primary_point_displacement_series, + "Tracking Score (mean)": data_obj.get_tracking_score_series, + "Tracking Score (min)": data_obj.get_tracking_score_series, "Instance Score (sum)": data_obj.get_instance_score_series, "Instance Score (min)": data_obj.get_instance_score_series, "Point Score (sum)": data_obj.get_point_score_series, @@ -1419,7 +1430,7 @@ def _set_seekbar_header(self, graph_name: str): else: if graph_name in header_functions: kwargs = dict(video=self.state["video"]) - reduction_name = re.search("\\((sum|max|min)\\)", graph_name) + reduction_name = re.search("\\((sum|max|min|mean)\\)", graph_name) if reduction_name is not None: kwargs["reduction"] = reduction_name.group(1) series = header_functions[graph_name](**kwargs) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index dfc0dbad8..fca982327 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -49,6 +49,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.dialogs.merge import MergeDialog, ReplaceSkeletonTableDialog from sleap.gui.dialogs.message import MessageDialog from sleap.gui.dialogs.missingfiles import MissingFilesDialog +from sleap.gui.dialogs.frame_range import FrameRangeDialog from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track @@ -494,6 +495,10 @@ def deleteInstanceLimitPredictions(self): """Gui for deleting instances beyond some number in each frame.""" self.execute(DeleteInstanceLimitPredictions) + def deleteFrameLimitPredictions(self): + """Gui for deleting instances beyond some frame number.""" + self.execute(DeleteFrameLimitPredictions) + def completeInstanceNodes(self, instance: Instance): """Adds missing nodes to given instance.""" self.execute(AddMissingInstanceNodes, instance=instance) @@ -2472,6 +2477,36 @@ def ask(cls, context: CommandContext, params: dict) -> bool: return super().ask(context, params) +class DeleteFrameLimitPredictions(InstanceDeleteCommand): + @staticmethod + def get_frame_instance_list(context: CommandContext, params: Dict): + """Called from the parent `InstanceDeleteCommand.ask` method. + + Returns: + List of instances to be deleted. + """ + instances = [] + # Select the instances to be deleted + for lf in context.labels.labeled_frames: + if lf.frame_idx < (params["min_frame_idx"] - 1) or lf.frame_idx > ( + params["max_frame_idx"] - 1 + ): + instances.extend([(lf, inst) for inst in lf.instances]) + return instances + + @classmethod + def ask(cls, context: CommandContext, params: Dict) -> bool: + current_video = context.state["video"] + dialog = FrameRangeDialog( + title="Delete Instances in Frame Range...", max_frame_idx=len(current_video) + ) + results = dialog.get_results() + if results: + params["min_frame_idx"] = results["min_frame_idx"] + params["max_frame_idx"] = results["max_frame_idx"] + return super().ask(context, params) + + class TransposeInstances(EditCommand): topics = [UpdateTopic.project_instances, UpdateTopic.tracks] diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index f68dc0180..721bdc321 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -15,20 +15,17 @@ """ -from qtpy import QtCore, QtWidgets, QtGui - -import numpy as np import os - from operator import itemgetter +from pathlib import Path +from typing import Any, Callable, List, Optional -from typing import Any, Callable, Dict, List, Optional, Type +import numpy as np +from qtpy import QtCore, QtGui, QtWidgets -from sleap.gui.state import GuiState from sleap.gui.commands import CommandContext -from sleap.gui.color import ColorManager -from sleap.io.dataset import Labels -from sleap.instance import LabeledFrame, Instance +from sleap.gui.state import GuiState +from sleap.instance import LabeledFrame from sleap.skeleton import Skeleton @@ -386,10 +383,25 @@ def getSelectedRowItem(self) -> Any: class VideosTableModel(GenericTableModel): - properties = ("filename", "frames", "height", "width", "channels") - - def item_to_data(self, obj, item): - return {key: getattr(item, key) for key in self.properties} + properties = ( + "name", + "filepath", + "frames", + "height", + "width", + "channels", + ) + + def item_to_data(self, obj, item: "Video"): + data = {} + for property in self.properties: + if property == "name": + data[property] = Path(item.filename).name + elif property == "filepath": + data[property] = str(Path(item.filename).parent) + else: + data[property] = getattr(item, property) + return data class SkeletonNodesTableModel(GenericTableModel): diff --git a/sleap/gui/dialogs/frame_range.py b/sleap/gui/dialogs/frame_range.py new file mode 100644 index 000000000..7165dd939 --- /dev/null +++ b/sleap/gui/dialogs/frame_range.py @@ -0,0 +1,42 @@ +"""Frame range dialog.""" +from qtpy import QtWidgets +from sleap.gui.dialogs.formbuilder import FormBuilderModalDialog +from typing import Optional + + +class FrameRangeDialog(FormBuilderModalDialog): + def __init__(self, max_frame_idx: Optional[int] = None, title: str = "Frame Range"): + + super().__init__(form_name="frame_range_form") + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + if max_frame_idx is not None: + min_frame_idx_field.setRange(1, max_frame_idx) + min_frame_idx_field.setValue(1) + + max_frame_idx_field.setRange(1, max_frame_idx) + max_frame_idx_field.setValue(max_frame_idx) + + min_frame_idx_field.valueChanged.connect(self._update_max_frame_range) + max_frame_idx_field.valueChanged.connect(self._update_min_frame_range) + + self.setWindowTitle(title) + + def _update_max_frame_range(self, value): + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + max_frame_idx_field.setRange(value, max_frame_idx_field.maximum()) + + def _update_min_frame_range(self, value): + min_frame_idx_field = self.form_widget.fields["min_frame_idx"] + max_frame_idx_field = self.form_widget.fields["max_frame_idx"] + + min_frame_idx_field.setRange(min_frame_idx_field.minimum(), value) + + +if __name__ == "__main__": + app = QtWidgets.QApplication([]) + dialog = FrameRangeDialog(max_frame_idx=100) + print(dialog.get_results()) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 2c2617036..bc26d826c 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -637,6 +637,20 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen ) return items_for_inference + def _validate_id_model(self) -> bool: + """Make sure we have instances with tracks set for ID models.""" + if not self.labels.tracks: + message = "Cannot run ID model training without tracks." + return False + + found_tracks = False + for inst in self.labels.instances(): + if type(inst) == sleap.Instance and inst.track is not None: + found_tracks = True + break + + return found_tracks + def _validate_pipeline(self): can_run = True message = "" @@ -655,6 +669,15 @@ def _validate_pipeline(self): f"({', '.join(untrained)})." ) + # Make sure we have instances with tracks set for ID models. + if self.mode == "training" and self.current_pipeline in ( + "top-down-id", + "bottom-up-id", + ): + can_run = self.validate_id_model() + if not can_run: + message = "Cannot run ID model training without tracks." + # Make sure skeleton will be valid for bottom-up inference. if self.mode == "training" and self.current_pipeline == "bottom-up": skeleton = self.labels.skeletons[0] diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 949703020..08ee5bf36 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -62,6 +62,7 @@ QShortcut, QVBoxLayout, QWidget, + QPinchGesture, ) import sleap @@ -823,6 +824,8 @@ def __init__(self, state=None, player=None, *args, **kwargs): # Set icon as default background. self.setImage(QImage(sleap.util.get_package_file("gui/background.png"))) + self.grabGesture(Qt.GestureType.PinchGesture) + def dragEnterEvent(self, event): if self.parentWidget(): self.parentWidget().dragEnterEvent(event) @@ -1189,6 +1192,23 @@ def keyReleaseEvent(self, event): """Custom event hander, disables default QGraphicsView behavior.""" event.ignore() # Kicks the event up to parent + def event(self, event): + if event.type() == QtCore.QEvent.Gesture: + return self.handleGestureEvent(event) + return super().event(event) + + def handleGestureEvent(self, event): + gesture = event.gesture(Qt.GestureType.PinchGesture) + if gesture: + self.handlePinchGesture(gesture) + return True + + def handlePinchGesture(self, gesture: QPinchGesture): + if gesture.state() == Qt.GestureState.GestureUpdated: + factor = gesture.scaleFactor() + self.zoomFactor = max(factor * self.zoomFactor, 1) + self.updateViewer() + class QtNodeLabel(QGraphicsTextItem): """ @@ -1570,7 +1590,6 @@ def mousePressEvent(self, event): def mouseMoveEvent(self, event): """Custom event handler for mouse move.""" - # print(event) if self.dragParent: self.parentObject().mouseMoveEvent(event) else: @@ -1581,7 +1600,6 @@ def mouseMoveEvent(self, event): def mouseReleaseEvent(self, event): """Custom event handler for mouse release.""" - # print(event) self.unsetCursor() if self.dragParent: self.parentObject().mouseReleaseEvent(event) @@ -1610,6 +1628,10 @@ def mouseDoubleClickEvent(self, event: QMouseEvent): view = scene.views()[0] view.instanceDoubleClicked.emit(self.parentObject().instance, event) + def hoverEnterEvent(self, event): + """Custom event handler for mouse hover enter.""" + return super().hoverEnterEvent(event) + class QtEdge(QGraphicsPolygonItem): """ @@ -1809,6 +1831,7 @@ def __init__( self.labels = {} self.labels_shown = True self._selected = False + self._is_hovering = False self._bounding_rect = QRectF() # Show predicted instances behind non-predicted ones @@ -1830,6 +1853,7 @@ def __init__( box_pen.setStyle(Qt.DashLine) box_pen.setCosmetic(True) self.box.setPen(box_pen) + self.setAcceptHoverEvents(True) # Add label for highlighted instance self.highlight_label = QtTextWithBackground(parent=self) @@ -1991,7 +2015,12 @@ def updateBox(self, *args, **kwargs): select this instance. """ # Only show box if instance is selected - op = 0.7 if self._selected else 0 + op = 0 + if self._selected: + op = 0.8 + elif self._is_hovering: + op = 0.4 + self.box.setOpacity(op) # Update the position for the box rect = self.getPointsBoundingRect() @@ -2085,6 +2114,16 @@ def paint(self, painter, option, widget=None): """Method required by Qt.""" pass + def hoverEnterEvent(self, event): + self._is_hovering = True + self.updateBox() + return super().hoverEnterEvent(event) + + def hoverLeaveEvent(self, event): + self._is_hovering = False + self.updateBox() + return super().hoverLeaveEvent(event) + class VisibleBoundingBox(QtWidgets.QGraphicsRectItem): """QGraphicsRectItem for user instance bounding boxes. @@ -2275,7 +2314,7 @@ def mouseReleaseEvent(self, event): self.parent.nodes[node_key].setPos(new_x, new_y) # Update the instance - self.parent.updatePoints(complete=True, user_change=True) + self.parent.updatePoints(complete=False, user_change=True) self.resizing = None diff --git a/sleap/info/summary.py b/sleap/info/summary.py index c6a6af60e..0cad1617e 100644 --- a/sleap/info/summary.py +++ b/sleap/info/summary.py @@ -21,7 +21,7 @@ class StatisticSeries: are frame index and value are some numerical value for the frame. Args: - labels: The :class:`Labels` for which to calculate series. + labels: The `Labels` for which to calculate series. """ labels: Labels @@ -41,7 +41,7 @@ def get_point_score_series( """Get series with statistic of point scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -67,7 +67,7 @@ def get_instance_score_series(self, video, reduction="sum") -> Dict[int, float]: """Get series with statistic of instance scores in each frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to scores: * sum * min @@ -93,7 +93,7 @@ def get_point_displacement_series(self, video, reduction="sum") -> Dict[int, flo same track) from the closest earlier labeled frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -121,7 +121,7 @@ def get_primary_point_displacement_series( Get sum of displacement for single node of each instance per frame. Args: - video: The :class:`Video` for which to calculate statistic. + video: The `Video` for which to calculate statistic. reduction: name of function applied to point scores: * sum * mean @@ -226,7 +226,7 @@ def _calculate_frame_velocity( Calculate total point displacement between two given frames. Args: - lf: The :class:`LabeledFrame` for which we want velocity + lf: The `LabeledFrame` for which we want velocity last_lf: The frame from which to calculate displacement. reduce_function: Numpy function (e.g., np.sum, np.nanmean) is applied to *point* displacement, and then those @@ -246,3 +246,35 @@ def _calculate_frame_velocity( inst_dist = reduce_function(point_dist) val += inst_dist if not np.isnan(inst_dist) else 0 return val + + def get_tracking_score_series( + self, video: Video, reduction: str = "min" + ) -> Dict[int, float]: + """Get series with statistic of tracking scores in each frame. + + Args: + video: The `Video` for which to calculate statistic. + reduction: name of function applied to scores: + * mean + * min + + Returns: + The series dictionary (see class docs for details) + """ + reduce_fn = { + "min": np.nanmin, + "mean": np.nanmean, + }[reduction] + + series = dict() + + for lf in self.labels.find(video): + vals = [ + inst.tracking_score for inst in lf if hasattr(inst, "tracking_score") + ] + if vals: + val = reduce_fn(vals) + if not np.isnan(val): + series[lf.frame_idx] = val + + return series diff --git a/sleap/instance.py b/sleap/instance.py index 08a5c6ae6..382ececf2 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1049,7 +1049,9 @@ def scores(self) -> np.ndarray: return self.points_and_scores_array[:, 2] @classmethod - def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": + def from_instance( + cls, instance: Instance, score: float, tracking_score: float = 0.0 + ) -> "PredictedInstance": """Create a `PredictedInstance` from an `Instance`. The fields are copied in a shallow manner with the exception of points. For each @@ -1059,6 +1061,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": Args: instance: The `Instance` object to shallow copy data from. score: The score for this instance. + tracking_score: The tracking score for this instance. Returns: A `PredictedInstance` for the given `Instance`. @@ -1070,6 +1073,7 @@ def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": ) kw_args["points"] = PredictedPointArray.from_array(instance._points) kw_args["score"] = score + kw_args["tracking_score"] = tracking_score return cls(**kw_args) @classmethod @@ -1080,6 +1084,7 @@ def from_arrays( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1094,6 +1099,7 @@ def from_arrays( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. @@ -1114,6 +1120,7 @@ def from_arrays( skeleton=skeleton, score=instance_score, track=track, + tracking_score=tracking_score, ) @classmethod @@ -1124,6 +1131,7 @@ def from_pointsarray( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1138,12 +1146,18 @@ def from_pointsarray( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) @classmethod @@ -1154,6 +1168,7 @@ def from_numpy( instance_score: float, skeleton: Skeleton, track: Optional[Track] = None, + tracking_score: float = 0.0, ) -> "PredictedInstance": """Create a predicted instance from data arrays. @@ -1168,12 +1183,18 @@ def from_numpy( skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. track: Optional `sleap.Track` to associate with the instance. + tracking_score: Optional float representing the track matching score. Returns: A new `PredictedInstance`. """ return cls.from_arrays( - points, point_confidences, instance_score, skeleton, track=track + points, + point_confidences, + instance_score, + skeleton, + track=track, + tracking_score=tracking_score, ) diff --git a/sleap/io/format/coco.py b/sleap/io/format/coco.py index 25122e4d0..44e7fb84a 100644 --- a/sleap/io/format/coco.py +++ b/sleap/io/format/coco.py @@ -180,6 +180,9 @@ def read( if flag == 0: # node not labeled for this instance + if (x, y) != (0, 0): + # If labeled but invisible, place the node at the coord + points[node] = Point(x, y, False) continue is_visible = flag == 2 diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index df061a289..f33af8e73 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -1140,9 +1140,11 @@ def export_model( info["predicted_tensors"] = tensors full_model = tf.function( - lambda x: sleap.nn.data.utils.unrag_example(model(x), numpy=False) - if unrag_outputs - else model(x) + lambda x: ( + sleap.nn.data.utils.unrag_example(model(x), numpy=False) + if unrag_outputs + else model(x) + ) ) full_model = full_model.get_concrete_function( @@ -3818,9 +3820,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=np.nanmean(confs), skeleton=skeleton, track=track, + tracking_score=np.nanmean(score), ) ) @@ -4502,18 +4505,27 @@ def _object_builder(): break # Loop over frames. - for image, video_ind, frame_ind, points, confidences, scores in zip( + for ( + image, + video_ind, + frame_ind, + centroid_vals, + points, + confidences, + scores, + ) in zip( ex["image"], ex["video_ind"], ex["frame_ind"], + ex["centroid_vals"], ex["instance_peaks"], ex["instance_peak_vals"], ex["instance_scores"], ): # Loop over instances. predicted_instances = [] - for i, (pts, confs, score) in enumerate( - zip(points, confidences, scores) + for i, (pts, centroid_val, confs, score) in enumerate( + zip(points, centroid_vals, confidences, scores) ): if np.isnan(pts).all(): continue @@ -4524,9 +4536,10 @@ def _object_builder(): PredictedInstance.from_numpy( points=pts, point_confidences=confs, - instance_score=np.nanmean(score), + instance_score=centroid_val, skeleton=skeleton, track=track, + tracking_score=score, ) ) @@ -5756,3 +5769,7 @@ def main(args: Optional[list] = None): "To retrack on predictions, must specify tracker. " "Use \"sleap-track --tracking.tracker ...' to specify tracker to use." ) + + +if __name__ == "__main__": + main() diff --git a/sleap/nn/tracker/kalman.py b/sleap/nn/tracker/kalman.py index 2b0343927..774a4634e 100644 --- a/sleap/nn/tracker/kalman.py +++ b/sleap/nn/tracker/kalman.py @@ -608,7 +608,7 @@ def remove_second_bests_from_cost_matrix( cost matrix with invalid matches set to specified invalid value. """ - valid_match_mask = np.full_like(cost_matrix, True, dtype=np.bool) + valid_match_mask = np.full_like(cost_matrix, True, dtype=bool) rows, columns = cost_matrix.shape diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..231b004f5 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -574,7 +574,7 @@ class Tracker(BaseTracker): max_tracking: bool = False # To enable maximum tracking. cleaner: Optional[Callable] = None # TODO: deprecate - target_instance_count: int = 0 + target_instance_count: int = 0 # TODO: deprecate pre_cull_function: Optional[Callable] = None post_connect_single_breaks: bool = False robust_best_instance: float = 1.0 @@ -824,8 +824,15 @@ def final_pass(self, frames: List[LabeledFrame]): # "tracking." # ) self.cleaner.run(frames) - elif self.target_instance_count and self.post_connect_single_breaks: + elif ( + self.target_instance_count or self.max_tracks + ) and self.post_connect_single_breaks: + if not self.target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + self.target_instance_count = self.max_tracks connect_single_track_breaks(frames, self.target_instance_count) + print("Connecting single track breaks.") def get_name(self): tracker_name = self.candidate_maker.__class__.__name__ @@ -850,7 +857,7 @@ def make_tracker_by_name( of_max_levels: int = 3, save_shifted_instances: bool = False, # Pre-tracking options to cull instances - target_instance_count: int = 0, + target_instance_count: int = 0, # TODO: deprecate target_instance_count pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, # Post-tracking options to connect broken tracks @@ -921,6 +928,7 @@ def make_tracker_by_name( pre_cull_function = None if target_instance_count and pre_cull_to_target: + # Right now this is not accessible from the GUI def pre_cull_function(inst_list): cull_frame_instances( @@ -940,11 +948,34 @@ def pre_cull_function(inst_list): pre_cull_function=pre_cull_function, max_tracking=max_tracking, max_tracks=max_tracks, - target_instance_count=target_instance_count, + target_instance_count=target_instance_count, # TODO: deprecate target_instance_count post_connect_single_breaks=post_connect_single_breaks, ) - if target_instance_count and kf_init_frame_count: + # Kalman filter requires deprecated target_instance_count + if (max_tracks or target_instance_count) and kf_init_frame_count: + if not kf_node_indices: + raise ValueError( + "Kalman filter requires node indices for instance tracking." + ) + + if tracker == "flow" or tracker == "flowmaxtracks": + # Tracking with Kalman filter requires initial tracker object to be simple + raise ValueError( + "Kalman filter requires simple tracker for initial tracking." + ) + + if similarity == "normalized_instance": + # Kalman filter doesnot support normalized_instance_similarity + raise ValueError( + "Kalman filter does not support normalized_instance_similarity." + ) + + if not target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + target_instance_count = max_tracks + kalman_obj = KalmanTracker.make_tracker( init_tracker=tracker_obj, init_frame_count=kf_init_frame_count, @@ -954,8 +985,10 @@ def pre_cull_function(inst_list): ) return kalman_obj - elif kf_init_frame_count and not target_instance_count: - raise ValueError("Kalman filter requires target instance count.") + elif kf_init_frame_count and not (max_tracks or target_instance_count): + raise ValueError( + "Kalman filter requires max tracks or target instance count." + ) else: return tracker_obj @@ -1369,6 +1402,10 @@ def cull_function(inst_list): if init_tracker.pre_cull_function is None: init_tracker.pre_cull_function = cull_function + print( + f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters." + ) + return cls( init_tracker=init_tracker, kalman_tracker=kalman_tracker, @@ -1386,6 +1423,7 @@ def track( untracked_instances: List[InstanceType], img: Optional[np.ndarray] = None, t: int = None, + **kwargs, ) -> List[InstanceType]: """Tracks individual frame, using Kalman filters if possible.""" @@ -1420,7 +1458,7 @@ def track( # Initialize the Kalman filters self.kalman_tracker.init_filters(self.init_set.instances) - # print(f"Kalman filters initialized (frame {t})") + print(f"Kalman filters initialized (frame {t})") # Clear the data used to init filters, so that if the filters # stop tracking and we need to re-init, we won't re-use the diff --git a/sleap/prefs.py b/sleap/prefs.py index 8790f1d3f..e043afc44 100644 --- a/sleap/prefs.py +++ b/sleap/prefs.py @@ -28,6 +28,8 @@ class Preferences(object): "node label size": 12, "show non-visible nodes": True, "share usage data": True, + "node marker sizes": (1, 2, 3, 4, 6, 8, 12), + "node label sizes": (6, 9, 12, 18, 24, 36), } _filename = "preferences.yaml" @@ -43,14 +45,14 @@ def load_(self): """Load preferences from file (regardless of whether loaded already).""" try: self._prefs = util.get_config_yaml(self._filename) - if not hasattr(self._prefs, "get"): - self._prefs = self._defaults - else: - self._prefs["trail length"] = self._prefs.get( - "trail length", self._defaults["trail length"] - ) except FileNotFoundError: - self._prefs = self._defaults + pass + + self._prefs = self._prefs or {} + + for k, v in self._defaults.items(): + if k not in self._prefs: + self._prefs[k] = v def save(self): """Save preferences to file.""" diff --git a/sleap/version.py b/sleap/version.py index 7711477cb..698710132 100644 --- a/sleap/version.py +++ b/sleap/version.py @@ -11,7 +11,7 @@ Must be a semver string, "aN" should be appended for alpha releases. """ -__version__ = "1.4.1a2" +__version__ = "1.4.1" def versions(): diff --git a/tests/data/tracks/clip.predictions.slp b/tests/data/tracks/clip.predictions.slp new file mode 100644 index 000000000..652e21302 Binary files /dev/null and b/tests/data/tracks/clip.predictions.slp differ diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index ec5dfbc29..c6507caec 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -97,6 +97,20 @@ def min_tracks_2node_labels(): ) +@pytest.fixture +def min_tracks_2node_predictions(): + """ + Generated with: + ``` + sleap-track -m "tests/data/models/min_tracks_2node.UNet.bottomup_multiclass" "tests/data/tracks/clip.mp4" + ``` + """ + return Labels.load_file( + "tests/data/tracks/clip.predictions.slp", + video_search=["tests/data/tracks/clip.mp4"], + ) + + @pytest.fixture def min_tracks_13node_labels(): return Labels.load_file( diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index 3d77c891f..389bb48a3 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -7,6 +7,7 @@ import pytest from qtpy import QtWidgets +import sleap from sleap.gui.learning.dialog import LearningDialog, TrainingEditorWidget from sleap.gui.learning.configs import ( TrainingConfigFilesWidget, @@ -429,3 +430,22 @@ def test_immutablilty_of_trained_config_info( # saving multiple configs from one config info. ld.save(output_dir=tmpdir) ld.save(output_dir=tmpdir) + + +def test_validate_id_model(qtbot, min_labels_slp, min_labels_slp_path): + app = MainWindow(no_usage_data=True) + ld = LearningDialog( + mode="training", + labels_filename=Path(min_labels_slp_path), + labels=min_labels_slp, + ) + assert not ld._validate_id_model() + + # Add track but don't assign it to instances + new_track = sleap.Track(name="new_track") + min_labels_slp.tracks.append(new_track) + assert not ld._validate_id_model() + + # Assign track to instances + min_labels_slp[0][0].track = new_track + assert ld._validate_id_model() diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index ffd382ab1..e19e00236 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -20,6 +20,7 @@ ReplaceVideo, OpenSkeleton, SaveProjectAs, + DeleteFrameLimitPredictions, get_new_version_filename, ) from sleap.instance import Instance, LabeledFrame @@ -851,6 +852,26 @@ def load_and_assert_changes(new_video_path: Path): shutil.move(new_video_path, expected_video_path) +def test_DeleteFrameLimitPredictions( + centered_pair_predictions: Labels, centered_pair_vid: Video +): + """Test deleting instances beyond a certain frame limit.""" + labels = centered_pair_predictions + + # Set-up command context + context = CommandContext.from_labels(labels) + context.state["video"] = centered_pair_vid + + # Set-up params for the command + params = {"min_frame_idx": 900, "max_frame_idx": 1000} + + instances_to_delete = DeleteFrameLimitPredictions.get_frame_instance_list( + context, params + ) + + assert len(instances_to_delete) == 2070 + + @pytest.mark.parametrize("export_extension", [".json.zip", ".slp"]) def test_exportLabelsPackage(export_extension, centered_pair_labels: Labels, tmpdir): def assert_loaded_package_similar(path_to_pkg: Path, sugg=False, pred=False): diff --git a/tests/info/test_summary.py b/tests/info/test_summary.py index 2cf76c166..672d97e63 100644 --- a/tests/info/test_summary.py +++ b/tests/info/test_summary.py @@ -37,6 +37,19 @@ def test_frame_statistics(simple_predictions): x = stats.get_point_displacement_series(video, "max") assert len(x) == 2 - assert len(x) == 2 assert x[0] == 0 assert x[1] == 18.0 + + +def test_get_tracking_score_series(min_tracks_2node_predictions): + + stats = StatisticSeries(min_tracks_2node_predictions) + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "min") + assert len(x) == 1500 + assert x[0] == 0.9999966621398926 + assert x[1000] == 0.9998022317886353 + + x = stats.get_tracking_score_series(min_tracks_2node_predictions.video, "mean") + assert len(x) == 1500 + assert x[0] == 0.9999983310699463 + assert x[1000] == 0.9999011158943176 diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..4a601ac00 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -2,13 +2,205 @@ import operator import os import time - +import pytest import sleap from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components from sleap.io.dataset import Labels, LabeledFrame +similarity_args = [ + "instance", + "normalized_instance", + "object_keypoint", + "centroid", + "iou", +] +match_args = ["hungarian", "greedy"] + + +@pytest.mark.parametrize( + "tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"] +) +@pytest.mark.parametrize("similarity", similarity_args) +@pytest.mark.parametrize("match", match_args) +def test_kalman_tracker( + tmpdir, centered_pair_predictions_slp_path, tracker_name, similarity, match +): + + if tracker_name == "flow" or tracker_name == "flowmaxtracks": + # Expecting ValueError for "flow" or "flowmaxtracks" due to Kalman filter requiring a simple tracker + with pytest.raises( + ValueError, + match="Kalman filter requires simple tracker for initial tracking.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + else: + # For simple or simplemaxtracks, continue with other tests + # Check for ValueError when similarity is "normalized_instance" + if similarity == "normalized_instance": + with pytest.raises( + ValueError, + match="Kalman filter does not support normalized_instance_similarity.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + return + + # Check for ValueError when kf_node_indices is None which is the default + with pytest.raises( + ValueError, + match="Kalman filter requires node indices for instance tracking.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + # Test for missing max_tracks and target_instance_count with kf_init_frame_count + with pytest.raises( + ValueError, + match="Kalman filter requires max tracks or target instance count.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + # Test with target_instance_count and without max_tracks + cli = ( + f"--tracking.tracker {tracker_name} " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + f"-o {tmpdir}/{tracker_name}_target_instance_count.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file(f"{tmpdir}/{tracker_name}_target_instance_count.slp") + assert len(labels.tracks) == 2 + + # Test with target_instance_count and with max_tracks + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp" + ) + assert len(labels.tracks) == 2 + + # Test with "--tracking.pre_cull_iou_threshold", "0.8" + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.pre_cull_iou_threshold 0.8 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp" + ) + assert len(labels.tracks) == 2 + + # Test with "--tracking.pre_cull_to_target", "1" + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.pre_cull_to_target 1 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp" + ) + assert len(labels.tracks) == 2 + + # Test with 'tracking.post_connect_single_breaks': 0 + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.post_connect_single_breaks 0 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp" + ) + assert len(labels.tracks) == 2 + + def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): cli = ( "--tracking.tracker simple "