Skip to content

Commit

Permalink
Simplify embedding precomputation in sam training
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 1, 2024
1 parent de3c5bc commit 5dbcd75
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 24 deletions.
29 changes: 7 additions & 22 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,25 +195,17 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y):
# Update Masks Iteratively while Training
#
def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_subiter, multimask_output):
# Precompute the image embeddings only once.
input_images, input_size = self.model.preprocess(
torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device)
)
image_embeddings = self.model.image_embeddings_oft(input_images)

# Update the input size for each input in the batch.
for i in range(len(batched_inputs)):
batched_inputs[i]["input_size"] = input_size
image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)

loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0

# this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch
# This loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch.
for i in range(0, num_subiter):
# we do multimasking only in the first sub-iteration as we then pass single prompt
# after the first sub-iteration, we don't do multimasking because we get multiple prompts
# We do multimasking only in the first sub-iteration as we then pass single prompt
# after the first sub-iteration, we don't do multimasking because we get multiple prompts.
batched_outputs = self.model(batched_inputs,
multimask_output=multimask_output if i == 0 else False,
image_embeddings=image_embeddings)
image_embeddings=image_embeddings,
multimask_output=multimask_output if i == 0 else False)

# we want to average the loss and then backprop over the net sub-iterations
net_loss, net_mask_loss, net_iou_regression_loss, net_mean_model_iou = self._get_net_loss(batched_outputs,
Expand Down Expand Up @@ -376,14 +368,7 @@ def _interactive_val_iteration(self, x, y, val_iteration):
n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, n_samples)

input_images, input_size = self.model.preprocess(
torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device)
)
image_embeddings = self.model.image_embeddings_oft(input_images)

# Update the input size for each input in the batch.
for i in range(len(batched_inputs)):
batched_inputs[i]["input_size"] = input_size
image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)

batched_outputs = self.model(
batched_inputs,
Expand Down
11 changes: 9 additions & 2 deletions micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,15 @@ def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
x = F.pad(x, (0, padw, 0, padh))
return x, input_size

def image_embeddings_oft(self, input_images):
"""@private"""
def image_embeddings_oft(self, batched_inputs):
# Compute the input images.
input_images, input_size = self.model.preprocess(
torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device)
)
# Update the input size for each input in the batch.
for i in range(len(batched_inputs)):
batched_inputs[i]["input_size"] = input_size
# Compute the image embeddings.
image_embeddings = self.sam.image_encoder(input_images)
return image_embeddings

Expand Down

0 comments on commit 5dbcd75

Please sign in to comment.