diff --git a/scripts/training/README.md b/scripts/training/README.md index becdaba..d43ddad 100644 --- a/scripts/training/README.md +++ b/scripts/training/README.md @@ -4,6 +4,7 @@ This folder contains the scripts for training a 3D U-Net for cell segmentation i It contains two relevant scripts: - `check_training_data.py`, which visualizes the training data and annotations in napari. - `train_distance_unet.py`, which trains the 3D U-Net. +- `train_micro_sam.py`, which fine-tunes a micro-sam model on the data. Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training. They will load the **image data** according to the following rules: @@ -12,3 +13,5 @@ They will load the **image data** according to the following rules: The training script will save the trained model in `checkpoints/cochlea_distance_unet_`, e.g. `checkpoints/cochlea_distance_unet_20250115`. For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`. + +The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/cochlea_micro_sam_`. diff --git a/scripts/training/train_distance_unet.py b/scripts/training/train_distance_unet.py index fbddb2d..575ed8e 100644 --- a/scripts/training/train_distance_unet.py +++ b/scripts/training/train_distance_unet.py @@ -46,7 +46,6 @@ def select_paths(image_paths, label_paths, split, filter_empty): assert len(image_paths) == len(label_paths) n_files = len(image_paths) - train_fraction = 0.85 n_train = int(train_fraction * n_files) diff --git a/scripts/training/train_micro_sam.py b/scripts/training/train_micro_sam.py new file mode 100644 index 0000000..c94b18a --- /dev/null +++ b/scripts/training/train_micro_sam.py @@ -0,0 +1,71 @@ +import argparse +from datetime import datetime + +import numpy as np +from micro_sam.training import default_sam_loader, train_sam +from train_distance_unet import get_image_and_label_paths, select_paths + +ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training" + + +def raw_transform(x): + x = x.astype("float32") + min_, max_ = np.percentile(x, 1), np.percentile(x, 99) + x -= min_ + x /= max_ + x = np.clip(x, 0, 1) + return x * 255 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--root", "-i", help="The root folder with the annotated training crops.", + default=ROOT_CLUSTER, + ) + parser.add_argument( + "--name", help="Optional name for the model to be trained. If not given the current date is used." + ) + parser.add_argument( + "--n_objects_per_batch", "-n", type=int, default=15, + help="The number of objects to use during training. Set it to a lower value if you run out of GPU memory." + "The default value is 15." + ) + args = parser.parse_args() + + root = args.root + run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name + name = f"cochlea_micro_sam_{run_name}" + n_objects_per_batch = args.n_objects_per_batch + + image_paths, label_paths = get_image_and_label_paths(root) + train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True) + val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True) + + patch_shape = (1, 256, 256) + max_sampling_attempts = 2500 + + train_loader = default_sam_loader( + raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None, + patch_shape=patch_shape, with_segmentation_decoder=True, + raw_transform=raw_transform, + num_workers=6, batch_size=1, is_train=True, + max_sampling_attempts=max_sampling_attempts, + ) + val_loader = default_sam_loader( + raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None, + patch_shape=patch_shape, with_segmentation_decoder=True, + raw_transform=raw_transform, + num_workers=6, batch_size=1, is_train=False, + max_sampling_attempts=max_sampling_attempts, + ) + + train_sam( + name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader, + n_epochs=50, n_objects_per_batch=n_objects_per_batch, + save_root=".", + ) + + +if __name__ == "__main__": + main()