diff --git a/src/otx/algo/classification/backbones/__init__.py b/src/otx/algo/classification/backbones/__init__.py index b92c0811293..4a9c8a98aaa 100644 --- a/src/otx/algo/classification/backbones/__init__.py +++ b/src/otx/algo/classification/backbones/__init__.py @@ -4,7 +4,7 @@ """Backbone modules for OTX custom model.""" from .otx_efficientnet import OTXEfficientNet -from .otx_efficientnet_v2 import OTXEfficientNetV2 +from .timm import TimmBackbone from .mobilenet_v3 import OTXMobileNetV3 -__all__ = ["OTXEfficientNet", "OTXEfficientNetV2", "OTXMobileNetV3"] +__all__ = ["OTXEfficientNet", "TimmBackbone", "OTXMobileNetV3"] diff --git a/src/otx/algo/classification/backbones/otx_efficientnet_v2.py b/src/otx/algo/classification/backbones/timm.py similarity index 67% rename from src/otx/algo/classification/backbones/otx_efficientnet_v2.py rename to src/otx/algo/classification/backbones/timm.py index 27fc7912e96..8048aa44bbb 100644 --- a/src/otx/algo/classification/backbones/otx_efficientnet_v2.py +++ b/src/otx/algo/classification/backbones/timm.py @@ -1,6 +1,6 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """EfficientNetV2 model. Original papers: @@ -11,10 +11,11 @@ import os +import torch import timm -from mmengine.runner import load_checkpoint -from mmpretrain.registry import MODELS +from otx.algo.utils.mmengine_utils import load_from_http, load_checkpoint_to_model from torch import nn +from typing import Literal PRETRAINED_ROOT = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/" pretrained_urls = { @@ -22,7 +23,7 @@ "efficientnetv2_s_1k": PRETRAINED_ROOT + "tf_efficientnetv2_s_21ft1k-d7dafa41.pth", } -NAME_DICT = { +TIMM_MODEL_NAME_DICT = { "mobilenetv3_large_21k": "mobilenetv3_large_100_miil_in21k", "mobilenetv3_large_1k": "mobilenetv3_large_100_miil", "tresnet": "tresnet_m", @@ -33,22 +34,34 @@ "efficientnetv2_b0": "tf_efficientnetv2_b0", } - -class TimmModelsWrapper(nn.Module): - """Timm model wrapper.""" - - def __init__(self, model_name, pretrained=False, pooling_type="avg", **kwargs): +TimmModelType = Literal[ + "mobilenetv3_large_21k", + "mobilenetv3_large_1k", + "tresnet", + "efficientnetv2_s_21k", + "efficientnetv2_s_1k", + "efficientnetv2_m_21k", + "efficientnetv2_m_1k", + "efficientnetv2_b0", +] + + +class TimmBackbone(nn.Module): + def __init__( + self, + backbone: TimmModelType, + pretrained=False, + pooling_type="avg", + **kwargs, + ): super().__init__(**kwargs) - self.model_name = model_name + self.backbone = backbone self.pretrained = pretrained - if model_name in ["mobilenetv3_large_100_miil_in21k", "mobilenetv3_large_100_miil"]: - self.is_mobilenet = True - else: - self.is_mobilenet = False + self.is_mobilenet = backbone.startswith("mobilenet") - self.model = timm.create_model(NAME_DICT[self.model_name], pretrained=pretrained, num_classes=1000) + self.model = timm.create_model(TIMM_MODEL_NAME_DICT[self.backbone], pretrained=pretrained, num_classes=1000) if self.pretrained: - print(f"init weight - {pretrained_urls[self.model_name]}") + print(f"init weight - {pretrained_urls[self.backbone]}") self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx. self.num_head_features = self.model.num_features self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features @@ -85,20 +98,14 @@ def get_config_optim(self, lrs): return parameters - -@MODELS.register_module() -class OTXEfficientNetV2(TimmModelsWrapper): - """EfficientNetV2 for OTX.""" - - def __init__(self, version="s_21k", **kwargs): - self.model_name = "efficientnetv2_" + version - super().__init__(model_name=self.model_name, **kwargs) - - def init_weights(self, pretrained=None): + def init_weights(self, pretrained: str | bool | None = None): """Initialize weights.""" + checkpoint = None if isinstance(pretrained, str) and os.path.exists(pretrained): - load_checkpoint(self, pretrained) + checkpoint = torch.load(pretrained, None) print(f"init weight - {pretrained}") elif pretrained is not None: - load_checkpoint(self, pretrained_urls[self.model_name]) - print(f"init weight - {pretrained_urls[self.model_name]}") + checkpoint = load_from_http(pretrained_urls[self.key]) + print(f"init weight - {pretrained_urls[self.key]}") + if checkpoint is not None: + load_checkpoint_to_model(self, checkpoint) diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py index 08de563bd73..69007658e2f 100644 --- a/src/otx/algo/classification/efficientnet_v2.py +++ b/src/otx/algo/classification/efficientnet_v2.py @@ -4,18 +4,35 @@ """EfficientNetV2 model implementation.""" from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable -from otx.algo.utils.mmconfig import read_mmconfig +import torch +from torch import nn + +from otx.algo.classification.backbones.timm import TimmBackbone +from otx.algo.classification.classifier.base_classifier import ImageClassifier +from otx.algo.classification.heads import HierarchicalLinearClsHead, LinearClsHead, MultiLabelLinearClsHead +from otx.algo.classification.necks.gap import GlobalAveragePooling from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.data.entity.classification import ( + HlabelClsBatchDataEntity, + HlabelClsBatchPredEntity, + MulticlassClsBatchDataEntity, + MulticlassClsBatchPredEntity, + MultilabelClsBatchDataEntity, + MultilabelClsBatchPredEntity, +) +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter +from otx.core.metrics import MetricInput from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import ( - MMPretrainHlabelClsModel, - MMPretrainMulticlassClsModel, - MMPretrainMultilabelClsModel, + OTXHlabelClsModel, + OTXMulticlassClsModel, + OTXMultilabelClsModel, ) -from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo, LabelInfoTypes @@ -25,82 +42,369 @@ from otx.core.metrics import MetricCallable -class EfficientNetV2ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlabelClsModel): - """EfficientNetV2 Model for hierarchical label classification task.""" +class EfficientNetV2ForMulticlassCls(OTXMulticlassClsModel): + """EfficientNetV2 Model for multi-class classification task.""" def __init__( self, label_info: HLabelInfo, + loss_callable: Callable[[], nn.Module] = nn.CrossEntropyLoss, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = HLabelClsMetricCallble, + metric: MetricCallable = MultiClassClsMetricCallable, torch_compile: bool = False, ) -> None: - config = read_mmconfig("efficientnet_v2_light", subdir_name="hlabel_classification") + self.head_config = {"loss_callable": loss_callable} super().__init__( label_info=label_info, - config=config, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, ) + def _create_model(self) -> nn.Module: + loss = self.head_config["loss_callable"] + return ImageClassifier( + backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), + neck=GlobalAveragePooling(dim=2), + head=LinearClsHead( + num_classes=self.label_info.num_classes, + in_channels=1280, + topk=(1, 5) if self.label_info.num_classes >= 5 else (1,), + loss=loss if isinstance(loss, nn.Module) else loss(), + ), + ) + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix) + + def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + return { + "images": inputs.stacked_images, + "labels": torch.cat(inputs.labels, dim=0), + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: MulticlassClsBatchDataEntity, + ) -> MulticlassClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + # To list, batch-wise + logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] + scores = torch.unbind(logits, 0) + preds = logits.argmax(-1, keepdim=True).unbind(0) + + return MulticlassClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=preds, + ) + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # NOTE: This should be done via onnx + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") -class EfficientNetV2ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel): + return MulticlassClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") + + +class EfficientNetV2ForMultilabelCls(OTXMultilabelClsModel): """EfficientNetV2 Model for multi-label classification task.""" def __init__( self, label_info: LabelInfoTypes, - light: bool = False, + loss_callable: Callable[[], nn.Module] = nn.CrossEntropyLoss, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = MultiClassClsMetricCallable, + metric: MetricCallable = MultiLabelClsMetricCallable, torch_compile: bool = False, ) -> None: - model_name = "efficientnet_v2_light" if light else "efficientnet_v2" - config = read_mmconfig(model_name=model_name, subdir_name="multiclass_classification") + self.head_config = {"loss_callable": loss_callable} + super().__init__( label_info=label_info, - config=config, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, ) + def _create_model(self) -> nn.Module: + loss = self.head_config["loss_callable"] + return ImageClassifier( + backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), + neck=GlobalAveragePooling(dim=2), + head=MultiLabelLinearClsHead( + num_classes=self.label_info.num_classes, + in_channels=1280, + loss=loss if isinstance(loss, nn.Module) else loss(), + normalized=True, + scale=7.0, + ), + ) + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix) + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix) + def _customize_inputs(self, inputs: MultilabelClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" -class EfficientNetV2ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel): - """EfficientNetV2 Model for multi-class classification task.""" + return { + "images": inputs.stacked_images, + "labels": torch.stack(inputs.labels), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: MultilabelClsBatchDataEntity, + ) -> MultilabelClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + # To list, batch-wise + logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] + scores = torch.unbind(logits, 0) + + return MultilabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=logits.argmax(-1, keepdim=True).unbind(0), + ) + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # NOTE: This should be done via onnx + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + + def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return MultilabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") + + +class EfficientNetV2ForHLabelCls(OTXHlabelClsModel): + """EfficientNetV2 Model for hierarchical label classification task.""" + + label_info: HLabelInfo def __init__( self, - label_info: LabelInfoTypes, + label_info: HLabelInfo, + multiclass_loss_callable: Callable[[], nn.Module] = nn.CrossEntropyLoss, + multilabel_loss_callable: Callable[[], nn.Module] = nn.CrossEntropyLoss, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = MultiLabelClsMetricCallable, + metric: MetricCallable = HLabelClsMetricCallble, torch_compile: bool = False, ) -> None: - config = read_mmconfig("efficientnet_v2_light", subdir_name="multilabel_classification") + self.head_config = { + "multiclass_loss_callable": multiclass_loss_callable, + "multilabel_loss_callable": multilabel_loss_callable, + } + super().__init__( label_info=label_info, - config=config, optimizer=optimizer, scheduler=scheduler, metric=metric, torch_compile=torch_compile, ) + def _create_model(self) -> nn.Module: + multiclass_loss = self.head_config["multiclass_loss_callable"] + multilabel_loss = self.head_config["multilabel_loss_callable"] + return ImageClassifier( + backbone=TimmBackbone(backbone="efficientnetv2_s_21k", pretrained=True), + neck=GlobalAveragePooling(dim=2), + head=HierarchicalLinearClsHead( + in_channels=1280, + multiclass_loss=multiclass_loss if isinstance(multiclass_loss, nn.Module) else multiclass_loss(), + multilabel_loss=multilabel_loss if isinstance(multilabel_loss, nn.Module) else multilabel_loss(), + **self.label_info.as_head_config_dict(), + ), + ) + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multilabel", add_prefix) + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) + + def _customize_inputs(self, inputs: HlabelClsBatchDataEntity) -> dict[str, Any]: + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" + else: + mode = "predict" + + return { + "images": inputs.stacked_images, + "labels": torch.stack(inputs.labels), + "imgs_info": inputs.imgs_info, + "mode": mode, + } + + def _customize_outputs( + self, + outputs: Any, # noqa: ANN401 + inputs: HlabelClsBatchDataEntity, + ) -> HlabelClsBatchPredEntity | OTXBatchLossEntity: + if self.training: + return OTXBatchLossEntity(loss=outputs) + + # To list, batch-wise + if isinstance(outputs, dict): + scores = outputs["pred_scores"] + labels = outputs["pred_labels"] + else: + scores = outputs + labels = outputs.argmax(-1, keepdim=True) + + return HlabelClsBatchPredEntity( + batch_size=inputs.batch_size, + images=inputs.images, + imgs_info=inputs.imgs_info, + scores=scores, + labels=labels, + ) + + def _convert_pred_entity_to_compute_metric( + self, + preds: HlabelClsBatchPredEntity, + inputs: HlabelClsBatchDataEntity, + ) -> MetricInput: + hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] + + _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels + _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores + if hlabel_info.num_multilabel_classes > 0: + preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] + preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] + pred_result = torch.cat([preds_multiclass, preds_multilabel], dim=1) + else: + pred_result = _labels + return { + "preds": pred_result, + "target": torch.stack(inputs.labels), + } + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # NOTE: This should be done via onnx + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + + def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return HlabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") + + return self.model(images=image, mode="tensor") diff --git a/src/otx/algo/classification/mmconfigs/hlabel_classification/efficientnet_v2_light.yaml b/src/otx/algo/classification/mmconfigs/hlabel_classification/efficientnet_v2_light.yaml deleted file mode 100644 index 49250229bc4..00000000000 --- a/src/otx/algo/classification/mmconfigs/hlabel_classification/efficientnet_v2_light.yaml +++ /dev/null @@ -1,33 +0,0 @@ -backbone: - pretrained: true - type: OTXEfficientNetV2 -head: - num_multiclass_heads: 0 - num_multilabel_classes: 0 - in_channels: 1280 - num_classes: 1000 - multiclass_loss_cfg: - loss_weight: 1.0 - type: CrossEntropyLoss - multilabel_loss_cfg: - reduction: sum - gamma_neg: 1.0 - gamma_pos: 0.0 - type: AsymmetricAngularLossWithIgnore - normalized: true - scale: 7.0 - type: CustomHierarchicalLinearClsHead -neck: - type: GlobalAveragePooling -data_preprocessor: - mean: - - 123.675 - - 116.28 - - 103.53 - std: - - 58.395 - - 57.12 - - 57.375 - to_rgb: False - type: ClsDataPreprocessor -type: ImageClassifier diff --git a/src/otx/algo/classification/mmconfigs/multiclass_classification/efficientnet_v2.yaml b/src/otx/algo/classification/mmconfigs/multiclass_classification/efficientnet_v2.yaml deleted file mode 100644 index 077d00d3a6c..00000000000 --- a/src/otx/algo/classification/mmconfigs/multiclass_classification/efficientnet_v2.yaml +++ /dev/null @@ -1,38 +0,0 @@ -backbone: - pretrained: true - type: OTXEfficientNetV2 -head: - act_cfg: - type: HSwish - dropout_rate: 0.2 - in_channels: 1280 - init_cfg: - bias: 0.0 - layer: Linear - mean: 0.0 - std: 0.01 - type: Normal - loss: - loss_weight: 1.0 - type: CrossEntropyLoss - mid_channels: - - 1280 - num_classes: 1000 - topk: - - 1 - - 5 - type: StackedLinearClsHead -neck: - type: GlobalAveragePooling -data_preprocessor: - mean: - - 123.675 - - 116.28 - - 103.53 - std: - - 58.395 - - 57.12 - - 57.375 - to_rgb: False - type: ClsDataPreprocessor -type: ImageClassifier diff --git a/src/otx/algo/classification/mmconfigs/multiclass_classification/efficientnet_v2_light.yaml b/src/otx/algo/classification/mmconfigs/multiclass_classification/efficientnet_v2_light.yaml deleted file mode 100644 index c2599fa9605..00000000000 --- a/src/otx/algo/classification/mmconfigs/multiclass_classification/efficientnet_v2_light.yaml +++ /dev/null @@ -1,27 +0,0 @@ -backbone: - pretrained: true - type: OTXEfficientNetV2 -head: - in_channels: 1280 - loss: - loss_weight: 1.0 - type: CrossEntropyLoss - num_classes: 1000 - topk: - - 1 - - 5 - type: LinearClsHead -neck: - type: GlobalAveragePooling -data_preprocessor: - mean: - - 123.675 - - 116.28 - - 103.53 - std: - - 58.395 - - 57.12 - - 57.375 - to_rgb: False - type: ClsDataPreprocessor -type: ImageClassifier diff --git a/src/otx/algo/classification/mmconfigs/multilabel_classification/efficientnet_v2_light.yaml b/src/otx/algo/classification/mmconfigs/multilabel_classification/efficientnet_v2_light.yaml deleted file mode 100644 index 29eb048563c..00000000000 --- a/src/otx/algo/classification/mmconfigs/multilabel_classification/efficientnet_v2_light.yaml +++ /dev/null @@ -1,28 +0,0 @@ -backbone: - pretrained: true - type: OTXEfficientNetV2 -head: - in_channels: 1280 - num_classes: 1000 - loss: - reduction: sum - gamma_neg: 1.0 - gamma_pos: 0.0 - type: AsymmetricAngularLossWithIgnore - normalized: true - scale: 7.0 - type: CustomMultiLabelLinearClsHead -neck: - type: GlobalAveragePooling -data_preprocessor: - mean: - - 123.675 - - 116.28 - - 103.53 - std: - - 58.395 - - 57.12 - - 57.375 - to_rgb: False - type: ClsDataPreprocessor -type: ImageClassifier diff --git a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml new file mode 100644 index 00000000000..9ccddbc6636 --- /dev/null +++ b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml @@ -0,0 +1,99 @@ +model: + class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForHLabelCls + init_args: + multiclass_loss_callable: + class_path: torch.nn.CrossEntropyLoss + + multilabel_loss_callable: + class_path: otx.algo.classification.losses.AsymmetricAngularLossWithIgnore + init_args: + reduction: sum + gamma_neg: 1.0 + gamma_pos: 0.0 + + optimizer: + class_path: torch.optim.SGD + init_args: + lr: 0.0071 + momentum: 0.9 + weight_decay: 0.0001 + + scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 1 + monitor: val/accuracy + +engine: + task: H_LABEL_CLS + device: auto + +callback_monitor: val/accuracy + +data: ../../_base_/data/torchvision_base.yaml +overrides: + max_epochs: 90 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 3 + data: + task: H_LABEL_CLS + config: + mem_cache_img_max_size: + - 500 + - 500 + stack_images: True + data_format: datumaro + train_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop + init_args: + scale: 224 + backend: cv2 + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + sampler: + class_path: otx.algo.samplers.balanced_sampler.BalancedSampler + val_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: 224 + transform_bbox: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + test_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: 224 + transform_bbox: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] diff --git a/src/otx/recipe/classification/h_label_cls/efficientnet_v2_light.yaml b/src/otx/recipe/classification/h_label_cls/efficientnet_v2_light.yaml deleted file mode 100644 index 96348a98d48..00000000000 --- a/src/otx/recipe/classification/h_label_cls/efficientnet_v2_light.yaml +++ /dev/null @@ -1,45 +0,0 @@ -model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForHLabelCls - init_args: - optimizer: - class_path: torch.optim.SGD - init_args: - lr: 0.0071 - momentum: 0.9 - weight_decay: 0.0001 - - scheduler: - class_path: lightning.pytorch.cli.ReduceLROnPlateau - init_args: - mode: max - factor: 0.1 - patience: 1 - monitor: val/accuracy - -engine: - task: H_LABEL_CLS - device: auto - -callback_monitor: val/accuracy - -data: ../../_base_/data/mmpretrain_base.yaml -overrides: - max_epochs: 90 - callbacks: - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - patience: 3 - data: - task: H_LABEL_CLS - config: - data_format: datumaro - train_subset: - transforms: - - type: LoadImageFromFile - - backend: cv2 - scale: 224 - type: RandomResizedCrop - - direction: horizontal - prob: 0.5 - type: RandomFlip - - type: PackInputs diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml new file mode 100644 index 00000000000..429cc1accd7 --- /dev/null +++ b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml @@ -0,0 +1,93 @@ +model: + class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMulticlassCls + init_args: + label_info: 1000 + loss_callable: + class_path: torch.nn.CrossEntropyLoss + init_args: + reduction: none + + optimizer: + class_path: torch.optim.SGD + init_args: + lr: 0.0071 + momentum: 0.9 + weight_decay: 0.0001 + + scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 1 + monitor: val/accuracy + +engine: + task: MULTI_CLASS_CLS + device: auto + +callback_monitor: val/accuracy + +data: ../../_base_/data/torchvision_base.yaml +overrides: + max_epochs: 90 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 3 + data: + config: + mem_cache_img_max_size: + - 500 + - 500 + stack_images: True + train_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop + init_args: + scale: 224 + backend: cv2 + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + sampler: + class_path: otx.algo.samplers.balanced_sampler.BalancedSampler + val_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: 224 + transform_bbox: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + test_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: 224 + transform_bbox: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2_light.yaml b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2_light.yaml deleted file mode 100644 index a9f2ac442da..00000000000 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2_light.yaml +++ /dev/null @@ -1,46 +0,0 @@ -model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMulticlassCls - init_args: - label_info: 1000 - light: True - - optimizer: - class_path: torch.optim.SGD - init_args: - lr: 0.0071 - momentum: 0.9 - weight_decay: 0.0001 - - scheduler: - class_path: lightning.pytorch.cli.ReduceLROnPlateau - init_args: - mode: max - factor: 0.1 - patience: 1 - monitor: val/accuracy - -engine: - task: MULTI_CLASS_CLS - device: auto - -callback_monitor: val/accuracy - -data: ../../_base_/data/mmpretrain_base.yaml -overrides: - max_epochs: 90 - callbacks: - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - patience: 3 - data: - config: - train_subset: - transforms: - - type: LoadImageFromFile - - backend: cv2 - scale: 224 - type: RandomResizedCrop - - direction: horizontal - prob: 0.5 - type: RandomFlip - - type: PackInputs diff --git a/src/otx/recipe/classification/multi_class_cls/otx_efficientnet_v2.yaml b/src/otx/recipe/classification/multi_class_cls/otx_efficientnet_v2.yaml deleted file mode 100644 index 8e8044f7dba..00000000000 --- a/src/otx/recipe/classification/multi_class_cls/otx_efficientnet_v2.yaml +++ /dev/null @@ -1,46 +0,0 @@ -model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMulticlassCls - init_args: - label_info: 1000 - light: false - - optimizer: - class_path: torch.optim.SGD - init_args: - lr: 0.0071 - momentum: 0.9 - weight_decay: 0.0001 - - scheduler: - class_path: lightning.pytorch.cli.ReduceLROnPlateau - init_args: - mode: max - factor: 0.1 - patience: 1 - monitor: val/accuracy - -engine: - task: MULTI_CLASS_CLS - device: auto - -callback_monitor: val/accuracy - -data: ../../_base_/data/mmpretrain_base.yaml -overrides: - max_epochs: 90 - callbacks: - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - patience: 3 - data: - config: - train_subset: - transforms: - - type: LoadImageFromFile - - backend: cv2 - scale: 224 - type: RandomResizedCrop - - direction: horizontal - prob: 0.5 - type: RandomFlip - - type: PackInputs diff --git a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml new file mode 100644 index 00000000000..0923fa91dc2 --- /dev/null +++ b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml @@ -0,0 +1,97 @@ +model: + class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMultilabelCls + init_args: + label_info: 1000 + loss_callable: + class_path: otx.algo.classification.losses.AsymmetricAngularLossWithIgnore + init_args: + reduction: sum + gamma_neg: 1.0 + gamma_pos: 0.0 + + optimizer: + class_path: torch.optim.SGD + init_args: + lr: 0.0071 + momentum: 0.9 + weight_decay: 0.0001 + + scheduler: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 1 + monitor: val/accuracy + +engine: + task: MULTI_LABEL_CLS + device: auto + +callback_monitor: val/accuracy + +data: ../../_base_/data/torchvision_base.yaml +overrides: + max_epochs: 90 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 3 + data: + task: MULTI_LABEL_CLS + config: + data_format: datumaro + mem_cache_img_max_size: + - 500 + - 500 + stack_images: True + train_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop + init_args: + scale: 224 + backend: cv2 + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + sampler: + class_path: otx.algo.samplers.balanced_sampler.BalancedSampler + val_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: 224 + transform_bbox: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + test_subset: + batch_size: 64 + transforms: + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + scale: 224 + transform_bbox: false + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] diff --git a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2_light.yaml b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2_light.yaml deleted file mode 100644 index f7347155b57..00000000000 --- a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2_light.yaml +++ /dev/null @@ -1,61 +0,0 @@ -model: - class_path: otx.algo.classification.efficientnet_v2.EfficientNetV2ForMultilabelCls - init_args: - label_info: 1000 - - optimizer: - class_path: torch.optim.SGD - init_args: - lr: 0.0071 - momentum: 0.9 - weight_decay: 0.0001 - - scheduler: - class_path: lightning.pytorch.cli.ReduceLROnPlateau - init_args: - mode: max - factor: 0.1 - patience: 1 - monitor: val/accuracy - -engine: - task: MULTI_LABEL_CLS - device: auto - -callback_monitor: val/accuracy - -data: ../../_base_/data/mmpretrain_base.yaml -overrides: - max_epochs: 90 - callbacks: - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - patience: 3 - data: - task: MULTI_LABEL_CLS - config: - data_format: datumaro - train_subset: - transforms: - - type: LoadImageFromFile - - backend: cv2 - scale: 224 - type: Resize - - direction: horizontal - prob: 0.5 - type: RandomFlip - - type: PackInputs - val_subset: - transforms: - - type: LoadImageFromFile - - backend: cv2 - scale: 224 - type: Resize - - type: PackInputs - test_subset: - transforms: - - type: LoadImageFromFile - - backend: cv2 - scale: 224 - type: Resize - - type: PackInputs diff --git a/src/otx/tools/converter.py b/src/otx/tools/converter.py index e0757b1524e..145ef0ff2fd 100644 --- a/src/otx/tools/converter.py +++ b/src/otx/tools/converter.py @@ -34,7 +34,7 @@ }, "Custom_Image_Classification_EfficientNet-V2-S": { "task": OTXTaskType.MULTI_CLASS_CLS, - "model_name": "efficientnet_v2_light", + "model_name": "efficientnet_v2", }, "Custom_Image_Classification_MobileNet-V3-large-1x": { "task": OTXTaskType.MULTI_CLASS_CLS, diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index c0461cd058d..aebd134534f 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -95,8 +95,11 @@ def test_predict_with_explain( if "dino" in model_name or "rtmdet_inst_tiny" in model_name: pytest.skip("DINO and Rtmdet_tiny are not supported.") + # TODO(GalyaZalesskaya): https://jira.devtools.intel.com/browse/CVS-138604 -> mobilenet_v3_large, efficientnet_v2 if "mobilenet_v3_large" in model_name: pytest.skip("There's issue with mobilenet_v3_large model. Skip for now.") + if "efficientnet_v2" in model_name: + pytest.skip("There's issue with efficientnet_v2 model. Skip for now.") if "ssd_mobilenetv2" in model_name: pytest.skip("There's issue with SSD model. Skip for now.") diff --git a/tests/unit/algo/classification/backbones/test_otx_efficientnet_v2.py b/tests/unit/algo/classification/backbones/test_timm.py similarity index 69% rename from tests/unit/algo/classification/backbones/test_otx_efficientnet_v2.py rename to tests/unit/algo/classification/backbones/test_timm.py index 8dcdeb3698a..0716e2c08cb 100644 --- a/tests/unit/algo/classification/backbones/test_otx_efficientnet_v2.py +++ b/tests/unit/algo/classification/backbones/test_timm.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from otx.algo.classification.backbones.otx_efficientnet_v2 import OTXEfficientNetV2 +from otx.algo.classification.backbones.timm import TimmBackbone class TestOTXEfficientNetV2: def test_forward(self): - model = OTXEfficientNetV2() + model = TimmBackbone(backbone="efficientnetv2_s_21k") model.init_weights() assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8]) def test_get_config_optim(self): - model = OTXEfficientNetV2() + model = TimmBackbone(backbone="efficientnetv2_s_21k") assert model.get_config_optim([0.01])[0]["lr"] == 0.01 assert model.get_config_optim(0.01)[0]["lr"] == 0.01 diff --git a/tests/unit/engine/utils/test_api.py b/tests/unit/engine/utils/test_api.py index e136058d463..d9b7e3f82af 100644 --- a/tests/unit/engine/utils/test_api.py +++ b/tests/unit/engine/utils/test_api.py @@ -29,11 +29,10 @@ def test_list_models_pattern() -> None: models = list_models(pattern="efficient") target = [ - "efficientnet_v2_light", "efficientnet_b0_light", + "efficientnet_v2", "maskrcnn_efficientnetb2b", "maskrcnn_efficientnetb2b_tile", - "otx_efficientnet_v2", "otx_efficientnet_b0", "tv_efficientnet_b0", "tv_efficientnet_b1", @@ -52,4 +51,3 @@ def test_list_models_print_table(capfd: pytest.CaptureFixture) -> None: assert "Model Name" in out assert "Recipe Path" in out assert "otx_efficientnet_b0" in out - assert "otx_efficientnet_v2" in out