Skip to content

Commit

Permalink
Merge pull request #2 from computational-cell-analytics/clean-up
Browse files Browse the repository at this point in the history
Clean up
  • Loading branch information
SarahMuth authored Jan 31, 2025
2 parents 80faf54 + 6fc36b7 commit f0ba5c0
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 85 deletions.
8 changes: 4 additions & 4 deletions inference/run_protein_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_volume(input_path):
return input_volume

def run_protein_detection(input_path, output_path, model_path):

tiling = parse_tiling(tile_shape=None, halo=None) #TODO implement tiling and halo choices
print(f"using tiling {tiling}")

Expand Down Expand Up @@ -82,14 +82,14 @@ def main():
args = parser.parse_args()

file = args.file

if file:
run_protein_detection(args.input_path, args.output_path, args.model_path)
else:
process_folder(args)


print("Finished segmenting!")

if __name__ == "__main__":
main()
main()
48 changes: 28 additions & 20 deletions training/train.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@
import os
from glob import glob
# from glob import glob
import argparse
import sys
sys.path.append("/user/muth9/u12095/czii-protein-challenge")

from utils.dataset_splits import get_paths
from utils.training import supervised_training
from utils.dataset_splits import get_paths # noqa
from utils.training import supervised_training # noqa

TRAIN_ROOT = "/scratch-grete/projects/nim00007/cryo-et/challenge-data/train/static/"
LABEL_ROOT = "/scratch-grete/projects/nim00007/cryo-et/challenge-data/train/overlay/"
TRAIN_ROOT = "/scratch-grete/projects/nim00007/cryo-et/challenge-data/train/static/"
LABEL_ROOT = "/scratch-grete/projects/nim00007/cryo-et/challenge-data/train/overlay/"
OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/cryo-et/czii_challenge/training"

def train(key, ignore_label = None, training_2D = False, testset = True, extension="zarr"):

datasets = [
"ExperimentRuns"
]
model_name="protein_detection_czii_v1"
def train(key, ignore_label=None, training_2D=False, testset=True, extension="zarr"):

output_path = os.path.join(OUTPUT_ROOT, model_name)
datasets = ["ExperimentRuns"]
model_name = "protein_detection_czii_v1"

output_path = os.path.join(OUTPUT_ROOT, model_name)
os.makedirs(output_path, exist_ok=True)

train_paths, train_label_paths = get_paths("train", datasets=datasets, train_root=TRAIN_ROOT, output_root=output_path, testset=testset, label_root=LABEL_ROOT)
val_paths, val_label_paths = get_paths("val", datasets=datasets, train_root=TRAIN_ROOT, output_root=output_path, testset=testset, label_root=LABEL_ROOT)
train_paths, train_label_paths = get_paths(
"train", datasets=datasets, train_root=TRAIN_ROOT,
output_root=output_path, testset=testset, label_root=LABEL_ROOT
)
val_paths, val_label_paths = get_paths(
"val", datasets=datasets, train_root=TRAIN_ROOT,
output_root=output_path, testset=testset, label_root=LABEL_ROOT
)

if testset:
test_paths, test_label_paths = get_paths("test", datasets=datasets, train_root=TRAIN_ROOT, output_root=output_path, testset=testset, label_root=LABEL_ROOT)
test_paths, test_label_paths = get_paths(
"test", datasets=datasets, train_root=TRAIN_ROOT,
output_root=output_path, testset=testset, label_root=LABEL_ROOT
)
else:
test_paths, test_label_paths = None, None

