Skip to content

Commit

Permalink
Merge branch 'dev' into new-instance-seg
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jan 1, 2024
2 parents 908e765 + d6f1432 commit 8de882a
Showing 1 changed file with 54 additions and 18 deletions.
72 changes: 54 additions & 18 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import argparse
import os
import argparse

import micro_sam.training as sam_training
import torch
import torch_em

import torch_em
from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.data.datasets import get_livecell_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model


Expand All @@ -20,14 +24,20 @@ def get_dataloaders(patch_shape, data_path, cell_type=None):
I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.
Important: the ID 0 is reseved for background, and the IDs must be consecutive
"""
label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
train_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="train", batch_size=2,
num_workers=16, cell_types=cell_type, download=True, shuffle=True,
label_transform=label_transform, raw_transform=raw_transform)
val_loader = get_livecell_loader(path=data_path, patch_shape=patch_shape, split="val", batch_size=1,
num_workers=16, cell_types=cell_type, download=True, shuffle=True,
label_transform=label_transform, raw_transform=raw_transform)
train_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32
)
val_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="val", batch_size=1, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32
)

return train_loader, val_loader

Expand All @@ -51,18 +61,37 @@ def finetune_livecell(args):
checkpoint_path=checkpoint_path,
freeze=freeze_parts
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False
)
unetr.to(device)

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = [params for params in model.parameters()] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

# all the stuff we need for training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs()

checkpoint_name = "livecell_sam"
# the trainer which performs training and validation (implemented using "torch_em")
trainer = sam_training.SamTrainer(
checkpoint_name = f"{args.model_type}/livecell_sam"

# the trainer which performs the joint training and validation (implemented using "torch_em")
trainer = sam_training.JointSamTrainer(
name=checkpoint_name,
save_root=args.save_root,
train_loader=train_loader,
Expand All @@ -74,16 +103,19 @@ def finetune_livecell(args):
metric=torch_em.loss.DiceLoss(),
device=device,
lr_scheduler=scheduler,
logger=sam_training.SamLogger,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5 # (optional) overwrite to provide the probability of using mask inputs while training
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True)
)
trainer.fit(args.iterations)
trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
Expand All @@ -98,7 +130,7 @@ def finetune_livecell(args):
def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/data/LiveCELL/",
"--input_path", "-i", default="/scratch/usr/nimanwai/data/livecell/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
Expand All @@ -121,6 +153,10 @@ def main():
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
parser.add_argument(
"--save_every_kth_epoch", type=int, default=None,
help="To save every kth epoch while fine-tuning. Expects an integer value."
)
args = parser.parse_args()
finetune_livecell(args)

Expand Down

0 comments on commit 8de882a

Please sign in to comment.