Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update visual prompting pipeline for multi-label zero-shot learning support #3993

Merged
merged 14 commits into from
Oct 15, 2024
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

All notable changes to this project will be documented in this file.

## \[Unreleased\]

### Enhancements

- Update visual prompting pipeline for multi-label zero-shot learning support
(https://github.com/openvinotoolkit/training_extensions/pull/3993)

## \[2.3.0\]

### New features
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/visual_prompting/losses/sam_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(
loss_dice += self.calculate_dice_loss(post_processed_pred_mask, flatten_gt_mask, num_masks)
loss_focal += self.calculate_sigmoid_ce_focal_loss(post_processed_pred_mask, flatten_gt_mask, num_masks)
batch_iou = self.calculate_iou(post_processed_pred_mask, flatten_gt_mask)
loss_iou += nn.functional.mse_loss(iou, batch_iou.unsqueeze(1), reduction="sum") / num_masks
loss_iou += nn.functional.mse_loss(iou, batch_iou, reduction="sum") / num_masks

loss = 20.0 * loss_focal + loss_dice + loss_iou

Expand Down
104 changes: 57 additions & 47 deletions src/otx/algo/visual_prompting/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Callable, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import torch
import torchvision.transforms.v2 as tvt_v2
from torch import Tensor, nn
from torchvision.tv_tensors import BoundingBoxes, Image, Mask
from torchvision.tv_tensors import BoundingBoxes, Image

from otx.algo.visual_prompting.decoders import SAMMaskDecoder
from otx.algo.visual_prompting.encoders import SAMImageEncoder, SAMPromptEncoder
from otx.algo.visual_prompting.losses.sam_loss import SAMCriterion
from otx.algo.visual_prompting.visual_prompters import SegmentAnything, ZeroShotSegmentAnything
from otx.core.data.entity.base import OTXBatchLossEntity, Points
from otx.core.data.entity.visual_prompting import (
ZeroShotPromptType,
ZeroShotVisualPromptingBatchDataEntity,
ZeroShotVisualPromptingBatchPredEntity,
)
Expand All @@ -34,7 +35,6 @@

if TYPE_CHECKING:
import numpy as np
from datumaro import Polygon as dmPolygon
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable

from otx.core.metrics import MetricCallable
Expand All @@ -56,30 +56,54 @@ class CommonSettingMixin:
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
load_state_dict: Callable[[dict[str, Tensor]], None]

def load_checkpoint(self, load_from: str | None) -> None:
def load_state_dict(
self,
state_dict: dict[str, Any] | None = None,
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
strict: bool = True,
assign: bool = False,
load_from: str | None = None,
) -> None:
"""Load checkpoint for SAM.

This method loads a pre-trained state dictionary for the SAM model. It can load from
a provided state dictionary or from a URL specified in the `load_from` parameter.

Args:
load_from (Optional[str], optional): Checkpoint path for SAM. Defaults to None.
state_dict (dict[str, Any] | None, optional): The state dictionary to load.
Defaults to None.
strict (bool, optional): Whether to strictly enforce that the keys in state_dict
match the keys returned by this module's state_dict() function. Defaults to True.
assign (bool, optional): Whether to copy parameters instead of moving them.
Defaults to False.
load_from (str | None, optional): URL to load the checkpoint from. If provided,
this will be used instead of the state_dict argument. Defaults to None.

Raises:
ValueError: If the checkpoint format is not desirable for torch.hub.load_state_dict_from_url.

Note:
If loading from a URL, some keys are removed from the loaded state dictionary
and a 'model.' prefix is added to all remaining keys.
"""
try:
state_dict = torch.hub.load_state_dict_from_url(str(load_from))
for key in [
"image_encoder.norm_head.weight",
"image_encoder.norm_head.bias",
"image_encoder.head.weight",
"image_encoder.head.bias",
]:
if key in state_dict:
state_dict.pop(key)

# add prefix 'model.' to all keys
for key in list(state_dict.keys()):
state_dict["model." + key] = state_dict.pop(key)

self.load_state_dict(state_dict)
if load_from is not None:
_state_dict: dict[str, Any] = torch.hub.load_state_dict_from_url(str(load_from))
for key in [
"image_encoder.norm_head.weight",
"image_encoder.norm_head.bias",
"image_encoder.head.weight",
"image_encoder.head.bias",
]:
if key in _state_dict:
_state_dict.pop(key)

# add prefix 'model.' to all keys
for key in list(_state_dict.keys()):
_state_dict["model." + key] = _state_dict.pop(key)

state_dict = _state_dict
super().load_state_dict(state_dict, strict, assign) # type: ignore[misc]

except (ValueError, RuntimeError) as e:
log.info(
Expand Down Expand Up @@ -151,7 +175,7 @@ def forward_for_tracing(
)


class SAM(OTXVisualPromptingModel, CommonSettingMixin):
class SAM(CommonSettingMixin, OTXVisualPromptingModel): # type: ignore[misc]
"""OTX visual prompting model class for Segment Anything Model (SAM)."""

input_size_multiplier = 16
Expand Down Expand Up @@ -195,7 +219,7 @@ def __init__(
torch_compile=torch_compile,
)

self.load_checkpoint(load_from=self.load_from[backbone_type])
self.load_state_dict(load_from=self.load_from[backbone_type])
self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder)

def _build_model(self) -> nn.Module:
Expand All @@ -219,7 +243,7 @@ def _build_model(self) -> nn.Module:
)


class ZeroShotSAM(OTXZeroShotVisualPromptingModel, CommonSettingMixin):
class ZeroShotSAM(CommonSettingMixin, OTXZeroShotVisualPromptingModel): # type: ignore[misc]
"""Zero-Shot Visual Prompting model."""

def __init__( # noqa: PLR0913
Expand Down Expand Up @@ -276,7 +300,7 @@ def __init__( # noqa: PLR0913
freeze_prompt_encoder = True
freeze_mask_decoder = True

self.load_checkpoint(load_from=self.load_from[backbone_type])
self.load_state_dict(load_from=self.load_from[backbone_type])
self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder)

self.save_outputs = save_outputs
Expand Down Expand Up @@ -359,23 +383,13 @@ def infer(
def _gather_prompts_with_labels(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity,
) -> list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]]:
) -> list[dict[int, list[ZeroShotPromptType]]]:
"""Gather prompts according to labels."""
total_processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]] = []
for batch, batch_labels in enumerate(inputs.labels):
total_processed_prompts: list[dict[int, list[ZeroShotPromptType]]] = []
for prompts, labels in zip(inputs.prompts, inputs.labels):
processed_prompts = defaultdict(list)
for prompt_type in ["prompts", "polygons", "masks"]:
_prompts = getattr(inputs, prompt_type, None)
prompt_labels = getattr(batch_labels, prompt_type, None)
if _prompts is None or prompt_labels is None:
continue

