From 1e8ec38e7e83bdc11f18281a0d838a128b8e99d6 Mon Sep 17 00:00:00 2001 From: Harim Kang Date: Wed, 21 Aug 2024 15:32:37 +0900 Subject: [PATCH] Refactor Classification Backbone to Factory Class Design (#3864) * update for releases 2.2.0rc0 * Refactor Classification Backbone to Factory Design * Revert "update for releases 2.2.0rc0" This reverts commit d1bd1d5d04d62cdcf57ff270ba2a3a33f52a333f. * Remove deprecated docstring --------- Co-authored-by: Yunchu Lee --- .../algo/classification/backbones/__init__.py | 6 +- .../classification/backbones/efficientnet.py | 226 +++++++++--------- .../classification/backbones/mobilenet_v3.py | 173 ++++++++------ src/otx/algo/classification/efficientnet.py | 8 +- src/otx/algo/classification/mobilenet_v3.py | 22 +- .../backbones/test_otx_efficientnet.py | 6 +- .../backbones/test_otx_mobilenet_v3.py | 6 +- .../classifier/test_base_classifier.py | 4 +- .../classifier/test_semi_sl_classifier.py | 4 +- 9 files changed, 239 insertions(+), 216 deletions(-) diff --git a/src/otx/algo/classification/backbones/__init__.py b/src/otx/algo/classification/backbones/__init__.py index e3b4d4a753c..229e4515e64 100644 --- a/src/otx/algo/classification/backbones/__init__.py +++ b/src/otx/algo/classification/backbones/__init__.py @@ -3,10 +3,10 @@ # """Backbone modules for OTX custom model.""" -from .efficientnet import OTXEfficientNet -from .mobilenet_v3 import OTXMobileNetV3 +from .efficientnet import EfficientNetBackbone +from .mobilenet_v3 import MobileNetV3Backbone from .timm import TimmBackbone from .torchvision import TorchvisionBackbone from .vision_transformer import VisionTransformer -__all__ = ["OTXEfficientNet", "TimmBackbone", "OTXMobileNetV3", "VisionTransformer", "TorchvisionBackbone"] +__all__ = ["EfficientNetBackbone", "TimmBackbone", "MobileNetV3Backbone", "VisionTransformer", "TorchvisionBackbone"] diff --git a/src/otx/algo/classification/backbones/efficientnet.py b/src/otx/algo/classification/backbones/efficientnet.py index fe28bb6ef09..667c13cb6c8 100644 --- a/src/otx/algo/classification/backbones/efficientnet.py +++ b/src/otx/algo/classification/backbones/efficientnet.py @@ -7,7 +7,7 @@ import math from pathlib import Path -from typing import Callable, Literal +from typing import Any, Callable, ClassVar, Literal import torch from pytorchcv.models.model_store import download_model @@ -17,7 +17,6 @@ from otx.algo.modules.activation import Swish, build_activation_layer from otx.algo.modules.conv_module import Conv2dModule from otx.algo.modules.norm import build_norm_layer -from otx.algo.utils.mmengine_utils import load_checkpoint_to_model PRETRAINED_ROOT = "https://github.com/osmr/imgclsmob/releases/download/v0.0.364/" pretrained_urls = { @@ -419,12 +418,10 @@ class EfficientNet(nn.Module): bn_eps : float, default 1e-5. Small float added to variance in Batch norm. in_channels : int, default 3. Number of input channels. in_size : tuple of two ints, default (224, 224). Spatial size of the expected input image. - dropout_cls : dict, dropout configurations. pooling_type : str, default 'avg'. Pooling type to use. bn_eval : bool, default False. Whether to use BatchNorm eval mode. bn_frozen : bool, default False. Whether to freeze BatchNorm parameters. instance_norm_first : bool, default False. Whether to use instance normalization first. - pretrained : bool, default False. Whether to load ImageNet pre-trained weights. """ def __init__( @@ -439,17 +436,14 @@ def __init__( bn_eps: float = 1e-5, in_channels: int = 3, in_size: tuple[int, int] = (224, 224), - dropout_cls: dict | None = None, pooling_type: str | None = "avg", bn_eval: bool = False, bn_frozen: bool = False, instance_norm_first: bool = False, - pretrained: bool = False, **kwargs, ): super().__init__(**kwargs) self.num_classes = 1000 - self.pretrained = pretrained self.in_size = in_size self.input_IN = nn.InstanceNorm2d(3, affine=True) if instance_norm_first else None self.bn_eval = bn_eval @@ -528,164 +522,158 @@ def _init_params(self) -> None: def forward( self, x: torch.Tensor, - return_featuremaps: bool = False, - get_embeddings: bool = False, + **kwargs, ) -> tuple | list[torch.Tensor] | torch.Tensor: """Forward.""" if self.input_IN is not None: x = self.input_IN(x) y = self.features(x) - if return_featuremaps: - return (y,) - - glob_features = self._glob_feature_vector(y, self.pooling_type, reduce_dims=False) - - logits = self.output(glob_features.view(x.shape[0], -1)) - - if not self.training and self.classification: - return [logits] - - if get_embeddings: - out_data = [logits, glob_features.view(x.shape[0], -1)] - elif self.loss in ["softmax", "am_softmax"]: - out_data = logits if self.lr_finder.enable and self.lr_finder.mode == "automatic" else [logits] - - elif self.loss in ["triplet"]: - out_data = [logits, glob_features] - else: - msg = f"Unsupported loss: {self.loss}" - raise KeyError(msg) - - if self.lr_finder.enable and self.lr_finder.mode == "automatic": - return out_data - return tuple(out_data) + return (y,) EFFICIENTNET_VERSION = Literal["b0", "b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8"] -class OTXEfficientNet(EfficientNet): - """Create EfficientNet model with specific parameters. - - Args: - version : str. Version of EfficientNet ('b0'...'b8'). - in_size : tuple of two ints. Spatial size of the expected input image. +class EfficientNetBackbone: + """EfficientNetBackbone class represents the backbone architecture of EfficientNet models. + + Attributes: + EFFICIENTNET_CFG (ClassVar[dict[str, Any]]): A dictionary containing configuration parameters + for different versions of EfficientNet. + init_block_channels (ClassVar[int]): The number of channels in the initial block of the backbone. + layers (ClassVar[list[int]]): A list specifying the number of layers in each stage of the backbone. + downsample (ClassVar[list[int]]): A list specifying whether downsampling is applied. + channels_per_layers (ClassVar[list[int]]): A list specifying the number of channels. + expansion_factors_per_layers (ClassVar[list[int]]): A list specifying the expansion factor. + kernel_sizes_per_layers (ClassVar[list[int]]): A list specifying the kernel size in each stage of the backbone. + strides_per_stage (ClassVar[list[int]]): A list specifying the stride in each stage of the backbone. + final_block_channels (ClassVar[int]): The number of channels in the final block of the backbone. """ - def __init__(self, version: EFFICIENTNET_VERSION, input_size: tuple[int, int] | None = None, **kwargs): - self.model_name = "efficientnet_" + version - - if version == "b0": - in_size = (224, 224) - depth_factor = 1.0 - width_factor = 1.0 - elif version == "b1": - in_size = (240, 240) - depth_factor = 1.1 - width_factor = 1.0 - elif version == "b2": - in_size = (260, 260) - depth_factor = 1.2 - width_factor = 1.1 - elif version == "b3": - in_size = (300, 300) - depth_factor = 1.4 - width_factor = 1.2 - elif version == "b4": - in_size = (380, 380) - depth_factor = 1.8 - width_factor = 1.4 - elif version == "b5": - in_size = (456, 456) - depth_factor = 2.2 - width_factor = 1.6 - elif version == "b6": - in_size = (528, 528) - depth_factor = 2.6 - width_factor = 1.8 - elif version == "b7": - in_size = (600, 600) - depth_factor = 3.1 - width_factor = 2.0 - elif version == "b8": - in_size = (672, 672) - depth_factor = 3.6 - width_factor = 2.2 - else: - msg = f"Unsupported EfficientNet version {version}" - raise ValueError(msg) - - if input_size is not None: - in_size = input_size - - init_block_channels = 32 - layers = [1, 2, 2, 3, 3, 4, 1] - downsample = [1, 1, 1, 1, 0, 1, 0] - channels_per_layers = [16, 24, 40, 80, 112, 192, 320] - expansion_factors_per_layers = [1, 6, 6, 6, 6, 6, 6] - kernel_sizes_per_layers = [3, 3, 5, 3, 5, 5, 3] - _strides_per_stage = [1, 2, 2, 2, 1, 2, 1] - final_block_channels = 1280 - - layers = [int(math.ceil(li * depth_factor)) for li in layers] - channels_per_layers = [round_channels(ci * width_factor) for ci in channels_per_layers] + EFFICIENTNET_CFG: ClassVar[dict[str, Any]] = { + "b0": { + "input_size": (224, 224), + "depth_factor": 1.0, + "width_factor": 1.0, + }, + "b1": { + "input_size": (240, 240), + "depth_factor": 1.1, + "width_factor": 1.0, + }, + "b2": { + "input_size": (260, 260), + "depth_factor": 1.2, + "width_factor": 1.1, + }, + "b3": { + "input_size": (300, 300), + "depth_factor": 1.4, + "width_factor": 1.2, + }, + "b4": { + "input_size": (380, 380), + "depth_factor": 1.8, + "width_factor": 1.4, + }, + "b5": { + "input_size": (456, 456), + "depth_factor": 2.2, + "width_factor": 1.6, + }, + "b6": { + "input_size": (528, 528), + "depth_factor": 2.6, + "width_factor": 1.8, + }, + "b7": { + "input_size": (600, 600), + "depth_factor": 3.1, + "width_factor": 2.0, + }, + "b8": { + "input_size": (672, 672), + "depth_factor": 3.6, + "width_factor": 2.2, + }, + } + + init_block_channels: ClassVar[int] = 32 + layers: ClassVar[list[int]] = [1, 2, 2, 3, 3, 4, 1] + downsample: ClassVar[list[int]] = [1, 1, 1, 1, 0, 1, 0] + channels_per_layers: ClassVar[list[int]] = [16, 24, 40, 80, 112, 192, 320] + expansion_factors_per_layers: ClassVar[list[int]] = [1, 6, 6, 6, 6, 6, 6] + kernel_sizes_per_layers: ClassVar[list[int]] = [3, 3, 5, 3, 5, 5, 3] + strides_per_stage: ClassVar[list[int]] = [1, 2, 2, 2, 1, 2, 1] + final_block_channels: ClassVar[int] = 1280 + + def __new__( + cls, + version: EFFICIENTNET_VERSION, + input_size: tuple[int, int] | None = None, + pretrained: bool = True, + **kwargs, + ) -> EfficientNet: + """Create a new instance of the EfficientNet class. + + Args: + version (EFFICIENTNET_VERSION): The version of EfficientNet to use. + input_size (tuple[int, int] | None, optional): The input size of the model. Defaults to None. + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + **kwargs: Additional keyword arguments to be passed to the EfficientNet constructor. + + Returns: + EfficientNet: The created EfficientNet model instance. + """ + origin_input_size, depth_factor, width_factor = cls.EFFICIENTNET_CFG[version].values() + input_size = input_size or origin_input_size + effnet_layers = [int(math.ceil(li * depth_factor)) for li in cls.layers] + channels_per_layers = [round_channels(ci * width_factor) for ci in cls.channels_per_layers] from functools import reduce channels: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], - zip(channels_per_layers, layers, downsample), + zip(channels_per_layers, effnet_layers, cls.downsample), [], ) kernel_sizes: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], - zip(kernel_sizes_per_layers, layers, downsample), + zip(cls.kernel_sizes_per_layers, effnet_layers, cls.downsample), [], ) expansion_factors: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], - zip(expansion_factors_per_layers, layers, downsample), + zip(cls.expansion_factors_per_layers, effnet_layers, cls.downsample), [], ) strides_per_stage: list = reduce( lambda x, y: [*x, [y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], - zip(_strides_per_stage, layers, downsample), + zip(cls.strides_per_stage, effnet_layers, cls.downsample), [], ) strides_per_stage = [si[0] for si in strides_per_stage] + init_block_channels = round_channels(cls.init_block_channels * width_factor) - init_block_channels = round_channels(init_block_channels * width_factor) - + final_block_channels = cls.final_block_channels if width_factor > 1.0: final_block_channels = round_channels(final_block_channels * width_factor) - super().__init__( + model = EfficientNet( channels=channels, init_block_channels=init_block_channels, final_block_channels=final_block_channels, kernel_sizes=kernel_sizes, strides_per_stage=strides_per_stage, expansion_factors=expansion_factors, - dropout_cls={"dist": "none"}, tf_mode=False, bn_eps=1e-5, - in_size=in_size, + in_size=input_size, **kwargs, ) - self.init_weights(self.pretrained) - - def forward(self, x: torch.Tensor, return_featuremaps: bool = True, get_embeddings: bool = False) -> torch.Tensor: - """Forward.""" - return super().forward(x, return_featuremaps=return_featuremaps, get_embeddings=get_embeddings) - - def init_weights(self, pretrained: bool | str | None = None) -> None: - """Initialize weights.""" - if isinstance(pretrained, str) and Path(pretrained).exists(): - checkpoint = torch.load(pretrained, None) - load_checkpoint_to_model(self, checkpoint) - print(f"init weight - {pretrained}") - elif pretrained: + if pretrained: cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" - download_model(net=self, model_name=self.model_name, local_model_store_dir_path=str(cache_dir)) + download_model(net=model, model_name=f"efficientnet_{version}", local_model_store_dir_path=str(cache_dir)) print(f"Download model weight in {cache_dir!s}") + return model diff --git a/src/otx/algo/classification/backbones/mobilenet_v3.py b/src/otx/algo/classification/backbones/mobilenet_v3.py index 5dc18a45194..7f30a2e0e60 100644 --- a/src/otx/algo/classification/backbones/mobilenet_v3.py +++ b/src/otx/algo/classification/backbones/mobilenet_v3.py @@ -9,7 +9,7 @@ from __future__ import annotations import math -from pathlib import Path +from typing import Any, ClassVar, Literal import torch from torch import nn @@ -269,12 +269,20 @@ def forward( class MobileNetV3(MobileNetV3Base): - """MobileNetV3.""" + """MobileNetV3 constructor. - def __init__(self, cfgs: list, mode: str, instance_norm_conv1: bool = False, **kwargs): + Args: + layer_cfgs (list): List of layer configurations. + instance_norm_conv1 (bool, optional): Whether to use instance normalization in the first convolutional layer. + Defaults to False. + **kwargs: Additional keyword arguments. + + """ + + def __init__(self, layer_cfgs: list, instance_norm_conv1: bool = False, **kwargs): super().__init__(**kwargs) # setting of inverted residual blocks - self.cfgs = cfgs + self.cfgs = layer_cfgs # building first layer input_channel = make_divisible(16 * self.width_mult, 8) stride = 1 if self.in_size[0] < 100 else 2 @@ -282,7 +290,7 @@ def __init__(self, cfgs: list, mode: str, instance_norm_conv1: bool = False, **k # building inverted residual blocks block = InvertedResidual flag = True - output_channel: int | dict[str, int] + output_channel: int for k, t, c, use_se, use_hs, s in self.cfgs: _s = s if (self.in_size[0] < 100) and (s == 2) and flag: @@ -295,10 +303,6 @@ def __init__(self, cfgs: list, mode: str, instance_norm_conv1: bool = False, **k self.features = nn.Sequential(*layers) # building last several layers self.conv = conv_1x1_bn(input_channel, exp_size, self.loss) - output_channel = {"large": 1280, "small": 1024} - output_channel = ( - make_divisible(output_channel[mode] * self.width_mult, 8) if self.width_mult > 1.0 else output_channel[mode] - ) self._initialize_weights() def extract_features(self, x: torch.Tensor) -> tuple[torch.Tensor]: @@ -306,13 +310,6 @@ def extract_features(self, x: torch.Tensor) -> tuple[torch.Tensor]: y = self.conv(self.features(x)) return (y,) - def infer_head(self, x: torch.Tensor, skip_pool: bool = False) -> tuple[torch.Tensor, torch.Tensor]: - """Inference head.""" - glob_features = self._glob_feature_vector(x, self.pooling_type, reduce_dims=False) if not skip_pool else x - - logits = self.classifier(glob_features.view(x.shape[0], -1)) - return glob_features, logits - def _initialize_weights(self) -> None: """Initialize weights.""" for m in self.modules(): @@ -330,59 +327,95 @@ def _initialize_weights(self) -> None: m.bias.data.zero_() -class OTXMobileNetV3(MobileNetV3): - """MobileNetV3 model for OTX.""" - - backbone_configs = { # noqa: RUF012 - "small": [ - # k, t, c, SE, HS, s - [3, 1, 16, 1, 0, 2], - [3, 4.5, 24, 0, 0, 2], - [3, 3.67, 24, 0, 0, 1], - [5, 4, 40, 1, 1, 2], - [5, 6, 40, 1, 1, 1], - [5, 6, 40, 1, 1, 1], - [5, 3, 48, 1, 1, 1], - [5, 3, 48, 1, 1, 1], - [5, 6, 96, 1, 1, 2], - [5, 6, 96, 1, 1, 1], - [5, 6, 96, 1, 1, 1], - ], - "large": [ - # k, t, c, SE, HS, s - [3, 1, 16, 0, 0, 1], - [3, 4, 24, 0, 0, 2], - [3, 3, 24, 0, 0, 1], - [5, 3, 40, 1, 0, 2], - [5, 3, 40, 1, 0, 1], - [5, 3, 40, 1, 0, 1], - [3, 6, 80, 0, 1, 2], - [3, 2.5, 80, 0, 1, 1], - [3, 2.3, 80, 0, 1, 1], - [3, 2.3, 80, 0, 1, 1], - [3, 6, 112, 1, 1, 1], - [3, 6, 112, 1, 1, 1], - [5, 6, 160, 1, 1, 2], - [5, 6, 160, 1, 1, 1], - [5, 6, 160, 1, 1, 1], - ], - } +class MobileNetV3Backbone: + """MobileNetV3Backbone class represents the backbone architecture of MobileNetV3. - def __init__(self, mode: str = "large", width_mult: float = 1.0, **kwargs): - super().__init__(self.backbone_configs[mode], mode=mode, width_mult=width_mult, **kwargs) - self.key = "mobilenetv3_" + mode - if width_mult != 1.0: - self.key = self.key + f"_{int(width_mult * 100):03d}" # pylint: disable=consider-using-f-string - self.init_weights(self.pretrained) + Args: + mode (Literal["small", "large"], optional): The mode of the backbone architecture. Defaults to "large". + width_mult (float, optional): Width multiplier for the backbone architecture. Defaults to 1.0. + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + **kwargs: Additional keyword arguments to be passed to the MobileNetV3 model. - def init_weights(self, pretrained: str | bool | None = None) -> None: - """Initialize weights.""" - checkpoint = None - if isinstance(pretrained, str) and Path(pretrained).exists(): - checkpoint = torch.load(pretrained, None) - print(f"init weight - {pretrained}") - elif pretrained is not None: - checkpoint = load_from_http(pretrained_urls[self.key]) - print(f"init weight - {pretrained_urls[self.key]}") - if checkpoint is not None: - load_checkpoint_to_model(self, checkpoint) + Returns: + MobileNetV3: An instance of the MobileNetV3 model. + + Examples: + # Create a MobileNetV3Backbone instance + backbone = MobileNetV3Backbone(mode="small", width_mult=0.75, pretrained=False) + + # Create a MobileNetV3 model with the specified backbone + model = MobileNetV3(backbone=backbone) + """ + + MV3_CFG: ClassVar[dict[str, Any]] = { + "small": { + "layer_cfgs": [ + # k, t, c, SE, HS, s + [3, 1, 16, 1, 0, 2], + [3, 4.5, 24, 0, 0, 2], + [3, 3.67, 24, 0, 0, 1], + [5, 4, 40, 1, 1, 2], + [5, 6, 40, 1, 1, 1], + [5, 6, 40, 1, 1, 1], + [5, 3, 48, 1, 1, 1], + [5, 3, 48, 1, 1, 1], + [5, 6, 96, 1, 1, 2], + [5, 6, 96, 1, 1, 1], + [5, 6, 96, 1, 1, 1], + ], + "out_channels": 576, + "hid_channels": 1024, + }, + "large": { + "layer_cfgs": [ + # k, t, c, SE, HS, s + [3, 1, 16, 0, 0, 1], + [3, 4, 24, 0, 0, 2], + [3, 3, 24, 0, 0, 1], + [5, 3, 40, 1, 0, 2], + [5, 3, 40, 1, 0, 1], + [5, 3, 40, 1, 0, 1], + [3, 6, 80, 0, 1, 2], + [3, 2.5, 80, 0, 1, 1], + [3, 2.3, 80, 0, 1, 1], + [3, 2.3, 80, 0, 1, 1], + [3, 6, 112, 1, 1, 1], + [3, 6, 112, 1, 1, 1], + [5, 6, 160, 1, 1, 2], + [5, 6, 160, 1, 1, 1], + [5, 6, 160, 1, 1, 1], + ], + "out_channels": 960, + "hid_channels": 1280, + }, + } + + def __new__( + cls, + mode: Literal["small", "large"] = "large", + width_mult: float = 1.0, + pretrained: bool = True, + **kwargs, + ) -> MobileNetV3: + """Create a new instance of the MobileNetV3 class. + + Args: + mode (Literal["small", "large"], optional): The mode of the MobileNetV3 model. Defaults to "large". + width_mult (float, optional): Width multiplier for the MobileNetV3 model. Defaults to 1.0. + pretrained (bool, optional): Whether to load pretrained weights for the MobileNetV3 model. Defaults to True. + **kwargs: Additional keyword arguments to be passed to the MobileNetV3 constructor. + + Returns: + MobileNetV3: A new instance of the MobileNetV3 class. + """ + model = MobileNetV3( + layer_cfgs=cls.MV3_CFG[mode]["layer_cfgs"], + width_mult=width_mult, + **kwargs, + ) + if pretrained: + key = f"mobilenetv3_{mode}" if width_mult == 1.0 else f"mobilenetv3_{mode}_{int(width_mult * 100):03d}" + checkpoint = load_from_http(pretrained_urls[key]) + print(f"init weight - {pretrained_urls[key]}") + load_checkpoint_to_model(model, checkpoint) + return model diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index 2f5c00d544e..d0ce1421b03 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -11,7 +11,7 @@ from torch import Tensor, nn -from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet +from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, EfficientNetBackbone from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, @@ -87,7 +87,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: - backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained) + backbone = EfficientNetBackbone(version=self.version, input_size=self.input_size, pretrained=self.pretrained) neck = GlobalAveragePooling(dim=2) if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( @@ -177,7 +177,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: - backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained) + backbone = EfficientNetBackbone(version=self.version, input_size=self.input_size, pretrained=self.pretrained) return ImageClassifier( backbone=backbone, neck=GlobalAveragePooling(dim=2), @@ -265,7 +265,7 @@ def _build_model(self, head_config: dict) -> nn.Module: if not isinstance(self.label_info, HLabelInfo): raise TypeError(self.label_info) - backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained) + backbone = EfficientNetBackbone(version=self.version, input_size=self.input_size, pretrained=self.pretrained) copied_head_config = copy(head_config) copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32)) diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index 18963d53b5c..c0e83840aa5 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -12,7 +12,7 @@ import torch from torch import Tensor, nn -from otx.algo.classification.backbones import OTXMobileNetV3 +from otx.algo.classification.backbones import MobileNetV3Backbone from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, @@ -102,16 +102,16 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: - backbone = OTXMobileNetV3(mode=self.mode, input_size=self.input_size) + backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size) + backbone_out_chennels = MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"] neck = GlobalAveragePooling(dim=2) - in_channels = 960 if self.mode == "large" else 576 if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( backbone=backbone, neck=neck, head=SemiSLLinearClsHead( num_classes=num_classes, - in_channels=in_channels, + in_channels=backbone_out_chennels, ), loss=nn.CrossEntropyLoss(reduction="none"), ) @@ -121,7 +121,7 @@ def _build_model(self, num_classes: int) -> nn.Module: neck=neck, head=LinearClsHead( num_classes=num_classes, - in_channels=in_channels, + in_channels=backbone_out_chennels, ), loss=nn.CrossEntropyLoss(), ) @@ -190,13 +190,14 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: + backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size) return ImageClassifier( - backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size), + backbone=backbone, neck=GlobalAveragePooling(dim=2), head=MultiLabelNonLinearClsHead( num_classes=num_classes, - in_channels=960, - hid_channels=1280, + in_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"], + hid_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["hid_channels"], normalized=True, activation=nn.PReLU(), ), @@ -314,11 +315,12 @@ def _build_model(self, head_config: dict) -> nn.Module: copied_head_config = copy(head_config) copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32)) + backbone = MobileNetV3Backbone(mode=self.mode, input_size=self.input_size) return HLabelClassifier( - backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size), + backbone=backbone, neck=nn.Identity(), head=HierarchicalCBAMClsHead( - in_channels=960, + in_channels=MobileNetV3Backbone.MV3_CFG[self.mode]["out_channels"], **copied_head_config, ), multiclass_loss=nn.CrossEntropyLoss(), diff --git a/tests/unit/algo/classification/backbones/test_otx_efficientnet.py b/tests/unit/algo/classification/backbones/test_otx_efficientnet.py index 3d7fb9017fd..e6e96708ebc 100644 --- a/tests/unit/algo/classification/backbones/test_otx_efficientnet.py +++ b/tests/unit/algo/classification/backbones/test_otx_efficientnet.py @@ -4,17 +4,17 @@ import pytest import torch -from otx.algo.classification.backbones.efficientnet import OTXEfficientNet +from otx.algo.classification.backbones.efficientnet import EfficientNetBackbone class TestOTXEfficientNet: @pytest.mark.parametrize("version", ["b0", "b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8"]) def test_forward(self, version): - model = OTXEfficientNet(version, pretrained=None) + model = EfficientNetBackbone(version, pretrained=None) assert model(torch.randn(1, 3, 244, 244))[0].shape[-1] == 8 assert model(torch.randn(1, 3, 244, 244))[0].shape[-2] == 8 def test_set_input_size(self): input_size = (300, 300) - model = OTXEfficientNet("b0", input_size=input_size, pretrained=None) + model = EfficientNetBackbone("b0", input_size=input_size, pretrained=None) assert model.in_size == input_size diff --git a/tests/unit/algo/classification/backbones/test_otx_mobilenet_v3.py b/tests/unit/algo/classification/backbones/test_otx_mobilenet_v3.py index b0275a34fc6..aa5da358b7c 100644 --- a/tests/unit/algo/classification/backbones/test_otx_mobilenet_v3.py +++ b/tests/unit/algo/classification/backbones/test_otx_mobilenet_v3.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from otx.algo.classification.backbones.mobilenet_v3 import OTXMobileNetV3 +from otx.algo.classification.backbones.mobilenet_v3 import MobileNetV3Backbone class TestOTXMobileNetV3: def test_forward(self): - model = OTXMobileNetV3() + model = MobileNetV3Backbone() assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 960, 8, 8]) def test_glob_feature_vector(self): - model = OTXMobileNetV3() + model = MobileNetV3Backbone() assert model._glob_feature_vector(torch.randn([1, 960, 8, 8]), "avg").shape == torch.Size([1, 960]) assert model._glob_feature_vector(torch.randn([1, 960, 8, 8]), "max").shape == torch.Size([1, 960]) assert model._glob_feature_vector(torch.randn([1, 960, 8, 8]), "avg+max").shape == torch.Size([1, 960]) diff --git a/tests/unit/algo/classification/classifier/test_base_classifier.py b/tests/unit/algo/classification/classifier/test_base_classifier.py index f27c5c40e27..39f7821a01b 100644 --- a/tests/unit/algo/classification/classifier/test_base_classifier.py +++ b/tests/unit/algo/classification/classifier/test_base_classifier.py @@ -3,7 +3,7 @@ import pytest import torch -from otx.algo.classification.backbones import OTXEfficientNet +from otx.algo.classification.backbones import EfficientNetBackbone from otx.algo.classification.classifier import ImageClassifier from otx.algo.classification.heads import LinearClsHead, MultiLabelLinearClsHead from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore @@ -21,7 +21,7 @@ class TestImageClassifier: ) def fxt_model_and_inputs(self, request): head_cls, loss_cls, input_fxt_name = request.param - backbone = OTXEfficientNet(version="b0") + backbone = EfficientNetBackbone(version="b0") neck = GlobalAveragePooling(dim=2) head = head_cls(num_classes=3, in_channels=backbone.num_features) loss = loss_cls() diff --git a/tests/unit/algo/classification/classifier/test_semi_sl_classifier.py b/tests/unit/algo/classification/classifier/test_semi_sl_classifier.py index 8a9a4e3cde9..c9100ede578 100644 --- a/tests/unit/algo/classification/classifier/test_semi_sl_classifier.py +++ b/tests/unit/algo/classification/classifier/test_semi_sl_classifier.py @@ -3,7 +3,7 @@ import pytest import torch -from otx.algo.classification.backbones import OTXEfficientNet +from otx.algo.classification.backbones import EfficientNetBackbone from otx.algo.classification.classifier import SemiSLClassifier from otx.algo.classification.heads import SemiSLLinearClsHead from otx.algo.classification.necks.gap import GlobalAveragePooling @@ -12,7 +12,7 @@ class TestSemiSLClassifier: @pytest.fixture() def fxt_semi_sl_classifier(self): - backbone = OTXEfficientNet(version="b0") + backbone = EfficientNetBackbone(version="b0") neck = GlobalAveragePooling(dim=2) head = SemiSLLinearClsHead( num_classes=2,