Skip to content

Commit

Permalink
Implement micro-sam training script
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 15, 2025
1 parent 170e449 commit 2fd0344
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
3 changes: 3 additions & 0 deletions scripts/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This folder contains the scripts for training a 3D U-Net for cell segmentation i
It contains two relevant scripts:
- `check_training_data.py`, which visualizes the training data and annotations in napari.
- `train_distance_unet.py`, which trains the 3D U-Net.
- `train_micro_sam.py`, which fine-tunes a micro-sam model on the data.

Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training.
They will load the **image data** according to the following rules:
Expand All @@ -12,3 +13,5 @@ They will load the **image data** according to the following rules:

The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.

The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/`.
1 change: 0 additions & 1 deletion scripts/training/train_distance_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def select_paths(image_paths, label_paths, split, filter_empty):
assert len(image_paths) == len(label_paths)

n_files = len(image_paths)

train_fraction = 0.85

n_train = int(train_fraction * n_files)
Expand Down
72 changes: 72 additions & 0 deletions scripts/training/train_micro_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
from datetime import datetime

import numpy as np
import torch_em
from micro_sam.training import default_sam_loader, train_sam
from train_distance_unet import get_image_and_label_paths, select_paths

ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"


def raw_transform(x):
x = x.astype("float32")
min_, max_ = np.percentile(x, 1), np.percentile(x, 99)
x -= min_
x /= max_
x = np.clip(x, 0, 1)
return x * 255


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--root", "-i", help="The root folder with the annotated training crops.",
default=ROOT_CLUSTER,
)
parser.add_argument(
"--name", help="Optional name for the model to be trained. If not given the current date is used."
)
parser.add_argument(
"--n_objects_per_batch", "-n", type=int, default=15,
help="The number of objects to use during training. Set it to a lower value if you run out of GPU memory."
"The default value is 15."
)
args = parser.parse_args()

root = args.root
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name
name = f"cochlea_distance_unet_{run_name}"
n_objects_per_batch = args.n_objects_per_batch

image_paths, label_paths = get_image_and_label_paths(root)
train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True)
val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True)

patch_shape = (1, 256, 256)
sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=10)
max_sampling_attempts = 2500

train_loader = default_sam_loader(
raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None,
patch_shape=patch_shape, with_segmentation_decoder=True,
raw_transform=raw_transform, sampler=sampler, min_size=10,
num_workers=6, batch_size=1, is_train=True,
max_sampling_attempts=max_sampling_attempts,
)
val_loader = default_sam_loader(
raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None,
patch_shape=patch_shape, with_segmentation_decoder=True,
raw_transform=raw_transform, sampler=sampler, min_size=10,
num_workers=6, batch_size=1, is_train=False,
max_sampling_attempts=max_sampling_attempts,
)

train_sam(
name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader,
n_epochs=50, n_objects_per_batch=n_objects_per_batch,
)


if __name__ == "__main__":
main()

0 comments on commit 2fd0344

Please sign in to comment.