for idx, _label in enumerate(prompt_labels):
if prompt_type in ("prompts", "polygons"):
processed_prompts[int(_label)].append(_prompts[batch][idx])
else:
# for mask
processed_prompts[int(_label)].append(Mask(_prompts[batch][idx]))
for prompt, label in zip(prompts, labels):
processed_prompts[int(label)].append(prompt)

sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x))
total_processed_prompts.append(sorted_processed_prompts)
Expand Down Expand Up @@ -411,19 +425,18 @@ def apply_boxes(self, boxes: BoundingBoxes, ori_shape: tuple[int, ...], target_l

def apply_prompts(
self,
prompts: list[Points | BoundingBoxes],
prompts: list[ZeroShotPromptType],
ori_shape: tuple[int, ...],
target_length: int = 1024,
) -> list[Points | BoundingBoxes]:
) -> list[ZeroShotPromptType]:
"""Preprocess prompts to be used in the model."""
transformed_prompts: list[Points | BoundingBoxes] = []
transformed_prompts: list[ZeroShotPromptType] = []
for prompt in prompts:
if isinstance(prompt, Points):
transformed_prompts.append(self.apply_points(prompt, ori_shape, target_length))
elif isinstance(prompt, BoundingBoxes):
transformed_prompts.append(self.apply_boxes(prompt, ori_shape, target_length))
else:
log.info(f"Current prompt ({prompt.__class__.__name__}) is not supported, saved as it is.")
transformed_prompts.append(prompt)
return transformed_prompts

Expand Down Expand Up @@ -452,9 +465,6 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot
self.apply_prompts(prompt, info.ori_shape, self.model.image_size)
for prompt, info in zip(entity.prompts, entity.imgs_info)
],
masks=entity.masks,
polygons=entity.polygons,
labels=entity.labels,
)

def initialize_reference_info(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(
multimask_output=False, # when given multiple prompts. if there is single prompt True would be better. # noqa: E501
)
low_res_masks.append(_low_res_masks)
iou_predictions.append(_iou_predictions)
iou_predictions.append(_iou_predictions.squeeze(1))

pred_masks.append(torch.cat(low_res_masks, dim=0))
ious.append(torch.cat(iou_predictions, dim=0))
Expand Down Expand Up @@ -614,7 +614,7 @@ def infer(
ori_shape=ori_shape,
is_cascade=is_cascade,
)
predicted_masks[label].append(mask * point_score[2])
predicted_masks[label].append(mask)
used_points[label].append(point_score)

# check overlapping area between different label masks
Expand Down
Loading
Loading