diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 21409237..81c7d8ca 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -15,18 +15,15 @@ class TrainableSAM(nn.Module): Args: sam: The SegmentAnything Model. device: The device for training. - upsampled_masks: Whether to return the output masks in the original input shape. """ def __init__( self, sam: Sam, device: Union[str, torch.device], - upsampled_masks: bool = True, ) -> None: super().__init__() self.sam = sam self.device = device - self.upsampled_masks = upsampled_masks self.transform = ResizeLongestSide(sam.image_encoder.img_size) def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -114,16 +111,18 @@ def forward( multimask_output=multimask_output, ) - curr_outputs = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions} - - if self.upsampled_masks: - masks = self.sam.postprocess_masks( - low_res_masks, - input_size=image_record["input_size"], - original_size=image_record["original_size"], - ) - curr_outputs["masks"] = masks + masks = self.sam.postprocess_masks( + low_res_masks, + input_size=image_record["input_size"], + original_size=image_record["original_size"], + ) - outputs.append(curr_outputs) + outputs.append( + { + "low_res_masks": low_res_masks, + "masks": masks, + "iou_predictions": iou_predictions + } + ) return outputs