Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VisualPredictor class #2048

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions sleap/gui/overlays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sleap import Labels, Video
from sleap.gui.widgets.video import QtVideoPlayer
from sleap.nn.data.providers import VideoReader
from sleap.nn.inference import VisualPredictor
from sleap.nn.inference import VisualPredictorWrapper as VisualPredictor

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,7 +69,9 @@ def remove_from_scene(self):
try:
self.player.scene.removeItem(item)

except RuntimeError as e: # Internal C++ object (PySide2.QtWidgets.QGraphicsPathItem) already deleted.
except (
RuntimeError
) as e: # Internal C++ object (PySide2.QtWidgets.QGraphicsPathItem) already deleted.
logger.debug(e)

# Stop tracking the items after they been removed from the scene
Expand Down Expand Up @@ -97,9 +99,12 @@ class ModelData(Sequence):
def __getitem__(self, i: int) -> np.ndarray:
"""Data data for frame i from predictor."""
# Get predictions for frame i
frame_result = self.predictor.predict(VideoReader(self.video, [i]))
frame_result = self.predictor.predict(
VideoReader(self.video, [i]), make_labels=False
)

# We just want the single image results
print("results key = ", self.result_key)
frame_result = frame_result[0][self.result_key]

if self.adjust_vals:
Expand Down Expand Up @@ -160,7 +165,7 @@ def _add(

@classmethod
def make_predictor(cls, filename: str) -> VisualPredictor:
return VisualPredictor.from_trained_models(filename)
return VisualPredictor.from_model_paths(filename)

@classmethod
def from_model(cls, filename: str, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion sleap/nn/data/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def output_keys(self) -> List[Text]:

def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
test_ex = next(iter(input_ds))
input_shapes = [test_ex[k].shape for k in self.model_input_keys]
input_shapes = [test_ex[k].shape[1:] for k in self.model_input_keys]
input_layers = [tf.keras.layers.Input(shape) for shape in input_shapes]
keras_model = tf.keras.Model(input_layers, self.keras_model(input_layers))

Expand Down
119 changes: 113 additions & 6 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
function which provides a simplified interface for creating `Predictor`s.
"""

from __future__ import annotations

import attr
import argparse
import logging
Expand Down Expand Up @@ -589,6 +591,108 @@ def export_model(
)


class VisualPredictorWrapper:

def __init__(self, predictor: Predictor):

self.predictor = predictor
self.wrap_make_pipeline()

def __getattr__(self, name):
"""Pass through all unfound attributes to the wrapped predictor."""

# Check if the attribute is a method of the wrapped predictor.
return getattr(self.predictor, name)

@property
def model(self) -> Model:
model_by_predictor = {
BottomUpPredictor: "bottomup_model",
}
if type(self.predictor) in model_by_predictor:
return getattr(
self.predictor, model_by_predictor.get(type(self.predictor), "model")
)

raise ValueError(
f"Predictor type {type(self.predictor)} not yet supported by "
"VisualPredictor. Please select a different predictor type."
)

@property
def confidence_maps_key_name(self) -> str | None:

key_names = {
SingleInstancePredictor: "predicted_confidence_maps",
BottomUpPredictor: "predicted_confidence_maps",
TopDownPredictor: "predicted_centroid_confidence_maps",
}

if type(self.predictor) in key_names:
return key_names.get(type(self.predictor), None)

raise ValueError(
f"Predictor type {type(self.predictor)} not yet supported by "
"VisualPredictor. Please select a different predictor type."
)

@property
def part_affinity_fields_key_name(self) -> str | None:

if isinstance(self.predictor, BottomUpPredictor):
return "predicted_part_affinity_fields"

raise ValueError(
f"Predictor type {type(self.predictor)} cannot display Part Affinity Fields."
)

def head_specific_output_keys(self) -> List[Text]:
keys = []

key = self.confidence_maps_key_name
if key:
keys.append(key)

key = self.part_affinity_fields_key_name
if key:
keys.append(key)

return keys

def wrap_make_pipeline(self):
"""Wrap the `make_pipeline` method of the predictor to add additional logic."""
original_make_pipeline = self.predictor.make_pipeline

def wrapped_method(*args, **kwargs):
pipeline = original_make_pipeline(*args, **kwargs)
pipeline = self.extend_pipeline(pipeline)
return pipeline

self.predictor.make_pipeline = wrapped_method

def extend_pipeline(self, pipeline: Pipeline):
"""Extend the data pipeline for the predictor."""

pipeline += KerasModelPredictor(
keras_model=self.model.keras_model,
model_input_keys="image",
model_output_keys=self.head_specific_output_keys(),
)
self.predictor.pipeline = pipeline
return pipeline

@classmethod
def from_model_paths(cls, model_path: str) -> VisualPredictorWrapper:
"""Create the appropriate `Predictos` subclass from a list of model paths.

Args:
model_path: A single or list of trained model paths. Special cases of
non-SLEAP models include "movenet-thunder" and "movenet-lightning".
"""
predictor = Predictor.from_model_paths(model_path)
return cls(predictor=predictor)


# TODO: Rewrite this class.
@attr.s(auto_attribs=True)
class VisualPredictor(Predictor):
Expand All @@ -611,6 +715,13 @@ def from_trained_models(cls, model_path: Text) -> "VisualPredictor":

return cls(config=cfg, model=model)

def _initialize_inference_model(self):
"""Initialize the inference model from the trained model and configuration."""
pass

def is_grayscale(self) -> bool:
return self.model.keras_model.input.shape[-1] == 1

def head_specific_output_keys(self) -> List[Text]:
keys = []

Expand All @@ -626,10 +737,8 @@ def head_specific_output_keys(self) -> List[Text]:

@property
def confidence_maps_key_name(self) -> Optional[Text]:
head_key = self.config.model.heads.which_oneof_attrib_name()

if head_key in ("multi_instance", "single_instance"):
return "predicted_confidence_maps"
return "predicted_confidence_maps"

if head_key == "centroid":
return "predicted_centroid_confidence_maps"
Expand All @@ -640,10 +749,8 @@ def confidence_maps_key_name(self) -> Optional[Text]:

@property
def part_affinity_fields_key_name(self) -> Optional[Text]:
head_key = self.config.model.heads.which_oneof_attrib_name()

if head_key == "multi_instance":
return "predicted_part_affinity_fields"
return "predicted_part_affinity_fields"

return None

Expand Down
24 changes: 24 additions & 0 deletions tests/gui/test_overlays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Module to test all overlays in the sleap/gui/overlays/base.py."""

from sleap.gui.overlays.base import DataOverlay, ModelData


def test_data_overlay(qtbot, min_bottomup_model_path, centered_pair_vid):
"""Test the data overlay."""

model_path = min_bottomup_model_path
video = centered_pair_vid

predictor = DataOverlay.make_predictor(filename=model_path)

overlay = DataOverlay.from_predictor(
predictor=predictor,
video=video,
show_pafs=True,
)


if __name__ == "__main__":
import pytest

pytest.main([f"{__file__}::test_data_overlay"])
20 changes: 20 additions & 0 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
MoveNetPredictor,
MoveNetInferenceLayer,
MoveNetInferenceModel,
VisualPredictor,
VisualPredictorWrapper,
MOVENET_MODELS,
load_model,
export_model,
Expand Down Expand Up @@ -2074,3 +2076,21 @@ def test_top_down_model(min_tracks_2node_labels: Labels, min_centroid_model_path

# Runs without error message
predictor.predict(labels.extract(inds=[0, 1]))


def test_visual_predictor(min_bottomup_model_path):

# Test bottom-up model
model_path: str = min_bottomup_model_path
predictor = VisualPredictor.from_trained_models(model_path=model_path)


def test_visual_predictor_wrapper(min_bottomup_model_path):

# Test bottom-up model
model_path: str = min_bottomup_model_path
predictor = VisualPredictorWrapper.from_model_paths(model_path=model_path)


if __name__ == "__main__":
pytest.main([f"{__file__}::test_visual_predictor_wrapper"])
Loading