Skip to content

Commit

Permalink
Restore trainable sam
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 25, 2024
1 parent d7e42c2 commit 4cf4420
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@ class TrainableSAM(nn.Module):
Args:
sam: The SegmentAnything Model.
device: The device for training.
upsampled_masks: Whether to return the output masks in the original input shape.
"""
def __init__(
self,
sam: Sam,
device: Union[str, torch.device],
upsampled_masks: bool = True,
) -> None:
super().__init__()
self.sam = sam
self.device = device
self.upsampled_masks = upsampled_masks
self.transform = ResizeLongestSide(sam.image_encoder.img_size)

def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
Expand Down Expand Up @@ -114,16 +111,18 @@ def forward(
multimask_output=multimask_output,
)

curr_outputs = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}

if self.upsampled_masks:
masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)
curr_outputs["masks"] = masks
masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)

outputs.append(curr_outputs)
outputs.append(
{
"low_res_masks": low_res_masks,
"masks": masks,
"iou_predictions": iou_predictions
}
)

return outputs

0 comments on commit 4cf4420

Please sign in to comment.