From cdba8e9ce38f5b4e85b7c280c143d5399dc545b1 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:34:15 -0800 Subject: [PATCH 1/3] Add failing test --- tests/nn/test_inference.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 0a978de0a..411905b64 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -46,6 +46,7 @@ MoveNetPredictor, MoveNetInferenceLayer, MoveNetInferenceModel, + VisualPredictor, MOVENET_MODELS, load_model, export_model, @@ -2074,3 +2075,8 @@ def test_top_down_model(min_tracks_2node_labels: Labels, min_centroid_model_path # Runs without error message predictor.predict(labels.extract(inds=[0, 1])) + + +def test_visual_predictor(min_bottomup_model_path): + model_path: str = min_bottomup_model_path + predictor = VisualPredictor.from_trained_models(model_path=model_path) From c3b12f5e59bbce74a8475f69f6317c25255b5606 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:25:19 -0800 Subject: [PATCH 2/3] Fix for bottom-up models --- sleap/nn/inference.py | 7 +++++++ tests/nn/test_inference.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 14e0d5c6f..dec829931 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -611,6 +611,13 @@ def from_trained_models(cls, model_path: Text) -> "VisualPredictor": return cls(config=cfg, model=model) + def _initialize_inference_model(self): + """Initialize the inference model from the trained model and configuration.""" + pass + + def is_grayscale(self) -> bool: + return self.model.keras_model.input.shape[-1] == 1 + def head_specific_output_keys(self) -> List[Text]: keys = [] diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 411905b64..a19904d75 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -2078,5 +2078,7 @@ def test_top_down_model(min_tracks_2node_labels: Labels, min_centroid_model_path def test_visual_predictor(min_bottomup_model_path): + + # Test bottom-up model model_path: str = min_bottomup_model_path predictor = VisualPredictor.from_trained_models(model_path=model_path) From dbdc78ffcc9953cfd0b293c3957374864e1e9ec6 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:43:52 -0800 Subject: [PATCH 3/3] [WIP] Create VisualPredictorWrapper class --- sleap/gui/overlays/base.py | 13 +++-- sleap/nn/data/inference.py | 2 +- sleap/nn/inference.py | 112 +++++++++++++++++++++++++++++++++++-- tests/gui/test_overlays.py | 24 ++++++++ tests/nn/test_inference.py | 12 ++++ 5 files changed, 152 insertions(+), 11 deletions(-) create mode 100644 tests/gui/test_overlays.py diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 879d12810..1604b0c6b 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -20,7 +20,7 @@ from sleap import Labels, Video from sleap.gui.widgets.video import QtVideoPlayer from sleap.nn.data.providers import VideoReader -from sleap.nn.inference import VisualPredictor +from sleap.nn.inference import VisualPredictorWrapper as VisualPredictor logger = logging.getLogger(__name__) @@ -69,7 +69,9 @@ def remove_from_scene(self): try: self.player.scene.removeItem(item) - except RuntimeError as e: # Internal C++ object (PySide2.QtWidgets.QGraphicsPathItem) already deleted. + except ( + RuntimeError + ) as e: # Internal C++ object (PySide2.QtWidgets.QGraphicsPathItem) already deleted. logger.debug(e) # Stop tracking the items after they been removed from the scene @@ -97,9 +99,12 @@ class ModelData(Sequence): def __getitem__(self, i: int) -> np.ndarray: """Data data for frame i from predictor.""" # Get predictions for frame i - frame_result = self.predictor.predict(VideoReader(self.video, [i])) + frame_result = self.predictor.predict( + VideoReader(self.video, [i]), make_labels=False + ) # We just want the single image results + print("results key = ", self.result_key) frame_result = frame_result[0][self.result_key] if self.adjust_vals: @@ -160,7 +165,7 @@ def _add( @classmethod def make_predictor(cls, filename: str) -> VisualPredictor: - return VisualPredictor.from_trained_models(filename) + return VisualPredictor.from_model_paths(filename) @classmethod def from_model(cls, filename: str, *args, **kwargs): diff --git a/sleap/nn/data/inference.py b/sleap/nn/data/inference.py index 772ac3f8b..16ec9ff48 100644 --- a/sleap/nn/data/inference.py +++ b/sleap/nn/data/inference.py @@ -35,7 +35,7 @@ def output_keys(self) -> List[Text]: def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset: test_ex = next(iter(input_ds)) - input_shapes = [test_ex[k].shape for k in self.model_input_keys] + input_shapes = [test_ex[k].shape[1:] for k in self.model_input_keys] input_layers = [tf.keras.layers.Input(shape) for shape in input_shapes] keras_model = tf.keras.Model(input_layers, self.keras_model(input_layers)) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 4281b5e27..80a25bd81 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -21,6 +21,8 @@ function which provides a simplified interface for creating `Predictor`s. """ +from __future__ import annotations + import attr import argparse import logging @@ -589,6 +591,108 @@ def export_model( ) +class VisualPredictorWrapper: + + def __init__(self, predictor: Predictor): + + self.predictor = predictor + self.wrap_make_pipeline() + + def __getattr__(self, name): + """Pass through all unfound attributes to the wrapped predictor.""" + + # Check if the attribute is a method of the wrapped predictor. + return getattr(self.predictor, name) + + @property + def model(self) -> Model: + model_by_predictor = { + BottomUpPredictor: "bottomup_model", + } + if type(self.predictor) in model_by_predictor: + return getattr( + self.predictor, model_by_predictor.get(type(self.predictor), "model") + ) + + raise ValueError( + f"Predictor type {type(self.predictor)} not yet supported by " + "VisualPredictor. Please select a different predictor type." + ) + + @property + def confidence_maps_key_name(self) -> str | None: + + key_names = { + SingleInstancePredictor: "predicted_confidence_maps", + BottomUpPredictor: "predicted_confidence_maps", + TopDownPredictor: "predicted_centroid_confidence_maps", + } + + if type(self.predictor) in key_names: + return key_names.get(type(self.predictor), None) + + raise ValueError( + f"Predictor type {type(self.predictor)} not yet supported by " + "VisualPredictor. Please select a different predictor type." + ) + + @property + def part_affinity_fields_key_name(self) -> str | None: + + if isinstance(self.predictor, BottomUpPredictor): + return "predicted_part_affinity_fields" + + raise ValueError( + f"Predictor type {type(self.predictor)} cannot display Part Affinity Fields." + ) + + def head_specific_output_keys(self) -> List[Text]: + keys = [] + + key = self.confidence_maps_key_name + if key: + keys.append(key) + + key = self.part_affinity_fields_key_name + if key: + keys.append(key) + + return keys + + def wrap_make_pipeline(self): + """Wrap the `make_pipeline` method of the predictor to add additional logic.""" + original_make_pipeline = self.predictor.make_pipeline + + def wrapped_method(*args, **kwargs): + pipeline = original_make_pipeline(*args, **kwargs) + pipeline = self.extend_pipeline(pipeline) + return pipeline + + self.predictor.make_pipeline = wrapped_method + + def extend_pipeline(self, pipeline: Pipeline): + """Extend the data pipeline for the predictor.""" + + pipeline += KerasModelPredictor( + keras_model=self.model.keras_model, + model_input_keys="image", + model_output_keys=self.head_specific_output_keys(), + ) + self.predictor.pipeline = pipeline + return pipeline + + @classmethod + def from_model_paths(cls, model_path: str) -> VisualPredictorWrapper: + """Create the appropriate `Predictos` subclass from a list of model paths. + + Args: + model_path: A single or list of trained model paths. Special cases of + non-SLEAP models include "movenet-thunder" and "movenet-lightning". + """ + predictor = Predictor.from_model_paths(model_path) + return cls(predictor=predictor) + + # TODO: Rewrite this class. @attr.s(auto_attribs=True) class VisualPredictor(Predictor): @@ -633,10 +737,8 @@ def head_specific_output_keys(self) -> List[Text]: @property def confidence_maps_key_name(self) -> Optional[Text]: - head_key = self.config.model.heads.which_oneof_attrib_name() - if head_key in ("multi_instance", "single_instance"): - return "predicted_confidence_maps" + return "predicted_confidence_maps" if head_key == "centroid": return "predicted_centroid_confidence_maps" @@ -647,10 +749,8 @@ def confidence_maps_key_name(self) -> Optional[Text]: @property def part_affinity_fields_key_name(self) -> Optional[Text]: - head_key = self.config.model.heads.which_oneof_attrib_name() - if head_key == "multi_instance": - return "predicted_part_affinity_fields" + return "predicted_part_affinity_fields" return None diff --git a/tests/gui/test_overlays.py b/tests/gui/test_overlays.py new file mode 100644 index 000000000..f269bfba5 --- /dev/null +++ b/tests/gui/test_overlays.py @@ -0,0 +1,24 @@ +"""Module to test all overlays in the sleap/gui/overlays/base.py.""" + +from sleap.gui.overlays.base import DataOverlay, ModelData + + +def test_data_overlay(qtbot, min_bottomup_model_path, centered_pair_vid): + """Test the data overlay.""" + + model_path = min_bottomup_model_path + video = centered_pair_vid + + predictor = DataOverlay.make_predictor(filename=model_path) + + overlay = DataOverlay.from_predictor( + predictor=predictor, + video=video, + show_pafs=True, + ) + + +if __name__ == "__main__": + import pytest + + pytest.main([f"{__file__}::test_data_overlay"]) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index a19904d75..498b64dd9 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -47,6 +47,7 @@ MoveNetInferenceLayer, MoveNetInferenceModel, VisualPredictor, + VisualPredictorWrapper, MOVENET_MODELS, load_model, export_model, @@ -2082,3 +2083,14 @@ def test_visual_predictor(min_bottomup_model_path): # Test bottom-up model model_path: str = min_bottomup_model_path predictor = VisualPredictor.from_trained_models(model_path=model_path) + + +def test_visual_predictor_wrapper(min_bottomup_model_path): + + # Test bottom-up model + model_path: str = min_bottomup_model_path + predictor = VisualPredictorWrapper.from_model_paths(model_path=model_path) + + +if __name__ == "__main__": + pytest.main([f"{__file__}::test_visual_predictor_wrapper"])