Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement micro-sam training for cochlea data #18

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This folder contains the scripts for training a 3D U-Net for cell segmentation i
It contains two relevant scripts:
- `check_training_data.py`, which visualizes the training data and annotations in napari.
- `train_distance_unet.py`, which trains the 3D U-Net.
- `train_micro_sam.py`, which fine-tunes a micro-sam model on the data.

Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training.
They will load the **image data** according to the following rules:
Expand All @@ -12,3 +13,5 @@ They will load the **image data** according to the following rules:

The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.

The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/cochlea_micro_sam_<CURRENT_DATE>`.
1 change: 0 additions & 1 deletion scripts/training/train_distance_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def select_paths(image_paths, label_paths, split, filter_empty):
assert len(image_paths) == len(label_paths)

n_files = len(image_paths)

train_fraction = 0.85

n_train = int(train_fraction * n_files)
Expand Down
71 changes: 71 additions & 0 deletions scripts/training/train_micro_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import argparse
from datetime import datetime

import numpy as np
from micro_sam.training import default_sam_loader, train_sam
from train_distance_unet import get_image_and_label_paths, select_paths

ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"


def raw_transform(x):
x = x.astype("float32")
min_, max_ = np.percentile(x, 1), np.percentile(x, 99)
x -= min_
x /= max_
x = np.clip(x, 0, 1)
return x * 255


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--root", "-i", help="The root folder with the annotated training crops.",
default=ROOT_CLUSTER,
)
parser.add_argument(
"--name", help="Optional name for the model to be trained. If not given the current date is used."
)
parser.add_argument(
"--n_objects_per_batch", "-n", type=int, default=15,
help="The number of objects to use during training. Set it to a lower value if you run out of GPU memory."
"The default value is 15."
)
args = parser.parse_args()

root = args.root
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name
name = f"cochlea_micro_sam_{run_name}"
n_objects_per_batch = args.n_objects_per_batch

image_paths, label_paths = get_image_and_label_paths(root)
train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True)
val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True)

patch_shape = (1, 256, 256)
max_sampling_attempts = 2500

train_loader = default_sam_loader(
raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None,
patch_shape=patch_shape, with_segmentation_decoder=True,
raw_transform=raw_transform,
num_workers=6, batch_size=1, is_train=True,
max_sampling_attempts=max_sampling_attempts,
)
val_loader = default_sam_loader(
raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None,
patch_shape=patch_shape, with_segmentation_decoder=True,
raw_transform=raw_transform,
num_workers=6, batch_size=1, is_train=False,
max_sampling_attempts=max_sampling_attempts,
)

train_sam(
name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader,
n_epochs=50, n_objects_per_batch=n_objects_per_batch,
save_root=".",
)


if __name__ == "__main__":
main()
Loading