Skip to content

Commit

Permalink
Update visual prompting pipeline for multi-label zero-shot learning s…
Browse files Browse the repository at this point in the history
…upport (#3993)

* Add flag for using mask or polygon in `learn`

* Revert some parts in #3769

* Fix metric

* Fix unit test

* Fix

* Fix unit tests

* Fix integration test

* Update docstring

* Fix error during ptq

* Update CHANGELOG

* Update CHANGELOG
  • Loading branch information
sungchul2 authored Oct 15, 2024
1 parent fa272d5 commit 78b560d
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 270 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

All notable changes to this project will be documented in this file.

## \[Unreleased\]

### Enhancements

- Update visual prompting pipeline for multi-label zero-shot learning support
(https://github.com/openvinotoolkit/training_extensions/pull/3993)

## \[2.3.0\]

### New features
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/visual_prompting/losses/sam_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(
loss_dice += self.calculate_dice_loss(post_processed_pred_mask, flatten_gt_mask, num_masks)
loss_focal += self.calculate_sigmoid_ce_focal_loss(post_processed_pred_mask, flatten_gt_mask, num_masks)
batch_iou = self.calculate_iou(post_processed_pred_mask, flatten_gt_mask)
loss_iou += nn.functional.mse_loss(iou, batch_iou.unsqueeze(1), reduction="sum") / num_masks
loss_iou += nn.functional.mse_loss(iou, batch_iou, reduction="sum") / num_masks

loss = 20.0 * loss_focal + loss_dice + loss_iou

Expand Down
104 changes: 57 additions & 47 deletions src/otx/algo/visual_prompting/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Callable, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import torch
import torchvision.transforms.v2 as tvt_v2
from torch import Tensor, nn
from torchvision.tv_tensors import BoundingBoxes, Image, Mask
from torchvision.tv_tensors import BoundingBoxes, Image

from otx.algo.visual_prompting.decoders import SAMMaskDecoder
from otx.algo.visual_prompting.encoders import SAMImageEncoder, SAMPromptEncoder
from otx.algo.visual_prompting.losses.sam_loss import SAMCriterion
from otx.algo.visual_prompting.visual_prompters import SegmentAnything, ZeroShotSegmentAnything
from otx.core.data.entity.base import OTXBatchLossEntity, Points
from otx.core.data.entity.visual_prompting import (
ZeroShotPromptType,
ZeroShotVisualPromptingBatchDataEntity,
ZeroShotVisualPromptingBatchPredEntity,
)
Expand All @@ -34,7 +35,6 @@

if TYPE_CHECKING:
import numpy as np
from datumaro import Polygon as dmPolygon
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable

from otx.core.metrics import MetricCallable
Expand All @@ -56,30 +56,54 @@ class CommonSettingMixin:
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
load_state_dict: Callable[[dict[str, Tensor]], None]

def load_checkpoint(self, load_from: str | None) -> None:
def load_state_dict(
self,
state_dict: dict[str, Any] | None = None,
strict: bool = True,
assign: bool = False,
load_from: str | None = None,
) -> None:
"""Load checkpoint for SAM.
This method loads a pre-trained state dictionary for the SAM model. It can load from
a provided state dictionary or from a URL specified in the `load_from` parameter.
Args:
load_from (Optional[str], optional): Checkpoint path for SAM. Defaults to None.
state_dict (dict[str, Any] | None, optional): The state dictionary to load.
Defaults to None.
strict (bool, optional): Whether to strictly enforce that the keys in state_dict
match the keys returned by this module's state_dict() function. Defaults to True.
assign (bool, optional): Whether to copy parameters instead of moving them.
Defaults to False.
load_from (str | None, optional): URL to load the checkpoint from. If provided,
this will be used instead of the state_dict argument. Defaults to None.
Raises:
ValueError: If the checkpoint format is not desirable for torch.hub.load_state_dict_from_url.
Note:
If loading from a URL, some keys are removed from the loaded state dictionary
and a 'model.' prefix is added to all remaining keys.
"""
try:
state_dict = torch.hub.load_state_dict_from_url(str(load_from))
for key in [
"image_encoder.norm_head.weight",
"image_encoder.norm_head.bias",
"image_encoder.head.weight",
"image_encoder.head.bias",
]:
if key in state_dict:
state_dict.pop(key)

# add prefix 'model.' to all keys
for key in list(state_dict.keys()):
state_dict["model." + key] = state_dict.pop(key)

self.load_state_dict(state_dict)
if load_from is not None:
_state_dict: dict[str, Any] = torch.hub.load_state_dict_from_url(str(load_from))
for key in [
"image_encoder.norm_head.weight",
"image_encoder.norm_head.bias",
"image_encoder.head.weight",
"image_encoder.head.bias",
]:
if key in _state_dict:
_state_dict.pop(key)

# add prefix 'model.' to all keys
for key in list(_state_dict.keys()):
_state_dict["model." + key] = _state_dict.pop(key)

state_dict = _state_dict
super().load_state_dict(state_dict, strict, assign) # type: ignore[misc]

except (ValueError, RuntimeError) as e:
log.info(
Expand Down Expand Up @@ -151,7 +175,7 @@ def forward_for_tracing(
)


class SAM(OTXVisualPromptingModel, CommonSettingMixin):
class SAM(CommonSettingMixin, OTXVisualPromptingModel): # type: ignore[misc]
"""OTX visual prompting model class for Segment Anything Model (SAM)."""

input_size_multiplier = 16
Expand Down Expand Up @@ -195,7 +219,7 @@ def __init__(
torch_compile=torch_compile,
)

self.load_checkpoint(load_from=self.load_from[backbone_type])
self.load_state_dict(load_from=self.load_from[backbone_type])
self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder)

def _build_model(self) -> nn.Module:
Expand All @@ -219,7 +243,7 @@ def _build_model(self) -> nn.Module:
)


class ZeroShotSAM(OTXZeroShotVisualPromptingModel, CommonSettingMixin):
class ZeroShotSAM(CommonSettingMixin, OTXZeroShotVisualPromptingModel): # type: ignore[misc]
"""Zero-Shot Visual Prompting model."""

def __init__( # noqa: PLR0913
Expand Down Expand Up @@ -276,7 +300,7 @@ def __init__( # noqa: PLR0913
freeze_prompt_encoder = True
freeze_mask_decoder = True

self.load_checkpoint(load_from=self.load_from[backbone_type])
self.load_state_dict(load_from=self.load_from[backbone_type])
self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder)

self.save_outputs = save_outputs
Expand Down Expand Up @@ -359,23 +383,13 @@ def infer(
def _gather_prompts_with_labels(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity,
) -> list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]]:
) -> list[dict[int, list[ZeroShotPromptType]]]:
"""Gather prompts according to labels."""
total_processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]] = []
for batch, batch_labels in enumerate(inputs.labels):
total_processed_prompts: list[dict[int, list[ZeroShotPromptType]]] = []
for prompts, labels in zip(inputs.prompts, inputs.labels):
processed_prompts = defaultdict(list)
for prompt_type in ["prompts", "polygons", "masks"]:
_prompts = getattr(inputs, prompt_type, None)
prompt_labels = getattr(batch_labels, prompt_type, None)
if _prompts is None or prompt_labels is None:
continue

