Skip to content

Commit

Permalink
Refactor Classification Backbone to Factory Class Design (#3864)
Browse files Browse the repository at this point in the history
* update for releases 2.2.0rc0

* Refactor Classification Backbone to Factory Design

* Revert "update for releases 2.2.0rc0"

This reverts commit d1bd1d5.

* Remove deprecated docstring

---------

Co-authored-by: Yunchu Lee <[email protected]>
  • Loading branch information
harimkang and yunchu authored Aug 21, 2024
1 parent a5470fc commit 1e8ec38
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 216 deletions.
6 changes: 3 additions & 3 deletions src/otx/algo/classification/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
#
"""Backbone modules for OTX custom model."""

from .efficientnet import OTXEfficientNet
from .mobilenet_v3 import OTXMobileNetV3
from .efficientnet import EfficientNetBackbone
from .mobilenet_v3 import MobileNetV3Backbone
from .timm import TimmBackbone
from .torchvision import TorchvisionBackbone
from .vision_transformer import VisionTransformer

__all__ = ["OTXEfficientNet", "TimmBackbone", "OTXMobileNetV3", "VisionTransformer", "TorchvisionBackbone"]
__all__ = ["EfficientNetBackbone", "TimmBackbone", "MobileNetV3Backbone", "VisionTransformer", "TorchvisionBackbone"]
226 changes: 107 additions & 119 deletions src/otx/algo/classification/backbones/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import math
from pathlib import Path
from typing import Callable, Literal
from typing import Any, Callable, ClassVar, Literal

import torch
from pytorchcv.models.model_store import download_model
Expand All @@ -17,7 +17,6 @@
from otx.algo.modules.activation import Swish, build_activation_layer
from otx.algo.modules.conv_module import Conv2dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model

PRETRAINED_ROOT = "https://github.com/osmr/imgclsmob/releases/download/v0.0.364/"
pretrained_urls = {
Expand Down Expand Up @@ -419,12 +418,10 @@ class EfficientNet(nn.Module):
bn_eps : float, default 1e-5. Small float added to variance in Batch norm.
in_channels : int, default 3. Number of input channels.
in_size : tuple of two ints, default (224, 224). Spatial size of the expected input image.
dropout_cls : dict, dropout configurations.
pooling_type : str, default 'avg'. Pooling type to use.
bn_eval : bool, default False. Whether to use BatchNorm eval mode.
bn_frozen : bool, default False. Whether to freeze BatchNorm parameters.
instance_norm_first : bool, default False. Whether to use instance normalization first.
pretrained : bool, default False. Whether to load ImageNet pre-trained weights.
"""

def __init__(
Expand All @@ -439,17 +436,14 @@ def __init__(
bn_eps: float = 1e-5,
in_channels: int = 3,
in_size: tuple[int, int] = (224, 224),
dropout_cls: dict | None = None,
pooling_type: str | None = "avg",
bn_eval: bool = False,
bn_frozen: bool = False,
instance_norm_first: bool = False,
pretrained: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.num_classes = 1000
self.pretrained = pretrained
self.in_size = in_size
self.input_IN = nn.InstanceNorm2d(3, affine=True) if instance_norm_first else None
self.bn_eval = bn_eval
Expand Down Expand Up @@ -528,164 +522,158 @@ def _init_params(self) -> None:
def forward(
self,
x: torch.Tensor,
return_featuremaps: bool = False,
get_embeddings: bool = False,
**kwargs,
) -> tuple | list[torch.Tensor] | torch.Tensor:
"""Forward."""
if self.input_IN is not None:
x = self.input_IN(x)

y = self.features(x)
if return_featuremaps:
return (y,)

glob_features = self._glob_feature_vector(y, self.pooling_type, reduce_dims=False)

logits = self.output(glob_features.view(x.shape[0], -1))

if not self.training and self.classification:
return [logits]

if get_embeddings:
out_data = [logits, glob_features.view(x.shape[0], -1)]
elif self.loss in ["softmax", "am_softmax"]:
out_data = logits if self.lr_finder.enable and self.lr_finder.mode == "automatic" else [logits]

elif self.loss in ["triplet"]:
out_data = [logits, glob_features]
else:
msg = f"Unsupported loss: {self.loss}"
raise KeyError(msg)

if self.lr_finder.enable and self.lr_finder.mode == "automatic":
return out_data
return tuple(out_data)
return (y,)


EFFICIENTNET_VERSION = Literal["b0", "b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8"]


class OTXEfficientNet(EfficientNet):
"""Create EfficientNet model with specific parameters.
Args:
version : str. Version of EfficientNet ('b0'...'b8').
in_size : tuple of two ints. Spatial size of the expected input image.
class EfficientNetBackbone:
"""EfficientNetBackbone class represents the backbone architecture of EfficientNet models.
Attributes:
EFFICIENTNET_CFG (ClassVar[dict[str, Any]]): A dictionary containing configuration parameters
for different versions of EfficientNet.
init_block_channels (ClassVar[int]): The number of channels in the initial block of the backbone.
layers (ClassVar[list[int]]): A list specifying the number of layers in each stage of the backbone.
downsample (ClassVar[list[int]]): A list specifying whether downsampling is applied.
channels_per_layers (ClassVar[list[int]]): A list specifying the number of channels.
expansion_factors_per_layers (ClassVar[list[int]]): A list specifying the expansion factor.
kernel_sizes_per_layers (ClassVar[list[int]]): A list specifying the kernel size in each stage of the backbone.
strides_per_stage (ClassVar[list[int]]): A list specifying the stride in each stage of the backbone.
final_block_channels (ClassVar[int]): The number of channels in the final block of the backbone.
"""

def __init__(self, version: EFFICIENTNET_VERSION, input_size: tuple[int, int] | None = None, **kwargs):
self.model_name = "efficientnet_" + version

if version == "b0":
in_size = (224, 224)
depth_factor = 1.0
width_factor = 1.0
elif version == "b1":
in_size = (240, 240)
depth_factor = 1.1
width_factor = 1.0
elif version == "b2":
in_size = (260, 260)
depth_factor = 1.2
width_factor = 1.1
elif version == "b3":
in_size = (300, 300)
depth_factor = 1.4
width_factor = 1.2
elif version == "b4":
in_size = (380, 380)
depth_factor = 1.8
width_factor = 1.4
elif version == "b5":
in_size = (456, 456)
depth_factor = 2.2
width_factor = 1.6
elif version == "b6":
in_size = (528, 528)
depth_factor = 2.6
width_factor = 1.8
elif version == "b7":
in_size = (600, 600)
depth_factor = 3.1
width_factor = 2.0
elif version == "b8":
in_size = (672, 672)
depth_factor = 3.6
width_factor = 2.2
else:
msg = f"Unsupported EfficientNet version {version}"
raise ValueError(msg)

if input_size is not None:
in_size = input_size

init_block_channels = 32
layers = [1, 2, 2, 3, 3, 4, 1]
downsample = [1, 1, 1, 1, 0, 1, 0]
channels_per_layers = [16, 24, 40, 80, 112, 192, 320]
expansion_factors_per_layers = [1, 6, 6, 6, 6, 6, 6]
kernel_sizes_per_layers = [3, 3, 5, 3, 5, 5, 3]
_strides_per_stage = [1, 2, 2, 2, 1, 2, 1]
final_block_channels = 1280

layers = [int(math.ceil(li * depth_factor)) for li in layers]
channels_per_layers = [round_channels(ci * width_factor) for ci in channels_per_layers]
EFFICIENTNET_CFG: ClassVar[dict[str, Any]] = {
"b0": {
"input_size": (224, 224),
"depth_factor": 1.0,
"width_factor": 1.0,
},
"b1": {
"input_size": (240, 240),
"depth_factor": 1.1,
"width_factor": 1.0,
},
"b2": {
"input_size": (260, 260),
"depth_factor": 1.2,
"width_factor": 1.1,
},
"b3": {
"input_size": (300, 300),
"depth_factor": 1.4,
"width_factor": 1.2,
},
"b4": {
"input_size": (380, 380),
"depth_factor": 1.8,
"width_factor": 1.4,
},
"b5": {
"input_size": (456, 456),
"depth_factor": 2.2,
"width_factor": 1.6,
},
"b6": {
"input_size": (528, 528),
"depth_factor": 2.6,
"width_factor": 1.8,
},
"b7": {
"input_size": (600, 600),
"depth_factor": 3.1,
"width_factor": 2.0,
},
"b8": {
"input_size": (672, 672),
"depth_factor": 3.6,
"width_factor": 2.2,
},
}

init_block_channels: ClassVar[int] = 32
layers: ClassVar[list[int]] = [1, 2, 2, 3, 3, 4, 1]
downsample: ClassVar[list[int]] = [1, 1, 1, 1, 0, 1, 0]
channels_per_layers: ClassVar[list[int]] = [16, 24, 40, 80, 112, 192, 320]
expansion_factors_per_layers: ClassVar[list[int]] = [1, 6, 6, 6, 6, 6, 6]
kernel_sizes_per_layers: ClassVar[list[int]] = [3, 3, 5, 3, 5, 5, 3]
strides_per_stage: ClassVar[list[int]] = [1, 2, 2, 2, 1, 2, 1]
final_block_channels: ClassVar[int] = 1280

def __new__(
cls,
version: EFFICIENTNET_VERSION,
input_size: tuple[int, int] | None = None,
pretrained: bool = True,
**kwargs,
) -> EfficientNet:
"""Create a new instance of the EfficientNet class.
Args:
version (EFFICIENTNET_VERSION): The version of EfficientNet to use.
input_size (tuple[int, int] | None, optional): The input size of the model. Defaults to None.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
**kwargs: Additional keyword arguments to be passed to the EfficientNet constructor.
Returns:
EfficientNet: The created EfficientNet model instance.
"""
origin_input_size, depth_factor, width_factor = cls.EFFICIENTNET_CFG[version].values()
input_size = input_size or origin_input_size
effnet_layers = [int(math.ceil(li * depth_factor)) for li in cls.layers]
channels_per_layers = [round_channels(ci * width_factor) for ci in cls.channels_per_layers]

from functools import reduce

channels: list = reduce(
lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
zip(channels_per_layers, layers, downsample),
zip(channels_per_layers, effnet_layers, cls.downsample),
[],
)
kernel_sizes: list = reduce(
lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
zip(kernel_sizes_per_layers, layers, downsample),
zip(cls.kernel_sizes_per_layers, effnet_layers, cls.downsample),
[],
)
expansion_factors: list = reduce(
lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
zip(expansion_factors_per_layers, layers, downsample),
zip(cls.expansion_factors_per_layers, effnet_layers, cls.downsample),
[],
)
strides_per_stage: list = reduce(
lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]],
zip(_strides_per_stage, layers, downsample),
zip(cls.strides_per_stage, effnet_layers, cls.downsample),
[],
)
strides_per_stage = [si[0] for si in strides_per_stage]
init_block_channels = round_channels(cls.init_block_channels * width_factor)

