diff --git a/README.md b/README.md index 8b0795ac..5d8897d5 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Please check out [the documentation](https://computational-cell-analytics.github We welcome new contributions! -If you are interested in contributing to micro-sam, please see the [contributing guide](docs/contributing.md) and [developer documentation](docs/development.md). The first step is to [discuss your idea in anew issue](https://github.com/computational-cell-analytics/micro-sam/issues/new) with the current developers. +If you are interested in contributing to micro-sam, please see the [contributing guide](doc/contributing.md) and [developer documentation](doc/development.md). The first step is to [discuss your idea in anew issue](https://github.com/computational-cell-analytics/micro-sam/issues/new) with the current developers. ## Citation diff --git a/finetuning/generalists/training/electron_microscopy/obtain_em_datasets.py b/finetuning/generalists/training/electron_microscopy/obtain_em_datasets.py index 7d0132b7..9b78d011 100644 --- a/finetuning/generalists/training/electron_microscopy/obtain_em_datasets.py +++ b/finetuning/generalists/training/electron_microscopy/obtain_em_datasets.py @@ -7,6 +7,7 @@ from skimage.segmentation import watershed from torch_em import get_data_loader +from torch_em.transform.raw import standardize from torch_em.transform.label import label_consecutive from torch_em.data import ConcatDataset, MinInstanceSampler, datasets @@ -20,8 +21,8 @@ def axondeepseg_label_trafo(labels): return seg -def raw_trafo_for_padding(raw): - desired_shape = (512, 512) +def raw_trafo_for_padding(raw, desired_shape=(512, 512)): + raw = standardize(raw) tmp_ddim = (desired_shape[0] - raw.shape[0], desired_shape[1] - raw.shape[1]) ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) raw = np.pad(raw, @@ -31,8 +32,7 @@ def raw_trafo_for_padding(raw): return raw -def label_trafo_for_padding(labels): - desired_shape = (512, 512) +def label_trafo_for_padding(labels, desired_shape=(512, 512)): labels = label(labels) labels = label_consecutive(labels) tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1]) diff --git a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py new file mode 100644 index 00000000..835417cb --- /dev/null +++ b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py @@ -0,0 +1,158 @@ +import os +import numpy as np +from math import ceil, floor +from typing import Optional, List + +from skimage import measure + +import torch +import torch.utils.data as data_util + +import torch_em +from torch_em.transform.raw import standardize +from torch_em.data import datasets, MinInstanceSampler, ConcatDataset + + +"""NOTE: test sets for in-domain histopathology evaluation + - monuseg test split + - monusac test split + - bcss test samples (split intrinsically - in the new PR) + +length of individual loaders: @all (3 channel input images) + - lizard: train - 718; val - 179 + - bcss: train - 108; val - 28 + - monuseg: train - 30; val - 7 + - monusac: train - 168; val - 41 + - pannuke: train - 1294; val - 680 +""" + + +def _get_train_val_split(ds, val_fraction: float = 0.2): + generator = torch.Generator().manual_seed(42) + train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator) + return train_ds, val_ds + + +class BCSSLabelTrafo: + def __init__(self, label_choices: Optional[List[int]] = None, do_connected_components: bool = False): + self.label_choices = label_choices + self.do_connected_components = do_connected_components + + def __call__(self, labels: np.ndarray) -> np.ndarray: + """Returns the transformed bcss data labels (use-case for SAM)""" + if self.label_choices is not None: + labels[~np.isin(labels, self.label_choices)] = 0 + + if self.do_connected_components: + segmentation = measure.label(labels) + else: + segmentation = label_consecutive_trafo(labels) + + return segmentation + + +def raw_padding_trafo(raw, desired_shape=(3, 512, 512)): + assert raw.shape[0] == 3, "The input shape isn't channels first, expected: (3, H, W)" + raw = standardize(raw) + tmp_ddim = (desired_shape[1] - raw.shape[1], desired_shape[2] - raw.shape[2]) + ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) + raw = np.pad( + raw, + pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), + mode="reflect" + ) + assert raw.shape == desired_shape + return raw + + +def label_padding_trafo(labels, desired_shape=(512, 512)): + tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1]) + ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) + labels = np.pad( + labels, + pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), + mode="reflect" + ) + assert labels.shape == desired_shape + labels = label_consecutive_trafo(labels) + return labels + + +def label_consecutive_trafo(labels): + labels = labels.astype(int) + labels = torch_em.transform.label.label_consecutive(labels) # to ensure consecutive IDs + return labels + + +def get_concat_hp_datasets(path, patch_shape): + label_dtype = torch.int64 + sampler = MinInstanceSampler(min_num_instances=5) + + # make lizard dataset splits into fractions + lizard_ds = datasets.get_lizard_dataset( + path=os.path.join(path, "lizard"), patch_shape=patch_shape, sampler=sampler, label_dtype=label_dtype, + raw_transform=raw_padding_trafo, label_transform=label_padding_trafo + ) + lizard_train_ds, lizard_val_ds = _get_train_val_split(ds=lizard_ds) + lizard_train_ds.ndim = 2 + lizard_val_ds.ndim = 2 + + # get bcss internal splits + bcss_train_ds = datasets.get_bcss_dataset( + path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(), + label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype + ) + bcss_val_ds = datasets.get_bcss_dataset( + path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="val", sampler=MinInstanceSampler(), + label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype + ) + + # make monuseg train dataset splits into fractions + monuseg_ds = datasets.get_monuseg_dataset( + path=os.path.join(path, "monuseg"), patch_shape=patch_shape, split="train", sampler=sampler, + label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype + ) + monuseg_train_ds, monuseg_val_ds = _get_train_val_split(ds=monuseg_ds) + + # make monusac train dataset splits into fractions + monusac_ds = datasets.get_monusac_dataset( + path=os.path.join(path, "monusac"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(), + label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype + ) + monusac_train_ds, monusac_val_ds = _get_train_val_split(ds=monusac_ds) + + # out of three folds (sets of data) of provided data, we use two for training and 1 for validation + pannuke_train_ds = datasets.get_pannuke_dataset( + path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_1", "fold_2"], + label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype + ) + pannuke_val_ds = datasets.get_pannuke_dataset( + path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_3"], + label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype + ) + + generalist_hp_train_dataset = ConcatDataset( + lizard_train_ds, bcss_train_ds, monuseg_train_ds, monusac_train_ds, pannuke_train_ds + ) + + generalist_hp_val_dataset = ConcatDataset( + lizard_val_ds, bcss_val_ds, monuseg_val_ds, monusac_val_ds, pannuke_val_ds + ) + + return generalist_hp_train_dataset, generalist_hp_val_dataset + + +def get_generalist_hp_loaders(patch_shape, data_path): + """This returns the concatenated histopathology datasets implemented in `torch_em`: + https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets + It will automatically download all the datasets + + NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) + in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format. + i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. + IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. + """ + generalist_train_dataset, generalist_val_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape) + train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) + val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) + return train_loader, val_loader diff --git a/finetuning/generalists/training/histopathology/train_histopathology_generalist.py b/finetuning/generalists/training/histopathology/train_histopathology_generalist.py index 1fda098d..893196c4 100644 --- a/finetuning/generalists/training/histopathology/train_histopathology_generalist.py +++ b/finetuning/generalists/training/histopathology/train_histopathology_generalist.py @@ -1,62 +1,50 @@ import os +import argparse -import micro_sam.training as sam_training import torch -import torch_em +from torch_em.loss import DiceLoss -import torch.utils.data as data_util -from torch_em.data.datasets import get_lizard_dataset -from torch_em.data.sampler import MinInstanceSampler +import micro_sam.training as sam_training from micro_sam.util import export_custom_sam_model +from obtain_hp_datasets import get_generalist_hp_loaders -# TODO use other datasets than lizard -def get_dataloaders(patch_shape, data_path): - label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs - sampler = MinInstanceSampler(min_num_instances=5) - dataset = get_lizard_dataset( - path=data_path, download=True, patch_shape=patch_shape, label_transform=label_transform, - sampler=sampler, - ) - train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1]) - train_loader = torch_em.get_data_loader(train_ds, batch_size=1) - val_loader = torch_em.get_data_loader(val_ds, batch_size=1) - return train_loader, val_loader - -def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2e4), save_root=None): - """Example code for finetuning SAM on LiveCELL""" +def finetune_hp_generalist(args): + """Example code for finetuning SAM on multiple histopathology datasets""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" # training settings: + model_type = args.model_type checkpoint_path = None # override this to start training from a custom checkpoint - device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu patch_shape = (512, 512) # the patch shape for training - n_objects_per_batch = 50 # this is the number of objects per batch that will be sampled - - train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path) + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = None # override this to freeze one or more of these backbones # get the trainable segment anything model - model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device) + model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, freeze_parts) # all the stuff we need for training optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) + train_loader, val_loader = get_generalist_hp_loaders(patch_shape=patch_shape, data_path=args.input_path) # this class creates all the training data for a batch (inputs, prompts and labels) convert_inputs = sam_training.ConvertToSamInputs() - checkpoint_name = "sam-histopatho-v1" + checkpoint_name = f"generalist-hp-sam-{args.model_type}" # the trainer which performs training and validation (implemented using "torch_em") trainer = sam_training.SamTrainer( name=checkpoint_name, - save_root=save_root, + save_root=args.save_root, train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, # currently we compute loss batch-wise, else we pass channelwise True - loss=torch_em.loss.DiceLoss(channelwise=False), - metric=torch_em.loss.DiceLoss(), + loss=DiceLoss(channelwise=False), + metric=DiceLoss(), device=device, lr_scheduler=scheduler, logger=sam_training.SamLogger, @@ -67,22 +55,42 @@ def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations= n_sub_iteration=8, compile_model=False ) - trainer.fit(iterations) - if export_path is not None: + trainer.fit(iterations=args.iterations) + if args.export_path is not None: checkpoint_path = os.path.join( - "" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt" + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" ) export_custom_sam_model( checkpoint_path=checkpoint_path, model_type=model_type, - save_path=export_path, + save_path=args.export_path, ) def main(): - input_path = "/scratch-grete/projects/nim00007/data/lizard" - export_path = "./sam-vith-histopatho-v1.pth" - finetune_histopatho(input_path, export_path) + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.") + parser.add_argument( + "--input_path", "-i", default="/scratch/usr/nimanwai/data/", + help="The filepath to all the respective hp datasets. If the data does not exist yet it will be downloaded" + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e5), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + args = parser.parse_args() + finetune_hp_generalist(args) if __name__ == "__main__": diff --git a/finetuning/generalists/training/histopathology/train_model.sbatch b/finetuning/generalists/training/histopathology/train_model.sbatch index 3a588465..7834d1a2 100755 --- a/finetuning/generalists/training/histopathology/train_model.sbatch +++ b/finetuning/generalists/training/histopathology/train_model.sbatch @@ -1,11 +1,14 @@ #! /bin/bash #SBATCH -c 16 #SBATCH --mem 128G -#SBATCH -t 2800 +#SBATCH -t 7-00:00:00 #SBATCH -p grete:shared #SBATCH -G A100:1 #SBATCH -A nim00007 #SBATCH --constraint=80gb +#SBATCH --qos=7d +#SBATCH --job-name=sam_histopathology -source activate sam +source ~/.bashrc +mamba activate sam python train_histopathology_generalist.py $@ diff --git a/micro_sam/evaluation/livecell.py b/micro_sam/evaluation/livecell.py index ed116519..58a95e07 100644 --- a/micro_sam/evaluation/livecell.py +++ b/micro_sam/evaluation/livecell.py @@ -226,7 +226,7 @@ def run_livecell_inference() -> None: help="Pass the checkpoint-specific model name being used for inference.") # the experiment type: - # - default settings (p1-n0, p2-n4, box) + # - default settings (p1-n0, p2-n4, p4-n8, box) # - full experiment (ranges: p:1-16, n:0-16) # - automatic mask generation (auto) # if none of the two are active then the prompt setting arguments will be parsed