Skip to content

Commit

Permalink
Refactoring visual prompting (#3789)
Browse files Browse the repository at this point in the history
* Divide `OTXSegmentAnything` into `SAMTinyViT` and `SAMViTBase` to specify backbones

* Move `load_checkpoint` and `freeze_networks` to `OTXModel`

* Move export related functions to `OTXModel`

* Create loss module

* Fix an unnecessary condition statement

Co-authored-by: Eunwoo Shin <[email protected]>

* Revert to use a super model instead of sub models

* Keep essential parameters only

* Reduce redundant parameters

* Rename `backbone` to `backbone_type`

* Set `load_from` dict as class var

* Divide nn.Module from OTXModel

* Move zsl to sam

* precommit

* Fix unit test

* Fix missed part

* Update pytorchcv version

---------

Co-authored-by: Eunwoo Shin <[email protected]>
  • Loading branch information
sungchul2 and eunwoosh authored Aug 16, 2024
1 parent 06ceb10 commit d77423e
Show file tree
Hide file tree
Showing 41 changed files with 2,406 additions and 2,229 deletions.
14 changes: 2 additions & 12 deletions src/otx/algo/visual_prompting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
#
"""Module for OTX visual prompting models."""

from . import backbones, decoders, encoders
from .segment_anything import OTXSegmentAnything, SegmentAnything
from .zero_shot_segment_anything import OTXZeroShotSegmentAnything, ZeroShotSegmentAnything
from .sam import SAM, ZeroShotSAM

__all__ = [
"backbones",
"encoders",
"decoders",
"OTXSegmentAnything",
"SegmentAnything",
"OTXZeroShotSegmentAnything",
"ZeroShotSegmentAnything",
]
__all__ = ["SAM", "ZeroShotSAM"]
2 changes: 1 addition & 1 deletion src/otx/algo/visual_prompting/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Backbone modules for OTX visual prompting model."""
Expand Down
6 changes: 3 additions & 3 deletions src/otx/algo/visual_prompting/backbones/tiny_vit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""TinyViT model for the OTX visual prompting."""
Expand Down Expand Up @@ -489,15 +489,15 @@ class TinyViT(nn.Module):

def __init__(
self,
img_size: int = 224,
img_size: int = 1024,
in_chans: int = 3,
embed_dims: list[int] | None = None,
depths: list[int] | None = None,
num_heads: list[int] | None = None,
window_sizes: list[int] | None = None,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
drop_path_rate: float = 0.0,
mbconv_expand_ratio: float = 4.0,
local_conv_size: int = 3,
layer_lr_decay: float = 1.0,
Expand Down
17 changes: 10 additions & 7 deletions src/otx/algo/visual_prompting/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from __future__ import annotations

from functools import partial
from typing import Callable

import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
Expand All @@ -26,13 +29,13 @@ class ViT(nn.Module):
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
out_chans (int): Number of output channels.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
norm_layer (Callable[..., nn.Module]): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
global_attn_indexes (tuple[int, ...]): Indexes for blocks using global attention.
"""

def __init__(
Expand All @@ -46,12 +49,12 @@ def __init__(
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: nn.Module = nn.LayerNorm,
norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6),
act_layer: nn.Module = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
use_rel_pos: bool = True,
rel_pos_zero_init: bool = True,
window_size: int = 0,
window_size: int = 14,
global_attn_indexes: tuple[int, ...] = (),
) -> None:
super().__init__()
Expand Down Expand Up @@ -130,7 +133,7 @@ class Block(nn.Module):
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
norm_layer (Callable[..., nn.Module] | nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
Expand All @@ -146,7 +149,7 @@ def __init__(
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: type[nn.Module] = nn.LayerNorm,
norm_layer: Callable[..., nn.Module] | type[nn.Module] = nn.LayerNorm,
act_layer: type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/visual_prompting/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Decoder modules for OTX visual prompting model."""
Expand Down
15 changes: 12 additions & 3 deletions src/otx/algo/visual_prompting/decoders/sam_mask_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""SAM mask decoder model for the OTX visual prompting."""
Expand All @@ -21,24 +21,33 @@ class SAMMaskDecoder(nn.Module):
Args:
transformer_dim (int): Channel dimension of the transformer.
Defaults to 256.
transformer_cfg (dict): Configuration of the transformer.
Defaults to `{"depth": 2, "embedding_dim": 256, "mlp_dim": 2048, "num_heads": 8}`.
num_multimask_outputs (int): The number of masks to predict when disambiguating masks.
Defaults to 3.
activation (nn.Module): Type of activation to use when upscaling masks.
Defaults to ``nn.GELU``.
iou_head_depth (int): Depth of the MLP used to predict mask quality.
Defaults to 3.
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
Defaults to 256.
"""

def __init__(
self,
*,
transformer_dim: int,
transformer_cfg: dict,
transformer_dim: int = 256,
transformer_cfg: dict | None = None,
num_multimask_outputs: int = 3,
activation: type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
super().__init__()
if transformer_cfg is None:
transformer_cfg = {"depth": 2, "embedding_dim": 256, "mlp_dim": 2048, "num_heads": 8}

self.transformer_dim = transformer_dim
self.transformer = TwoWayTransformer(**transformer_cfg)

Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/visual_prompting/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Encoder modules for OTX visual prompting model."""
Expand Down
41 changes: 7 additions & 34 deletions src/otx/algo/visual_prompting/encoders/sam_image_encoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""SAM image encoder model for the OTX visual prompting."""

from functools import partial
from typing import ClassVar

from torch import nn
Expand All @@ -14,68 +13,42 @@ class SAMImageEncoder(nn.Module):

backbone_configs: ClassVar[dict] = {
"tiny_vit": {
"img_size": 1024,
"embed_dims": [64, 128, 160, 320],
"depths": [2, 2, 6, 2],
"num_heads": [2, 4, 5, 10],
"window_sizes": [7, 7, 14, 7],
"drop_path_rate": 0.0,
},
"vit_b": {
"img_size": 1024,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"out_chans": 256,
"patch_size": 16,
"mlp_ratio": 4,
"qkv_bias": True,
"use_rel_pos": True,
"window_size": 14,
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"global_attn_indexes": [2, 5, 8, 11],
},
"vit_l": {
"img_size": 1024,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"out_chans": 256,
"patch_size": 16,
"mlp_ratio": 4,
"qkv_bias": True,
"use_rel_pos": True,
"window_size": 14,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"global_attn_indexes": [5, 11, 17, 23],
},
"vit_h": {
"img_size": 1024,
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"out_chans": 256,
"patch_size": 16,
"mlp_ratio": 4,
"qkv_bias": True,
"use_rel_pos": True,
"window_size": 14,
"embed_dim": 1280,
"depth": 32,
"num_heads": 16,
"global_attn_indexes": [7, 15, 23, 31],
},
}

def __new__(cls, backbone: str, *args, **kwargs): # noqa: ARG003
def __new__(cls, backbone_type: str, *args, **kwargs): # noqa: ARG003
"""Initialize image encoder to target backbone."""
if backbone.lower() == "tiny_vit":
if backbone_type.lower() == "tiny_vit":
from otx.algo.visual_prompting.backbones.tiny_vit import TinyViT

return TinyViT(**{**cls.backbone_configs.get(backbone.lower()), **kwargs}) # type: ignore[dict-item]
elif backbone.lower() in ["vit_b", "vit_l", "vit_h"]: # noqa: RET505
return TinyViT(**{**cls.backbone_configs.get(backbone_type.lower()), **kwargs}) # type: ignore[dict-item]
elif backbone_type.lower() in ["vit_b", "vit_l", "vit_h"]: # noqa: RET505
from otx.algo.visual_prompting.backbones.vit import ViT

return ViT(**{**cls.backbone_configs.get(backbone.lower()), **kwargs}) # type: ignore[dict-item]
return ViT(**{**cls.backbone_configs.get(backbone_type.lower()), **kwargs}) # type: ignore[dict-item]

else:
error_log = f"{backbone} is not supported for SAMImageEncoder. Set among tiny_vit and vit_b."
error_log = f"{backbone_type} is not supported for SAMImageEncoder. Set among tiny_vit and vit_b."
raise ValueError(error_log)
10 changes: 6 additions & 4 deletions src/otx/algo/visual_prompting/encoders/sam_prompt_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""SAM prompt encoder model for the OTX visual prompting."""
Expand All @@ -20,19 +20,21 @@ class SAMPromptEncoder(nn.Module):
Reference: https://github.com/facebookresearch/segment-anything
Args:
embed_dim (int): The prompts' embedding dimension.
image_embedding_size (tuple(int, int)): The spatial size of the image embedding, as (H, W).
input_image_size (int): The padded size of the image as input to the image encoder, as (H, W).
embed_dim (int): The prompts' embedding dimension.
Defaults to 256.
mask_in_chans (int): The number of hidden channels used for encoding input masks.
Defaults to 16.
activation (nn.Module): The activation to use when encoding input masks.
"""

def __init__(
self,
embed_dim: int,
image_embedding_size: tuple[int, int],
input_image_size: tuple[int, int],
mask_in_chans: int,
embed_dim: int = 256,
mask_in_chans: int = 16,
activation: type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
Expand Down
8 changes: 8 additions & 0 deletions src/otx/algo/visual_prompting/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Loss modules for OTX visual prompting model."""

from .sam_loss import SAMCriterion

__all__ = ["SAMCriterion"]
Loading

0 comments on commit d77423e

Please sign in to comment.