Skip to content

Commit

Permalink
Implement vesicle post-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 8, 2024
1 parent 224ef6c commit 451ff69
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 11 deletions.
89 changes: 83 additions & 6 deletions scripts/otoferlin/automatic_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import numpy as np
import pandas as pd

from skimage.measure import label
from skimage.segmentation import relabel_sequential

from synapse_net.distance_measurements import measure_segmentation_to_object_distances, load_distances
from synapse_net.file_utils import read_mrc
from synapse_net.inference.vesicles import segment_vesicles
Expand All @@ -12,6 +15,12 @@

from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, get_adapted_model

# These are tomograms for which the sophisticated membrane processing fails.
# In this case, we just select the largest boundary piece.
SIMPLE_MEM_POSTPROCESSING = [
"Otof_TDAKO1blockA_GridN5_2_rec.mrc", "Otof_TDAKO2blockC_GridF5_1_rec.mrc", "Otof_TDAKO2blockC_GridF5_2_rec.mrc"
]


def _get_center_crop(input_):
halo_xy = (600, 600)
Expand Down Expand Up @@ -55,6 +64,13 @@ def process_vesicles(mrc_path, output_path, process_center_crop):
f.create_dataset(key, data=segmentation, compression="gzip")


def _simple_membrane_postprocessing(membrane_prediction):
seg = label(membrane_prediction)
ids, sizes = np.unique(seg, return_counts=True)
ids, sizes = ids[1:], sizes[1:]
return (seg == ids[np.argmax(sizes)]).astype("uint8")


def process_ribbon_structures(mrc_path, output_path, process_center_crop):
key = "segmentation/ribbon"
with h5py.File(output_path, "r") as f:
Expand All @@ -78,6 +94,12 @@ def process_ribbon_structures(mrc_path, output_path, process_center_crop):
return_predictions=True, n_slices_exclude=5,
)

# The distance based post-processing for membranes fails for some tomograms.
# In these cases, just choose the largest membrane piece.
fname = os.path.basename(mrc_path)
if fname in SIMPLE_MEM_POSTPROCESSING:
segmentations["membrane"] = _simple_membrane_postprocessing(predictions["membrane"])

if process_center_crop:
for name, seg in segmentations.items():
full_seg = np.zeros(full_shape, dtype=seg.dtype)
Expand All @@ -94,6 +116,49 @@ def process_ribbon_structures(mrc_path, output_path, process_center_crop):
f.create_dataset(f"prediction/{name}", data=predictions[name], compression="gzip")


def postprocess_vesicles(mrc_path, output_path, process_center_crop):
key = "segmentation/veiscles_postprocessed"
with h5py.File(output_path, "r") as f:
if key in f:
return
vesicles = f["segmentation/vesicles"][:]
if process_center_crop:
bb, full_shape = _get_center_crop(vesicles)
vesicles = vesicles[bb]
else:
bb = np.s_[:]

ribbon = f["segmentation/ribbon"][bb]
membrane = f["segmentation/membrane"][bb]

# Filter out small vesicle fragments.
min_size = 5000
ids, sizes = np.unique(vesicles, return_counts=True)
ids, sizes = ids[1:], sizes[1:]
filter_ids = ids[sizes < min_size]
vesicles[np.isin(vesicles, filter_ids)] = 0

input_, voxel_size = read_mrc(mrc_path)
voxel_size = tuple(voxel_size[ax] for ax in "zyx")
input_ = input_[bb]

# Filter out all vesicles farther than 120 nm from the membrane or ribbon.
max_dist = 120
seg = (ribbon + membrane) > 0
distances, _, _, seg_ids = measure_segmentation_to_object_distances(vesicles, seg, resolution=voxel_size)
filter_ids = seg_ids[distances > max_dist]
vesicles[np.isin(vesicles, filter_ids)] = 0

vesicles, _, _ = relabel_sequential(vesicles)

if process_center_crop:
full_seg = np.zeros(full_shape, dtype=vesicles.dtype)
full_seg[bb] = vesicles
vesicles = full_seg
with h5py.File(output_path, "a") as f:
f.create_dataset(key, data=vesicles, compression="gzip")


