From 5dbcd7561e4f09394e9389bf9097015fb93bb749 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 1 Jan 2024 19:02:58 +0100 Subject: [PATCH] Simplify embedding precomputation in sam training --- micro_sam/training/sam_trainer.py | 29 +++++++---------------------- micro_sam/training/trainable_sam.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 7b6471430..05ae2493b 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -195,25 +195,17 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y): # Update Masks Iteratively while Training # def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_subiter, multimask_output): - # Precompute the image embeddings only once. - input_images, input_size = self.model.preprocess( - torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device) - ) - image_embeddings = self.model.image_embeddings_oft(input_images) - - # Update the input size for each input in the batch. - for i in range(len(batched_inputs)): - batched_inputs[i]["input_size"] = input_size + image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 - # this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch + # This loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch. for i in range(0, num_subiter): - # we do multimasking only in the first sub-iteration as we then pass single prompt - # after the first sub-iteration, we don't do multimasking because we get multiple prompts + # We do multimasking only in the first sub-iteration as we then pass single prompt + # after the first sub-iteration, we don't do multimasking because we get multiple prompts. batched_outputs = self.model(batched_inputs, - multimask_output=multimask_output if i == 0 else False, - image_embeddings=image_embeddings) + image_embeddings=image_embeddings, + multimask_output=multimask_output if i == 0 else False) # we want to average the loss and then backprop over the net sub-iterations net_loss, net_mask_loss, net_iou_regression_loss, net_mean_model_iou = self._get_net_loss(batched_outputs, @@ -376,14 +368,7 @@ def _interactive_val_iteration(self, x, y, val_iteration): n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples) - input_images, input_size = self.model.preprocess( - torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device) - ) - image_embeddings = self.model.image_embeddings_oft(input_images) - - # Update the input size for each input in the batch. - for i in range(len(batched_inputs)): - batched_inputs[i]["input_size"] = input_size + image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) batched_outputs = self.model( batched_inputs, diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 6b773244e..e5c977c09 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -51,8 +51,15 @@ def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: x = F.pad(x, (0, padw, 0, padh)) return x, input_size - def image_embeddings_oft(self, input_images): - """@private""" + def image_embeddings_oft(self, batched_inputs): + # Compute the input images. + input_images, input_size = self.model.preprocess( + torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device) + ) + # Update the input size for each input in the batch. + for i in range(len(batched_inputs)): + batched_inputs[i]["input_size"] = input_size + # Compute the image embeddings. image_embeddings = self.sam.image_encoder(input_images) return image_embeddings