init_block_channels = round_channels(init_block_channels * width_factor)

final_block_channels = cls.final_block_channels
if width_factor > 1.0:
final_block_channels = round_channels(final_block_channels * width_factor)

super().__init__(
model = EfficientNet(
channels=channels,
init_block_channels=init_block_channels,
final_block_channels=final_block_channels,
kernel_sizes=kernel_sizes,
strides_per_stage=strides_per_stage,
expansion_factors=expansion_factors,
dropout_cls={"dist": "none"},
tf_mode=False,
bn_eps=1e-5,
in_size=in_size,
in_size=input_size,
**kwargs,
)
self.init_weights(self.pretrained)

def forward(self, x: torch.Tensor, return_featuremaps: bool = True, get_embeddings: bool = False) -> torch.Tensor:
"""Forward."""
return super().forward(x, return_featuremaps=return_featuremaps, get_embeddings=get_embeddings)

def init_weights(self, pretrained: bool | str | None = None) -> None:
"""Initialize weights."""
if isinstance(pretrained, str) and Path(pretrained).exists():
checkpoint = torch.load(pretrained, None)
load_checkpoint_to_model(self, checkpoint)
print(f"init weight - {pretrained}")
elif pretrained:
if pretrained:
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
download_model(net=self, model_name=self.model_name, local_model_store_dir_path=str(cache_dir))
download_model(net=model, model_name=f"efficientnet_{version}", local_model_store_dir_path=str(cache_dir))
print(f"Download model weight in {cache_dir!s}")
return model
Loading

0 comments on commit 1e8ec38

Please sign in to comment.