def measure_distances(mrc_path, seg_path, output_folder):
result_folder = os.path.join(output_folder, "distances")
if os.path.exists(result_folder):
Expand Down Expand Up @@ -171,20 +236,32 @@ def process_tomogram(mrc_path):

process_vesicles(mrc_path, output_path, process_center_crop)
process_ribbon_structures(mrc_path, output_path, process_center_crop)
return
# TODO vesicle post-processing:
# snap to boundaries?
# remove vesicles in ribbon
postprocess_vesicles(mrc_path, output_path, process_center_crop)

measure_distances(mrc_path, output_path, output_folder)
assign_vesicle_pools(output_folder)
# We don't need to do the analysis of the auto semgentation, it only
# makes sense to do this after segmentation. I am leaving this here for
# now, to move it to the files for analysis later.

# measure_distances(mrc_path, output_path, output_folder)
# assign_vesicle_pools(output_folder)


def main():
tomograms = get_all_tomograms()
for tomogram in tqdm(tomograms, desc="Process tomograms"):
process_tomogram(tomogram)

# Update the membrane postprocessing for the tomograms where this went wrong.
# for tomo in tqdm(tomograms, desc="Fix membrame postprocesing"):
# if os.path.basename(tomo) not in SIMPLE_MEM_POSTPROCESSING:
# continue
# seg_path = get_seg_path(tomo)
# with h5py.File(seg_path, "r") as f:
# pred = f["prediction/membrane"][:]
# seg = _simple_membrane_postprocessing(pred)
# with h5py.File(seg_path, "a") as f:
# f["segmentation/membrane"][:] = seg


if __name__:
main()
3 changes: 2 additions & 1 deletion scripts/otoferlin/check_automatic_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def main():
enumerate(tomograms), total=len(tomograms), desc="Visualize automatic segmentation results"
):
print("Checking tomogram", tomogram)
check_automatic_result(tomogram, version)
# check_automatic_result(tomogram, version, segmentation_group="vesicles")
check_automatic_result(tomogram, version, segmentation_group="prediction")
# check_automatic_result(tomogram, version, segmentation_group="prediction")


if __name__:
Expand Down
7 changes: 4 additions & 3 deletions scripts/otoferlin/check_structure_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ def check_structure_postprocessing(mrc_path, center_crop=True):
g = f["segmentation"]
for name in STRUCTURE_NAMES:
segmentations[f"seg/{name}"] = g[name][bb]
colormaps[name] = get_colormaps().get(name, None)

g = f["prediction"]
for name in STRUCTURE_NAMES:
predictions[f"pred/{name}"] = g[name][bb]
colormaps[name] = get_colormaps().get(name, None)

v = napari.Viewer()
v.add_image(tomogram)
for name, seg in segmentations.items():
v.add_labels(seg, name=name, colormap=colormaps.get(name.split("/")[1]))
for name, seg in predictions.items():
v.add_labels(seg, name=name, colormap=colormaps.get(name.split("/")[1]), visible=False)
for name, pred in predictions.items():
v.add_labels(pred, name=name, colormap=colormaps.get(name.split("/")[1]), visible=False)
v.title = os.path.basename(mrc_path)
napari.run()

Expand Down
5 changes: 4 additions & 1 deletion synapse_net/ground_truth/shape_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def refine_individual_vesicle_shapes(
edge_map: np.ndarray,
foreground_erosion: int = 4,
background_erosion: int = 8,
compactness: float = 0.5,
) -> np.ndarray:
"""Refine vesicle shapes by fitting vesicles to a boundary map.
Expand All @@ -215,6 +216,8 @@ def refine_individual_vesicle_shapes(
You can use `edge_filter` to compute this based on the tomogram.
foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
background_erosion: By how many pixels the background should be eroded in the seeds.
compactness: The compactness parameter passed to the watershed function.
Higher compactness leads to more regular sized vesicles.
Returns:
The refined vesicles.
"""
Expand Down Expand Up @@ -250,7 +253,7 @@ def fit_vesicle(prop):

# Run seeded watershed to fit the shapes.
seeds = fg_seed + 2 * bg_seed
seg[z] = watershed(hmap[z], seeds) == 1
seg[z] = watershed(hmap[z], seeds, compactness=compactness) == 1

# import napari
# v = napari.Viewer()
Expand Down

0 comments on commit 451ff69

Please sign in to comment.