diff --git a/dreem/io/config.py b/dreem/io/config.py index 53fd733..2147555 100644 --- a/dreem/io/config.py +++ b/dreem/io/config.py @@ -182,10 +182,17 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]: label_files = vid_files = None if dir_cfg: - labels_suff = dir_cfg.labels_suffix - vid_suff = dir_cfg.vid_suffix - labels_path = f"{dir_cfg.path}/*{labels_suff}" - vid_path = f"{dir_cfg.path}/*{vid_suff}" + labels_suff = dir_cfg.get("labels_suffix") + vid_suff = dir_cfg.get("vid_suffix") + if labels_suff is None or vid_suff is None: + raise KeyError( + f"Must provide a labels suffix and vid suffix to search for but found {labels_suff} and {vid_suff}!" + ) + dir_path = dir_cfg.get("path", ".") + logger.debug(f"Searching `{dir_path}` directory") + + labels_path = f"{dir_path}/*{labels_suff}" + vid_path = f"{dir_path}/*{vid_suff}" logger.debug(f"Searching for labels matching {labels_path}") label_files = glob.glob(labels_path) logger.debug(f"Searching for videos matching {vid_path}") diff --git a/tests/test_config.py b/tests/test_config.py index 74f8b36..b9d66d2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -54,7 +54,7 @@ def test_setter(base_config): assert cfg.cfg.test_config == -1 -def test_getters(base_config): +def test_getters(base_config, sleap_data_dir): """Test each getter function in the config class. Args: @@ -87,17 +87,38 @@ def test_getters(base_config): ds = cfg.get_dataset("train") assert ds.clip_length == 4 + assert len(ds.label_files) == len(ds.vid_files) == 1 ds = cfg.get_dataset("val") assert ds.clip_length == 8 ds = cfg.get_dataset("test") assert ds.clip_length == 16 + cfg.set_hparams( + { + "dataset.train_dataset.dir": { + "path": sleap_data_dir, + "labels_suffix": ".slp", + "vid_suffix": ".mp4", + } + } + ) + ds = cfg.get_dataset("train") + assert len(ds.label_files) == len(ds.vid_files) == 4 + optim = cfg.get_optimizer(model.parameters()) assert isinstance(optim, torch.optim.Adam) scheduler = cfg.get_scheduler(optim) assert isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + label_paths, data_path = cfg.get_data_paths(cfg.get("train_dataset", {})) + assert label_paths is None and data_path is None + + label_paths, data_path = cfg.get_data_paths( + {"dir": {"path": sleap_data_dir, "labels_suffix": ".slp", "vid_suffix": ".mp4"}} + ) + assert len(label_paths) == len(data_path) == 4 + def test_missing(base_config): """Test cases when keys are missing from config for expected behavior.