Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update trainers to compute loss over low_res_mask #669

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions micro_sam/training/joint_sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections import OrderedDict

import torch
from torchvision.utils import make_grid

from .sam_trainer import SamTrainer

Expand Down Expand Up @@ -85,8 +84,8 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

with forward_context():
# 1. train for the interactive segmentation
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
(loss, mask_loss, iou_regression_loss,
model_iou) = self._interactive_train_iteration(x, labels_instances)

backprop(loss)

Expand All @@ -100,10 +99,9 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
self.logger.log_train(
self._iteration, loss, lr, x, labels_instances, samples,
mask_loss, iou_regression_loss, model_iou, unetr_loss
self._iteration, loss, lr, x, labels_instances, mask_loss,
iou_regression_loss, model_iou, unetr_loss
)

self._iteration += 1
Expand Down Expand Up @@ -133,7 +131,7 @@ def _validate_impl(self, forward_context):
with forward_context():
# 1. validate for the interactive segmentation
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)
metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)

with forward_context():
# 2. validate for the automatic instance segmentation
Expand All @@ -150,7 +148,7 @@ def _validate_impl(self, forward_context):

if self.logger is not None:
self.logger.log_validation(
self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y,
self._iteration, metric_val, loss_val, x, labels_instances,
mask_loss, iou_regression_loss, model_iou_val, unetr_loss
)

Expand All @@ -161,25 +159,22 @@ class JointSamLogger(TorchEmLogger):
"""@private"""
def __init__(self, trainer, save_root, **unused_kwargs):
super().__init__(trainer, save_root)
self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
os.path.join(save_root, "logs", trainer.name)
self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name)
os.makedirs(self.log_dir, exist_ok=True)

self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
self.log_image_interval = trainer.log_image_interval

def add_image(self, x, y, samples, name, step):
def add_image(self, x, y, name, step):
selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]

image = normalize_im(x[selection].cpu())

self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step)
self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step)
sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)

