Skip to content

Commit

Permalink
Update Semantic Segmentation refactoring (reply comments from previou…
Browse files Browse the repository at this point in the history
…s PR) (#3863)

* reply comments

* align names. reply comments

* fix recipes

* fixed naming for OV

* fix integration tests|

* minor fix pre-commit
  • Loading branch information
kprokofi authored Aug 22, 2024
1 parent fe87ec2 commit 71a819e
Show file tree
Hide file tree
Showing 39 changed files with 133 additions and 339 deletions.
8 changes: 4 additions & 4 deletions src/otx/algo/segmentation/backbones/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DinoVisionTransformer(nn.Module):

def __init__(
self,
name: str,
model_name: str,
freeze_backbone: bool,
out_index: list[int],
pretrained_weights: str | None = None,
Expand All @@ -37,17 +37,17 @@ def __init__(
if ci_data_root is not None and Path(ci_data_root).exists():
pretrained = False

self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=name, pretrained=pretrained)
self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=model_name, pretrained=pretrained)

if ci_data_root is not None and Path(ci_data_root).exists():
ckpt_filename = f"{name}4_pretrain.pth"
ckpt_filename = f"{model_name}4_pretrain.pth"
ckpt_path = Path(ci_data_root) / "torch" / "hub" / "checkpoints" / ckpt_filename
if not ckpt_path.exists():
msg = (
f"Internal cache was specified but cannot find weights file: {ckpt_filename}. load from torch hub."
)
logger.warning(msg)
self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=name, pretrained=True)
self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=model_name, pretrained=True)
else:
self.backbone.load_state_dict(torch.load(ckpt_path))

Expand Down
238 changes: 14 additions & 224 deletions src/otx/algo/segmentation/backbones/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,10 @@ class Stem(nn.Module):
def __init__(
self,
in_channels: int,
stem_channels: int,
out_channels: int,
expand_ratio: int,
normalization: Callable[..., nn.Module] = nn.BatchNorm2d,
stem_channels: int = 32,
out_channels: int = 32,
expand_ratio: int = 1,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True),
with_cp: bool = False,
strides: tuple[int, int] = (2, 2),
extra_stride: bool = False,
Expand Down Expand Up @@ -666,190 +666,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return cp.checkpoint(self._inner_forward, x) if self.with_cp and x.requires_grad else self._inner_forward(x)


class StemV2(nn.Module):
"""StemV2.
Args:
in_channels (int): Number of input image channels. Typically 3.
stem_channels (int): Number of output channels of the stem layer.
out_channels (int): Number of output channels of the backbone network.
expand_ratio (int): Expansion ratio of the internal channels.
normalization (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``nn.BatchNorm2d``.
with_cp (bool): Use checkpointing to save memory during forward pass.
num_stages (int): Number of stages in the backbone network.
strides (tuple[int, int]): Strides of the first and subsequent stages.
extra_stride (bool): Use an extra stride in the second stage.
input_norm (bool): Use instance normalization on the input image.
Raises:
ValueError: If num_stages is less than 1.
TypeError: If strides is not a tuple or list.
ValueError: If len(strides) is not equal to num_stages + 1.
"""

def __init__(
self,
in_channels: int,
stem_channels: int,
out_channels: int,
expand_ratio: int,
normalization: Callable[..., nn.Module] = nn.BatchNorm2d,
with_cp: bool = False,
num_stages: int = 1,
strides: tuple[int, int] = (2, 2),
extra_stride: bool = False,
input_norm: bool = False,
) -> None:
"""StemV2 initialization."""
super().__init__()

if num_stages < 1:
msg = "num_stages must be greater than 0."
raise ValueError(msg)
if not isinstance(strides, (tuple, list)):
msg = "strides must be tuple or list."
raise TypeError(msg)

if len(strides) != 1 + num_stages:
msg = "len(strides) must equal to num_stages + 1."
raise ValueError(msg)

self.in_channels = in_channels
self.out_channels = out_channels
self.normalization = normalization
self.with_cp = with_cp
self.num_stages = num_stages

self.input_norm = None
if input_norm:
self.input_norm = nn.InstanceNorm2d(in_channels)

self.conv1 = Conv2dModule(
in_channels=in_channels,
out_channels=stem_channels,
kernel_size=3,
stride=strides[0],
padding=1,
normalization=build_norm_layer(self.normalization, num_features=stem_channels),
activation=build_activation_layer(nn.ReLU),
)

