Skip to content

Commit

Permalink
Add check for instances with track assigned before training ID models
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Dec 16, 2024
1 parent 089b48e commit d233bea
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
23 changes: 23 additions & 0 deletions sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,20 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen
)
return items_for_inference

def _validate_id_model(self) -> bool:
"""Make sure we have instances with tracks set for ID models."""
if not self.labels.tracks:
message = "Cannot run ID model training without tracks."
return False

found_tracks = False
for inst in self.labels.instances():
if type(inst) == sleap.Instance and inst.track is not None:
found_tracks = True
break

return found_tracks

def _validate_pipeline(self):
can_run = True
message = ""
Expand All @@ -655,6 +669,15 @@ def _validate_pipeline(self):
f"({', '.join(untrained)})."
)

# Make sure we have instances with tracks set for ID models.
if self.mode == "training" and self.current_pipeline in (
"top-down-id",
"bottom-up-id",
):
can_run = self.validate_id_model()
if not can_run:
message = "Cannot run ID model training without tracks."

# Make sure skeleton will be valid for bottom-up inference.
if self.mode == "training" and self.current_pipeline == "bottom-up":
skeleton = self.labels.skeletons[0]
Expand Down
20 changes: 20 additions & 0 deletions tests/gui/learning/test_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from qtpy import QtWidgets

import sleap
from sleap.gui.learning.dialog import LearningDialog, TrainingEditorWidget
from sleap.gui.learning.configs import (
TrainingConfigFilesWidget,
Expand Down Expand Up @@ -429,3 +430,22 @@ def test_immutablilty_of_trained_config_info(
# saving multiple configs from one config info.
ld.save(output_dir=tmpdir)
ld.save(output_dir=tmpdir)


def test_validate_id_model(qtbot, min_labels_slp, min_labels_slp_path):
app = MainWindow(no_usage_data=True)
ld = LearningDialog(
mode="training",
labels_filename=Path(min_labels_slp_path),
labels=min_labels_slp,
)
assert not ld._validate_id_model()

# Add track but don't assign it to instances
new_track = sleap.Track(name="new_track")
min_labels_slp.tracks.append(new_track)
assert not ld._validate_id_model()

# Assign track to instances
min_labels_slp[0][0].track = new_track
assert ld._validate_id_model()

0 comments on commit d233bea

Please sign in to comment.