Skip to content

Commit

Permalink
Minor fix to preprocessing labels
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 25, 2024
1 parent fa4017e commit d7e42c2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _compute_iou(self, pred, true, eps=1e-7):
return iou

def preprocess_one_hot_masks(self, y_one_hot):
"""
"""Converts the masks to match the "low_res_masks" shape.
"""
# Convert the labels to "low_res_mask" shape
# First step is to use the logic from `ResizeLongestSide` to resize the longest side.
Expand All @@ -143,11 +143,11 @@ def _compute_loss(self, batched_outputs, y_one_hot):
"""
mask_loss, iou_regression_loss = 0.0, 0.0

# TODO
y_one_hot = self.preprocess_one_hot_masks(y_one_hot)

# Loop over the batch.
for batch_output, targets in zip(batched_outputs, y_one_hot):
# Let's convert the inputs to the match the expected "low_res_masks" shape.
targets = self.preprocess_one_hot_masks(targets)

predicted_objects = torch.sigmoid(batch_output["low_res_masks"])
# Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
# We swap the axes that go into the dice loss so that the object axis
Expand Down

0 comments on commit d7e42c2

Please sign in to comment.