self.conv2 = None
if extra_stride:
self.conv2 = Conv2dModule(
in_channels=stem_channels,
out_channels=stem_channels,
kernel_size=3,
stride=2,
padding=1,
normalization=build_norm_layer(self.normalization, num_features=stem_channels),
activation=build_activation_layer(nn.ReLU),
)

mid_channels = int(round(stem_channels * expand_ratio))
internal_branch_channels = stem_channels // 2
out_branch_channels = self.out_channels // 2

self.branch1, self.branch2 = nn.ModuleList(), nn.ModuleList()
for stage in range(1, num_stages + 1):
self.branch1.append(
nn.Sequential(
Conv2dModule(
internal_branch_channels,
internal_branch_channels,
kernel_size=3,
stride=strides[stage],
padding=1,
groups=internal_branch_channels,
normalization=build_norm_layer(normalization, num_features=internal_branch_channels),
activation=None,
),
Conv2dModule(
internal_branch_channels,
out_branch_channels if stage == num_stages else internal_branch_channels,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(
normalization,
num_features=out_branch_channels if stage == num_stages else internal_branch_channels,
),
activation=build_activation_layer(nn.ReLU),
),
),
)

self.branch2.append(
nn.Sequential(
Conv2dModule(
internal_branch_channels,
mid_channels,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(normalization, num_features=mid_channels),
activation=build_activation_layer(nn.ReLU),
),
Conv2dModule(
mid_channels,
mid_channels,
kernel_size=3,
stride=strides[stage],
padding=1,
groups=mid_channels,
normalization=build_norm_layer(normalization, num_features=mid_channels),
activation=None,
),
Conv2dModule(
mid_channels,
out_branch_channels if stage == num_stages else internal_branch_channels,
kernel_size=1,
stride=1,
padding=0,
normalization=build_norm_layer(
normalization,
num_features=out_branch_channels if stage == num_stages else internal_branch_channels,
),
activation=build_activation_layer(nn.ReLU),
),
),
)

def _inner_forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Forward pass of Stem module.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
list[torch.Tensor]: List of output tensors at each stage of the backbone.
"""
if self.input_norm is not None:
x = self.input_norm(x)

y = self.conv1(x)
if self.conv2 is not None:
y = self.conv2(y)

out_list = [y]
for stage in range(self.num_stages):
y1, y2 = y.chunk(2, dim=1)

y1 = self.branch1[stage](y1)
y2 = self.branch2[stage](y2)

y = torch.cat((y1, y2), dim=1)
y = channel_shuffle(y, 2)
out_list.append(y)

return out_list

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""Forward."""
return cp.checkpoint(self._inner_forward, x) if self.with_cp and x.requires_grad else self._inner_forward(x)