Expand All @@ -35,24 +43,24 @@ def train(key, ignore_label = None, training_2D = False, testset = True, extensi

patch_shape = [48, 256, 256]

batch_size = 2
batch_size = 2
check = False

#TODO do we want n_samples_train and n_samples_val in the supervised training?
# TODO do we want n_samples_train and n_samples_val in the supervised training?
supervised_training(
name=model_name,
train_paths=train_paths,
train_label_paths = train_label_paths,
train_label_paths=train_label_paths,
val_paths=val_paths,
val_label_paths = val_label_paths,
val_label_paths=val_label_paths,
patch_shape=patch_shape, batch_size=batch_size,
check=check,
lr=1e-4,
n_iterations=1e3,
out_channels=1,
augmentations=None,
eps=1e-5,
sigma=None,
eps=1e-5,
sigma=None,
lower_bound=None,
upper_bound=None,
test_paths=test_paths,
Expand Down
6 changes: 3 additions & 3 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .data_loader import CreateDataLoader
from .data_loader import create_data_loader
from .dataset_splits import get_paths
from .heatmap_loader import HeatmapLoader
from .heatmap_dataset import HeatmapDataset
from .training import supervised_training
from .prediction import get_prediction_torch_em
from .protein_detection import protein_detection
from .tiling_helper import parse_tiling
from .tiling_helper import parse_tiling
Binary file removed utils/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file removed utils/__pycache__/data_loader.cpython-311.pyc
Binary file not shown.
48 changes: 35 additions & 13 deletions utils/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,46 @@
from torch.utils.data import DataLoader
from .heatmap_loader import HeatmapLoader
from .heatmap_dataset import HeatmapDataset

def CreateDataLoader(train_images, train_labels, val_images, val_labels, test_images, test_labels, raw_transform, transform, patch_shape, num_workers, batch_size, eps=0.00001, sigma=None, lower_bound=None, upper_bound=None):


train_set = HeatmapLoader(train_images, train_labels, patch_shape, raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma, lower_bound=lower_bound, upper_bound=upper_bound)
val_set = HeatmapLoader(val_images, val_labels, patch_shape, raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma, lower_bound=lower_bound, upper_bound=upper_bound)
test_set = HeatmapLoader(test_images, test_labels, patch_shape, raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma, lower_bound=lower_bound, upper_bound=upper_bound)

def create_data_loader(
train_images, train_labels,
val_images, val_labels,
test_images,
test_labels,
raw_transform, transform,
patch_shape, num_workers, batch_size,
eps=0.00001, sigma=None,
lower_bound=None, upper_bound=None,
dataset_class=HeatmapDataset,
n_samples_train=None,
n_samples_val=None,
):
train_set = dataset_class(
train_images, train_labels, patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound, n_samples=n_samples_train,
)
val_set = dataset_class(
val_images, val_labels, patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound, n_samples=n_samples_val,
)
test_set = dataset_class(
test_images, test_labels, patch_shape,
raw_transform=raw_transform, transform=transform, eps=eps, sigma=sigma,
lower_bound=lower_bound, upper_bound=upper_bound
)

# put into DataLoader
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True,
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True,
num_workers=num_workers)

train_dataloader.shuffle = True
val_dataloader.shuffle = True
test_dataloader.shuffle = True

return train_dataloader, val_dataloader, test_dataloader
return train_dataloader, val_dataloader, test_dataloader
6 changes: 3 additions & 3 deletions utils/dataset_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _train_val_test_split(names):
continue

ds_path = os.path.join(train_root, ds)
#need to check if the dataset contains files or folders (like eg for zarr)
# need to check if the dataset contains files or folders (like eg for zarr)
if any(os.path.isfile(os.path.join(ds_path, f)) for f in os.listdir(ds_path)):
# If the dataset contains files
file_paths = sorted(glob(os.path.join(ds_path, f"*.{extension}")))
Expand All @@ -39,7 +39,7 @@ def _train_val_test_split(names):


def _require_train_val_split(datasets, train_root, output_root, extension):
train_ratio, val_ratio = 0.8, 0.2
train_ratio, val_ratio = 0.8, 0.2 # noqa

def _train_val_split(names):
train, val = train_test_split(names, test_size=1 - train_ratio, shuffle=True)
Expand All @@ -52,7 +52,7 @@ def _train_val_split(names):
continue

ds_path = os.path.join(train_root, ds)
#need to check if the dataset contains files or folders (like eg for zarr)
# need to check if the dataset contains files or folders (like eg for zarr)
if any(os.path.isfile(os.path.join(ds_path, f)) for f in os.listdir(ds_path)):
# If the dataset contains files
file_paths = sorted(glob(os.path.join(ds_path, f"*.{extension}")))
Expand Down
40 changes: 27 additions & 13 deletions utils/heatmap_loader.py → utils/heatmap_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
from torch.utils.data import Dataset
from torch_em.util import ensure_spatial_array, ensure_tensor_with_channels
# from torch.utils.data import Dataset
from torch_em.util import ensure_tensor_with_channels
import numpy as np

import zarr
import os

from data_processing.create_heatmap import process_tomogram

class HeatmapLoader(torch.utils.data.Dataset):

class HeatmapDataset(torch.utils.data.Dataset):
max_sampling_attempts = 500

def __init__(
Expand Down Expand Up @@ -49,6 +50,8 @@ def __init__(
self.lower_bound = lower_bound
self.upper_bound = upper_bound

# TODO: n_samples should be derived from how often the bounding box fits
# rather than the number of volumes (as in SegmentationDataset)
if n_samples is None:
self._len = len(self.raw_images)
self.sample_random_index = False
Expand All @@ -73,18 +76,31 @@ def _sample_bounding_box(self, shape):
for sh, psh in zip(shape, self.patch_shape)
]
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))

