diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 08ab8c39..135ea7de 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -4,7 +4,6 @@ from collections import OrderedDict import torch -from torchvision.utils import make_grid from .sam_trainer import SamTrainer @@ -85,8 +84,8 @@ def _train_epoch_impl(self, progress, forward_context, backprop): with forward_context(): # 1. train for the interactive segmentation - (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) + (loss, mask_loss, iou_regression_loss, + model_iou) = self._interactive_train_iteration(x, labels_instances) backprop(loss) @@ -100,10 +99,9 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None self.logger.log_train( - self._iteration, loss, lr, x, labels_instances, samples, - mask_loss, iou_regression_loss, model_iou, unetr_loss + self._iteration, loss, lr, x, labels_instances, mask_loss, + iou_regression_loss, model_iou, unetr_loss ) self._iteration += 1 @@ -133,7 +131,7 @@ def _validate_impl(self, forward_context): with forward_context(): # 1. validate for the interactive segmentation (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) + metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) with forward_context(): # 2. validate for the automatic instance segmentation @@ -150,7 +148,7 @@ def _validate_impl(self, forward_context): if self.logger is not None: self.logger.log_validation( - self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, + self._iteration, metric_val, loss_val, x, labels_instances, mask_loss, iou_regression_loss, model_iou_val, unetr_loss ) @@ -161,25 +159,22 @@ class JointSamLogger(TorchEmLogger): """@private""" def __init__(self, trainer, save_root, **unused_kwargs): super().__init__(trainer, save_root) - self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ - os.path.join(save_root, "logs", trainer.name) + self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name) os.makedirs(self.log_dir, exist_ok=True) self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) self.log_image_interval = trainer.log_image_interval - def add_image(self, x, y, samples, name, step): + def add_image(self, x, y, name, step): selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] image = normalize_im(x[selection].cpu()) self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step) - sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) - self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) def log_train( - self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss ): self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) @@ -188,10 +183,10 @@ def log_train( self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) if step % self.log_image_interval == 0: - self.add_image(x, y, samples, "train", step) + self.add_image(x, y, "train", step) def log_validation( - self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss + self, step, metric, loss, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss ): self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) @@ -199,4 +194,4 @@ def log_validation( self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) - self.add_image(x, y, samples, "validation", step) + self.add_image(x, y, "validation", step) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index c251e749..594cdab0 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -5,10 +5,11 @@ from typing import Optional import numpy as np + import torch -import torch_em +from torch.nn import functional as F -from torchvision.utils import make_grid +import torch_em from torch_em.trainer.logger_base import TorchEmLogger from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator @@ -117,6 +118,24 @@ def _compute_iou(self, pred, true, eps=1e-7): iou = overlap / (union + eps) return iou + def preprocess_one_hot_masks(self, y_one_hot): + """Converts the masks to match the "low_res_masks" shape. + """ + # Convert the labels to "low_res_mask" shape + # First step is to use the logic from `ResizeLongestSide` to resize the longest side. + target_length = self.model.transform.target_length + target_shape = self.model.transform.get_preprocess_shape(y_one_hot.shape[2], y_one_hot.shape[3], target_length) + y_one_hot = F.interpolate(input=y_one_hot, size=target_shape) + # Next, we pad the remaining region to (1024, 1024) + h, w = y_one_hot.shape[-2:] + padh = self.model.sam.image_encoder.img_size - h + padw = self.model.sam.image_encoder.img_size - w + y_one_hot = F.pad(input=y_one_hot, pad=(0, padw, 0, padh)) + # Finally, let's resize the labels to the desired shape (i.e. (256, 256)) + y_one_hot = F.interpolate(input=y_one_hot, size=(256, 256)) + + return y_one_hot + def _compute_loss(self, batched_outputs, y_one_hot): """Compute the loss for one iteration. The loss is made up of two components: - The mask loss: dice score between the predicted masks and targets. @@ -126,8 +145,10 @@ def _compute_loss(self, batched_outputs, y_one_hot): # Loop over the batch. for batch_output, targets in zip(batched_outputs, y_one_hot): + # Let's convert the inputs to the match the expected "low_res_masks" shape. + targets = self.preprocess_one_hot_masks(targets) - predicted_objects = torch.sigmoid(batch_output["masks"]) + predicted_objects = torch.sigmoid(batch_output["low_res_masks"]) # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). # We swap the axes that go into the dice loss so that the object axis # corresponds to the channel axes. This ensures that the dice is computed @@ -160,10 +181,15 @@ def _compute_loss(self, batched_outputs, y_one_hot): # Functionality for iterative prompting loss # - def _get_best_masks(self, batched_outputs, batched_iou_predictions): + def _get_best_masks(self, batched_outputs, batched_iou_predictions, input_size, original_size): # Batched mask and logit (low-res mask) predictions. - masks = torch.stack([m["masks"] for m in batched_outputs]) logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) + # masks = torch.stack([m["masks"] for m in batched_outputs]) + + masks = torch.stack([ + self.model.sam.postprocess_masks(log, input_size=input_size, original_size=original_size) + for log in logits + ]) # Determine the best IOU across the multi-object prediction axis # and turn this into a mask we can use for indexing. @@ -220,7 +246,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim with torch.no_grad(): # Get the mask and logit predictions corresponding to the predicted object # (per actual object) with the best IOU. - masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) + masks, logits = self._get_best_masks( + batched_outputs, batched_iou_predictions, + input_size=batched_inputs[0]["input_size"], original_size=batched_inputs[0]["original_size"] + ) batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) loss = loss / num_subiter @@ -300,7 +329,7 @@ def _interactive_train_iteration(self, x, y): batched_inputs, y_one_hot, num_subiter=self.n_sub_iteration, multimask_output=multimask_output ) - return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot + return loss, mask_loss, iou_regression_loss, model_iou def _check_input_normalization(self, x, input_check_done): # The expected data range of the SAM model is 8bit (0-255). @@ -329,22 +358,29 @@ def _train_epoch_impl(self, progress, forward_context, backprop): n_iter = 0 t_per_iter = time.time() + + import json + bench_iters = 5 + tbs = [] + tfs = [] + for x, y in self.train_loader: input_check_done = self._check_input_normalization(x, input_check_done) self.optimizer.zero_grad() + tf = time.time() with forward_context(): - (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y) = self._interactive_train_iteration(x, y) + (loss, mask_loss, iou_regression_loss, model_iou) = self._interactive_train_iteration(x, y) + tf = time.time() - tf + tb = time.time() backprop(loss) + tb = time.time() - tb if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None - self.logger.log_train(self._iteration, loss, lr, x, y, samples, - mask_loss, iou_regression_loss, model_iou) + self.logger.log_train(self._iteration, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou) self._iteration += 1 n_iter += 1 @@ -352,6 +388,16 @@ def _train_epoch_impl(self, progress, forward_context, backprop): break progress.update(1) + tfs.append(tf) + tbs.append(tb) + + if n_iter >= bench_iters: + with open("training_times.json", "w") as f: + json.dump({ + "forward_pass": tfs, "backward_pass": tbs + }, f) + quit() + t_per_iter = (time.time() - t_per_iter) / n_iter return t_per_iter @@ -374,7 +420,7 @@ def _interactive_val_iteration(self, x, y, val_iteration): metric = mask_loss model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) - return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric + return loss, mask_loss, iou_regression_loss, model_iou, metric def _validate_impl(self, forward_context): self.model.eval() @@ -389,8 +435,8 @@ def _validate_impl(self, forward_context): input_check_done = self._check_input_normalization(x, input_check_done) with forward_context(): - (loss, mask_loss, iou_regression_loss, model_iou, - sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) + (loss, mask_loss, iou_regression_loss, + model_iou, metric) = self._interactive_val_iteration(x, y, val_iteration) loss_val += loss.item() metric_val += metric.item() @@ -405,8 +451,7 @@ def _validate_impl(self, forward_context): if self.logger is not None: self.logger.log_validation( - self._iteration, metric_val, loss_val, x, y, - sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val + self._iteration, metric_val, loss_val, x, y, mask_loss, iou_regression_loss, model_iou_val ) return metric_val @@ -423,25 +468,23 @@ def __init__(self, trainer, save_root, **unused_kwargs): self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) self.log_image_interval = trainer.log_image_interval - def add_image(self, x, y, samples, name, step): + def add_image(self, x, y, name, step): self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) - sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) - self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) - def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou): + def log_train(self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou): self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) if step % self.log_image_interval == 0: - self.add_image(x, y, samples, "train", step) + self.add_image(x, y, "train", step) - def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou): + def log_validation(self, step, metric, loss, x, y, mask_loss, iou_regression_loss, model_iou): self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) - self.add_image(x, y, samples, "validation", step) + self.add_image(x, y, "validation", step) diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 81c7d8ca..42b12025 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -20,11 +20,13 @@ def __init__( self, sam: Sam, device: Union[str, torch.device], + upsample_masks: bool = False, ) -> None: super().__init__() self.sam = sam self.device = device self.transform = ResizeLongestSide(sam.image_encoder.img_size) + self.upsample_masks = upsample_masks def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: """Resize, normalize pixel values and pad to a square input. @@ -111,18 +113,14 @@ def forward( multimask_output=multimask_output, ) - masks = self.sam.postprocess_masks( - low_res_masks, - input_size=image_record["input_size"], - original_size=image_record["original_size"], - ) - - outputs.append( - { - "low_res_masks": low_res_masks, - "masks": masks, - "iou_predictions": iou_predictions - } - ) + this_output = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions} + if self.upsample_masks: + masks = self.sam.postprocess_masks( + low_res_masks, + input_size=image_record["input_size"], + original_size=image_record["original_size"], + ) + this_output["masks"] = masks + outputs.append(this_output) return outputs diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 9e2bbafb..a0893622 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -147,6 +147,7 @@ def train_sam( scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[QObject] = None, + **kwargs, ) -> None: """Run training for a SAM model. @@ -260,7 +261,7 @@ def train_sam( lr_scheduler=scheduler, logger=trainers.SamLogger, log_image_interval=100, - mixed_precision=True, + mixed_precision=False, convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=n_sub_iteration,