class ShuffleUnit(nn.Module):
"""InvertedResidual block for ShuffleNetV2 backbone.
Expand Down Expand Up @@ -1193,7 +1009,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
return out


class NNLiteHRNet(nn.Module):
class LiteHRNetModule(nn.Module):
"""Lite-HRNet backbone.
`High-Resolution Representations for Labeling Pixels and Regions
Expand All @@ -1215,8 +1031,8 @@ class NNLiteHRNet(nn.Module):

def __init__(
self,
stem: dict[str, Any],
num_stages: int,
stem_configuration: dict[str, Any],
stages_spec: dict[str, Any],
in_channels: int = 3,
normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True),
Expand All @@ -1233,11 +1049,7 @@ def __init__(
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
self.stem = Stem(
in_channels,
normalization=self.normalization,
**stem,
)
self.stem = Stem(in_channels=in_channels, **stem_configuration, normalization=normalization)
self.num_stages = num_stages
self.stages_spec = stages_spec

Expand Down Expand Up @@ -1435,14 +1247,7 @@ class LiteHRNetBackbone:

LITEHRNET_CFG: ClassVar[dict[str, Any]] = {
"lite_hrnet_s": {
"stem": {
"stem_channels": 32,
"out_channels": 32,
"expand_ratio": 1,
"strides": [2, 2],
"extra_stride": True,
"input_norm": False,
},
"stem_configuration": {"extra_stride": True},
"num_stages": 2,
"stages_spec": {
"num_modules": [4, 4],
Expand All @@ -1456,14 +1261,7 @@ class LiteHRNetBackbone:
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetsv2_imagenet1k_rsc.pth",
},
"lite_hrnet_18": {
"stem": {
"stem_channels": 32,
"out_channels": 32,
"expand_ratio": 1,
"strides": [2, 2],
"extra_stride": False,
"input_norm": False,
},
"stem_configuration": {},
"num_stages": 3,
"stages_spec": {
"num_modules": [2, 4, 2],
Expand All @@ -1477,14 +1275,7 @@ class LiteHRNetBackbone:
"pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnet18_imagenet1k_rsc.pth",
},
"lite_hrnet_x": {
"stem": {
"stem_channels": 60,
"out_channels": 60,
"expand_ratio": 1,
"strides": [2, 1],
"extra_stride": False,
"input_norm": False,
},
"stem_configuration": {"stem_channels": 60, "out_channels": 60, "strides": (2, 1)},
"num_stages": 4,
"stages_spec": {
"weighting_module_version": "v1",
Expand All @@ -1500,10 +1291,9 @@ class LiteHRNetBackbone:
},
}

def __new__(cls, version: str) -> NNLiteHRNet:
def __new__(cls, model_name: str) -> LiteHRNetModule:
"""Constructor for LiteHRNet backbone."""
if version not in cls.LITEHRNET_CFG:
msg = f"model type '{version}' is not supported"
if model_name not in cls.LITEHRNET_CFG:
msg = f"model type '{model_name}' is not supported"
raise KeyError(msg)

return NNLiteHRNet(**cls.LITEHRNET_CFG[version])
return LiteHRNetModule(**cls.LITEHRNET_CFG[model_name])
10 changes: 5 additions & 5 deletions src/otx/algo/segmentation/backbones/mscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
return x, h, w


class NNMSCAN(nn.Module):
class MSCANModule(nn.Module):
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
This backbone is the implementation of `SegNeXt: Rethinking
Expand Down Expand Up @@ -462,10 +462,10 @@ class MSCAN:
},
}

def __new__(cls, version: str) -> NNMSCAN:
def __new__(cls, model_name: str) -> MSCANModule:
"""Constructor for MSCAN backbone."""
if version not in cls.MSCAN_CFG:
msg = f"model type '{version}' is not supported"
if model_name not in cls.MSCAN_CFG:
msg = f"model type '{model_name}' is not supported"
raise KeyError(msg)

return NNMSCAN(**cls.MSCAN_CFG[version])
return MSCANModule(**cls.MSCAN_CFG[model_name])
12 changes: 6 additions & 6 deletions src/otx/algo/segmentation/dino_v2_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from otx.algo.segmentation.backbones import DinoVisionTransformer
from otx.algo.segmentation.heads import FCNHead
from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore
from otx.algo.segmentation.segmentors import BaseSegmModel
from otx.algo.segmentation.segmentors import BaseSegmentationModel
from otx.core.model.segmentation import OTXSegmentationModel

if TYPE_CHECKING:
Expand All @@ -26,15 +26,15 @@ class DinoV2Seg(OTXSegmentationModel):
]

def _build_model(self) -> nn.Module:
if self.model_version not in self.AVAILABLE_MODEL_VERSIONS:
msg = f"Model version {self.model_version} is not supported."
if self.model_name not in self.AVAILABLE_MODEL_VERSIONS:
msg = f"Model version {self.model_name} is not supported."
raise ValueError(msg)

backbone = DinoVisionTransformer(name=self.model_version, freeze_backbone=True, out_index=[8, 9, 10, 11])
decode_head = FCNHead(self.model_version, num_classes=self.num_classes)
backbone = DinoVisionTransformer(model_name=self.model_name, freeze_backbone=True, out_index=[8, 9, 10, 11])
decode_head = FCNHead(self.model_name, num_classes=self.num_classes)
criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined]

return BaseSegmModel(
return BaseSegmentationModel(
backbone=backbone,
decode_head=decode_head,
criterion=criterion,
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/segmentation/heads/base_segm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http


class BaseSegmHead(nn.Module):
class BaseSegmentationHead(nn.Module):
"""Base class for segmentation heads.
Args:
Expand Down
Loading

0 comments on commit 71a819e

Please sign in to comment.