Skip to content

Commit

Permalink
fix n_sample and preload zarr stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahMuth committed Feb 11, 2025
1 parent c490326 commit 36bca53
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 93 deletions.
6 changes: 6 additions & 0 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,19 @@ def train(key, ignore_label=None, training_2D=False, testset=True, extension="za
batch_size = 2
check = False

#add the zarr file path ending to each path
train_paths = [os.path.join(path, "VoxelSpacing10.000", "denoised.zarr") for path in train_paths]
val_paths = [os.path.join(path, "VoxelSpacing10.000", "denoised.zarr") for path in val_paths]
test_paths = [os.path.join(path, "VoxelSpacing10.000", "denoised.zarr") for path in test_paths]

# TODO do we want n_samples_train and n_samples_val in the supervised training?
supervised_training(
name=model_name,
train_paths=train_paths,
train_label_paths=train_label_paths,
val_paths=val_paths,
val_label_paths=val_label_paths,
raw_key = "0",
patch_shape=patch_shape, batch_size=batch_size,
check=check,
lr=1e-4,
Expand Down
96 changes: 85 additions & 11 deletions utils/image.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,106 @@
import os
import zarr
from typing import Optional, Sequence, Union

import imageio.v3 as imageio
import numpy as np
from typing import Union, Sequence
from elf.io import open_file
from numpy.typing import ArrayLike

def load_image(image_path):
try:
import tifffile
except ImportError:
tifffile = None

TIF_EXTS = (".tif", ".tiff")


def supports_memmap(image_path):
"""@private
"""
if tifffile is None:
return False
ext = os.path.splitext(image_path)[1]
if ext.lower() not in TIF_EXTS:
return False
try:
tifffile.memmap(image_path, mode="r")
except ValueError:
return False
return True


def load_image(image_path, memmap=True):
"""@private
"""
if supports_memmap(image_path) and memmap:
return tifffile.memmap(image_path, mode="r")
elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"):
return tifffile.imread(image_path)
elif os.path.splitext(image_path)[1].lower() == ".nrrd":
import nrrd
return nrrd.read(image_path)[0]
elif os.path.splitext(image_path)[1].lower() == ".mha":
import SimpleITK as sitk
image = sitk.ReadImage(image_path)
return sitk.GetArrayFromImage(image)
else:
return imageio.imread(image_path)


class MultiDatasetWrapper:
"""@private
"""
#TODO make flexible for other data types
return zarr.open(os.path.join(image_path, "VoxelSpacing10.000", "denoised.zarr", "0"), mode='r')
def __init__(self, *file_datasets):
# Make sure we have the same shapes.
reference_shape = file_datasets[0].shape
assert all(reference_shape == ds.shape for ds in file_datasets)
self.file_datasets = file_datasets

self.shape = (len(self.file_datasets),) + reference_shape

def __getitem__(self, index):
channel_index, spatial_index = index[:1], index[1:]
data = []
for ds in self.file_datasets:
ds_data = ds[spatial_index]
data.append(ds_data)
data = np.stack(data)
data = data[channel_index]
return data



def load_data(
path: Union[str, Sequence[str]],
key: Optional[Union[str, Sequence[str]]] = None,
mode: str = "r",
) -> ArrayLike:
"""Load data from a file or multiple files.
Supports loading regular image formats, such as tif or jpg, or container data formats, such as hdf5, n5 or zarr.
For the latter case, specify the name of the internal dataset to load via the `key` argument.
Args:
path: The file path or paths to the data.
key: The key or keys to the internal datasets.
mode: The mode for reading datasets.
Returns:
The loaded data.
"""
#TODO can expand to also read in a key (check torch-em/torch_em/util/image.py)

have_single_file = isinstance(path, str)
have_single_key = isinstance(key, str)

if have_single_file:
return load_image(path)
if key is None:
if have_single_file:
return load_image(path)
else:
return np.stack([load_image(p) for p in path])
else:
return np.stack([load_image(p) for p in path])
if have_single_key and have_single_file:
return open_file(path, mode=mode)[key]
elif have_single_key and not have_single_file:
return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path])
elif not have_single_key and have_single_file:
return MultiDatasetWrapper(*[open_file(path, mode=mode)[k] for k in key])
else: # have multipe keys and multiple files
return MultiDatasetWrapper(*[open_file(p, mode=mode)[k] for k in key for p in path])
72 changes: 63 additions & 9 deletions utils/training/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,88 @@
from torch.utils.data import DataLoader
from .heatmap_dataset import HeatmapDataset
from torch_em.data.concat_dataset import ConcatDataset

def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
"""@private
"""
assert split in ("balanced", "uniform")
n_datasets = len(raw_paths)
if split == "uniform":
# even distribution of samples to datasets
samples_per_ds = n_samples // n_datasets
divider = n_samples % n_datasets
return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)]
else:
# distribution of samples to dataset based on the dataset lens
raise NotImplementedError

def _load_dataset(
raw_paths,
label_paths,
raw_transform, transform,
patch_shape,
raw_key=None,
eps=0.00001, sigma=None,
lower_bound=None, upper_bound=None,
dataset_class=HeatmapDataset,
n_samples=None,
):
print(f"in _load_dataset raw_paths {raw_paths}")

if isinstance(raw_paths, str):
print(f"in isinstance raw_paths {raw_paths}")
ds = dataset_class(
raw_path=raw_paths, raw_key=raw_key, label_path=label_paths, patch_shape=patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound, n_samples=n_samples,
)
else:
assert len(raw_paths) > 0

samples_per_ds = (
[None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
)
ds = []
for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
print(f"in else raw_path {raw_path}")

dset = dataset_class(
raw_path=raw_path, raw_key=raw_key, label_path=label_path, patch_shape=patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound, n_samples=samples_per_ds[i],
)

ds.append(dset)
ds = ConcatDataset(*ds)
return ds

def create_data_loader(
train_images, train_labels,
val_images, val_labels,
test_images,
test_labels,
test_images,test_labels,
raw_transform, transform,
patch_shape, num_workers, batch_size,
raw_key=None,
eps=0.00001, sigma=None,
lower_bound=None, upper_bound=None,
dataset_class=HeatmapDataset,
n_samples_train=None,
n_samples_val=None,
):
train_set = dataset_class(
train_images, train_labels, patch_shape,
train_set = _load_dataset(
raw_paths=train_images, raw_key=raw_key, label_paths=train_labels, patch_shape=patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound, n_samples=n_samples_train,
)
val_set = dataset_class(
val_images, val_labels, patch_shape,
val_set = _load_dataset(
raw_paths=val_images, raw_key=raw_key, label_paths=val_labels, patch_shape=patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound, n_samples=n_samples_val,
)
test_set = dataset_class(
test_images, test_labels, patch_shape,
test_set = _load_dataset(
raw_paths=test_images, raw_key=raw_key, label_paths=test_labels, patch_shape=patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound
lower_bound=lower_bound, upper_bound=upper_bound,
)

# put into DataLoader
Expand Down
Loading

0 comments on commit 36bca53

Please sign in to comment.