def _get_sample(self, index):
if self.sample_random_index:
index = np.random.randint(0, len(self.raw_images))
raw, label = self.raw_images[index], self.label_images[index]

#TODO this is specific for challenge zarr files now, maybe need to generalize in the future
#zarr_file = zarr.open(f"{raw}/VoxelSpacing10.000/denoised.zarr/0", mode='r')
# TODO this is specific for challenge zarr files now, maybe need to generalize in the future
# Yes ;). We can discuss this soon.

# zarr_file = zarr.open(f"{raw}/VoxelSpacing10.000/denoised.zarr/0", mode='r')
zarr_file = zarr.open(os.path.join(raw, "VoxelSpacing10.000", "denoised.zarr", "0"), mode='r')
raw = zarr_file[:]
#sigma is not really used in my process_tomogram ... TODO ?
label = process_tomogram(label, raw.shape, eps=self.eps, sigma=self.sigma, lower_bound=self.lower_bound, upper_bound=self.upper_bound)

# This was very inefficient!
# You first load the full data from zarr and then later load the bounding box.
# raw = zarr_file[:]
# Instead, you can just load the bounding box from the zarr
raw = zarr_file

# This is also quite inefficient.
# You compute he labels for the full tomogram, and then sub-sample.
# sigma is not really used in my process_tomogram ... TODO ?
label = process_tomogram(
label, raw.shape, eps=self.eps, sigma=self.sigma,
lower_bound=self.lower_bound, upper_bound=self.upper_bound
)

have_raw_channels = raw.ndim == 4 # 3D with channels
have_label_channels = label.ndim == 4
Expand Down Expand Up @@ -117,12 +133,11 @@ def _get_sample(self, index):
if have_raw_channels and len(prefix_box) == 0:
raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width


return raw_patch, label_patch

def __getitem__(self, index):
raw, labels = self._get_sample(index)
initial_label_dtype = labels.dtype
# initial_label_dtype = labels.dtype

if self.raw_transform is not None:
raw = self.raw_transform(raw)
Expand All @@ -133,7 +148,6 @@ def __getitem__(self, index):
if self.transform is not None:
raw, labels = self.transform(raw, labels)


raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
return raw, labels
return raw, labels
7 changes: 4 additions & 3 deletions utils/prediction.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
import time
import warnings
from glob import glob
from typing import Dict, Optional, Tuple
# from glob import glob
from typing import Dict # , Optional, Tuple

import numpy as np
import torch
import torch_em

from torch_em.util.prediction import predict_with_halo


def get_prediction_torch_em(
input_volume: np.ndarray, # [z, y, x]
tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
Expand Down Expand Up @@ -55,7 +56,7 @@ def get_prediction_torch_em(
# Suppress warning when loading the model.
with warnings.catch_warnings():
warnings.simplefilter("ignore")

if os.path.isdir(model_path): # Load the model from a torch_em checkpoint.
model = torch_em.util.load_model(checkpoint=model_path, device=device)
else: # Load the model directly from a serialized pytorch model.
Expand Down
6 changes: 3 additions & 3 deletions utils/protein_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from scipy.ndimage import gaussian_laplace, label, find_objects, center_of_mass
from scipy.ndimage.measurements import variance

def protein_detection(heatmap): #TODO do this properly

def protein_detection(heatmap): # TODO do this properly
"""
Detects local maxima and estimates sizes of Gaussians in a 3D heatmap.
Expand Down Expand Up @@ -45,10 +46,9 @@ def protein_detection(heatmap): #TODO do this properly
starts = [r.start for r in region] # Dynamically handle dimensions
com_global = tuple(com[i] + starts[i] for i in range(len(starts)))


detections.append({
'coordinates': com_global,
'size': size
})

return detections
return detections
3 changes: 2 additions & 1 deletion utils/tiling_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


def get_default_tiling():
"""Determine the tile shape and halo depending on the available VRAM.
"""
Expand Down Expand Up @@ -64,4 +65,4 @@ def parse_tiling(tile_shape, halo):
halo = dict(zip("zyx", halo))

tiling = {"tile": tile_shape, "halo": halo}
return tiling
return tiling
Loading

0 comments on commit f0ba5c0

Please sign in to comment.