diff --git a/CHANGELOG.md b/CHANGELOG.md index 3640f36e4fd..cdf5c5840d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ All notable changes to this project will be documented in this file. +## \[Unreleased\] + +### Enhancements + +- Update visual prompting pipeline for multi-label zero-shot learning support + (https://github.com/openvinotoolkit/training_extensions/pull/3993) + ## \[2.3.0\] ### New features diff --git a/src/otx/algo/visual_prompting/losses/sam_loss.py b/src/otx/algo/visual_prompting/losses/sam_loss.py index 064427f3058..c98b775e534 100644 --- a/src/otx/algo/visual_prompting/losses/sam_loss.py +++ b/src/otx/algo/visual_prompting/losses/sam_loss.py @@ -49,7 +49,7 @@ def forward( loss_dice += self.calculate_dice_loss(post_processed_pred_mask, flatten_gt_mask, num_masks) loss_focal += self.calculate_sigmoid_ce_focal_loss(post_processed_pred_mask, flatten_gt_mask, num_masks) batch_iou = self.calculate_iou(post_processed_pred_mask, flatten_gt_mask) - loss_iou += nn.functional.mse_loss(iou, batch_iou.unsqueeze(1), reduction="sum") / num_masks + loss_iou += nn.functional.mse_loss(iou, batch_iou, reduction="sum") / num_masks loss = 20.0 * loss_focal + loss_dice + loss_iou diff --git a/src/otx/algo/visual_prompting/sam.py b/src/otx/algo/visual_prompting/sam.py index e691acf60e2..023936adec2 100644 --- a/src/otx/algo/visual_prompting/sam.py +++ b/src/otx/algo/visual_prompting/sam.py @@ -10,12 +10,12 @@ from collections import defaultdict from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Callable, ClassVar, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal import torch import torchvision.transforms.v2 as tvt_v2 from torch import Tensor, nn -from torchvision.tv_tensors import BoundingBoxes, Image, Mask +from torchvision.tv_tensors import BoundingBoxes, Image from otx.algo.visual_prompting.decoders import SAMMaskDecoder from otx.algo.visual_prompting.encoders import SAMImageEncoder, SAMPromptEncoder @@ -23,6 +23,7 @@ from otx.algo.visual_prompting.visual_prompters import SegmentAnything, ZeroShotSegmentAnything from otx.core.data.entity.base import OTXBatchLossEntity, Points from otx.core.data.entity.visual_prompting import ( + ZeroShotPromptType, ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, ) @@ -34,7 +35,6 @@ if TYPE_CHECKING: import numpy as np - from datumaro import Polygon as dmPolygon from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from otx.core.metrics import MetricCallable @@ -56,30 +56,54 @@ class CommonSettingMixin: "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", } - load_state_dict: Callable[[dict[str, Tensor]], None] - def load_checkpoint(self, load_from: str | None) -> None: + def load_state_dict( + self, + state_dict: dict[str, Any] | None = None, + strict: bool = True, + assign: bool = False, + load_from: str | None = None, + ) -> None: """Load checkpoint for SAM. + This method loads a pre-trained state dictionary for the SAM model. It can load from + a provided state dictionary or from a URL specified in the `load_from` parameter. + Args: - load_from (Optional[str], optional): Checkpoint path for SAM. Defaults to None. + state_dict (dict[str, Any] | None, optional): The state dictionary to load. + Defaults to None. + strict (bool, optional): Whether to strictly enforce that the keys in state_dict + match the keys returned by this module's state_dict() function. Defaults to True. + assign (bool, optional): Whether to copy parameters instead of moving them. + Defaults to False. + load_from (str | None, optional): URL to load the checkpoint from. If provided, + this will be used instead of the state_dict argument. Defaults to None. + + Raises: + ValueError: If the checkpoint format is not desirable for torch.hub.load_state_dict_from_url. + + Note: + If loading from a URL, some keys are removed from the loaded state dictionary + and a 'model.' prefix is added to all remaining keys. """ try: - state_dict = torch.hub.load_state_dict_from_url(str(load_from)) - for key in [ - "image_encoder.norm_head.weight", - "image_encoder.norm_head.bias", - "image_encoder.head.weight", - "image_encoder.head.bias", - ]: - if key in state_dict: - state_dict.pop(key) - - # add prefix 'model.' to all keys - for key in list(state_dict.keys()): - state_dict["model." + key] = state_dict.pop(key) - - self.load_state_dict(state_dict) + if load_from is not None: + _state_dict: dict[str, Any] = torch.hub.load_state_dict_from_url(str(load_from)) + for key in [ + "image_encoder.norm_head.weight", + "image_encoder.norm_head.bias", + "image_encoder.head.weight", + "image_encoder.head.bias", + ]: + if key in _state_dict: + _state_dict.pop(key) + + # add prefix 'model.' to all keys + for key in list(_state_dict.keys()): + _state_dict["model." + key] = _state_dict.pop(key) + + state_dict = _state_dict + super().load_state_dict(state_dict, strict, assign) # type: ignore[misc] except (ValueError, RuntimeError) as e: log.info( @@ -151,7 +175,7 @@ def forward_for_tracing( ) -class SAM(OTXVisualPromptingModel, CommonSettingMixin): +class SAM(CommonSettingMixin, OTXVisualPromptingModel): # type: ignore[misc] """OTX visual prompting model class for Segment Anything Model (SAM).""" input_size_multiplier = 16 @@ -195,7 +219,7 @@ def __init__( torch_compile=torch_compile, ) - self.load_checkpoint(load_from=self.load_from[backbone_type]) + self.load_state_dict(load_from=self.load_from[backbone_type]) self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder) def _build_model(self) -> nn.Module: @@ -219,7 +243,7 @@ def _build_model(self) -> nn.Module: ) -class ZeroShotSAM(OTXZeroShotVisualPromptingModel, CommonSettingMixin): +class ZeroShotSAM(CommonSettingMixin, OTXZeroShotVisualPromptingModel): # type: ignore[misc] """Zero-Shot Visual Prompting model.""" def __init__( # noqa: PLR0913 @@ -276,7 +300,7 @@ def __init__( # noqa: PLR0913 freeze_prompt_encoder = True freeze_mask_decoder = True - self.load_checkpoint(load_from=self.load_from[backbone_type]) + self.load_state_dict(load_from=self.load_from[backbone_type]) self.freeze_networks(freeze_image_encoder, freeze_prompt_encoder, freeze_mask_decoder) self.save_outputs = save_outputs @@ -359,23 +383,13 @@ def infer( def _gather_prompts_with_labels( self, inputs: ZeroShotVisualPromptingBatchDataEntity, - ) -> list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]]: + ) -> list[dict[int, list[ZeroShotPromptType]]]: """Gather prompts according to labels.""" - total_processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]] = [] - for batch, batch_labels in enumerate(inputs.labels): + total_processed_prompts: list[dict[int, list[ZeroShotPromptType]]] = [] + for prompts, labels in zip(inputs.prompts, inputs.labels): processed_prompts = defaultdict(list) - for prompt_type in ["prompts", "polygons", "masks"]: - _prompts = getattr(inputs, prompt_type, None) - prompt_labels = getattr(batch_labels, prompt_type, None) - if _prompts is None or prompt_labels is None: - continue - - for idx, _label in enumerate(prompt_labels): - if prompt_type in ("prompts", "polygons"): - processed_prompts[int(_label)].append(_prompts[batch][idx]) - else: - # for mask - processed_prompts[int(_label)].append(Mask(_prompts[batch][idx])) + for prompt, label in zip(prompts, labels): + processed_prompts[int(label)].append(prompt) sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x)) total_processed_prompts.append(sorted_processed_prompts) @@ -411,19 +425,18 @@ def apply_boxes(self, boxes: BoundingBoxes, ori_shape: tuple[int, ...], target_l def apply_prompts( self, - prompts: list[Points | BoundingBoxes], + prompts: list[ZeroShotPromptType], ori_shape: tuple[int, ...], target_length: int = 1024, - ) -> list[Points | BoundingBoxes]: + ) -> list[ZeroShotPromptType]: """Preprocess prompts to be used in the model.""" - transformed_prompts: list[Points | BoundingBoxes] = [] + transformed_prompts: list[ZeroShotPromptType] = [] for prompt in prompts: if isinstance(prompt, Points): transformed_prompts.append(self.apply_points(prompt, ori_shape, target_length)) elif isinstance(prompt, BoundingBoxes): transformed_prompts.append(self.apply_boxes(prompt, ori_shape, target_length)) else: - log.info(f"Current prompt ({prompt.__class__.__name__}) is not supported, saved as it is.") transformed_prompts.append(prompt) return transformed_prompts @@ -452,9 +465,6 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot self.apply_prompts(prompt, info.ori_shape, self.model.image_size) for prompt, info in zip(entity.prompts, entity.imgs_info) ], - masks=entity.masks, - polygons=entity.polygons, - labels=entity.labels, ) def initialize_reference_info(self) -> None: diff --git a/src/otx/algo/visual_prompting/visual_prompters/segment_anything.py b/src/otx/algo/visual_prompting/visual_prompters/segment_anything.py index 49ecf8958ef..a15fc4c8c4c 100644 --- a/src/otx/algo/visual_prompting/visual_prompters/segment_anything.py +++ b/src/otx/algo/visual_prompting/visual_prompters/segment_anything.py @@ -101,7 +101,7 @@ def forward( multimask_output=False, # when given multiple prompts. if there is single prompt True would be better. # noqa: E501 ) low_res_masks.append(_low_res_masks) - iou_predictions.append(_iou_predictions) + iou_predictions.append(_iou_predictions.squeeze(1)) pred_masks.append(torch.cat(low_res_masks, dim=0)) ious.append(torch.cat(iou_predictions, dim=0)) @@ -614,7 +614,7 @@ def infer( ori_shape=ori_shape, is_cascade=is_cascade, ) - predicted_masks[label].append(mask * point_score[2]) + predicted_masks[label].append(mask) used_points[label].append(point_score) # check overlapping area between different label masks diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 0047e9350fe..10bdeda3405 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -7,7 +7,7 @@ from collections import defaultdict from functools import partial -from typing import Callable, Literal +from typing import Callable import torch from datumaro import Bbox as dmBbox @@ -26,9 +26,9 @@ from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchDataEntity, VisualPromptingDataEntity, + ZeroShotPromptType, ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingDataEntity, - ZeroShotVisualPromptingLabel, ) from otx.core.types.label import NullLabelInfo from otx.core.utils.mask_util import polygon_to_bitmap @@ -74,7 +74,7 @@ def __init__( # if using only point prompt self.prob = 0.0 - self.label_info = NullLabelInfo() + self.label_info = NullLabelInfo() # TODO (sungchul): update label_info for multi-label support def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: item = self.dm_subset[index] @@ -177,27 +177,42 @@ class OTXZeroShotVisualPromptingDataset(OTXDataset[ZeroShotVisualPromptingDataEn Args: dm_subset (dmDataset): The subset of the dataset. transforms (Transforms): Data transformations to be applied. + use_mask (bool): Whether to use bitmap mask prompt for `learn`. + use_mask has the top priority rather than use_polygon, use_bbox, and use_point. + Defaults to False. + use_polygon (bool): Whether to use polygon prompt for `learn`. + use_polygon has higher priority than use_bbox and use_point. + Defaults to False. use_bbox (bool): Whether to use bounding box prompt. - If both use_bbox and use_point are False, use_bbox is set to True as default. - If both are True, divide the probability into both. Defaults to True. use_point (bool): Whether to use point prompt. - If both use_bbox and use_point are False, use_bbox is set to True as default. - If both are True, divide the probability into both. Defaults to False. **kwargs: Additional keyword arguments passed to the base class. + + Examples: + - use_mask == True : use bitmap mask as a prompt no matter what use_polygon, use_bbox, and use_point. + - use_mask == False and use_polygon == True : use polygon as a prompt no matter what use_bbox and use_point. + - use_mask == False and use_polygon == False + - use_bbox == False and use_point == False : set use_bbox to True as default. + - use_bbox == True and use_point == False : use bbox as a prompt. + - use_bbox == False and use_point == True : use point as a prompt. + - use_bbox == True and use_point == True : divide the probability into both. """ def __init__( self, dm_subset: dmDataset, transforms: Transforms, + use_mask: bool = False, + use_polygon: bool = False, use_bbox: bool = True, use_point: bool = False, stack_images: bool = True, **kwargs, ) -> None: super().__init__(dm_subset, transforms, stack_images=stack_images, **kwargs) + self.use_mask = use_mask + self.use_polygon = use_polygon if not use_bbox and not use_point: # if both are False, use bbox as default use_bbox = True @@ -209,17 +224,17 @@ def __init__( # if using only point prompt self.prob = 0.0 - self.label_info = NullLabelInfo() + self.label_info = NullLabelInfo() # TODO (sungchul): update label_info for multi-label support def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None: item = self.dm_subset[index] img = item.media_as(dmImage) img_data, img_shape = self._get_img_data_and_shape(img) - gt_prompts: list[tvBoundingBoxes | Points] = [] + prompts: list[ZeroShotPromptType] = [] gt_masks: list[tvMask] = [] gt_polygons: list[dmPolygon] = [] - gt_labels: dict[Literal["prompts", "polygons", "masks"], list[int]] = defaultdict(list) + gt_labels: list[int] = [] for annotation in item.annotations: if isinstance(annotation, dmPolygon): # generate prompts from polygon @@ -229,7 +244,15 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None # skip very small region continue - if torch.rand(1) < self.prob: + if self.use_mask: + # get mask + prompts.append(mask) + + elif self.use_polygon: + # get polygon + prompts.append(annotation) + + elif torch.rand(1) < self.prob: # get bbox bbox = tvBoundingBoxes( annotation.get_bbox(), @@ -238,7 +261,7 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None dtype=torch.float32, ) bbox = convert_bounding_box_format(bbox, new_format=tvBoundingBoxFormat.XYXY) - gt_prompts.append(bbox) + prompts.append(bbox) else: # get center point point = Points( @@ -246,11 +269,9 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None canvas_size=img_shape, dtype=torch.float32, ) - gt_prompts.append(point) + prompts.append(point) - gt_labels["prompts"].append(annotation.label) - gt_labels["polygons"].append(annotation.label) - gt_labels["masks"].append(annotation.label) + gt_labels.append(annotation.label) gt_masks.append(mask) gt_polygons.append(annotation) @@ -258,12 +279,10 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None elif isinstance(annotation, (dmBbox, dmMask, dmPoints)): pass - if not gt_prompts: + if not prompts: return None - labels = { - str(prompt_type): torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items() - } + labels = torch.as_tensor(gt_labels, dtype=torch.int64) masks = tvMask(torch.stack(gt_masks, dim=0), dtype=torch.uint8) return ZeroShotVisualPromptingDataEntity( @@ -274,9 +293,9 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None ori_shape=img_shape, ), masks=masks, - labels=ZeroShotVisualPromptingLabel(**labels), + labels=labels, polygons=gt_polygons, - prompts=gt_prompts, + prompts=prompts, ) @property diff --git a/src/otx/core/data/entity/visual_prompting.py b/src/otx/core/data/entity/visual_prompting.py index 8f5eed33b57..a0a9472035f 100644 --- a/src/otx/core/data/entity/visual_prompting.py +++ b/src/otx/core/data/entity/visual_prompting.py @@ -6,19 +6,22 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Type +from datumaro import Polygon as dmPolygon from torchvision import tv_tensors +from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes +from torchvision.tv_tensors import Mask as tvMask from otx.core.data.entity.base import OTXBatchDataEntity, OTXBatchPredEntity, OTXDataEntity, OTXPredEntity, Points from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType if TYPE_CHECKING: - from datumaro import Polygon as dmPolygon from torch import LongTensor - from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes - from torchvision.tv_tensors import Mask as tvMask + + +ZeroShotPromptType = Type[tvBoundingBoxes | Points | tvMask | dmPolygon] @register_pytree_node @@ -127,15 +130,6 @@ class VisualPromptingBatchPredEntity(OTXBatchPredEntity, VisualPromptingBatchDat """Data entity to represent model output predictions for visual prompting task.""" -@dataclass -class ZeroShotVisualPromptingLabel: - """Label dataclass for zero-shot visual prompting data entity.""" - - prompts: LongTensor | None = None - polygons: LongTensor | None = None - masks: LongTensor | None = None - - @register_pytree_node @dataclass class ZeroShotVisualPromptingDataEntity(OTXDataEntity): @@ -143,10 +137,9 @@ class ZeroShotVisualPromptingDataEntity(OTXDataEntity): Attributes: masks (tvMask): The masks of the instances. - labels (ZeroShotVisualPromptingLabel): The labels of the instances - for each prompt. + labels (LongTensor): The labels of the instances. polygons (list[dmPolygon]): The polygons of the instances. - prompts (list[tvBoundingBoxes | Points]): The prompts of the instances. + prompts (list[ZeroShotPromptType]): The prompts of the instances. """ @property @@ -155,9 +148,9 @@ def task(self) -> OTXTaskType: return OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING masks: tvMask - labels: ZeroShotVisualPromptingLabel + labels: LongTensor polygons: list[dmPolygon] - prompts: list[tvBoundingBoxes | Points] + prompts: list[ZeroShotPromptType] @dataclass @@ -166,15 +159,15 @@ class ZeroShotVisualPromptingBatchDataEntity(OTXBatchDataEntity[ZeroShotVisualPr Attributes: masks (list[tvMask]): List of masks. - labels (list[ZeroShotVisualPromptingLabel]): List of labels. + labels (list[LongTensor]): List of labels. polygons (list[list[dmPolygon]]): List of polygons. - prompts (list[list[tvBoundingBoxes | Points]]): List of prompts. + prompts (list[list[ZeroShotPromptType]]): List of prompts. """ masks: list[tvMask] - labels: list[ZeroShotVisualPromptingLabel] + labels: list[LongTensor] polygons: list[list[dmPolygon]] - prompts: list[list[tvBoundingBoxes | Points]] + prompts: list[list[ZeroShotPromptType]] @property def task(self) -> OTXTaskType: @@ -217,10 +210,7 @@ def pin_memory(self) -> ZeroShotVisualPromptingBatchDataEntity: for prompts in self.prompts ], masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks], - labels=[ - ZeroShotVisualPromptingLabel(**{k: v.pin_memory() for k, v in label.__dict__.items()}) - for label in self.labels - ], + labels=[label.pin_memory() for label in self.labels], ) ) diff --git a/src/otx/core/metrics/visual_prompting.py b/src/otx/core/metrics/visual_prompting.py index a5d009c0ae0..6b841926c76 100644 --- a/src/otx/core/metrics/visual_prompting.py +++ b/src/otx/core/metrics/visual_prompting.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # """Module for OTX Dice metric used for the OTX semantic segmentation task.""" + from __future__ import annotations from torchmetrics import MetricCollection @@ -12,6 +13,7 @@ def _visual_prompting_metric_callable(label_info: LabelInfo) -> MetricCollection: # noqa: ARG001 + # TODO (sungchul): consider to use iseg and sseg's metrics return MetricCollection( metrics={ "iou": BinaryJaccardIndex(), diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index cbc08ec4019..5e3ca58aae1 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -31,7 +31,6 @@ VisualPromptingBatchPredEntity, ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, - ZeroShotVisualPromptingLabel, ) from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter @@ -59,7 +58,10 @@ def _convert_pred_entity_to_compute_metric( preds: VisualPromptingBatchPredEntity | ZeroShotVisualPromptingBatchPredEntity, inputs: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity, ) -> MetricInput: - """Convert the prediction entity to the format required by the compute metric function.""" + """Convert the prediction entity to the format required by the compute metric function. + + TODO (sungchul): consider to use iseg and sseg's metrics + """ pred_info = [] target_info = [] @@ -82,10 +84,14 @@ def _convert_pred_entity_to_compute_metric( inputs.polygons, inputs.labels, ): - bit_masks = masks if len(masks) else polygon_to_bitmap(polygons, *imgs_info.ori_shape) + bit_masks = ( + masks + if len(masks) + else tv_tensors.Mask(polygon_to_bitmap(polygons, *imgs_info.ori_shape), dtype=torch.uint8) + ) target_info.append( { - "masks": tv_tensors.Mask(bit_masks, dtype=torch.bool).data, + "masks": bit_masks.data, "labels": torch.cat(list(labels.values())) if isinstance(labels, dict) else labels, }, ) @@ -94,49 +100,19 @@ def _convert_pred_entity_to_compute_metric( def _inference_step( - model: OTXVisualPromptingModel | OVVisualPromptingModel, - metric: MetricCollection, - inputs: VisualPromptingBatchDataEntity, -) -> None: - """Perform a single inference step on a batch of data from the inference set.""" - preds = model.forward(inputs) - - if not isinstance(preds, VisualPromptingBatchPredEntity): - raise TypeError(preds) - - converted_entities: dict[str, list[dict[str, Tensor]]] = _convert_pred_entity_to_compute_metric(preds, inputs) # type: ignore[assignment] - - for _name, _metric in metric.items(): - if _name == "mAP": - # MeanAveragePrecision - _preds = [ - {k: v > 0.5 if k == "masks" else v.squeeze(1) if k == "scores" else v for k, v in ett.items()} - for ett in converted_entities["preds"] - ] - _target = converted_entities["target"] - _metric.update(preds=_preds, target=_target) - elif _name in ["iou", "f1-score", "dice"]: - # BinaryJaccardIndex, BinaryF1Score, Dice - for cvt_preds, cvt_target in zip(converted_entities["preds"], converted_entities["target"]): - _metric.update(cvt_preds["masks"], cvt_target["masks"]) - - -def _inference_step_for_zero_shot( - model: OTXZeroShotVisualPromptingModel | OVZeroShotVisualPromptingModel, + model: OTXVisualPromptingModel + | OVVisualPromptingModel + | OTXZeroShotVisualPromptingModel + | OVZeroShotVisualPromptingModel, metric: MetricCollection, - inputs: ZeroShotVisualPromptingBatchDataEntity, + inputs: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity, ) -> None: """Perform a single inference step on a batch of data from the inference set.""" - preds = model.forward(inputs) + preds = model.forward(inputs) # type: ignore[arg-type] - if not isinstance(preds, ZeroShotVisualPromptingBatchPredEntity): + if not isinstance(preds, (VisualPromptingBatchPredEntity, ZeroShotVisualPromptingBatchPredEntity)): raise TypeError(preds) - # filter labels using corresponding ground truth - inputs.labels = [ - label.masks if inputs.masks and label.masks is not None else label.polygons for label in inputs.labels - ] - converted_entities: dict[str, list[dict[str, Tensor]]] = _convert_pred_entity_to_compute_metric(preds, inputs) # type: ignore[assignment] for _name, _metric in metric.items(): @@ -150,11 +126,23 @@ def _inference_step_for_zero_shot( _metric.update(preds=_preds, target=_target) elif _name in ["iou", "f1-score", "dice"]: # BinaryJaccardIndex, BinaryF1Score, Dice + # TODO (sungchul): change to multi-class metric + # Currently, label_info is NullLabelInfo and it is required to be changed for multi-label support. + # But huge changes is required, it will be changed in the near future. for cvt_preds, cvt_target in zip(converted_entities["preds"], converted_entities["target"]): - _metric.update( - cvt_preds["masks"].sum(dim=0).clamp(0, 1), - cvt_target["masks"].sum(dim=0).clamp(0, 1), - ) + max_label = torch.cat((cvt_preds["labels"], cvt_target["labels"])).max() + for label in range(max_label + 1): + mask_preds = cvt_preds["masks"][cvt_preds["labels"] == label] + mask_target = cvt_target["masks"][cvt_target["labels"] == label] + if len(mask_preds) == 0: + mask_preds = torch.zeros((1, *mask_target.shape[1:]), device=model.device) + if len(mask_target) == 0: + mask_target = torch.zeros((1, *mask_preds.shape[1:]), device=model.device, dtype=torch.uint8) + + _metric.update( + mask_preds.sum(dim=0).clamp(0, 1).float().flatten(), + mask_target.sum(dim=0).clamp(0, 1).flatten(), + ) class OTXVisualPromptingModel(OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity]): @@ -162,7 +150,7 @@ class OTXVisualPromptingModel(OTXModel[VisualPromptingBatchDataEntity, VisualPro def __init__( self, - label_info: LabelInfoTypes = NullLabelInfo(), + label_info: LabelInfoTypes = NullLabelInfo(), # TODO (sungchul): update label_info for multi-label support input_size: tuple[int, int] = (1024, 1024), optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, @@ -172,7 +160,7 @@ def __init__( msg = f"Given label_info={label_info} has no effect." log.debug(msg) super().__init__( - label_info=NullLabelInfo(), + label_info=NullLabelInfo(), # TODO (sungchul): update label_info for multi-label support input_size=input_size, optimizer=optimizer, scheduler=scheduler, @@ -348,7 +336,7 @@ class OTXZeroShotVisualPromptingModel( def __init__( self, input_size: tuple[int, int], - label_info: LabelInfoTypes = NullLabelInfo(), + label_info: LabelInfoTypes = NullLabelInfo(), # TODO (sungchul): update label_info for multi-label support optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = VisualPromptingMetricCallable, @@ -357,7 +345,7 @@ def __init__( msg = f"Given label_info={label_info} has no effect." log.debug(msg) super().__init__( - label_info=NullLabelInfo(), + label_info=NullLabelInfo(), # TODO (sungchul): update label_info for multi-label support input_size=input_size, optimizer=optimizer, scheduler=scheduler, @@ -579,7 +567,7 @@ def test_step( Raises: TypeError: If the predictions are not of type ZeroShotVisualPromptingBatchDataEntity. """ - _inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs) + _inference_step(model=self, metric=self.metric, inputs=inputs) def _convert_pred_entity_to_compute_metric( self, @@ -596,7 +584,7 @@ def _set_label_info(self, _: LabelInfoTypes) -> None: def get_dummy_input(self, batch_size: int = 1) -> ZeroShotVisualPromptingBatchDataEntity: """Returns a dummy input for ZSL VPT model.""" images = [torch.rand(3, *self.input_size) for _ in range(batch_size)] - labels = [ZeroShotVisualPromptingLabel(prompts=torch.LongTensor([0]))] * batch_size + labels = [{"points": torch.LongTensor([0] * batch_size)}] * batch_size prompts = [torch.zeros((1, 2))] * batch_size infos = [] for i, img in enumerate(images): @@ -768,7 +756,7 @@ def _customize_outputs( labels: list[Tensor] = [] for image_output in outputs: masks.append(tv_tensors.Mask(np.concatenate(image_output.hard_predictions), device=self.device)) - scores.append(torch.as_tensor(np.concatenate(image_output.scores), device=self.device)) + scores.append(torch.as_tensor(np.concatenate(image_output.scores)[:, 0], device=self.device)) labels.append(torch.as_tensor(image_output.labels, device=self.device)) return VisualPromptingBatchPredEntity( @@ -934,7 +922,7 @@ def _convert_pred_entity_to_compute_metric( def _create_label_info_from_ov_ir(self) -> LabelInfo: """Create NullLabelInfo since Visual Prompting tasks has no use of label information.""" - return NullLabelInfo() + return NullLabelInfo() # TODO (sungchul): update label_info for multi-label support def _set_label_info(self, _: LabelInfoTypes) -> None: msg = f"Reconfiguring label_info has no effect on {self.__class__.__name__}." @@ -1125,10 +1113,9 @@ def _customize_inputs( # type: ignore[override] images: list[np.ndarray] = [] processed_prompts: list[dict[str, Any]] = [] - for image, prompts, polygons, labels in zip( + for image, prompts, labels in zip( entity.images, entity.prompts, - entity.polygons, entity.labels, ): # preprocess image encoder inputs @@ -1136,25 +1123,28 @@ def _customize_inputs( # type: ignore[override] images.append(numpy_image) if self.training: - _bboxes: list[Prompt] = [] + _boxes: list[Prompt] = [] _points: list[Prompt] = [] _polygons: list[Prompt] = [] - for prompt, label in zip(prompts, labels.prompts): # type: ignore[arg-type] + for prompt, label in zip(prompts, labels): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): - _bboxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) + _boxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) elif isinstance(prompt, Points): _points.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) + elif isinstance(prompt, dmPolygon): + _polygons.extend( + [ + Prompt(np.array(polygon.points, dtype=np.int32), label.cpu().numpy()) + for polygon in prompt + ], + ) - if polygons and labels.polygons is not None: - for polygon, label in zip(polygons, labels.polygons): - _polygons.append(Prompt(np.array(polygon.points, dtype=np.int32), label.cpu().numpy())) - - # TODO (sungchul, sovrasov): support mask? + # TODO (sungchul, sovrasov): support mask? # preprocess decoder inputs processed_prompts.append( { - "boxes": _bboxes, + "boxes": _boxes, "points": _points, "polygons": _polygons, }, @@ -1269,7 +1259,7 @@ def transform_fn( _labels: dict[str, list[int]] = defaultdict(list) # use only the first prompt - for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0].prompts): # type: ignore[arg-type] + for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0]): if isinstance(prompt, tv_tensors.BoundingBoxes): bboxes.append(prompt.cpu().numpy()) _labels["bboxes"].append(label.cpu().numpy()) @@ -1521,7 +1511,7 @@ def test_step( Raises: TypeError: If the predictions are not of type ZeroShotVisualPromptingBatchPredEntity. """ - _inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs) + _inference_step(model=self, metric=self.metric, inputs=inputs) def _convert_pred_entity_to_compute_metric( self, @@ -1533,7 +1523,7 @@ def _convert_pred_entity_to_compute_metric( def _create_label_info_from_ov_ir(self) -> LabelInfo: """Create NullLabelInfo since Visual Prompting tasks has no use of label information.""" - return NullLabelInfo() + return NullLabelInfo() # TODO (sungchul): update label_info for multi-label support def _set_label_info(self, _: LabelInfoTypes) -> None: msg = f"Reconfiguring label_info has no effect on {self.__class__.__name__}." @@ -1543,7 +1533,7 @@ def get_dummy_input(self, batch_size: int = 1) -> ZeroShotVisualPromptingBatchDa """Returns a dummy input for classification OV model.""" # Resize is embedded to the OV model, which means we don't need to know the actual size images = [torch.rand(3, 224, 224) for _ in range(batch_size)] - labels = [ZeroShotVisualPromptingLabel(prompts=torch.LongTensor([0]))] * batch_size + labels = [torch.LongTensor([0] * batch_size)] * batch_size prompts = [torch.zeros((1, 2))] * batch_size infos = [] for i, img in enumerate(images): diff --git a/tests/unit/algo/visual_prompting/test_sam.py b/tests/unit/algo/visual_prompting/test_sam.py index 94f4f1fc3d9..33bf8f6df39 100644 --- a/tests/unit/algo/visual_prompting/test_sam.py +++ b/tests/unit/algo/visual_prompting/test_sam.py @@ -19,11 +19,8 @@ class TestCommonSettingMixin: - def test_load_checkpoint_success(self, mocker) -> None: - # Mock torch.hub.load_state_dict_from_url + def test_load_state_dict_success(self, mocker) -> None: mock_load_state_dict_from_url = mocker.patch("torch.hub.load_state_dict_from_url") - - # Mock state dictionary returned by load_state_dict_from_url mock_state_dict = { "image_encoder.norm_head.weight": torch.tensor([1.0]), "image_encoder.norm_head.bias": torch.tensor([1.0]), @@ -33,18 +30,39 @@ def test_load_checkpoint_success(self, mocker) -> None: } mock_load_state_dict_from_url.return_value = mock_state_dict - # Create an instance of CommonSettingMixin and set the mock model - mixin = CommonSettingMixin() - mixin.load_state_dict = mock.Mock() + # Mock only nn.Module's load_state_dict + mock_module_load_state_dict = mocker.patch.object(nn.Module, "load_state_dict") + + # Create a test class that inherits from nn.Module and CommonSettingMixin + class TestMixin(CommonSettingMixin, nn.Module): + def __init__(self): + super().__init__() + self.some_param = nn.Parameter(torch.randn(1)) - # Call the load_checkpoint method - mixin.load_checkpoint("https://example.com/checkpoint.pth") + # Create an instance of the test class + test_mixin = TestMixin() - # Assertions + # Call load_state_dict (this will use CommonSettingMixin's implementation) + test_mixin.load_state_dict(state_dict=None, load_from="https://example.com/checkpoint.pth") + + # Verify that load_state_dict_from_url was called mock_load_state_dict_from_url.assert_called_once_with("https://example.com/checkpoint.pth") - mixin.load_state_dict.assert_called_once_with(mock_state_dict) - def test_load_checkpoint_failure(self, mocker) -> None: + # Verify that nn.Module's load_state_dict was called with the expected arguments + expected_state_dict = { + k: v + for k, v in mock_state_dict.items() + if k + not in [ + "image_encoder.norm_head.weight", + "image_encoder.norm_head.bias", + "image_encoder.head.weight", + "image_encoder.head.bias", + ] + } + mock_module_load_state_dict.assert_called_once_with(expected_state_dict, True, False) + + def test_load_state_dict_failure(self, mocker) -> None: mock_load_state_dict_from_url = mocker.patch( "torch.hub.load_state_dict_from_url", side_effect=ValueError("Invalid URL"), @@ -52,13 +70,10 @@ def test_load_checkpoint_failure(self, mocker) -> None: mock_log_info = mocker.patch("logging.info") mixin = CommonSettingMixin() - mixin.load_checkpoint("invalid_url") + mixin.load_state_dict(load_from="invalid_url") mock_load_state_dict_from_url.assert_called_once_with("invalid_url") - mock_log_info.assert_called_once_with( - "Invalid URL: invalid_url is not desirable format for torch.hub.load_state_dict_from_url. " - "To manually load invalid_url, try to set it to trainer.checkpoint.", - ) + mock_log_info.assert_called_once() @pytest.mark.parametrize("freeze_image_encoder", [True, False]) @pytest.mark.parametrize("freeze_prompt_encoder", [True, False]) @@ -132,7 +147,7 @@ def sam(self) -> SAM: def test_initialization(self, mocker) -> None: mock_freeze_networks = mocker.patch.object(CommonSettingMixin, "freeze_networks") - mock_load_checkpoint = mocker.patch.object(CommonSettingMixin, "load_checkpoint") + mock_load_state_dict = mocker.patch.object(CommonSettingMixin, "load_state_dict") sam = SAM(backbone_type="tiny_vit") @@ -144,7 +159,7 @@ def test_initialization(self, mocker) -> None: assert sam.return_extra_metrics is False assert sam.stability_score_offset == 1.0 - mock_load_checkpoint.assert_called_once_with(load_from=sam.load_from["tiny_vit"]) + mock_load_state_dict.assert_called_once_with(load_from=sam.load_from["tiny_vit"]) mock_freeze_networks.assert_called_once_with(True, True, False) def test_build_model(self, sam: SAM) -> None: @@ -251,9 +266,7 @@ def test_gather_prompts_with_labels(self, zero_shot_sam: ZeroShotSAM, fxt_zero_s results = zero_shot_sam._gather_prompts_with_labels(entity) assert torch.all(results[0][1][0] == entity.prompts[0][0]) - assert torch.all(results[0][1][1] == entity.masks[0]) assert torch.all(results[0][2][0] == entity.prompts[0][1]) - assert results[0][2][1] == entity.polygons[0][0] @pytest.mark.parametrize( ("image", "expected"), diff --git a/tests/unit/algo/visual_prompting/visual_prompters/test_segment_anything.py b/tests/unit/algo/visual_prompting/visual_prompters/test_segment_anything.py index c585ae1718b..e423de28357 100644 --- a/tests/unit/algo/visual_prompting/visual_prompters/test_segment_anything.py +++ b/tests/unit/algo/visual_prompting/visual_prompters/test_segment_anything.py @@ -107,7 +107,7 @@ def test_forward( assert results[0][0][0].shape == torch.Size(ori_shapes[0]) # check ious - assert results[1][0].ndim == 2 + assert results[1][0].ndim == 1 @pytest.mark.parametrize( "ori_shape", @@ -406,7 +406,7 @@ def test_infer(self, mocker, zero_shot_segment_anything: ZeroShotSegmentAnything mocker.patch.object( zero_shot_segment_anything.prompt_getter, "get_prompt_candidates", - return_value=({0: torch.tensor([[0, 0, 0.5], [1000, 1000, 0.7]])}, {0: torch.tensor([[500, 500]])}), + return_value=({0: torch.tensor([[0, 0, 1.0], [1000, 1000, 1.0]])}, {0: torch.tensor([[500, 500]])}), ) def _patch_predict_masks(**kwargs) -> Tensor: diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index 0dacf5fd15a..617495daa13 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -18,7 +18,6 @@ ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, ZeroShotVisualPromptingDataEntity, - ZeroShotVisualPromptingLabel, ) from torchvision import tv_tensors @@ -94,10 +93,10 @@ def fxt_vpm_data_entity() -> ( dtype=torch.float32, ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) + fake_masks = tv_tensors.Mask(torch.ones(2, *img_size)) fake_labels = {"bboxes": torch.as_tensor([1], dtype=torch.int64), "points": torch.as_tensor([1])} fake_polygons = [None] - fake_scores = torch.tensor([[1.0]]) + fake_scores = torch.tensor([1.0]) # define data entity single_data_entity = VisualPromptingDataEntity( image=fake_image, @@ -123,7 +122,7 @@ def fxt_vpm_data_entity() -> ( images=[fake_image], imgs_info=[fake_image_info], masks=[fake_masks], - labels=[fake_labels], + labels=[torch.cat(list(fake_labels.values()))], polygons=[fake_polygons], bboxes=[fake_bboxes], points=[fake_points], @@ -151,14 +150,10 @@ def fxt_zero_shot_vpm_data_entity() -> ( dtype=torch.float32, ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) - fake_labels = ZeroShotVisualPromptingLabel( - prompts=torch.as_tensor([1, 2], dtype=torch.int64), - masks=torch.as_tensor([1], dtype=torch.int64), - polygons=torch.as_tensor([2], dtype=torch.int64), - ) + fake_masks = tv_tensors.Mask(torch.ones(2, *img_size)) + fake_labels = torch.as_tensor([1, 2], dtype=torch.int64) fake_polygons = [Polygon(points=[1, 1, 1, 2, 2, 2, 2, 1])] - fake_scores = torch.tensor([[1.0]]) + fake_scores = torch.tensor([1.0]) # define data entity single_data_entity = ZeroShotVisualPromptingDataEntity( image=fake_image, @@ -182,7 +177,7 @@ def fxt_zero_shot_vpm_data_entity() -> ( images=[fake_image], imgs_info=[fake_image_info], masks=[fake_masks], - labels=[fake_labels.prompts], + labels=[fake_labels], polygons=[fake_polygons], prompts=[[fake_bboxes, fake_points]], scores=[fake_scores], diff --git a/tests/unit/core/data/dataset/test_visual_prompting.py b/tests/unit/core/data/dataset/test_visual_prompting.py index a651cf1589b..38294e49496 100644 --- a/tests/unit/core/data/dataset/test_visual_prompting.py +++ b/tests/unit/core/data/dataset/test_visual_prompting.py @@ -9,7 +9,7 @@ from datumaro import Dataset as DmDataset from otx.core.data.dataset.visual_prompting import OTXVisualPromptingDataset, OTXZeroShotVisualPromptingDataset from otx.core.data.entity.base import ImageInfo, Points -from otx.core.data.entity.visual_prompting import ZeroShotVisualPromptingLabel +from torch import Tensor from torchvision.transforms.v2 import Identity, Transform from torchvision.tv_tensors import BoundingBoxes, Image, Mask @@ -103,7 +103,7 @@ def test_get_item_impl_subset( assert hasattr(entity, "masks") assert isinstance(entity.masks, Mask) assert hasattr(entity, "labels") - assert isinstance(entity.labels, ZeroShotVisualPromptingLabel) + assert isinstance(entity.labels, Tensor) assert hasattr(entity, "polygons") assert isinstance(entity.polygons, list) assert hasattr(entity, "prompts") diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index 7243cd2a3f4..49995d7aec7 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -17,7 +17,6 @@ from otx.core.data.entity.base import Points from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchPredEntity, - ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, ) from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter @@ -27,7 +26,6 @@ OVVisualPromptingModel, OVZeroShotVisualPromptingModel, _inference_step, - _inference_step_for_zero_shot, ) from otx.core.types.export import TaskLevelExportParameters from torchvision import tv_tensors @@ -60,7 +58,7 @@ def test_inference_step(mocker, otx_visual_prompting_model, fxt_vpm_data_entity) _inference_step(otx_visual_prompting_model, otx_visual_prompting_model.metric, fxt_vpm_data_entity[1]) for v in mocker_updates.values(): - v.assert_called_once() + v.assert_called() def test_inference_step_for_zero_shot(mocker, otx_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: @@ -73,70 +71,10 @@ def test_inference_step_for_zero_shot(mocker, otx_visual_prompting_model, fxt_ze for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") - _inference_step_for_zero_shot(otx_visual_prompting_model, otx_visual_prompting_model.metric, entity) + _inference_step(otx_visual_prompting_model, otx_visual_prompting_model.metric, entity) for v in mocker_updates.values(): - v.assert_called_once() - - -def test_inference_step_for_zero_shot_with_more_preds( - mocker, - otx_visual_prompting_model, - fxt_zero_shot_vpm_data_entity, -) -> None: - """Test _inference_step_for_zero_shot with more preds.""" - otx_visual_prompting_model.configure_metric() - entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) - pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) - preds = {} - for k, v in pred_entity.__dict__.items(): - if k in ["batch_size", "polygons"]: - preds[k] = v - else: - preds[k] = v * 2 - mocker.patch.object( - otx_visual_prompting_model, - "forward", - return_value=ZeroShotVisualPromptingBatchPredEntity(**preds), - ) - mocker_updates = {} - for k, v in otx_visual_prompting_model.metric.items(): - mocker_updates[k] = mocker.patch.object(v, "update") - - _inference_step_for_zero_shot(otx_visual_prompting_model, otx_visual_prompting_model.metric, entity) - - for v in mocker_updates.values(): - v.assert_called_once() - - -def test_inference_step_for_zero_shot_with_more_target( - mocker, - otx_visual_prompting_model, - fxt_zero_shot_vpm_data_entity, -) -> None: - """Test _inference_step_for_zero_shot with more target.""" - otx_visual_prompting_model.configure_metric() - entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) - pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) - mocker.patch.object(otx_visual_prompting_model, "forward", return_value=pred_entity) - mocker_updates = {} - for k, v in otx_visual_prompting_model.metric.items(): - mocker_updates[k] = mocker.patch.object(v, "update") - target = {} - for k, v in entity.__dict__.items(): - if k in ["batch_size"]: - target[k] = v - else: - target[k] = v * 2 - - _inference_step_for_zero_shot( - otx_visual_prompting_model, - otx_visual_prompting_model.metric, - ZeroShotVisualPromptingBatchDataEntity(**target), - ) - - for v in mocker_updates.values(): - v.assert_called_once() + v.assert_called() class TestOTXVisualPromptingModel: @@ -480,9 +418,9 @@ def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ reset_feat=False, ) - assert reference_info["reference_feats"].shape == torch.Size((3, 1, 256)) + assert reference_info["reference_feats"].shape == torch.Size((2, 1, 256)) assert 1 in reference_info["used_indices"] - assert ref_masks[0].shape == torch.Size((3, 1024, 1024)) + assert ref_masks[0].shape == torch.Size((2, 1024, 1024)) def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test infer."""