diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 2c2617036..bc26d826c 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -637,6 +637,20 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen ) return items_for_inference + def _validate_id_model(self) -> bool: + """Make sure we have instances with tracks set for ID models.""" + if not self.labels.tracks: + message = "Cannot run ID model training without tracks." + return False + + found_tracks = False + for inst in self.labels.instances(): + if type(inst) == sleap.Instance and inst.track is not None: + found_tracks = True + break + + return found_tracks + def _validate_pipeline(self): can_run = True message = "" @@ -655,6 +669,15 @@ def _validate_pipeline(self): f"({', '.join(untrained)})." ) + # Make sure we have instances with tracks set for ID models. + if self.mode == "training" and self.current_pipeline in ( + "top-down-id", + "bottom-up-id", + ): + can_run = self.validate_id_model() + if not can_run: + message = "Cannot run ID model training without tracks." + # Make sure skeleton will be valid for bottom-up inference. if self.mode == "training" and self.current_pipeline == "bottom-up": skeleton = self.labels.skeletons[0] diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index 3d77c891f..389bb48a3 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -7,6 +7,7 @@ import pytest from qtpy import QtWidgets +import sleap from sleap.gui.learning.dialog import LearningDialog, TrainingEditorWidget from sleap.gui.learning.configs import ( TrainingConfigFilesWidget, @@ -429,3 +430,22 @@ def test_immutablilty_of_trained_config_info( # saving multiple configs from one config info. ld.save(output_dir=tmpdir) ld.save(output_dir=tmpdir) + + +def test_validate_id_model(qtbot, min_labels_slp, min_labels_slp_path): + app = MainWindow(no_usage_data=True) + ld = LearningDialog( + mode="training", + labels_filename=Path(min_labels_slp_path), + labels=min_labels_slp, + ) + assert not ld._validate_id_model() + + # Add track but don't assign it to instances + new_track = sleap.Track(name="new_track") + min_labels_slp.tracks.append(new_track) + assert not ld._validate_id_model() + + # Assign track to instances + min_labels_slp[0][0].track = new_track + assert ld._validate_id_model()