def log_train(
self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss
):
self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
Expand All @@ -188,15 +183,15 @@ def log_train(
self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
if step % self.log_image_interval == 0:
self.add_image(x, y, samples, "train", step)
self.add_image(x, y, "train", step)

def log_validation(
self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
self, step, metric, loss, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss
):
self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
self.add_image(x, y, samples, "validation", step)
self.add_image(x, y, "validation", step)
91 changes: 67 additions & 24 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Optional

import numpy as np

import torch
import torch_em
from torch.nn import functional as F

from torchvision.utils import make_grid
import torch_em
from torch_em.trainer.logger_base import TorchEmLogger

from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator
Expand Down Expand Up @@ -117,6 +118,24 @@ def _compute_iou(self, pred, true, eps=1e-7):
iou = overlap / (union + eps)
return iou

def preprocess_one_hot_masks(self, y_one_hot):
"""Converts the masks to match the "low_res_masks" shape.
"""
# Convert the labels to "low_res_mask" shape
# First step is to use the logic from `ResizeLongestSide` to resize the longest side.
target_length = self.model.transform.target_length
target_shape = self.model.transform.get_preprocess_shape(y_one_hot.shape[2], y_one_hot.shape[3], target_length)
y_one_hot = F.interpolate(input=y_one_hot, size=target_shape)
# Next, we pad the remaining region to (1024, 1024)
h, w = y_one_hot.shape[-2:]
padh = self.model.sam.image_encoder.img_size - h
padw = self.model.sam.image_encoder.img_size - w
y_one_hot = F.pad(input=y_one_hot, pad=(0, padw, 0, padh))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not just ResizeLongsetSide until here? Can't we just use it instead of duplicating the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because of one parameter that I have to remplement this: (they set antialias to True, which for labels here, we need to set to False.) https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/transforms.py#L63-L65

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also to add: with ResizeLongestSide, it's still a two stage agenda: (taking LIVECell (images of shape (520, 704) as an example)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense!

# Finally, let's resize the labels to the desired shape (i.e. (256, 256))
y_one_hot = F.interpolate(input=y_one_hot, size=(256, 256))

return y_one_hot

def _compute_loss(self, batched_outputs, y_one_hot):
"""Compute the loss for one iteration. The loss is made up of two components:
- The mask loss: dice score between the predicted masks and targets.
Expand All @@ -126,8 +145,10 @@ def _compute_loss(self, batched_outputs, y_one_hot):

# Loop over the batch.
for batch_output, targets in zip(batched_outputs, y_one_hot):
# Let's convert the inputs to the match the expected "low_res_masks" shape.
targets = self.preprocess_one_hot_masks(targets)

predicted_objects = torch.sigmoid(batch_output["masks"])
predicted_objects = torch.sigmoid(batch_output["low_res_masks"])
# Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
# We swap the axes that go into the dice loss so that the object axis
# corresponds to the channel axes. This ensures that the dice is computed
Expand Down Expand Up @@ -160,10 +181,15 @@ def _compute_loss(self, batched_outputs, y_one_hot):
# Functionality for iterative prompting loss
#

def _get_best_masks(self, batched_outputs, batched_iou_predictions):
def _get_best_masks(self, batched_outputs, batched_iou_predictions, input_size, original_size):
# Batched mask and logit (low-res mask) predictions.
masks = torch.stack([m["masks"] for m in batched_outputs])
logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
# masks = torch.stack([m["masks"] for m in batched_outputs])

masks = torch.stack([
self.model.sam.postprocess_masks(log, input_size=input_size, original_size=original_size)
for log in logits
])

# Determine the best IOU across the multi-object prediction axis
# and turn this into a mask we can use for indexing.
Expand Down Expand Up @@ -220,7 +246,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim
with torch.no_grad():
# Get the mask and logit predictions corresponding to the predicted object
# (per actual object) with the best IOU.
masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
masks, logits = self._get_best_masks(
batched_outputs, batched_iou_predictions,
input_size=batched_inputs[0]["input_size"], original_size=batched_inputs[0]["original_size"]
)
batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)

loss = loss / num_subiter
Expand Down Expand Up @@ -300,7 +329,7 @@ def _interactive_train_iteration(self, x, y):
batched_inputs, y_one_hot,
num_subiter=self.n_sub_iteration, multimask_output=multimask_output
)
return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
return loss, mask_loss, iou_regression_loss, model_iou

def _check_input_normalization(self, x, input_check_done):
# The expected data range of the SAM model is 8bit (0-255).
Expand Down Expand Up @@ -329,29 +358,46 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

n_iter = 0
t_per_iter = time.time()

import json
bench_iters = 5
tbs = []
tfs = []

for x, y in self.train_loader:
input_check_done = self._check_input_normalization(x, input_check_done)

self.optimizer.zero_grad()

tf = time.time()
with forward_context():
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y) = self._interactive_train_iteration(x, y)
(loss, mask_loss, iou_regression_loss, model_iou) = self._interactive_train_iteration(x, y)
tf = time.time() - tf

tb = time.time()
backprop(loss)
tb = time.time() - tb

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
self.logger.log_train(self._iteration, loss, lr, x, y, samples,
mask_loss, iou_regression_loss, model_iou)
self.logger.log_train(self._iteration, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou)

self._iteration += 1
n_iter += 1
if self._iteration >= self.max_iteration:
break
progress.update(1)

tfs.append(tf)
tbs.append(tb)

if n_iter >= bench_iters:
with open("training_times.json", "w") as f:
json.dump({
"forward_pass": tfs, "backward_pass": tbs
}, f)
quit()

t_per_iter = (time.time() - t_per_iter) / n_iter
return t_per_iter

Expand All @@ -374,7 +420,7 @@ def _interactive_val_iteration(self, x, y, val_iteration):
metric = mask_loss
model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))

return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
return loss, mask_loss, iou_regression_loss, model_iou, metric

def _validate_impl(self, forward_context):
self.model.eval()
Expand All @@ -389,8 +435,8 @@ def _validate_impl(self, forward_context):
input_check_done = self._check_input_normalization(x, input_check_done)

with forward_context():
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
(loss, mask_loss, iou_regression_loss,
model_iou, metric) = self._interactive_val_iteration(x, y, val_iteration)

loss_val += loss.item()
metric_val += metric.item()
Expand All @@ -405,8 +451,7 @@ def _validate_impl(self, forward_context):

if self.logger is not None:
self.logger.log_validation(
self._iteration, metric_val, loss_val, x, y,
sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
self._iteration, metric_val, loss_val, x, y, mask_loss, iou_regression_loss, model_iou_val
)

return metric_val
Expand All @@ -423,25 +468,23 @@ def __init__(self, trainer, save_root, **unused_kwargs):
self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
self.log_image_interval = trainer.log_image_interval

def add_image(self, x, y, samples, name, step):
def add_image(self, x, y, name, step):
self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step)
self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step)
sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)

def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou):
def log_train(self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou):
self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
if step % self.log_image_interval == 0:
self.add_image(x, y, samples, "train", step)
self.add_image(x, y, "train", step)

def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou):
def log_validation(self, step, metric, loss, x, y, mask_loss, iou_regression_loss, model_iou):
self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
self.add_image(x, y, samples, "validation", step)
self.add_image(x, y, "validation", step)
24 changes: 11 additions & 13 deletions micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def __init__(
self,
sam: Sam,
device: Union[str, torch.device],
upsample_masks: bool = False,
) -> None:
super().__init__()
self.sam = sam
self.device = device
self.transform = ResizeLongestSide(sam.image_encoder.img_size)
self.upsample_masks = upsample_masks

def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Resize, normalize pixel values and pad to a square input.
Expand Down Expand Up @@ -111,18 +113,14 @@ def forward(
multimask_output=multimask_output,
)

masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)

outputs.append(
{
"low_res_masks": low_res_masks,
"masks": masks,
"iou_predictions": iou_predictions
}
)
this_output = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
if self.upsample_masks:
masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)
this_output["masks"] = masks
outputs.append(this_output)

return outputs
3 changes: 2 additions & 1 deletion micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def train_sam(
scheduler_kwargs: Optional[Dict[str, Any]] = None,
save_every_kth_epoch: Optional[int] = None,
pbar_signals: Optional[QObject] = None,
**kwargs,
) -> None:
"""Run training for a SAM model.

Expand Down Expand Up @@ -260,7 +261,7 @@ def train_sam(
lr_scheduler=scheduler,
logger=trainers.SamLogger,
log_image_interval=100,
mixed_precision=True,
mixed_precision=False,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=n_sub_iteration,
Expand Down
Loading