Skip to content

Commit

Permalink
- support multiple crop sizes given as list in configs for multiple d…
Browse files Browse the repository at this point in the history
…ata paths

- handle case where single crop size given as int (vs list) for multiple data paths
- exception handling
- added unit tests
  • Loading branch information
shaikh58 committed Nov 25, 2024
1 parent 580fd2b commit 1547c7d
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 7 deletions.
3 changes: 2 additions & 1 deletion dreem/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dreem.datasets import data_utils
from dreem.io import Frame
from torch.utils.data import Dataset
from typing import Union
import numpy as np
import torch

Expand All @@ -15,7 +16,7 @@ def __init__(
label_files: list[str],
vid_files: list[str],
padding: int,
crop_size: int,
crop_size: Union[int, list[int]],
chunk: bool,
clip_length: int,
mode: str,
Expand Down
41 changes: 36 additions & 5 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import sleap_io as sio
import random
from pathlib import Path
import logging
from typing import Union, List
from dreem.io import Instance, Frame
from dreem.datasets import data_utils, BaseDataset
from torchvision.transforms import functional as tvf
Expand All @@ -21,8 +23,9 @@ def __init__(
self,
slp_files: list[str],
video_files: list[str],
data_dirs: list[str] = [],
padding: int = 5,
crop_size: int = 128,
crop_size: Union[int, list[int]] = 128,
anchors: int | list[str] | str = "",
chunk: bool = True,
clip_length: int = 500,
Expand Down Expand Up @@ -79,6 +82,7 @@ def __init__(
)

self.slp_files = slp_files
self.data_dirs = data_dirs # empty list, list of paths, or string of single path
self.video_files = video_files
self.padding = padding
self.crop_size = crop_size
Expand All @@ -95,7 +99,24 @@ def __init__(
self.anchors = [anchors]
else:
self.anchors = anchors



if not isinstance(self.data_dirs, list):
self.data_dirs = [self.data_dirs]

if not isinstance(self.crop_size, list):
# make a list so its handled consistently if multiple crops are used
if len(self.data_dirs) > 0: # for test mode, data_dirs is []
self.crop_size = [self.crop_size] * len(self.data_dirs)
else:
self.crop_size = [self.crop_size]


if len(self.data_dirs) > 0 and len(self.crop_size) != len(self.data_dirs):
raise ValueError(f"If a list of crop sizes or data directories are given,"
f"they must have the same length but got {len(self.crop_size)} "
f"and {len(self.data_dirs)}")

if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
Expand Down Expand Up @@ -140,6 +161,16 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

video_name = self.video_files[label_idx]

# get the correct crop size based on the video
video_par_path = Path(video_name).parent
if len(self.data_dirs) > 0:
for j, data_dir in enumerate(self.data_dirs):
if Path(data_dir) == video_par_path:
crop_size = self.crop_size[j]
break
else:
crop_size = self.crop_size[0]

vid_reader = self.videos[label_idx]

# img = vid_reader.get_data(0)
Expand Down Expand Up @@ -316,15 +347,15 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
data_utils.get_bbox(centroid, crop_size),
padding=self.padding,
)

if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
crop_size + 2 * self.padding,
crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
Expand Down
10 changes: 9 additions & 1 deletion dreem/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]:
lists of labels file paths and video file paths respectively
"""
dir_cfg = data_cfg.pop("dir", None)

label_files = vid_files = None

if dir_cfg:
Expand Down Expand Up @@ -253,6 +252,12 @@ def get_dataset(
raise ValueError(
"`mode` must be one of ['train', 'val','test'], not '{mode}'"
)

if "dir" in dataset_params:
self.data_dirs = dataset_params["dir"]["path"]
else:
self.data_dirs = []

if label_files is None or vid_files is None:
label_files, vid_files = self.get_data_paths(dataset_params)
# todo: handle this better
Expand All @@ -261,6 +266,9 @@ def get_dataset(
dataset_params["slp_files"] = label_files
if vid_files is not None:
dataset_params["video_files"] = vid_files

dataset_params["data_dirs"] = self.data_dirs

return SleapDataset(**dataset_params)

elif "tracks" in dataset_params or "source" in dataset_params:
Expand Down
51 changes: 51 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,57 @@ def test_sleap_dataset(two_flies):

assert len(train_ds) == ds_length

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
crop_size=128,
chunk=True,
clip_length=clip_length,
n_chunks=ds_length + 10000,
)

with pytest.raises(Exception):
train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs="./data/sleap",
crop_size=[128, 128],
chunk=True,
clip_length=clip_length,
n_chunks=ds_length + 10000,
)

with pytest.raises(Exception):
train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs=["./data/sleap", "./data/microscopy"],
crop_size=[128],
chunk=True,
clip_length=clip_length,
n_chunks=ds_length + 10000,
)

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs=["./data/sleap"],
crop_size=128,
chunk=True,
clip_length=clip_length,
n_chunks=ds_length + 10000,
)

train_ds = SleapDataset(
slp_files=[two_flies[0]],
video_files=[two_flies[1]],
data_dirs=["./data/sleap", "./data/microscopy"],
crop_size=128,
chunk=True,
clip_length=clip_length,
n_chunks=ds_length + 10000,
)


def test_icy_dataset(ten_icy_particles):
"""Test icy dataset logic.
Expand Down

0 comments on commit 1547c7d

Please sign in to comment.