From c236df3edbc4f4766b5e4e358dd09a81d2e8b973 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 15 Nov 2023 20:02:27 +0100 Subject: [PATCH] Experiments with Mask Inputs Probability (#265) Use mask prompts stochastic in finetuning --- .../evaluation/iterative_prompting.py | 9 +++++++-- finetuning/livecell_finetuning.py | 7 ++++--- micro_sam/training/sam_trainer.py | 20 +++++++++++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/finetuning/livecell/evaluation/iterative_prompting.py b/finetuning/livecell/evaluation/iterative_prompting.py index d188240e..d156afd1 100644 --- a/finetuning/livecell/evaluation/iterative_prompting.py +++ b/finetuning/livecell/evaluation/iterative_prompting.py @@ -5,7 +5,7 @@ from micro_sam.evaluation import inference from micro_sam.evaluation.evaluation import run_evaluation -from util import get_paths, get_checkpoint +from util import get_paths, get_checkpoint, MODELS LIVECELL_GT_ROOT = "/scratch/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images" PREDICTION_ROOT = "/scratch/projects/nim00007/sam/iterative_evaluation" @@ -78,7 +78,11 @@ def main(args): prediction_root = get_prediction_root(start_with_box_prompt, model_description) # get the model checkpoints and desired model name to initialize the predictor - checkpoint, model_type = get_checkpoint(model_description) + if args.checkpoint is None and model_description in MODELS.keys(): + checkpoint, model_type = get_checkpoint(model_description) + else: + checkpoint = args.checkpoint + model_type = model_description[:5] # get the predictor to perform inference predictor = inference.get_predictor(checkpoint, model_type) @@ -94,5 +98,6 @@ def main(args): "-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist" help="Provide the model type to initialize the predictor" ) + parser.add_argument("-c", "--checkpoint", type=str, default=None) args = parser.parse_args() main(args) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 9c533391..686a6fae 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -42,7 +42,7 @@ def finetune_livecell(args): n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled # get the trainable segment anything model - model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device) + model = sam_training.get_trainable_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) # all the stuff we need for training optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) @@ -72,7 +72,8 @@ def finetune_livecell(args): convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=8, - compile_model=False + compile_model=False, + mask_prob=0.5 # (optional) overwrite to provide the probability of using mask inputs while training ) trainer.fit(args.iterations) if args.export_path is not None: @@ -89,7 +90,7 @@ def finetune_livecell(args): def main(): parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") parser.add_argument( - "--input_path", "-i", default="", + "--input_path", "-i", default="/scratch/projects/nim00007/data/LiveCELL/", help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." ) parser.add_argument( diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 85d8f02c..1c0a290c 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -1,5 +1,6 @@ import os import time +import random from typing import Optional import numpy as np @@ -9,7 +10,7 @@ from torchvision.utils import make_grid from torch_em.trainer.logger_base import TorchEmLogger -from ..prompt_generators import IterativePromptGenerator +from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator class SamTrainer(torch_em.trainer.DefaultTrainer): @@ -20,7 +21,7 @@ class SamTrainer(torch_em.trainer.DefaultTrainer): for details on its usage and implementation. Args: - convert_inputs: Class that converts the output of the dataloader to the expected input format of SAM. + convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class `micro_sam.training.util.ConvertToSamInputs` can be used here. n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong. @@ -28,6 +29,8 @@ class SamTrainer(torch_em.trainer.DefaultTrainer): Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. sigmoid: The activation function for normalizing the model output. + prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training + mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) **kwargs: The keyword arguments of the DefaultTrainer super class. """ @@ -38,7 +41,8 @@ def __init__( n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.Module = torch.nn.MSELoss(), _sigmoid: torch.nn.Module = torch.nn.Sigmoid(), - prompt_generator=IterativePromptGenerator(), + prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), + mask_prob: float = 0.5, **kwargs ): super().__init__(**kwargs) @@ -48,6 +52,7 @@ def __init__( self.n_objects_per_batch = n_objects_per_batch self.n_sub_iteration = n_sub_iteration self.prompt_generator = prompt_generator + self.mask_prob = mask_prob self._kwargs = kwargs def _get_prompt_and_multimasking_choices(self, current_iteration): @@ -250,7 +255,14 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc _inp["point_coords"] = updated_point_coords _inp["point_labels"] = updated_point_labels - _inp["mask_inputs"] = logits + + if self.mask_prob > 0: + # using mask inputs for iterative prompting while training, with a probability + use_mask_inputs = (random.random() < self.mask_prob) + if use_mask_inputs: + _inp["mask_inputs"] = logits + else: # remove previously existing mask inputs to avoid using them in next sub-iteration + _inp.pop("mask_inputs", None) # # Training Loop