Skip to content

Commit

Permalink
Qt widget for loading pose datasets as napari Points layers (#253)
Browse files Browse the repository at this point in the history
* initialise napari plugin development

* Create skeleton for napari plugin with collapsible widgets (#218)

* initialise napari plugin development

* initialise napari plugin development

* create  skeleton for napari plugin with collapsible widgets

* add basic widget smoke tests and allow headless testing

* do not depend on napari from pip

* include napari option in install instructions

* make meta_widget module private

* pin atlasapi version to avoid unnecessary dependencies

* pin napari >= 0.4.19 from conda-forge

* switched to pip install of napari[all]

* seperation of concerns in widget tests

* add pytest-mock dev dependency

* initialise napari plugin development

* initialise napari plugin development

* initialise napari plugin development

* Added loader widget for poses

* update widget tests

* simplify dependency on brainglobe-utils

* consistent monospace formatting for movement in public docstrings

* get rid of code that's only relevant for displaying Tracks

* enable visibility of napari layer tooltips

* renamed widget to PosesLoader

* make cmap optional in set_color_by method

* wrote unit tests for napari convert module

* wrote unit-tests for the layer styles module

* linkcheck ignore zenodo redirects

* move _sample_colormap out of PointsStyle class

* small refactoring in the loader widget

* Expand tests for loader widget

* added comments and docstrings to napari plugin tests

* refactored all napari tests into separate unit test folder

* added napari-video to dependencies

* replaced deprecated edge_width with border_width

* got rid of widget pytest fixtures

* remove duplicate word from docstring

* remove napari-video dependency

* include napari extras in docs requirements

* add test for _on_browse_clicked method

* getOpenFileName returns tuple, not str

* simplify poses_to_napari_tracks

Co-authored-by: Chang Huan Lo <[email protected]>

* [pre-commit.ci] pre-commit autoupdate (#338)

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.7.2](astral-sh/ruff-pre-commit@v0.6.9...v0.7.2)
- [github.com/pre-commit/mirrors-mypy: v1.11.2 → v1.13.0](pre-commit/mirrors-mypy@v1.11.2...v1.13.0)
- [github.com/mgedmin/check-manifest: 0.49 → 0.50](mgedmin/check-manifest@0.49...0.50)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Implement `compute_speed` and `compute_path_length` (#280)

* implement compute_speed and compute_path_length functions

* added speed to existing kinematics unit test

* rewrote compute_path_length with various nan policies

* unit test compute_path_length across time ranges

* fixed and refactor compute_path_length and its tests

* fixed docstring for compute_path_length

* Accept suggestion on docstring wording

Co-authored-by: Chang Huan Lo <[email protected]>

* Remove print statement from test

Co-authored-by: Chang Huan Lo <[email protected]>

* Ensure nan report is printed

Co-authored-by: Chang Huan Lo <[email protected]>

* adapt warning message match in test

* change 'any' to 'all'

* uniform wording across path length docstrings

* (mostly) leave time range validation to xarray slice

* refactored parameters for test across time ranges

* simplified test for path lenght with nans

* replace drop policy with ffill

* remove B905 ruff rule

* make pre-commit happy

---------

Co-authored-by: Chang Huan Lo <[email protected]>

* initialise napari plugin development

* initialise napari plugin development

* initialise napari plugin development

* initialise napari plugin development

* initialise napari plugin development

* avoid redefining duplicate attributes in child dataclass

* modify test case to match poses_to_napari_tracks simplification

* expected_log_messages should be a subset of captured messages

Co-authored-by: Chang Huan Lo <[email protected]>

* fix typo

Co-authored-by: Chang Huan Lo <[email protected]>

* use names for Qwidgets

* reorganised test_valid_poses_to_napari_tracks

* parametrised layer style tests

* delet integration test which was reintroduced after conflict resolution

* added test about file filters

* deleted obsolete loader widget file (had snuck back in due to conflict merging)

* combine tests for button callouts

Co-authored-by: Chang Huan Lo <[email protected]>

* Simplify test_layer_style_as_kwargs

Co-authored-by: Chang Huan Lo <[email protected]>

---------

Co-authored-by: Chang Huan Lo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent 9a91fc7 commit cc4155d
Show file tree
Hide file tree
Showing 13 changed files with 711 additions and 103 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-e .
-e .[napari]
linkify-it-py
myst-parser
nbsphinx
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
"https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error
]


myst_url_schemes = {
"http": None,
"https": None,
Expand Down
32 changes: 0 additions & 32 deletions movement/napari/_loader_widget.py

This file was deleted.

161 changes: 161 additions & 0 deletions movement/napari/_loader_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Widgets for loading movement datasets from file."""

import logging
from pathlib import Path

from napari.settings import get_settings
from napari.utils.notifications import show_warning
from napari.viewer import Viewer
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
QFormLayout,
QHBoxLayout,
QLineEdit,
QPushButton,
QSpinBox,
QWidget,
)

from movement.io import load_poses
from movement.napari.convert import poses_to_napari_tracks
from movement.napari.layer_styles import PointsStyle

logger = logging.getLogger(__name__)

# Allowed poses file suffixes for each supported source software
SUPPORTED_POSES_FILES = {
"DeepLabCut": ["*.h5", "*.csv"],
"LightningPose": ["*.csv"],
"SLEAP": ["*.h5", "*.slp"],
}


class PosesLoader(QWidget):
"""Widget for loading movement poses datasets from file."""

def __init__(self, napari_viewer: Viewer, parent=None):
"""Initialize the loader widget."""
super().__init__(parent=parent)
self.viewer = napari_viewer
self.setLayout(QFormLayout())
# Create widgets
self._create_source_software_widget()
self._create_fps_widget()
self._create_file_path_widget()
self._create_load_button()
# Enable layer tooltips from napari settings
self._enable_layer_tooltips()

def _create_source_software_widget(self):
"""Create a combo box for selecting the source software."""
self.source_software_combo = QComboBox()
self.source_software_combo.setObjectName("source_software_combo")
self.source_software_combo.addItems(SUPPORTED_POSES_FILES.keys())
self.layout().addRow("source software:", self.source_software_combo)

def _create_fps_widget(self):
"""Create a spinbox for selecting the frames per second (fps)."""
self.fps_spinbox = QSpinBox()
self.fps_spinbox.setObjectName("fps_spinbox")
self.fps_spinbox.setMinimum(1)
self.fps_spinbox.setMaximum(1000)
self.fps_spinbox.setValue(30)
self.layout().addRow("fps:", self.fps_spinbox)

def _create_file_path_widget(self):
"""Create a line edit and browse button for selecting the file path.
This allows the user to either browse the file system,
or type the path directly into the line edit.
"""
# File path line edit and browse button
self.file_path_edit = QLineEdit()
self.file_path_edit.setObjectName("file_path_edit")
self.browse_button = QPushButton("Browse")
self.browse_button.setObjectName("browse_button")
self.browse_button.clicked.connect(self._on_browse_clicked)
# Layout for line edit and button
self.file_path_layout = QHBoxLayout()
self.file_path_layout.addWidget(self.file_path_edit)
self.file_path_layout.addWidget(self.browse_button)
self.layout().addRow("file path:", self.file_path_layout)

def _create_load_button(self):
"""Create a button to load the file and add layers to the viewer."""
self.load_button = QPushButton("Load")
self.load_button.setObjectName("load_button")
self.load_button.clicked.connect(lambda: self._on_load_clicked())
self.layout().addRow(self.load_button)

def _on_browse_clicked(self):
"""Open a file dialog to select a file."""
file_suffixes = SUPPORTED_POSES_FILES[
self.source_software_combo.currentText()
]

file_path, _ = QFileDialog.getOpenFileName(
self,
caption="Open file containing predicted poses",
filter=f"Poses files ({' '.join(file_suffixes)})",
)

# A blank string is returned if the user cancels the dialog
if not file_path:
return

# Add the file path to the line edit (text field)
self.file_path_edit.setText(file_path)

def _on_load_clicked(self):
"""Load the file and add as a Points layer to the viewer."""
fps = self.fps_spinbox.value()
source_software = self.source_software_combo.currentText()
file_path = self.file_path_edit.text()
if file_path == "":
show_warning("No file path specified.")
return
ds = load_poses.from_file(file_path, source_software, fps)

self.data, self.props = poses_to_napari_tracks(ds)
logger.info("Converted poses dataset to a napari Tracks array.")
logger.debug(f"Tracks array shape: {self.data.shape}")

self.file_name = Path(file_path).name
self._add_points_layer()

self._set_playback_fps(fps)
logger.debug(f"Set napari playback speed to {fps} fps.")

def _add_points_layer(self):
"""Add the predicted poses to the viewer as a Points layer."""
# Style properties for the napari Points layer
points_style = PointsStyle(
name=f"poses: {self.file_name}",
properties=self.props,
)
# Color the points by individual if there are multiple individuals
# Otherwise, color by keypoint
n_individuals = len(self.props["individual"].unique())
points_style.set_color_by(
prop="individual" if n_individuals > 1 else "keypoint"
)
# Add the points layer to the viewer
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs())
logger.info("Added poses dataset as a napari Points layer.")

@staticmethod
def _set_playback_fps(fps: int):
"""Set the playback speed for the napari viewer."""
settings = get_settings()
settings.application.playback_fps = fps

@staticmethod
def _enable_layer_tooltips():
"""Toggle on tooltip visibility for napari layers.
This nicely displays the layer properties as a tooltip
when hovering over the layer in the napari viewer.
"""
settings = get_settings()
settings.appearance.layer_tooltip_visibility = True
6 changes: 3 additions & 3 deletions movement/napari/_meta_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
from napari.viewer import Viewer

from movement.napari._loader_widget import Loader
from movement.napari._loader_widgets import PosesLoader


class MovementMetaWidget(CollapsibleWidgetContainer):
Expand All @@ -18,9 +18,9 @@ def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__()

self.add_widget(
Loader(napari_viewer, parent=self),
PosesLoader(napari_viewer, parent=self),
collapsible=True,
widget_title="Load data",
widget_title="Load poses",
)

self.loader = self.collapsible_widgets[0]
Expand Down
73 changes: 73 additions & 0 deletions movement/napari/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Conversion functions from ``movement`` datasets to napari layers."""

import logging

import numpy as np
import pandas as pd
import xarray as xr

# get logger
logger = logging.getLogger(__name__)


def _construct_properties_dataframe(ds: xr.Dataset) -> pd.DataFrame:
"""Construct a properties DataFrame from a ``movement`` dataset."""
return pd.DataFrame(
{
"individual": ds.coords["individuals"].values,
"keypoint": ds.coords["keypoints"].values,
"time": ds.coords["time"].values,
"confidence": ds["confidence"].values.flatten(),
}
)


def poses_to_napari_tracks(ds: xr.Dataset) -> tuple[np.ndarray, pd.DataFrame]:
"""Convert poses dataset to napari Tracks array and properties.
Parameters
----------
ds : xr.Dataset
``movement`` dataset containing pose tracks, confidence scores,
and associated metadata.
Returns
-------
data : np.ndarray
napari Tracks array with shape (N, 4),
where N is n_keypoints * n_individuals * n_frames
and the 4 columns are (track_id, frame_idx, y, x).
properties : pd.DataFrame
DataFrame with properties (individual, keypoint, time, confidence).
Notes
-----
A corresponding napari Points array can be derived from the Tracks array
by taking its last 3 columns: (frame_idx, y, x). See the documentation
on the napari Tracks [1]_ and Points [2]_ layers.
References
----------
.. [1] https://napari.org/stable/howtos/layers/tracks.html
.. [2] https://napari.org/stable/howtos/layers/points.html
"""
n_frames = ds.sizes["time"]
n_individuals = ds.sizes["individuals"]
n_keypoints = ds.sizes["keypoints"]
n_tracks = n_individuals * n_keypoints
# Construct the napari Tracks array
# Reorder axes to (individuals, keypoints, frames, xy)
yx_cols = np.transpose(ds.position.values, (1, 2, 0, 3)).reshape(-1, 2)[
:, [1, 0] # swap x and y columns
]
# Each keypoint of each individual is a separate track
track_id_col = np.repeat(np.arange(n_tracks), n_frames).reshape(-1, 1)
time_col = np.tile(np.arange(n_frames), (n_tracks)).reshape(-1, 1)
data = np.hstack((track_id_col, time_col, yx_cols))
# Construct the properties DataFrame
# Stack 3 dimensions into a new single dimension named "tracks"
ds_ = ds.stack(tracks=("individuals", "keypoints", "time"))
properties = _construct_properties_dataframe(ds_)

return data, properties
64 changes: 64 additions & 0 deletions movement/napari/layer_styles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Dataclasses containing layer styles for napari."""

from dataclasses import dataclass, field

import numpy as np
import pandas as pd
from napari.utils.colormaps import ensure_colormap

DEFAULT_COLORMAP = "turbo"


@dataclass
class LayerStyle:
"""Base class for napari layer styles."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"

def as_kwargs(self) -> dict:
"""Return the style properties as a dictionary of kwargs."""
return self.__dict__


@dataclass
class PointsStyle(LayerStyle):
"""Style properties for a napari Points layer."""

symbol: str = "disc"
size: int = 10
border_width: int = 0
face_color: str | None = None
face_color_cycle: list[tuple] | None = None
face_colormap: str = DEFAULT_COLORMAP
text: dict = field(default_factory=lambda: {"visible": False})

def set_color_by(self, prop: str, cmap: str | None = None) -> None:
"""Set the face_color to a column in the properties DataFrame.
Parameters
----------
prop : str
The column name in the properties DataFrame to color by.
cmap : str, optional
The name of the colormap to use, otherwise use the face_colormap.
"""
if cmap is None:
cmap = self.face_colormap
self.face_color = prop
self.text["string"] = prop
n_colors = len(self.properties[prop].unique())
self.face_color_cycle = _sample_colormap(n_colors, cmap)


def _sample_colormap(n: int, cmap_name: str) -> list[tuple]:
"""Sample n equally-spaced colors from a napari colormap.
This includes the endpoints of the colormap.
"""
cmap = ensure_colormap(cmap_name)
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
return [tuple(cmap.colors[i]) for i in samples]
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ entry-points."napari.manifest".movement = "movement.napari:napari.yaml"

[project.optional-dependencies]
napari = [
"napari[all]>=0.4.19",
# the rest will be replaced by brainglobe-utils[qt]>=0.6 after release
"brainglobe-atlasapi>=2.0.7",
"brainglobe-utils>=0.5",
"qtpy",
"superqt",
"napari[all]>=0.5.0",
"brainglobe-utils[qt]>=0.6" # needed for collapsible widgets
]
dev = [
"pytest",
Expand Down
Loading

0 comments on commit cc4155d

Please sign in to comment.