for idx, _label in enumerate(prompt_labels):
if prompt_type in ("prompts", "polygons"):
processed_prompts[int(_label)].append(_prompts[batch][idx])
else:
# for mask
processed_prompts[int(_label)].append(Mask(_prompts[batch][idx]))
for prompt, label in zip(prompts, labels):
processed_prompts[int(label)].append(prompt)

sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x))
total_processed_prompts.append(sorted_processed_prompts)
Expand Down Expand Up @@ -411,19 +425,18 @@ def apply_boxes(self, boxes: BoundingBoxes, ori_shape: tuple[int, ...], target_l

def apply_prompts(
self,
prompts: list[Points | BoundingBoxes],
prompts: list[ZeroShotPromptType],
ori_shape: tuple[int, ...],
target_length: int = 1024,
) -> list[Points | BoundingBoxes]:
) -> list[ZeroShotPromptType]:
"""Preprocess prompts to be used in the model."""
transformed_prompts: list[Points | BoundingBoxes] = []
transformed_prompts: list[ZeroShotPromptType] = []
for prompt in prompts:
if isinstance(prompt, Points):
transformed_prompts.append(self.apply_points(prompt, ori_shape, target_length))
elif isinstance(prompt, BoundingBoxes):
transformed_prompts.append(self.apply_boxes(prompt, ori_shape, target_length))
else:
log.info(f"Current prompt ({prompt.__class__.__name__}) is not supported, saved as it is.")
transformed_prompts.append(prompt)
return transformed_prompts

Expand Down Expand Up @@ -452,9 +465,6 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot
self.apply_prompts(prompt, info.ori_shape, self.model.image_size)
for prompt, info in zip(entity.prompts, entity.imgs_info)
],
masks=entity.masks,
polygons=entity.polygons,
labels=entity.labels,
)

def initialize_reference_info(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(
multimask_output=False, # when given multiple prompts. if there is single prompt True would be better. # noqa: E501
)
low_res_masks.append(_low_res_masks)
iou_predictions.append(_iou_predictions)
iou_predictions.append(_iou_predictions.squeeze(1))

pred_masks.append(torch.cat(low_res_masks, dim=0))
ious.append(torch.cat(iou_predictions, dim=0))
Expand Down Expand Up @@ -614,7 +614,7 @@ def infer(
ori_shape=ori_shape,
is_cascade=is_cascade,
)
predicted_masks[label].append(mask * point_score[2])
predicted_masks[label].append(mask)
used_points[label].append(point_score)

# check overlapping area between different label masks
Expand Down
Loading

0 comments on commit 78b560d

Please sign in to comment.