Skip to content

Commit

Permalink
Update visual prompting refactoring to develop (#3193)
Browse files Browse the repository at this point in the history
* Update

* Enable using `print_config`

* Add visual prompting tutorial

* Update unit tests

* Update

* precommit

* Updates for unit test

* Updates for integration tests

* Fix ruff errors

* Update docs

* Fix
  • Loading branch information
sungchul2 authored Mar 26, 2024
1 parent ec23ba8 commit 54422a7
Show file tree
Hide file tree
Showing 8 changed files with 508 additions and 108 deletions.
22 changes: 19 additions & 3 deletions src/otx/algo/visual_prompting/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,25 @@ def __init__(
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
metric: MetricCallable = VisualPromptingMetricCallable,
torch_compile: bool = False,
**kwargs,
):
self.config = {"backbone": backbone, **DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone], **kwargs}
freeze_image_encoder: bool = True,
freeze_prompt_encoder: bool = True,
freeze_mask_decoder: bool = False,
use_stability_score: bool = False,
return_single_mask: bool = True,
return_extra_metrics: bool = False,
stability_score_offset: float = 1.0,
) -> None:
self.config = {
"backbone": backbone,
"freeze_image_encoder": freeze_image_encoder,
"freeze_prompt_encoder": freeze_prompt_encoder,
"freeze_mask_decoder": freeze_mask_decoder,
"use_stability_score": use_stability_score,
"return_single_mask": return_single_mask,
"return_extra_metrics": return_extra_metrics,
"stability_score_offset": stability_score_offset,
**DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone],
}
super().__init__(
num_classes=num_classes,
optimizer=optimizer,
Expand Down
190 changes: 107 additions & 83 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
from torch import LongTensor, Tensor, nn
from torch.nn import functional as F # noqa: N812
from torchvision import tv_tensors
from torchvision.tv_tensors import BoundingBoxes, Image
from torchvision.tv_tensors import BoundingBoxes, Image, Mask, TVTensor

from otx.algo.visual_prompting.segment_anything import (
DEFAULT_CONFIG_SEGMENT_ANYTHING,
SegmentAnything,
)
from otx.algo.visual_prompting.segment_anything import DEFAULT_CONFIG_SEGMENT_ANYTHING, SegmentAnything
from otx.core.data.entity.base import OTXBatchLossEntity, Points
from otx.core.data.entity.visual_prompting import (
ZeroShotVisualPromptingBatchDataEntity,
Expand Down Expand Up @@ -230,8 +227,8 @@ def expand_reference_info(self, reference_feats: Tensor, new_largest_label: int)
@torch.no_grad()
def learn(
self,
images: list[tv_tensors.Image],
processed_prompts: list[dict[int, list[tv_tensors.TVTensor]]],
images: list[Image],
processed_prompts: list[dict[int, list[TVTensor]]],
reference_feats: Tensor,
used_indices: Tensor,
ori_shapes: list[Tensor],
Expand All @@ -244,8 +241,8 @@ def learn(
Currently, single batch is only supported.
Args:
images (list[tv_tensors.Image]): List of given images for reference features.
processed_prompts (dict[int, list[tv_tensors.TVTensor]]): The class-wise prompts
images (list[Image]): List of given images for reference features.
processed_prompts (dict[int, list[TVTensor]]): The class-wise prompts
processed at OTXZeroShotSegmentAnything._gather_prompts_with_labels.
reference_feats (Tensor): Reference features for target prediction.
used_indices (Tensor): To check which indices of reference features are validate.
Expand All @@ -269,7 +266,7 @@ def learn(
# TODO (sungchul): ensemble multi reference features (current : use merged masks)
ref_mask = torch.zeros(*map(int, ori_shape), dtype=torch.uint8, device=image.device)
for input_prompt in input_prompts:
if isinstance(input_prompt, tv_tensors.Mask):
if isinstance(input_prompt, Mask):
# directly use annotation information as a mask
ref_mask[input_prompt == 1] += 1 # TODO (sungchul): check if the mask is bool or int
else:
Expand Down Expand Up @@ -321,7 +318,7 @@ def learn(
@torch.no_grad()
def infer(
self,
images: list[tv_tensors.Image],
images: list[Image],
reference_feats: Tensor,
used_indices: Tensor,
ori_shapes: list[Tensor],
Expand All @@ -334,7 +331,7 @@ def infer(
Get target results by using reference features and target images' features.
Args:
images (list[tv_tensors.Image]): Given images for target results.
images (list[Image]): Given images for target results.
reference_feats (Tensor): Reference features for target prediction.
used_indices (Tensor): To check which indices of reference features are validate.
ori_shapes (list[Tensor]): Original image size.
Expand Down Expand Up @@ -455,66 +452,73 @@ def _predict_masks(
masks: Tensor
logits: Tensor
scores: Tensor
num_iter = 3 if is_cascade else 1
for i in range(num_iter):
if i == 0:
# First-step prediction
mask_input = torch.zeros(
1,
1,
*(x * 4 for x in image_embeddings.shape[2:]),
device=image_embeddings.device,
)
has_mask_input = self.has_mask_inputs[0].to(mask_input.device)

elif i == 1:
# Cascaded Post-refinement-1
# TODO (sungchul2): Fix the following ruff errors, ticket no. 135852
# src/otx/algo/visual_prompting/zero_shot_segment_anything.py:473:21: F821 Undefined name `masks`
# src/otx/algo/visual_prompting/zero_shot_segment_anything.py:474:21: F821 Undefined name `logits`
# src/otx/algo/visual_prompting/zero_shot_segment_anything.py:475:21: F821 Undefined name `scores`
mask_input, best_masks = self._decide_cascade_results(
masks, # noqa: F821
logits, # noqa: F821
scores, # noqa: F821
is_single=True,
)
if best_masks.sum() == 0:
return best_masks

has_mask_input = self.has_mask_inputs[1].to(mask_input.device)

elif i == 2:
# Cascaded Post-refinement-2
# TODO (sungchul2): Fix the following ruff errors, ticket no. 135852
# src/otx/algo/visual_prompting/zero_shot_segment_anything.py:475:21: F821 Undefined name `masks`
# src/otx/algo/visual_prompting/zero_shot_segment_anything.py:476:21: F821 Undefined name `logits`
# src/otx/algo/visual_prompting/zero_shot_segment_anything.py:477:21: F821 Undefined name `scores`
mask_input, best_masks = self._decide_cascade_results(masks, logits, scores) # noqa: F821
if best_masks.sum() == 0:
return best_masks

has_mask_input = self.has_mask_inputs[1].to(mask_input.device)
coords = torch.nonzero(best_masks)
y, x = coords[:, 0], coords[:, 1]
box_coords = self._preprocess_coords(
torch.tensor([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=torch.float32, device=coords.device),
ori_shape,
self.image_size,

# First-step prediction
mask_input = torch.zeros(
1,
1,
*(x * 4 for x in image_embeddings.shape[2:]),
device=image_embeddings.device,
)
has_mask_input = self.has_mask_inputs[0].to(mask_input.device)
high_res_masks, scores, logits = self(
mode=mode,
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)
masks = high_res_masks > self.mask_threshold

if is_cascade:
for i in range(2):
if i == 0:
# Cascaded Post-refinement-1
mask_input, best_masks = self._decide_cascade_results(
masks,
logits,
scores,
is_single=True,
)
if best_masks.sum() == 0:
return best_masks

has_mask_input = self.has_mask_inputs[1].to(mask_input.device)

else:
# Cascaded Post-refinement-2
mask_input, best_masks = self._decide_cascade_results(masks, logits, scores)
if best_masks.sum() == 0:
return best_masks

has_mask_input = self.has_mask_inputs[1].to(mask_input.device)
coords = torch.nonzero(best_masks)
y, x = coords[:, 0], coords[:, 1]
box_coords = self._preprocess_coords(
torch.tensor(
[[[x.min(), y.min()], [x.max(), y.max()]]],
dtype=torch.float32,
device=coords.device,
),
ori_shape,
self.image_size,
)
point_coords = torch.cat((point_coords, box_coords), dim=1)
point_labels = torch.cat((point_labels, self.point_labels_box.to(point_labels.device)), dim=1)

high_res_masks, scores, logits = self(
mode=mode,
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)
point_coords = torch.cat((point_coords, box_coords), dim=1)
point_labels = torch.cat((point_labels, self.point_labels_box.to(point_labels.device)), dim=1)
masks = high_res_masks > self.mask_threshold

high_res_masks, scores, logits = self(
mode=mode,
image_embeddings=image_embeddings,
point_coords=point_coords,
point_labels=point_labels,
mask_input=mask_input,
has_mask_input=has_mask_input,
ori_shape=ori_shape,
)
masks = high_res_masks > self.mask_threshold
_, best_masks = self._decide_cascade_results(masks, logits, scores)
return best_masks

Expand Down Expand Up @@ -623,17 +627,37 @@ def __init__(
self,
backbone: Literal["tiny_vit", "vit_b"],
num_classes: int = 0,
root_reference_info: Path | str = "vpm_zsl_reference_infos",
save_outputs: bool = True,
pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006
pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
metric: MetricCallable = VisualPromptingMetricCallable,
torch_compile: bool = False,
**kwargs,
):
self.config = {"backbone": backbone, **DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone], **kwargs}
root_reference_info: Path | str = "vpm_zsl_reference_infos",
save_outputs: bool = True,
pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006
pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006
freeze_image_encoder: bool = True,
freeze_prompt_encoder: bool = True,
freeze_mask_decoder: bool = True,
default_threshold_reference: float = 0.3,
default_threshold_target: float = 0.65,
use_stability_score: bool = False,
return_single_mask: bool = False,
return_extra_metrics: bool = False,
stability_score_offset: float = 1.0,
) -> None:
self.config = {
"backbone": backbone,
"freeze_image_encoder": freeze_image_encoder,
"freeze_prompt_encoder": freeze_prompt_encoder,
"freeze_mask_decoder": freeze_mask_decoder,
"default_threshold_reference": default_threshold_reference,
"default_threshold_target": default_threshold_target,
"use_stability_score": use_stability_score,
"return_single_mask": return_single_mask,
"return_extra_metrics": return_extra_metrics,
"stability_score_offset": stability_score_offset,
**DEFAULT_CONFIG_SEGMENT_ANYTHING[backbone],
}
super().__init__(
num_classes=num_classes,
optimizer=optimizer,
Expand Down Expand Up @@ -729,15 +753,15 @@ def _customize_outputs( # type: ignore[override]
self.used_indices = outputs[0].get("used_indices")
return outputs

masks: list[tv_tensors.Mask] = []
masks: list[Mask] = []
prompts: list[Points] = []
scores: list[Tensor] = []
labels: list[LongTensor] = []
for predicted_masks, used_points in outputs:
for label, predicted_mask in predicted_masks.items():
if len(predicted_mask) == 0:
continue
masks.append(tv_tensors.Mask(torch.stack(predicted_mask, dim=0), dtype=torch.float32))
masks.append(Mask(torch.stack(predicted_mask, dim=0), dtype=torch.float32))
prompts.append(
Points(
torch.stack([p[:2] for p in used_points[label]], dim=0),
Expand All @@ -761,11 +785,11 @@ def _customize_outputs( # type: ignore[override]

def _gather_prompts_with_labels(
self,
prompts: list[list[tv_tensors.TVTensor]],
prompts: list[list[TVTensor]],
labels: list[Tensor],
) -> list[dict[int, list[tv_tensors.TVTensor]]]:
) -> list[dict[int, list[TVTensor]]]:
"""Gather prompts according to labels."""
total_processed_prompts: list[dict[int, list[tv_tensors.TVTensor]]] = []
total_processed_prompts: list[dict[int, list[TVTensor]]] = []
for prompt, label in zip(prompts, labels):
processed_prompts = defaultdict(list)
for _prompt, _label in zip(prompt, label): # type: ignore[arg-type]
Expand All @@ -774,7 +798,7 @@ def _gather_prompts_with_labels(
total_processed_prompts.append(sorted_processed_prompts)
return total_processed_prompts

def apply_image(self, image: tv_tensors.Image | np.ndarray, target_length: int = 1024) -> tv_tensors.Image:
def apply_image(self, image: Image | np.ndarray, target_length: int = 1024) -> Image:
"""Preprocess image to be used in the model."""
h, w = image.shape[-2:]
target_size = self.get_preprocess_shape(h, w, target_length)
Expand Down
20 changes: 11 additions & 9 deletions src/otx/core/model/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def _inference_step(
]
_target = converted_entities["target"]
_metric.update(preds=_preds, target=_target)
elif _name in ["IoU", "F1", "Dice"]:
elif _name in ["iou", "f1-score", "dice"]:
# BinaryJaccardIndex, BinaryF1Score, Dice
for cvt_preds, cvt_target in zip(converted_entities["preds"], converted_entities["target"]):
_metric.update(cvt_preds["masks"], cvt_target["masks"])


def _inference_step_for_zeroshot(
def _inference_step_for_zero_shot(
model: OTXZeroShotVisualPromptingModel | OVZeroShotVisualPromptingModel,
metric: MetricCollection,
inputs: ZeroShotVisualPromptingBatchDataEntity,
Expand Down Expand Up @@ -160,7 +160,7 @@ def _inference_step_for_zeroshot(
_preds.append(_preds[idx] if idx < len(_preds) else pad_prediction)

_metric.update(preds=_preds, target=_target)
elif _name in ["IoU", "F1", "Dice"]:
elif _name in ["iou", "f1-score", "dice"]:
# BinaryJaccardIndex, BinaryF1Score, Dice
for cvt_preds, cvt_target in zip(converted_entities["preds"], converted_entities["target"]):
_metric.update(
Expand Down Expand Up @@ -441,7 +441,7 @@ def test_step(
Raises:
TypeError: If the predictions are not of type VisualPromptingBatchPredEntity.
"""
_inference_step_for_zeroshot(model=self, metric=self.metric, inputs=inputs)
_inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs)

def _convert_pred_entity_to_compute_metric(
self,
Expand Down Expand Up @@ -789,6 +789,7 @@ def learn(
inputs: ZeroShotVisualPromptingBatchDataEntity,
reset_feat: bool = False,
default_threshold_reference: float = 0.3,
is_cascade: bool = False,
) -> tuple[dict[str, np.ndarray], list[np.ndarray]]:
"""`Learn` for reference features."""
if reset_feat or self.reference_feats is None:
Expand All @@ -815,7 +816,7 @@ def learn(
if "point_coords" in inputs_decoder:
# bboxes and points
inputs_decoder.update(image_embeddings)
prediction = self._predict_masks(inputs_decoder, original_shape, is_cascade=False)
prediction = self._predict_masks(inputs_decoder, original_shape, is_cascade=is_cascade)
masks = prediction["upscaled_masks"]
else:
log.warning("annotation and polygon will be supported.")
Expand Down Expand Up @@ -847,7 +848,7 @@ def infer(
inputs: ZeroShotVisualPromptingBatchDataEntity,
reference_feats: np.ndarray,
used_indices: np.ndarray,
is_cascade: bool = False,
is_cascade: bool = True,
threshold: float = 0.0,
num_bg_points: int = 1,
default_threshold_target: float = 0.65,
Expand Down Expand Up @@ -1087,9 +1088,10 @@ def _predict_masks(
has_mask_input = self.has_mask_inputs[1]
y, x = np.nonzero(masks)
box_coords = self.model["decoder"].apply_coords(
np.array([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=np.float32),
original_size[0],
np.array([[x.min(), y.min()], [x.max(), y.max()]], dtype=np.float32),
original_size,
)
box_coords = np.expand_dims(box_coords, axis=0)
inputs.update(
{
"point_coords": np.concatenate((inputs["point_coords"], box_coords), axis=1),
Expand Down Expand Up @@ -1419,7 +1421,7 @@ def test_step(
Raises:
TypeError: If the predictions are not of type VisualPromptingBatchPredEntity.
"""
_inference_step_for_zeroshot(model=self, metric=self.metric, inputs=inputs)
_inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs)

def _convert_pred_entity_to_compute_metric(
self,
Expand Down
Loading

0 comments on commit 54422a7

Please sign in to comment.