diff --git a/src/otx/algo/segmentation/backbones/dinov2.py b/src/otx/algo/segmentation/backbones/dinov2.py index 5468870ffef..ce1d605fe38 100644 --- a/src/otx/algo/segmentation/backbones/dinov2.py +++ b/src/otx/algo/segmentation/backbones/dinov2.py @@ -24,7 +24,7 @@ class DinoVisionTransformer(nn.Module): def __init__( self, - name: str, + model_name: str, freeze_backbone: bool, out_index: list[int], pretrained_weights: str | None = None, @@ -37,17 +37,17 @@ def __init__( if ci_data_root is not None and Path(ci_data_root).exists(): pretrained = False - self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=name, pretrained=pretrained) + self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=model_name, pretrained=pretrained) if ci_data_root is not None and Path(ci_data_root).exists(): - ckpt_filename = f"{name}4_pretrain.pth" + ckpt_filename = f"{model_name}4_pretrain.pth" ckpt_path = Path(ci_data_root) / "torch" / "hub" / "checkpoints" / ckpt_filename if not ckpt_path.exists(): msg = ( f"Internal cache was specified but cannot find weights file: {ckpt_filename}. load from torch hub." ) logger.warning(msg) - self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=name, pretrained=True) + self.backbone = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=model_name, pretrained=True) else: self.backbone.load_state_dict(torch.load(ckpt_path)) diff --git a/src/otx/algo/segmentation/backbones/litehrnet.py b/src/otx/algo/segmentation/backbones/litehrnet.py index ca32fc23db4..f46007775bf 100644 --- a/src/otx/algo/segmentation/backbones/litehrnet.py +++ b/src/otx/algo/segmentation/backbones/litehrnet.py @@ -522,10 +522,10 @@ class Stem(nn.Module): def __init__( self, in_channels: int, - stem_channels: int, - out_channels: int, - expand_ratio: int, - normalization: Callable[..., nn.Module] = nn.BatchNorm2d, + stem_channels: int = 32, + out_channels: int = 32, + expand_ratio: int = 1, + normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), with_cp: bool = False, strides: tuple[int, int] = (2, 2), extra_stride: bool = False, @@ -666,190 +666,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return cp.checkpoint(self._inner_forward, x) if self.with_cp and x.requires_grad else self._inner_forward(x) -class StemV2(nn.Module): - """StemV2. - - Args: - in_channels (int): Number of input image channels. Typically 3. - stem_channels (int): Number of output channels of the stem layer. - out_channels (int): Number of output channels of the backbone network. - expand_ratio (int): Expansion ratio of the internal channels. - normalization (Callable[..., nn.Module]): Normalization layer module. - Defaults to ``nn.BatchNorm2d``. - with_cp (bool): Use checkpointing to save memory during forward pass. - num_stages (int): Number of stages in the backbone network. - strides (tuple[int, int]): Strides of the first and subsequent stages. - extra_stride (bool): Use an extra stride in the second stage. - input_norm (bool): Use instance normalization on the input image. - - Raises: - ValueError: If num_stages is less than 1. - TypeError: If strides is not a tuple or list. - ValueError: If len(strides) is not equal to num_stages + 1. - """ - - def __init__( - self, - in_channels: int, - stem_channels: int, - out_channels: int, - expand_ratio: int, - normalization: Callable[..., nn.Module] = nn.BatchNorm2d, - with_cp: bool = False, - num_stages: int = 1, - strides: tuple[int, int] = (2, 2), - extra_stride: bool = False, - input_norm: bool = False, - ) -> None: - """StemV2 initialization.""" - super().__init__() - - if num_stages < 1: - msg = "num_stages must be greater than 0." - raise ValueError(msg) - if not isinstance(strides, (tuple, list)): - msg = "strides must be tuple or list." - raise TypeError(msg) - - if len(strides) != 1 + num_stages: - msg = "len(strides) must equal to num_stages + 1." - raise ValueError(msg) - - self.in_channels = in_channels - self.out_channels = out_channels - self.normalization = normalization - self.with_cp = with_cp - self.num_stages = num_stages - - self.input_norm = None - if input_norm: - self.input_norm = nn.InstanceNorm2d(in_channels) - - self.conv1 = Conv2dModule( - in_channels=in_channels, - out_channels=stem_channels, - kernel_size=3, - stride=strides[0], - padding=1, - normalization=build_norm_layer(self.normalization, num_features=stem_channels), - activation=build_activation_layer(nn.ReLU), - ) - - self.conv2 = None - if extra_stride: - self.conv2 = Conv2dModule( - in_channels=stem_channels, - out_channels=stem_channels, - kernel_size=3, - stride=2, - padding=1, - normalization=build_norm_layer(self.normalization, num_features=stem_channels), - activation=build_activation_layer(nn.ReLU), - ) - - mid_channels = int(round(stem_channels * expand_ratio)) - internal_branch_channels = stem_channels // 2 - out_branch_channels = self.out_channels // 2 - - self.branch1, self.branch2 = nn.ModuleList(), nn.ModuleList() - for stage in range(1, num_stages + 1): - self.branch1.append( - nn.Sequential( - Conv2dModule( - internal_branch_channels, - internal_branch_channels, - kernel_size=3, - stride=strides[stage], - padding=1, - groups=internal_branch_channels, - normalization=build_norm_layer(normalization, num_features=internal_branch_channels), - activation=None, - ), - Conv2dModule( - internal_branch_channels, - out_branch_channels if stage == num_stages else internal_branch_channels, - kernel_size=1, - stride=1, - padding=0, - normalization=build_norm_layer( - normalization, - num_features=out_branch_channels if stage == num_stages else internal_branch_channels, - ), - activation=build_activation_layer(nn.ReLU), - ), - ), - ) - - self.branch2.append( - nn.Sequential( - Conv2dModule( - internal_branch_channels, - mid_channels, - kernel_size=1, - stride=1, - padding=0, - normalization=build_norm_layer(normalization, num_features=mid_channels), - activation=build_activation_layer(nn.ReLU), - ), - Conv2dModule( - mid_channels, - mid_channels, - kernel_size=3, - stride=strides[stage], - padding=1, - groups=mid_channels, - normalization=build_norm_layer(normalization, num_features=mid_channels), - activation=None, - ), - Conv2dModule( - mid_channels, - out_branch_channels if stage == num_stages else internal_branch_channels, - kernel_size=1, - stride=1, - padding=0, - normalization=build_norm_layer( - normalization, - num_features=out_branch_channels if stage == num_stages else internal_branch_channels, - ), - activation=build_activation_layer(nn.ReLU), - ), - ), - ) - - def _inner_forward(self, x: torch.Tensor) -> list[torch.Tensor]: - """Forward pass of Stem module. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). - - Returns: - list[torch.Tensor]: List of output tensors at each stage of the backbone. - """ - if self.input_norm is not None: - x = self.input_norm(x) - - y = self.conv1(x) - if self.conv2 is not None: - y = self.conv2(y) - - out_list = [y] - for stage in range(self.num_stages): - y1, y2 = y.chunk(2, dim=1) - - y1 = self.branch1[stage](y1) - y2 = self.branch2[stage](y2) - - y = torch.cat((y1, y2), dim=1) - y = channel_shuffle(y, 2) - out_list.append(y) - - return out_list - - def forward(self, x: torch.Tensor) -> list[torch.Tensor]: - """Forward.""" - return cp.checkpoint(self._inner_forward, x) if self.with_cp and x.requires_grad else self._inner_forward(x) - - class ShuffleUnit(nn.Module): """InvertedResidual block for ShuffleNetV2 backbone. @@ -1193,7 +1009,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: return out -class NNLiteHRNet(nn.Module): +class LiteHRNetModule(nn.Module): """Lite-HRNet backbone. `High-Resolution Representations for Labeling Pixels and Regions @@ -1215,8 +1031,8 @@ class NNLiteHRNet(nn.Module): def __init__( self, - stem: dict[str, Any], num_stages: int, + stem_configuration: dict[str, Any], stages_spec: dict[str, Any], in_channels: int = 3, normalization: Callable[..., nn.Module] = partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), @@ -1233,11 +1049,7 @@ def __init__( self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual - self.stem = Stem( - in_channels, - normalization=self.normalization, - **stem, - ) + self.stem = Stem(in_channels=in_channels, **stem_configuration, normalization=normalization) self.num_stages = num_stages self.stages_spec = stages_spec @@ -1435,14 +1247,7 @@ class LiteHRNetBackbone: LITEHRNET_CFG: ClassVar[dict[str, Any]] = { "lite_hrnet_s": { - "stem": { - "stem_channels": 32, - "out_channels": 32, - "expand_ratio": 1, - "strides": [2, 2], - "extra_stride": True, - "input_norm": False, - }, + "stem_configuration": {"extra_stride": True}, "num_stages": 2, "stages_spec": { "num_modules": [4, 4], @@ -1456,14 +1261,7 @@ class LiteHRNetBackbone: "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetsv2_imagenet1k_rsc.pth", }, "lite_hrnet_18": { - "stem": { - "stem_channels": 32, - "out_channels": 32, - "expand_ratio": 1, - "strides": [2, 2], - "extra_stride": False, - "input_norm": False, - }, + "stem_configuration": {}, "num_stages": 3, "stages_spec": { "num_modules": [2, 4, 2], @@ -1477,14 +1275,7 @@ class LiteHRNetBackbone: "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnet18_imagenet1k_rsc.pth", }, "lite_hrnet_x": { - "stem": { - "stem_channels": 60, - "out_channels": 60, - "expand_ratio": 1, - "strides": [2, 1], - "extra_stride": False, - "input_norm": False, - }, + "stem_configuration": {"stem_channels": 60, "out_channels": 60, "strides": (2, 1)}, "num_stages": 4, "stages_spec": { "weighting_module_version": "v1", @@ -1500,10 +1291,9 @@ class LiteHRNetBackbone: }, } - def __new__(cls, version: str) -> NNLiteHRNet: + def __new__(cls, model_name: str) -> LiteHRNetModule: """Constructor for LiteHRNet backbone.""" - if version not in cls.LITEHRNET_CFG: - msg = f"model type '{version}' is not supported" + if model_name not in cls.LITEHRNET_CFG: + msg = f"model type '{model_name}' is not supported" raise KeyError(msg) - - return NNLiteHRNet(**cls.LITEHRNET_CFG[version]) + return LiteHRNetModule(**cls.LITEHRNET_CFG[model_name]) diff --git a/src/otx/algo/segmentation/backbones/mscan.py b/src/otx/algo/segmentation/backbones/mscan.py index f0c37afe1ae..a58484f8ffa 100644 --- a/src/otx/algo/segmentation/backbones/mscan.py +++ b/src/otx/algo/segmentation/backbones/mscan.py @@ -324,7 +324,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: return x, h, w -class NNMSCAN(nn.Module): +class MSCANModule(nn.Module): """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone. This backbone is the implementation of `SegNeXt: Rethinking @@ -462,10 +462,10 @@ class MSCAN: }, } - def __new__(cls, version: str) -> NNMSCAN: + def __new__(cls, model_name: str) -> MSCANModule: """Constructor for MSCAN backbone.""" - if version not in cls.MSCAN_CFG: - msg = f"model type '{version}' is not supported" + if model_name not in cls.MSCAN_CFG: + msg = f"model type '{model_name}' is not supported" raise KeyError(msg) - return NNMSCAN(**cls.MSCAN_CFG[version]) + return MSCANModule(**cls.MSCAN_CFG[model_name]) diff --git a/src/otx/algo/segmentation/dino_v2_seg.py b/src/otx/algo/segmentation/dino_v2_seg.py index e8e5b810721..14a8ef6e74b 100644 --- a/src/otx/algo/segmentation/dino_v2_seg.py +++ b/src/otx/algo/segmentation/dino_v2_seg.py @@ -10,7 +10,7 @@ from otx.algo.segmentation.backbones import DinoVisionTransformer from otx.algo.segmentation.heads import FCNHead from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore -from otx.algo.segmentation.segmentors import BaseSegmModel +from otx.algo.segmentation.segmentors import BaseSegmentationModel from otx.core.model.segmentation import OTXSegmentationModel if TYPE_CHECKING: @@ -26,15 +26,15 @@ class DinoV2Seg(OTXSegmentationModel): ] def _build_model(self) -> nn.Module: - if self.model_version not in self.AVAILABLE_MODEL_VERSIONS: - msg = f"Model version {self.model_version} is not supported." + if self.model_name not in self.AVAILABLE_MODEL_VERSIONS: + msg = f"Model version {self.model_name} is not supported." raise ValueError(msg) - backbone = DinoVisionTransformer(name=self.model_version, freeze_backbone=True, out_index=[8, 9, 10, 11]) - decode_head = FCNHead(self.model_version, num_classes=self.num_classes) + backbone = DinoVisionTransformer(model_name=self.model_name, freeze_backbone=True, out_index=[8, 9, 10, 11]) + decode_head = FCNHead(self.model_name, num_classes=self.num_classes) criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined] - return BaseSegmModel( + return BaseSegmentationModel( backbone=backbone, decode_head=decode_head, criterion=criterion, diff --git a/src/otx/algo/segmentation/heads/base_segm_head.py b/src/otx/algo/segmentation/heads/base_segm_head.py index 2c62683b2e0..5d5ab25f80a 100644 --- a/src/otx/algo/segmentation/heads/base_segm_head.py +++ b/src/otx/algo/segmentation/heads/base_segm_head.py @@ -16,7 +16,7 @@ from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http -class BaseSegmHead(nn.Module): +class BaseSegmentationHead(nn.Module): """Base class for segmentation heads. Args: diff --git a/src/otx/algo/segmentation/heads/fcn_head.py b/src/otx/algo/segmentation/heads/fcn_head.py index 7f7801aa09e..7018e2c89c6 100644 --- a/src/otx/algo/segmentation/heads/fcn_head.py +++ b/src/otx/algo/segmentation/heads/fcn_head.py @@ -16,13 +16,13 @@ from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.modules import IterativeAggregator -from .base_segm_head import BaseSegmHead +from .base_segm_head import BaseSegmentationHead if TYPE_CHECKING: from pathlib import Path -class NNFCNHead(BaseSegmHead): +class FCNHeadModule(BaseSegmentationHead): """Fully Convolution Networks for Semantic Segmentation with aggregation. This head is implemented of `FCNNet `_. @@ -218,7 +218,6 @@ class FCNHead: "aggregator_use_concat": False, }, "dinov2_vits14": { - "normalization": partial(build_norm_layer, nn.SyncBatchNorm, requires_grad=True), "in_channels": [384, 384, 384, 384], "in_index": [0, 1, 2, 3], "input_transform": "resize_concat", @@ -227,10 +226,16 @@ class FCNHead: }, } - def __new__(cls, version: str, num_classes: int) -> NNFCNHead: + def __new__(cls, model_name: str, num_classes: int) -> FCNHeadModule: """Constructor for FCNHead.""" - if version not in cls.FCNHEAD_CFG: - msg = f"model type '{version}' is not supported" + if model_name not in cls.FCNHEAD_CFG: + msg = f"model type '{model_name}' is not supported" raise KeyError(msg) - return NNFCNHead(**cls.FCNHEAD_CFG[version], num_classes=num_classes) + normalization = ( + partial(build_norm_layer, nn.SyncBatchNorm, requires_grad=True) + if model_name == "dinov2_vits14" + else partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True) + ) + + return FCNHeadModule(**cls.FCNHEAD_CFG[model_name], num_classes=num_classes, normalization=normalization) diff --git a/src/otx/algo/segmentation/heads/ham_head.py b/src/otx/algo/segmentation/heads/ham_head.py index 67f2213888e..7d1cd16baa7 100644 --- a/src/otx/algo/segmentation/heads/ham_head.py +++ b/src/otx/algo/segmentation/heads/ham_head.py @@ -17,7 +17,7 @@ from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.modules import resize -from .base_segm_head import BaseSegmHead +from .base_segm_head import BaseSegmentationHead if TYPE_CHECKING: from pathlib import Path @@ -67,7 +67,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return f.relu(x + enjoy, inplace=True) -class NNLightHamHead(BaseSegmHead): +class LightHamHeadModule(BaseSegmentationHead): """SegNeXt decode head.""" def __init__( @@ -338,10 +338,10 @@ class LightHamHead: }, } - def __new__(cls, version: str, num_classes: int) -> NNLightHamHead: + def __new__(cls, model_name: str, num_classes: int) -> LightHamHeadModule: """Constructor for FCNHead.""" - if version not in cls.HAMHEAD_CFG: - msg = f"model type '{version}' is not supported" + if model_name not in cls.HAMHEAD_CFG: + msg = f"model type '{model_name}' is not supported" raise KeyError(msg) - return NNLightHamHead(**cls.HAMHEAD_CFG[version], num_classes=num_classes) + return LightHamHeadModule(**cls.HAMHEAD_CFG[model_name], num_classes=num_classes) diff --git a/src/otx/algo/segmentation/huggingface_model.py b/src/otx/algo/segmentation/huggingface_model.py index 83629896ed8..428a489e1fa 100644 --- a/src/otx/algo/segmentation/huggingface_model.py +++ b/src/otx/algo/segmentation/huggingface_model.py @@ -39,7 +39,7 @@ class HuggingFaceModelForSegmentation(OTXSegmentationModel): """A class representing a Hugging Face model for segmentation. Args: - model_name_or_path (str): The name or path of the pre-trained model. + model_name (str): The name or path of the pre-trained model. label_info (LabelInfoTypes): The label information for the model. optimizer (OptimizerCallable, optional): The optimizer for training the model. Defaults to DefaultOptimizerCallable. @@ -52,30 +52,30 @@ class HuggingFaceModelForSegmentation(OTXSegmentationModel): Example: 1. API >>> model = HuggingFaceModelForSegmentation( - ... model_name_or_path="nvidia/segformer-b0-finetuned-ade-512-512", + ... model_name="nvidia/segformer-b0-finetuned-ade-512-512", ... label_info=, ... ) 2. CLI >>> otx train \ ... --model otx.algo.segmentation.huggingface_model.HuggingFaceModelForSegmentation \ - ... --model.model_name_or_path nvidia/segformer-b0-finetuned-ade-512-512 + ... --model.model_name nvidia/segformer-b0-finetuned-ade-512-512 """ def __init__( self, - model_name_or_path: str, # https://huggingface.co/models?pipeline_tag=image-segmentation label_info: LabelInfoTypes, + model_name: str, # https://huggingface.co/models?pipeline_tag=image-segmentation input_size: tuple[int, int] = (512, 512), # input size of default semantic segmentation data recipe optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = SegmCallable, # type: ignore[assignment] torch_compile: bool = False, ) -> None: - self.model_name = model_name_or_path self.load_from = None super().__init__( label_info=label_info, + model_name=model_name, input_size=input_size, optimizer=optimizer, scheduler=scheduler, diff --git a/src/otx/algo/segmentation/litehrnet.py b/src/otx/algo/segmentation/litehrnet.py index 11b20d421f7..57eba12fc38 100644 --- a/src/otx/algo/segmentation/litehrnet.py +++ b/src/otx/algo/segmentation/litehrnet.py @@ -12,7 +12,7 @@ from otx.algo.segmentation.backbones import LiteHRNetBackbone from otx.algo.segmentation.heads import FCNHead from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore -from otx.algo.segmentation.segmentors import BaseSegmModel +from otx.algo.segmentation.segmentors import BaseSegmentationModel from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter @@ -32,14 +32,14 @@ class LiteHRNet(OTXSegmentationModel): ] def _build_model(self) -> nn.Module: - if self.model_version not in self.AVAILABLE_MODEL_VERSIONS: - msg = f"Model version {self.model_version} is not supported." + if self.model_name not in self.AVAILABLE_MODEL_VERSIONS: + msg = f"Model version {self.model_name} is not supported." raise ValueError(msg) - backbone = LiteHRNetBackbone(self.model_version) - decode_head = FCNHead(self.model_version, num_classes=self.num_classes) + backbone = LiteHRNetBackbone(self.model_name) + decode_head = FCNHead(self.model_name, num_classes=self.num_classes) criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined] - return BaseSegmModel( + return BaseSegmentationModel( backbone=backbone, decode_head=decode_head, criterion=criterion, @@ -87,7 +87,7 @@ def _exporter(self) -> OTXModelExporter: @property def ignore_scope(self) -> dict[str, Any]: """Get the ignored scope for LiteHRNet.""" - if self.model_version == "large": + if self.model_name == "large": return { "ignored_scope": { "patterns": ["__module.model.decode_head.aggregator/*"], @@ -175,7 +175,7 @@ def ignore_scope(self) -> dict[str, Any]: "preset": "performance", } - if self.model_version == "medium": + if self.model_name == "medium": return { "ignored_scope": { "patterns": ["__module.model.backbone/*"], @@ -263,7 +263,7 @@ def ignore_scope(self) -> dict[str, Any]: "preset": "mixed", } - if self.model_version == "small": + if self.model_name == "small": return { "ignored_scope": { "names": [ diff --git a/src/otx/algo/segmentation/segmentors/__init__.py b/src/otx/algo/segmentation/segmentors/__init__.py index 7b7456cded1..a730b141399 100644 --- a/src/otx/algo/segmentation/segmentors/__init__.py +++ b/src/otx/algo/segmentation/segmentors/__init__.py @@ -3,7 +3,7 @@ # """Module for base NN segmentation models.""" -from .base_model import BaseSegmModel +from .base_model import BaseSegmentationModel from .mean_teacher import MeanTeacher -__all__ = ["BaseSegmModel", "MeanTeacher"] +__all__ = ["BaseSegmentationModel", "MeanTeacher"] diff --git a/src/otx/algo/segmentation/segmentors/base_model.py b/src/otx/algo/segmentation/segmentors/base_model.py index c66c49f84f5..9cad16b45ea 100644 --- a/src/otx/algo/segmentation/segmentors/base_model.py +++ b/src/otx/algo/segmentation/segmentors/base_model.py @@ -14,7 +14,7 @@ from otx.core.data.entity.base import ImageInfo -class BaseSegmModel(nn.Module): +class BaseSegmentationModel(nn.Module): """Base Segmentation Model. Args: diff --git a/src/otx/algo/segmentation/segnext.py b/src/otx/algo/segmentation/segnext.py index 57e3176faa7..33a7a763193 100644 --- a/src/otx/algo/segmentation/segnext.py +++ b/src/otx/algo/segmentation/segnext.py @@ -10,7 +10,7 @@ from otx.algo.segmentation.backbones import MSCAN from otx.algo.segmentation.heads import LightHamHead from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore -from otx.algo.segmentation.segmentors import BaseSegmModel +from otx.algo.segmentation.segmentors import BaseSegmentationModel from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.model.segmentation import OTXSegmentationModel @@ -29,14 +29,14 @@ class SegNext(OTXSegmentationModel): def _build_model(self) -> nn.Module: # initialize backbones - if self.model_version not in self.AVAILABLE_MODEL_VERSIONS: - msg = f"Model version {self.model_version} is not supported." + if self.model_name not in self.AVAILABLE_MODEL_VERSIONS: + msg = f"Model version {self.model_name} is not supported." raise ValueError(msg) - backbone = MSCAN(version=self.model_version) - decode_head = LightHamHead(version=self.model_version, num_classes=self.num_classes) + backbone = MSCAN(model_name=self.model_name) + decode_head = LightHamHead(model_name=self.model_name, num_classes=self.num_classes) criterion = CrossEntropyLossWithIgnore(ignore_index=self.label_info.ignore_index) # type: ignore[attr-defined] - return BaseSegmModel( + return BaseSegmentationModel( backbone=backbone, decode_head=decode_head, criterion=criterion, diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 81b006d4d74..3bd7a1d308a 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -190,7 +190,11 @@ def engine_subcommand_parser(subcommand: str, **kwargs) -> tuple[ArgumentParser, if "logger" in added_arguments: parser.link_arguments("workspace.work_dir", "logger.init_args.save_dir", apply_on="instantiate") parser.link_arguments("workspace.work_dir", "logger.init_args.log_dir", apply_on="instantiate") - if "checkpoint" in added_arguments and "--checkpoint" in sys.argv: + if ( + "checkpoint" in added_arguments + and "--checkpoint" in sys.argv + and any("openvino_model.yaml" in arg for arg in sys.argv) + ): # This is code for an OVModel that uses checkpoint in model.model_name. parser.link_arguments("checkpoint", "model.init_args.model_name") diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index 330deb89bec..af8a4569ab0 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -47,13 +47,13 @@ class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity]): def __init__( self, label_info: LabelInfoTypes, + model_name: str, input_size: tuple[int, int] = (512, 512), optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = SegmCallable, # type: ignore[assignment] torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, - model_version: str | None = None, unsupervised_weight: float = 0.7, semisl_start_epoch: int = 2, drop_unreliable_pixels_percent: int = 20, @@ -62,6 +62,7 @@ def __init__( Args: label_info (LabelInfoTypes): The label information for the segmentation model. + model_name (str): The version/name/size of the model. input_size (tuple[int, int]): Model input size in the order of height and width. optimizer (OptimizerCallable, optional): The optimizer to use for training. Defaults to DefaultOptimizerCallable. @@ -73,7 +74,6 @@ def __init__( Defaults to False. train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The training type of the model. Defaults to OTXTrainType.SUPERVISED. - model_version (str | None, optional): The version of the model. Defaults to None. unsupervised_weight (float, optional): The weight of the unsupervised loss. Only for semi-supervised learning. Defaults to 0.7. semisl_start_epoch (int, optional): The epoch at which the semi-supervised learning starts. @@ -81,7 +81,7 @@ def __init__( drop_unreliable_pixels_percent (int, optional): The percentage of unreliable pixels to drop. Only for semi-supervised learning. Defaults to 20. """ - self.model_version = model_version + self.model_name = model_name self.unsupervised_weight = unsupervised_weight self.semisl_start_epoch = semisl_start_epoch self.drop_unreliable_pixels_percent = drop_unreliable_pixels_percent diff --git a/src/otx/recipe/semantic_segmentation/dino_v2.yaml b/src/otx/recipe/semantic_segmentation/dino_v2.yaml index 713b8e92624..8987ff8c398 100644 --- a/src/otx/recipe/semantic_segmentation/dino_v2.yaml +++ b/src/otx/recipe/semantic_segmentation/dino_v2.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.dino_v2_seg.DinoV2Seg init_args: label_info: 2 - model_version: dinov2_vits14 + model_name: dinov2_vits14 input_size: - 560 - 560 diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml index e7a20d7e369..35dc76bf161 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_18.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - model_version: lite_hrnet_18 + model_name: lite_hrnet_18 optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml index d353ffdfc4c..8ebc1d40867 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_s.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - model_version: lite_hrnet_s + model_name: lite_hrnet_s optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml b/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml index 85bb55d55ca..f6b1e0d39cc 100644 --- a/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml +++ b/src/otx/recipe/semantic_segmentation/litehrnet_x.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - model_version: lite_hrnet_x + model_name: lite_hrnet_x optimizer: class_path: torch.optim.Adam diff --git a/src/otx/recipe/semantic_segmentation/segnext_b.yaml b/src/otx/recipe/semantic_segmentation/segnext_b.yaml index 49626e58d6c..cad76ebf83b 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_b.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_b.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - model_version: segnext_base + model_name: segnext_base optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/segnext_s.yaml b/src/otx/recipe/semantic_segmentation/segnext_s.yaml index e8eab1d22e7..25f6ca9b1a9 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_s.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_s.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - model_version: segnext_small + model_name: segnext_small optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/segnext_t.yaml b/src/otx/recipe/semantic_segmentation/segnext_t.yaml index 755c26ee49c..fddc1256b5b 100644 --- a/src/otx/recipe/semantic_segmentation/segnext_t.yaml +++ b/src/otx/recipe/semantic_segmentation/segnext_t.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - model_version: segnext_tiny + model_name: segnext_tiny optimizer: class_path: torch.optim.AdamW diff --git a/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml index 7dc5ece097c..721085499db 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/dino_v2_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.dino_v2_seg.DinoV2Seg init_args: label_info: 2 - model_version: dinov2_vits14 + model_name: dinov2_vits14 train_type: SEMI_SUPERVISED input_size: - 560 diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml index a98f1ab47a2..d0f570c6c36 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_18_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - model_version: lite_hrnet_18 + model_name: lite_hrnet_18 train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml index c0cd0de594f..aa73a34aee6 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_s_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - model_version: lite_hrnet_s + model_name: lite_hrnet_s train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml index ab757f65887..5abb0004bac 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/litehrnet_x_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.litehrnet.LiteHRNet init_args: label_info: 2 - model_version: lite_hrnet_x + model_name: lite_hrnet_x train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml index 395d0fb5c5e..abc39e7273f 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_b_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - model_version: segnext_base + model_name: segnext_base train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml index 8748572ed6b..477d591a49f 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_s_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - model_version: segnext_small + model_name: segnext_small train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml b/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml index b6b884b2759..f4a4d750620 100644 --- a/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml +++ b/src/otx/recipe/semantic_segmentation/semisl/segnext_t_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.segmentation.segnext.SegNext init_args: label_info: 2 - model_version: segnext_tiny + model_name: segnext_tiny train_type: SEMI_SUPERVISED optimizer: diff --git a/tests/unit/algo/segmentation/backbones/test_dinov2.py b/tests/unit/algo/segmentation/backbones/test_dinov2.py index 0e5f920d67e..45fb2aaf84e 100644 --- a/tests/unit/algo/segmentation/backbones/test_dinov2.py +++ b/tests/unit/algo/segmentation/backbones/test_dinov2.py @@ -30,7 +30,7 @@ def mock_torch_hub_load(self, mocker, mock_backbone): return mocker.patch("otx.algo.segmentation.backbones.dinov2.torch.hub.load", return_value=mock_backbone) def test_init(self, mock_backbone, mock_backbone_named_parameters): - dino = DinoVisionTransformer(name="dinov2_vits14", freeze_backbone=True, out_index=[8, 9, 10, 11]) + dino = DinoVisionTransformer(model_name="dinov2_vits14", freeze_backbone=True, out_index=[8, 9, 10, 11]) assert dino.backbone == mock_backbone for parameter in mock_backbone_named_parameters.values(): @@ -39,7 +39,7 @@ def test_init(self, mock_backbone, mock_backbone_named_parameters): @pytest.fixture() def dino_vit(self) -> DinoVisionTransformer: return DinoVisionTransformer( - name="dinov2_vits14", + model_name="dinov2_vits14", freeze_backbone=True, out_index=[8, 9, 10, 11], ) diff --git a/tests/unit/algo/segmentation/backbones/test_litehrnet.py b/tests/unit/algo/segmentation/backbones/test_litehrnet.py index 32242fc3549..099587345a8 100644 --- a/tests/unit/algo/segmentation/backbones/test_litehrnet.py +++ b/tests/unit/algo/segmentation/backbones/test_litehrnet.py @@ -2,7 +2,7 @@ import pytest import torch -from otx.algo.segmentation.backbones.litehrnet import NeighbourSupport, NNLiteHRNet, SpatialWeightingV2, StemV2 +from otx.algo.segmentation.backbones.litehrnet import LiteHRNetModule, NeighbourSupport, SpatialWeightingV2, Stem class TestSpatialWeightingV2: @@ -15,27 +15,27 @@ def test_forward(self) -> None: assert outputs is not None -class TestStemV2: +class TestStem: @pytest.fixture() - def stemv2(self) -> StemV2: - return StemV2(in_channels=32, stem_channels=32, out_channels=32, expand_ratio=1) + def stem(self) -> Stem: + return Stem(in_channels=32, stem_channels=32, out_channels=32, expand_ratio=1) def test_init(self) -> None: - stemv2_extra_stride = StemV2( + stem_extra_stride = Stem( in_channels=32, stem_channels=32, out_channels=32, expand_ratio=1, extra_stride=True, ) - assert stemv2_extra_stride is not None + assert stem_extra_stride is not None - stemv2_input_norm = StemV2(in_channels=32, stem_channels=32, out_channels=32, expand_ratio=1, input_norm=True) - assert stemv2_input_norm is not None + stem_input_norm = Stem(in_channels=32, stem_channels=32, out_channels=32, expand_ratio=1, input_norm=True) + assert stem_input_norm is not None - def test_forward(self, stemv2) -> None: + def test_forward(self, stem) -> None: inputs = torch.randn(1, 32, 32, 32) - outputs = stemv2(inputs) + outputs = stem(inputs) assert outputs is not None @@ -49,18 +49,11 @@ def test_forward(self) -> None: assert outputs is not None -class TestNNLiteHRNet: +class TestLiteHRNetModule: @pytest.fixture() def cfg(self) -> dict: return { - "stem": { - "stem_channels": 32, - "out_channels": 32, - "expand_ratio": 1, - "strides": (2, 2), - "extra_stride": False, - "input_norm": False, - }, + "stem_configuration": {}, "num_stages": 3, "stages_spec": { "num_modules": (2, 4, 2), @@ -78,8 +71,8 @@ def cfg(self) -> dict: } @pytest.fixture() - def backbone(self, cfg) -> NNLiteHRNet: - return NNLiteHRNet(**cfg) + def backbone(self, cfg) -> LiteHRNetModule: + return LiteHRNetModule(**cfg) @pytest.fixture() def mock_torch_load(self, mocker) -> MagicMock: @@ -100,7 +93,7 @@ def pretrained_weight(self, tmp_path) -> str: return str(weight) def test_init(self, cfg) -> None: - model = NNLiteHRNet(**cfg) + model = LiteHRNetModule(**cfg) assert model is not None def test_forward(self, cfg, backbone) -> None: @@ -127,7 +120,7 @@ def test_load_pretrained_weights( mock_torch_load, mock_load_checkpoint_to_model, ): - model = NNLiteHRNet(**cfg) + model = LiteHRNetModule(**cfg) model.load_pretrained_weights(pretrained=pretrained_weight) mock_torch_load.assert_called_once_with(pretrained_weight, "cpu") diff --git a/tests/unit/algo/segmentation/backbones/test_mscan.py b/tests/unit/algo/segmentation/backbones/test_mscan.py index 441e121ead1..da242475c5a 100644 --- a/tests/unit/algo/segmentation/backbones/test_mscan.py +++ b/tests/unit/algo/segmentation/backbones/test_mscan.py @@ -4,7 +4,7 @@ import pytest import torch from otx.algo.segmentation.backbones import mscan as target_file -from otx.algo.segmentation.backbones.mscan import NNMSCAN, DropPath, drop_path +from otx.algo.segmentation.backbones.mscan import DropPath, MSCANModule, drop_path @pytest.mark.parametrize("dim", [1, 2, 3, 4]) @@ -59,7 +59,7 @@ def test_forward(self): class TestMSCABlock: def test_init(self): num_stages = 4 - mscan = NNMSCAN(num_stages=num_stages) + mscan = MSCANModule(num_stages=num_stages) for i in range(num_stages): assert hasattr(mscan, f"patch_embed{i + 1}") @@ -68,7 +68,7 @@ def test_init(self): def test_forward(self): num_stages = 4 - mscan = NNMSCAN(num_stages=num_stages) + mscan = MSCANModule(num_stages=num_stages) x = torch.rand(8, 3, 3, 3) out = mscan.forward(x) @@ -93,14 +93,14 @@ def mock_torch_load(self, mocker) -> MagicMock: return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load") def test_load_pretrained_weights(self, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model): - NNMSCAN(pretrained_weights=pretrained_weight) + MSCANModule(pretrained_weights=pretrained_weight) mock_torch_load.assert_called_once_with(pretrained_weight, "cpu") mock_load_checkpoint_to_model.assert_called_once() def test_load_pretrained_weights_from_url(self, mock_load_from_http, mock_load_checkpoint_to_model): pretrained_weight = "www.fake.com/fake.pth" - NNMSCAN(pretrained_weights=pretrained_weight) + MSCANModule(pretrained_weights=pretrained_weight) cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" mock_load_from_http.assert_called_once_with(filename=pretrained_weight, map_location="cpu", model_dir=cache_dir) diff --git a/tests/unit/algo/segmentation/heads/test_class_incremental.py b/tests/unit/algo/segmentation/heads/test_class_incremental.py index 2ea0bf1713c..0b8562c8632 100644 --- a/tests/unit/algo/segmentation/heads/test_class_incremental.py +++ b/tests/unit/algo/segmentation/heads/test_class_incremental.py @@ -14,7 +14,7 @@ class MockGT: class TestClassIncrementalMixin: def test_ignore_label(self) -> None: - hrnet = LiteHRNet(3, input_size=(128, 128), model_version="lite_hrnet_18") + hrnet = LiteHRNet(3, input_size=(128, 128), model_name="lite_hrnet_18") seg_logits = torch.randn(1, 3, 128, 128) # no annotations for class=3 diff --git a/tests/unit/algo/segmentation/heads/test_ham_head.py b/tests/unit/algo/segmentation/heads/test_ham_head.py index d8148154627..2352af4f91e 100644 --- a/tests/unit/algo/segmentation/heads/test_ham_head.py +++ b/tests/unit/algo/segmentation/heads/test_ham_head.py @@ -6,11 +6,11 @@ import pytest import torch from otx.algo.modules.norm import build_norm_layer -from otx.algo.segmentation.heads.ham_head import NNLightHamHead +from otx.algo.segmentation.heads.ham_head import LightHamHeadModule from torch import nn -class TestNNLightHamHead: +class TestLightHamHeadModule: @pytest.fixture() def head_config(self) -> dict[str, Any]: return { @@ -26,7 +26,7 @@ def head_config(self) -> dict[str, Any]: } def test_init(self, head_config): - light_ham_head = NNLightHamHead(**head_config) + light_ham_head = LightHamHeadModule(**head_config) assert light_ham_head.ham_channels == head_config["ham_channels"] @pytest.fixture() @@ -43,7 +43,7 @@ def fake_input(self, batch_size) -> list[torch.Tensor]: ] def test_forward(self, head_config, fake_input, batch_size): - light_ham_head = NNLightHamHead(**head_config) + light_ham_head = LightHamHeadModule(**head_config) out = light_ham_head.forward(fake_input) assert out.size()[0] == batch_size assert out.size()[2] == fake_input[head_config["in_index"][0]].size()[2] diff --git a/tests/unit/algo/segmentation/segmentors/test_base_model.py b/tests/unit/algo/segmentation/segmentors/test_base_model.py index 33a33af4dda..d970ead0c32 100644 --- a/tests/unit/algo/segmentation/segmentors/test_base_model.py +++ b/tests/unit/algo/segmentation/segmentors/test_base_model.py @@ -3,17 +3,17 @@ # import pytest import torch -from otx.algo.segmentation.segmentors.base_model import BaseSegmModel +from otx.algo.segmentation.segmentors.base_model import BaseSegmentationModel from otx.core.data.entity.base import ImageInfo -class TestBaseSegmModel: +class TestBaseSegmentationModel: @pytest.fixture() def model(self): backbone = torch.nn.Sequential(torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)) decode_head = torch.nn.Sequential(torch.nn.Conv2d(64, 2, kernel_size=1)) decode_head.num_classes = 3 - return BaseSegmModel(backbone, decode_head) + return BaseSegmentationModel(backbone, decode_head) @pytest.fixture() def inputs(self): diff --git a/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py b/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py index f7be592b5fa..1bf3bede605 100644 --- a/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py +++ b/tests/unit/algo/segmentation/segmentors/test_mean_teacher.py @@ -5,7 +5,7 @@ import pytest import torch from otx.algo.segmentation.losses import CrossEntropyLossWithIgnore -from otx.algo.segmentation.segmentors import BaseSegmModel, MeanTeacher +from otx.algo.segmentation.segmentors import BaseSegmentationModel, MeanTeacher from otx.core.data.entity.base import ImageInfo from torch import nn @@ -16,7 +16,7 @@ def model(self): decode_head = nn.Conv2d(3, 2, 1) decode_head.num_classes = 2 loss = CrossEntropyLossWithIgnore(ignore_index=255) - model = BaseSegmModel( + model = BaseSegmentationModel( backbone=nn.Sequential(nn.Conv2d(3, 5, 1), nn.ReLU(), nn.Conv2d(5, 3, 1)), decode_head=decode_head, criterion=loss, diff --git a/tests/unit/algo/segmentation/test_dino_v2_seg.py b/tests/unit/algo/segmentation/test_dino_v2_seg.py index 5353a43616a..d51c3d75665 100644 --- a/tests/unit/algo/segmentation/test_dino_v2_seg.py +++ b/tests/unit/algo/segmentation/test_dino_v2_seg.py @@ -10,7 +10,7 @@ class TestDinoV2Seg: @pytest.fixture(scope="class") def fxt_dino_v2_seg(self) -> DinoV2Seg: - return DinoV2Seg(label_info=10, model_version="dinov2_vits14", input_size=(560, 560)) + return DinoV2Seg(label_info=10, model_name="dinov2_vits14", input_size=(560, 560)) def test_dino_v2_seg_init(self, fxt_dino_v2_seg): assert isinstance(fxt_dino_v2_seg, DinoV2Seg) diff --git a/tests/unit/algo/segmentation/test_huggingface_model.py b/tests/unit/algo/segmentation/test_huggingface_model.py index 36693561692..053aef6fd2e 100644 --- a/tests/unit/algo/segmentation/test_huggingface_model.py +++ b/tests/unit/algo/segmentation/test_huggingface_model.py @@ -24,7 +24,7 @@ class TestHuggingFaceModelForSegmentation: @pytest.fixture() def fxt_seg_model(self): return HuggingFaceModelForSegmentation( - model_name_or_path="nvidia/segformer-b0-finetuned-ade-512-512", + model_name="nvidia/segformer-b0-finetuned-ade-512-512", label_info=2, ) @@ -86,7 +86,7 @@ def mock_automodel(self, mocker) -> MagicMock: def test_set_input_size(self, mock_pretrainedconfig, mock_automodel): input_size = (1, 3, 1024, 1024) HuggingFaceModelForSegmentation( - model_name_or_path="facebook/deit-tiny-patch16-224", + model_name="facebook/deit-tiny-patch16-224", label_info=10, input_size=input_size, ) diff --git a/tests/unit/algo/segmentation/test_segnext.py b/tests/unit/algo/segmentation/test_segnext.py index 375ad9d0b61..cd30a4bea96 100644 --- a/tests/unit/algo/segmentation/test_segnext.py +++ b/tests/unit/algo/segmentation/test_segnext.py @@ -10,7 +10,7 @@ class TestSegNext: @pytest.fixture() def fxt_segnext(self) -> SegNext: - return SegNext(10, model_version="segnext_base", input_size=(512, 512)) + return SegNext(10, model_name="segnext_base", input_size=(512, 512)) def test_segnext_init(self, fxt_segnext): assert isinstance(fxt_segnext, SegNext) diff --git a/tests/unit/core/model/test_segmentation.py b/tests/unit/core/model/test_segmentation.py index d364c9ab273..b7181ce87cc 100644 --- a/tests/unit/core/model/test_segmentation.py +++ b/tests/unit/core/model/test_segmentation.py @@ -18,7 +18,7 @@ class TestOTXSegmentationModel: @pytest.fixture() def model(self, label_info, optimizer, scheduler, metric, torch_compile): - return OTXSegmentationModel(label_info, (512, 512), optimizer, scheduler, metric, torch_compile) + return OTXSegmentationModel(label_info, "segm_model", (512, 512), optimizer, scheduler, metric, torch_compile) @pytest.fixture() def batch_data_entity(self): @@ -76,6 +76,8 @@ def test_dispatch_label_info(self, model, label_info, expected_label_info): def test_init(self, model): assert model.num_classes == 3 + assert model.model_name == "segm_model" + assert model.input_size == (512, 512) def test_customize_inputs(self, model, batch_data_entity): customized_inputs = model._customize_inputs(batch_data_entity)