Skip to content

Commit

Permalink
Debug training WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jul 29, 2024
1 parent 4cf4420 commit fb9219a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
34 changes: 31 additions & 3 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,15 @@ def _compute_loss(self, batched_outputs, y_one_hot):
# Functionality for iterative prompting loss
#

def _get_best_masks(self, batched_outputs, batched_iou_predictions):
def _get_best_masks(self, batched_outputs, batched_iou_predictions, input_size, original_size):
# Batched mask and logit (low-res mask) predictions.
masks = torch.stack([m["masks"] for m in batched_outputs])
logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
# masks = torch.stack([m["masks"] for m in batched_outputs])

masks = torch.stack([
self.model.sam.postprocess_masks(log, input_size=input_size, original_size=original_size)
for log in logits
])

# Determine the best IOU across the multi-object prediction axis
# and turn this into a mask we can use for indexing.
Expand Down Expand Up @@ -241,7 +246,10 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim
with torch.no_grad():
# Get the mask and logit predictions corresponding to the predicted object
# (per actual object) with the best IOU.
masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
masks, logits = self._get_best_masks(
batched_outputs, batched_iou_predictions,
input_size=batched_inputs[0]["input_size"], original_size=batched_inputs[0]["original_size"]
)
batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)

loss = loss / num_subiter
Expand Down Expand Up @@ -350,15 +358,25 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

n_iter = 0
t_per_iter = time.time()

import json
bench_iters = 5
tbs = []
tfs = []

for x, y in self.train_loader:
input_check_done = self._check_input_normalization(x, input_check_done)

self.optimizer.zero_grad()

tf = time.time()
with forward_context():
(loss, mask_loss, iou_regression_loss, model_iou) = self._interactive_train_iteration(x, y)
tf = time.time() - tf

tb = time.time()
backprop(loss)
tb = time.time() - tb

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
Expand All @@ -370,6 +388,16 @@ def _train_epoch_impl(self, progress, forward_context, backprop):
break
progress.update(1)

tfs.append(tf)
tbs.append(tb)

if n_iter >= bench_iters:
with open("training_times.json", "w") as f:
json.dump({
"forward_pass": tfs, "backward_pass": tbs
}, f)
quit()

t_per_iter = (time.time() - t_per_iter) / n_iter
return t_per_iter

Expand Down
24 changes: 11 additions & 13 deletions micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def __init__(
self,
sam: Sam,
device: Union[str, torch.device],
upsample_masks: bool = False,
) -> None:
super().__init__()
self.sam = sam
self.device = device
self.transform = ResizeLongestSide(sam.image_encoder.img_size)
self.upsample_masks = upsample_masks

def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Resize, normalize pixel values and pad to a square input.
Expand Down Expand Up @@ -111,18 +113,14 @@ def forward(
multimask_output=multimask_output,
)

masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)

outputs.append(
{
"low_res_masks": low_res_masks,
"masks": masks,
"iou_predictions": iou_predictions
}
)
this_output = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
if self.upsample_masks:
masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)
this_output["masks"] = masks
outputs.append(this_output)

return outputs
3 changes: 2 additions & 1 deletion micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def train_sam(
scheduler_kwargs: Optional[Dict[str, Any]] = None,
save_every_kth_epoch: Optional[int] = None,
pbar_signals: Optional[QObject] = None,
**kwargs,
) -> None:
"""Run training for a SAM model.
Expand Down Expand Up @@ -260,7 +261,7 @@ def train_sam(
lr_scheduler=scheduler,
logger=trainers.SamLogger,
log_image_interval=100,
mixed_precision=True,
mixed_precision=False,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=n_sub_iteration,
Expand Down

0 comments on commit fb9219a

Please sign in to comment.