Skip to content

Commit

Permalink
patch get_datapaths to use dict.get (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Oct 17, 2024
1 parent 89599bc commit a90b2be
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
15 changes: 11 additions & 4 deletions dreem/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,17 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]:
label_files = vid_files = None

if dir_cfg:
labels_suff = dir_cfg.labels_suffix
vid_suff = dir_cfg.vid_suffix
labels_path = f"{dir_cfg.path}/*{labels_suff}"
vid_path = f"{dir_cfg.path}/*{vid_suff}"
labels_suff = dir_cfg.get("labels_suffix")
vid_suff = dir_cfg.get("vid_suffix")
if labels_suff is None or vid_suff is None:
raise KeyError(
f"Must provide a labels suffix and vid suffix to search for but found {labels_suff} and {vid_suff}!"
)
dir_path = dir_cfg.get("path", ".")
logger.debug(f"Searching `{dir_path}` directory")

labels_path = f"{dir_path}/*{labels_suff}"
vid_path = f"{dir_path}/*{vid_suff}"
logger.debug(f"Searching for labels matching {labels_path}")
label_files = glob.glob(labels_path)
logger.debug(f"Searching for videos matching {vid_path}")
Expand Down
23 changes: 22 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_setter(base_config):
assert cfg.cfg.test_config == -1


def test_getters(base_config):
def test_getters(base_config, sleap_data_dir):
"""Test each getter function in the config class.
Args:
Expand Down Expand Up @@ -87,17 +87,38 @@ def test_getters(base_config):

ds = cfg.get_dataset("train")
assert ds.clip_length == 4
assert len(ds.label_files) == len(ds.vid_files) == 1
ds = cfg.get_dataset("val")
assert ds.clip_length == 8
ds = cfg.get_dataset("test")
assert ds.clip_length == 16

cfg.set_hparams(
{
"dataset.train_dataset.dir": {
"path": sleap_data_dir,
"labels_suffix": ".slp",
"vid_suffix": ".mp4",
}
}
)
ds = cfg.get_dataset("train")
assert len(ds.label_files) == len(ds.vid_files) == 4

optim = cfg.get_optimizer(model.parameters())
assert isinstance(optim, torch.optim.Adam)

scheduler = cfg.get_scheduler(optim)
assert isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)

label_paths, data_path = cfg.get_data_paths(cfg.get("train_dataset", {}))
assert label_paths is None and data_path is None

label_paths, data_path = cfg.get_data_paths(
{"dir": {"path": sleap_data_dir, "labels_suffix": ".slp", "vid_suffix": ".mp4"}}
)
assert len(label_paths) == len(data_path) == 4


def test_missing(base_config):
"""Test cases when keys are missing from config for expected behavior.
Expand Down

0 comments on commit a90b2be

Please sign in to comment.