From 43f1fc9b25c4a45abb1da7d1b7d39d8f6359527b Mon Sep 17 00:00:00 2001 From: Harim Kang Date: Sat, 17 Aug 2024 11:07:02 +0900 Subject: [PATCH] Refactor Classifier & Head (#3830) * Refactor classification * Add Loss scale * fix unit-tests * fix unit-tests * Fix Loss reduction settings * Revert unnecessary change * Add unit-tests * Rename OTX*Head to *Head * Fix unlabeled_coef * Fix tv model loss scale * Refactor torchvision models * Fix in_features found way * Refactor H-label side * Remove regacy code * Remove _exporter * Fix wrong config * Update docstring & Add unit-test * Update docstring 2 * Remove hard Type assign --- .../multi_class_classification.rst | 2 +- .../algo/callbacks/unlabeled_loss_warmup.py | 2 +- .../algo/classification/backbones/__init__.py | 3 +- .../classification/backbones/torchvision.py | 108 +++ .../classification/classifier/__init__.py | 3 +- .../classifier/base_classifier.py | 78 +- .../classifier/h_label_classifier.py | 145 +++ .../classifier/semi_sl_classifier.py | 71 +- src/otx/algo/classification/efficientnet.py | 24 +- src/otx/algo/classification/heads/__init__.py | 6 +- .../classification/heads/hlabel_cls_head.py | 101 --- .../algo/classification/heads/linear_head.py | 30 - .../heads/multilabel_cls_head.py | 93 +- .../algo/classification/heads/semi_sl_head.py | 51 +- .../heads/vision_transformer_head.py | 22 - src/otx/algo/classification/mobilenet_v3.py | 57 +- src/otx/algo/classification/timm_model.py | 23 +- .../algo/classification/torchvision_model.py | 835 ++++++------------ .../classification/utils/ignored_labels.py | 28 + src/otx/algo/classification/vit.py | 48 +- src/otx/core/model/classification.py | 4 +- src/otx/engine/engine.py | 14 - .../h_label_cls/tv_efficientnet_b3.yaml | 3 +- .../h_label_cls/tv_efficientnet_v2_l.yaml | 3 +- .../h_label_cls/tv_mobilenet_v3_small.yaml | 3 +- .../semisl/tv_efficientnet_b3_semisl.yaml | 3 +- .../semisl/tv_efficientnet_v2_l_semisl.yaml | 3 +- .../semisl/tv_mobilenet_v3_small_semisl.yaml | 3 +- .../multi_class_cls/tv_efficientnet_b3.yaml | 3 +- .../multi_class_cls/tv_efficientnet_v2_l.yaml | 3 +- .../tv_mobilenet_v3_small.yaml | 3 +- .../multi_label_cls/tv_efficientnet_b3.yaml | 3 +- .../multi_label_cls/tv_efficientnet_v2_l.yaml | 3 +- .../tv_mobilenet_v3_small.yaml | 3 +- .../classifier/test_base_classifier.py | 63 ++ .../classifier/test_semi_sl_classifier.py | 45 + .../heads/test_hlabel_cls_head.py | 31 - .../heads/test_multilabel_cls_head.py | 18 - .../classification/heads/test_semi_sl_head.py | 23 +- .../classification/test_torchvision_model.py | 40 +- .../utils/test_ignored_labels.py | 26 + tests/unit/engine/test_engine.py | 4 +- 42 files changed, 862 insertions(+), 1172 deletions(-) create mode 100644 src/otx/algo/classification/backbones/torchvision.py create mode 100644 src/otx/algo/classification/classifier/h_label_classifier.py create mode 100644 src/otx/algo/classification/utils/ignored_labels.py create mode 100644 tests/unit/algo/classification/classifier/test_base_classifier.py create mode 100644 tests/unit/algo/classification/classifier/test_semi_sl_classifier.py create mode 100644 tests/unit/algo/classification/utils/test_ignored_labels.py diff --git a/docs/source/guide/explanation/algorithms/classification/multi_class_classification.rst b/docs/source/guide/explanation/algorithms/classification/multi_class_classification.rst index 3b724a3ee7d..2227a38c121 100644 --- a/docs/source/guide/explanation/algorithms/classification/multi_class_classification.rst +++ b/docs/source/guide/explanation/algorithms/classification/multi_class_classification.rst @@ -110,7 +110,7 @@ There are 58 different models available from torchvision, see `TVModelType int: + """Get the in_features value from the first layer of an nn.Sequential object.""" + for layer in sequential.children(): + if isinstance(layer, nn.Linear): + return layer.in_features + if isinstance(layer, nn.Conv2d): + return layer.in_channels + # Add more conditions if needed for other layer types + msg = "No suitable layer found to extract in_features" + raise ValueError(msg) + + +class TorchvisionBackbone(nn.Module): + """TorchvisionBackbone is a class that represents a backbone model from the torchvision library.""" + + def __init__( + self, + backbone: TVModelType, + pretrained: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + + tv_model_cfg = {"name": backbone} + if pretrained: + tv_model_cfg["weights"] = get_model_weights(backbone) + net = get_model(**tv_model_cfg) + self.features = net.features + + last_layer = list(net.children())[-1] + self.in_features = get_in_features(last_layer) + + def forward(self, *args) -> torch.Tensor: + """Forward pass of the model.""" + return self.features(*args) diff --git a/src/otx/algo/classification/classifier/__init__.py b/src/otx/algo/classification/classifier/__init__.py index ac84666cdbc..f6125044743 100644 --- a/src/otx/algo/classification/classifier/__init__.py +++ b/src/otx/algo/classification/classifier/__init__.py @@ -4,6 +4,7 @@ """Head modules for OTX custom model.""" from .base_classifier import ImageClassifier +from .h_label_classifier import HLabelClassifier from .semi_sl_classifier import SemiSLClassifier -__all__ = ["ImageClassifier", "SemiSLClassifier"] +__all__ = ["ImageClassifier", "SemiSLClassifier", "HLabelClassifier"] diff --git a/src/otx/algo/classification/classifier/base_classifier.py b/src/otx/algo/classification/classifier/base_classifier.py index 3c5126824b2..b1c8daaf037 100644 --- a/src/otx/algo/classification/classifier/base_classifier.py +++ b/src/otx/algo/classification/classifier/base_classifier.py @@ -12,10 +12,13 @@ from __future__ import annotations import copy +import inspect from typing import TYPE_CHECKING import torch +from otx.algo.classification.necks.gap import GlobalAveragePooling +from otx.algo.classification.utils.ignored_labels import get_valid_label_mask from otx.algo.explain.explain_algo import ReciproCAM from otx.algo.modules.base_module import BaseModule @@ -27,30 +30,13 @@ class ImageClassifier(BaseModule): """Image classifiers for supervised classification task. Args: - backbone (dict): The backbone module. See - :mod:`mmpretrain.models.backbones`. - neck (dict, optional): The neck module to process features from - backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. - head (dict, optional): The head module to do prediction and calculate - loss from processed features. See :mod:`mmpretrain.models.heads`. + backbone (nn.Module): The backbone module. + neck (nn.Module | None): The neck module to process features from backbone. + head (nn.Module): The head module to do prediction and calculate loss from processed features. Notice that if the head is not set, almost all methods cannot be used except :meth:`extract_feat`. Defaults to None. - pretrained (str, optional): The pretrained checkpoint path, support - local path and remote path. Defaults to None. - train_cfg (dict, optional): The training setting. The acceptable - fields are: - - - augments (List[dict]): The batch augmentation methods to use. - More details can be found in - :mod:`mmpretrain.model.utils.augment`. - - probs (List[float], optional): The probability of every batch - augmentation methods. If None, choose evenly. Defaults to None. - - Defaults to None. - data_preprocessor (dict, optional): The config for preprocessing input - data. If None or no specified type, it will use - "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for - more details. Defaults to None. + loss (nn.Module): The loss module to calculate the loss. + loss_scale (float, optional): The scaling factor for the loss. Defaults to 1.0. init_cfg (dict, optional): the config to control the initialization. Defaults to None. """ @@ -60,16 +46,10 @@ def __init__( backbone: nn.Module, neck: nn.Module | None, head: nn.Module, - pretrained: str | None = None, - optimize_gap: bool = True, - mean: list[float] | None = None, - std: list[float] | None = None, - to_rgb: bool = False, + loss: nn.Module, + loss_scale: float = 1.0, init_cfg: dict | list[dict] | None = None, ): - if pretrained is not None: - init_cfg = {"type": "Pretrained", "checkpoint": pretrained} - super().__init__(init_cfg=init_cfg) self._is_init = False @@ -78,11 +58,14 @@ def __init__( self.backbone = backbone self.neck = neck self.head = head + self.loss_module = loss + self.loss_scale = loss_scale + self.is_ignored_label_loss = "valid_label_mask" in inspect.getfullargspec(self.loss_module.forward).args self.explainer = ReciproCAM( self._head_forward_fn, num_classes=head.num_classes, - optimize_gap=optimize_gap, + optimize_gap=isinstance(neck, GlobalAveragePooling), ) def forward( @@ -103,8 +86,7 @@ def forward( torch.Tensor: The output logits or loss, depending on the training mode. """ if mode == "tensor": - feats = self.extract_feat(images) - return self.head(feats) + return self.extract_feat(images, stage="head") if mode == "loss": return self.loss(images, labels, **kwargs) if mode == "predict": @@ -115,7 +97,7 @@ def forward( msg = f'Invalid mode "{mode}".' raise RuntimeError(msg) - def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> tuple | torch.Tensor: + def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> torch.Tensor: """Extract features from the input tensor with shape (N, C, ...). Args: @@ -133,10 +115,8 @@ def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> tuple | tor Defaults to "neck". Returns: - tuple | Tensor: The output of specified stage. - The output depends on detailed implementation. In general, the - output of backbone and neck is a tuple and the output of - pre_logits is a tensor. + torch.Tensor: The output of specified stage. + In general, the output of pre_logits is a tensor. """ x = self.backbone(inputs) @@ -151,7 +131,7 @@ def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> tuple | tor return self.head(x) - def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> dict: + def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: """Calculate losses from a batch of inputs and data samples. Args: @@ -161,12 +141,15 @@ def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> dict: every samples. Returns: - dict[str, Tensor]: a dictionary of loss components + torch.Tensor: loss components """ - feats = self.extract_feat(inputs) - return self.head.loss(feats, labels, **kwargs) + cls_score = self.extract_feat(inputs, stage="head") * self.loss_scale + imgs_info = kwargs.pop("imgs_info", None) + if imgs_info is not None and self.is_ignored_label_loss: + kwargs["valid_label_mask"] = get_valid_label_mask(imgs_info, self.head.num_classes).to(cls_score.device) + return self.loss_module(cls_score, labels, **kwargs) / self.loss_scale - def predict(self, inputs: torch.Tensor, **kwargs) -> list[torch.Tensor]: + def predict(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: """Predict results from a batch of inputs. Args: @@ -206,13 +189,8 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis logits = self.head(x) pred_results = self.head._get_predictions(logits) # noqa: SLF001 - # H-Label Classification Case - if isinstance(pred_results, dict): - scores = pred_results["scores"] - preds = pred_results["labels"] - else: - scores = pred_results.unbind(0) - preds = logits.argmax(-1, keepdim=True).unbind(0) + scores = pred_results.unbind(0) + preds = logits.argmax(-1, keepdim=True).unbind(0) outputs = { "logits": logits, diff --git a/src/otx/algo/classification/classifier/h_label_classifier.py b/src/otx/algo/classification/classifier/h_label_classifier.py new file mode 100644 index 00000000000..71e740a1a76 --- /dev/null +++ b/src/otx/algo/classification/classifier/h_label_classifier.py @@ -0,0 +1,145 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Classifier for H-Label Classification.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING + +import torch + +from otx.algo.classification.heads.hlabel_cls_head import HierarchicalClsHead +from otx.algo.classification.utils.ignored_labels import get_valid_label_mask + +from .base_classifier import ImageClassifier + +if TYPE_CHECKING: + from torch import nn + + +class HLabelClassifier(ImageClassifier): + """Hierarchical label classifier. + + Args: + backbone (nn.Module): Backbone network. + neck (nn.Module | None): Neck network. + head (nn.Module): Head network. + multiclass_loss (nn.Module): Multiclass loss function. + multilabel_loss (nn.Module | None, optional): Multilabel loss function. + init_cfg (dict | list[dict] | None, optional): Initialization configuration. + + Attributes: + multiclass_loss (nn.Module): Multiclass loss function. + multilabel_loss (nn.Module | None): Multilabel loss function. + is_ignored_label_loss (bool): Flag indicating if ignored label loss is used. + + Methods: + loss(inputs, labels, **kwargs): Calculate losses from a batch of inputs and data samples. + _forward_explain(images): Perform forward pass for explanation. + """ + + def __init__( + self, + backbone: nn.Module, + neck: nn.Module | None, + head: HierarchicalClsHead, + multiclass_loss: nn.Module, + multilabel_loss: nn.Module | None = None, + init_cfg: dict | list[dict] | None = None, + ): + super().__init__( + backbone=backbone, + neck=neck, + head=head, + loss=multiclass_loss, + init_cfg=init_cfg, + ) + + self.multiclass_loss = multiclass_loss + self.multilabel_loss = None + if self.head.num_multilabel_classes > 0 and multilabel_loss is not None: + self.multilabel_loss = multilabel_loss + self.is_ignored_label_loss = "valid_label_mask" in inspect.getfullargspec(self.multilabel_loss.forward).args + + def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + labels (torch.Tensor): The annotation data of + every samples. + + Returns: + torch.Tensor: loss components + """ + cls_scores = self.extract_feat(inputs, stage="head") + loss_score = torch.tensor(0.0, device=cls_scores.device) + + # Multiclass loss + num_effective_heads_in_batch = 0 # consider the label removal case + for i in range(self.head.num_multiclass_heads): + if i not in self.head.empty_multiclass_head_indices: + head_gt = labels[:, i] + logit_range = self.head._get_head_idx_to_logits_range(i) # noqa: SLF001 + head_logits = cls_scores[:, logit_range[0] : logit_range[1]] + valid_mask = head_gt >= 0 + + head_gt = head_gt[valid_mask] + if len(head_gt) > 0: + head_logits = head_logits[valid_mask, :] + loss_score += self.multiclass_loss(head_logits, head_gt) + num_effective_heads_in_batch += 1 + + if num_effective_heads_in_batch > 0: + loss_score /= num_effective_heads_in_batch + + # Multilabel loss + if self.head.num_multilabel_classes > 0: + head_gt = labels[:, self.head.num_multiclass_heads :] + head_logits = cls_scores[:, self.head.num_single_label_classes :] + valid_mask = head_gt > 0 + head_gt = head_gt[valid_mask] + if len(head_gt) > 0 and self.multilabel_loss is not None: + head_logits = head_logits[valid_mask] + imgs_info = kwargs.pop("imgs_info", None) + if imgs_info is not None and self.is_ignored_label_loss: + valid_label_mask = get_valid_label_mask(imgs_info, self.head.num_classes).to(head_logits.device) + valid_label_mask = valid_label_mask[:, self.head.num_single_label_classes :] + valid_label_mask = valid_label_mask[valid_mask] + kwargs["valid_label_mask"] = valid_label_mask + loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs) + + return loss_score + + @torch.no_grad() + def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]: + from otx.algo.explain.explain_algo import feature_vector_fn + + x = self.backbone(images) + backbone_feat = x + + feature_vector = feature_vector_fn(backbone_feat) + saliency_map = self.explainer.func(backbone_feat) + + if hasattr(self, "neck") and self.neck is not None: + x = self.neck(x) + + logits = self.head(x) + pred_results = self.head._get_predictions(logits) # noqa: SLF001 + scores = pred_results["scores"] + preds = pred_results["labels"] + + outputs = { + "logits": logits, + "feature_vector": feature_vector, + "saliency_map": saliency_map, + } + + if not torch.jit.is_tracing(): + outputs["scores"] = scores + outputs["preds"] = preds + + return outputs diff --git a/src/otx/algo/classification/classifier/semi_sl_classifier.py b/src/otx/algo/classification/classifier/semi_sl_classifier.py index 8981e70a953..af35b2149d1 100644 --- a/src/otx/algo/classification/classifier/semi_sl_classifier.py +++ b/src/otx/algo/classification/classifier/semi_sl_classifier.py @@ -10,22 +10,54 @@ from __future__ import annotations import torch +from torch import nn + +from otx.algo.classification.heads.semi_sl_head import OTXSemiSLClsHead from .base_classifier import ImageClassifier class SemiSLClassifier(ImageClassifier): - """Semi-SL Classifier.""" + """Semi-SL classifier. + + Args: + backbone (nn.Module): The backbone network. + neck (nn.Module | None): The neck module. Defaults to None. + head (nn.Module): The head module. + loss (nn.Module): The loss module. + unlabeled_coef (float): The coefficient for the unlabeled loss. Defaults to 1.0. + init_cfg (dict | list[dict] | None): The initialization configuration. Defaults to None. + """ + + head: OTXSemiSLClsHead + + def __init__( + self, + backbone: nn.Module, + neck: nn.Module | None, + head: nn.Module, + loss: nn.Module, + unlabeled_coef: float = 1.0, + init_cfg: dict | list[dict] | None = None, + ): + super().__init__( + backbone=backbone, + neck=neck, + head=head, + loss=loss, + init_cfg=init_cfg, + ) + self.unlabeled_coef = unlabeled_coef def extract_feat( self, - inputs: dict[str, torch.Tensor], + inputs: dict[str, torch.Tensor] | torch.Tensor, stage: str = "neck", - ) -> dict[str, tuple | torch.Tensor] | tuple | torch.Tensor: + ) -> dict[str, torch.Tensor] | tuple[torch.Tensor] | torch.Tensor: """Extract features from the input tensor with shape (N, C, ...). Args: - inputs (dict[str, Tensor]): A batch of inputs. The shape of it should be + inputs (dict[str, torch.Tensor] | torch.Tensor): A batch of inputs. The shape of it should be ``(num_samples, num_channels, *img_shape)``. stage (str): Which stage to output the feature. Choose from: @@ -37,7 +69,7 @@ def extract_feat( Defaults to "neck". Returns: - dict[str, tuple | torch.Tensor] | tuple | Tensor: The output of specified stage. + dict[str, torch.Tensor] | tuple[torch.Tensor] | torch.Tensor: The output of specified stage. The output depends on detailed implementation. In general, the Semi-SL output is a dict of labeled feats and unlabeled feats. """ @@ -56,3 +88,32 @@ def extract_feat( x["unlabeled_strong"] = super().extract_feat(unlabeled_strong_inputs, stage) return x + + def loss(self, inputs: dict[str, torch.Tensor], labels: torch.Tensor, **kwargs) -> torch.Tensor: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + labels (torch.Tensor): The annotation data of + every samples. + + Returns: + torch.Tensor: loss components + """ + semi_inputs = self.extract_feat(inputs) + logits, labels, pseudo_label, mask = self.head.get_logits(semi_inputs, labels) + logits_x, logits_u_s = logits + num_samples = len(logits_x) + + # compute supervised loss + labeled_loss = self.loss_module(logits_x, labels).sum() / num_samples + + unlabeled_loss = torch.tensor(0.0) + num_pseudo_labels = 0 if mask is None else int(mask.sum().item()) + if len(logits_u_s) > 0 and num_pseudo_labels > 0 and mask is not None: + # compute unsupervised loss + unlabeled_loss = (self.loss_module(logits_u_s, pseudo_label) * mask).sum() / mask.sum().item() + unlabeled_loss.masked_fill_(torch.isnan(unlabeled_loss), 0.0) + + return labeled_loss + self.unlabeled_coef * unlabeled_loss diff --git a/src/otx/algo/classification/efficientnet.py b/src/otx/algo/classification/efficientnet.py index d6ab980a3a4..b96e3cedd62 100644 --- a/src/otx/algo/classification/efficientnet.py +++ b/src/otx/algo/classification/efficientnet.py @@ -13,13 +13,12 @@ from torch import Tensor, nn from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet -from otx.algo.classification.classifier.base_classifier import ImageClassifier -from otx.algo.classification.classifier.semi_sl_classifier import SemiSLClassifier +from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, LinearClsHead, MultiLabelLinearClsHead, - OTXSemiSLLinearClsHead, + SemiSLLinearClsHead, ) from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore from otx.algo.classification.necks.gap import GlobalAveragePooling @@ -91,16 +90,15 @@ def _create_model(self) -> nn.Module: def _build_model(self, num_classes: int) -> nn.Module: backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained) neck = GlobalAveragePooling(dim=2) - loss = nn.CrossEntropyLoss(reduction="none") if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( backbone=backbone, neck=neck, - head=OTXSemiSLLinearClsHead( + head=SemiSLLinearClsHead( num_classes=num_classes, in_channels=backbone.num_features, - loss=loss, ), + loss=nn.CrossEntropyLoss(reduction="none"), ) return ImageClassifier( @@ -109,9 +107,8 @@ def _build_model(self, num_classes: int) -> nn.Module: head=LinearClsHead( num_classes=num_classes, in_channels=backbone.num_features, - topk=(1, 5) if num_classes >= 5 else (1,), - loss=loss, ), + loss=nn.CrossEntropyLoss(), ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -188,10 +185,10 @@ def _build_model(self, num_classes: int) -> nn.Module: head=MultiLabelLinearClsHead( num_classes=num_classes, in_channels=backbone.num_features, - scale=7.0, normalized=True, - loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), ), + loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + loss_scale=7.0, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -274,16 +271,15 @@ def _build_model(self, head_config: dict) -> nn.Module: copied_head_config = copy(head_config) copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32)) - return ImageClassifier( + return HLabelClassifier( backbone=backbone, neck=nn.Identity(), head=HierarchicalCBAMClsHead( in_channels=backbone.num_features, - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **copied_head_config, ), - optimize_gap=False, + multiclass_loss=nn.CrossEntropyLoss(), + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/heads/__init__.py b/src/otx/algo/classification/heads/__init__.py index a920d6782bb..0b3356cdd1b 100644 --- a/src/otx/algo/classification/heads/__init__.py +++ b/src/otx/algo/classification/heads/__init__.py @@ -6,7 +6,7 @@ from .hlabel_cls_head import HierarchicalCBAMClsHead, HierarchicalLinearClsHead, HierarchicalNonLinearClsHead from .linear_head import LinearClsHead from .multilabel_cls_head import MultiLabelLinearClsHead, MultiLabelNonLinearClsHead -from .semi_sl_head import OTXSemiSLLinearClsHead, OTXSemiSLVisionTransformerClsHead +from .semi_sl_head import SemiSLLinearClsHead, SemiSLVisionTransformerClsHead from .vision_transformer_head import VisionTransformerClsHead __all__ = [ @@ -17,6 +17,6 @@ "HierarchicalNonLinearClsHead", "HierarchicalCBAMClsHead", "VisionTransformerClsHead", - "OTXSemiSLLinearClsHead", - "OTXSemiSLVisionTransformerClsHead", + "SemiSLLinearClsHead", + "SemiSLVisionTransformerClsHead", ] diff --git a/src/otx/algo/classification/heads/hlabel_cls_head.py b/src/otx/algo/classification/heads/hlabel_cls_head.py index b0f6cfb9711..cce90446feb 100644 --- a/src/otx/algo/classification/heads/hlabel_cls_head.py +++ b/src/otx/algo/classification/heads/hlabel_cls_head.py @@ -5,7 +5,6 @@ from __future__ import annotations -import inspect from typing import Callable, Sequence import torch @@ -13,7 +12,6 @@ from otx.algo.modules.base_module import BaseModule from otx.algo.utils.weight_init import constant_init, normal_init -from otx.core.data.entity.base import ImageInfo class HierarchicalClsHead(BaseModule): @@ -32,8 +30,6 @@ def __init__( empty_multiclass_head_indices: list[int], in_channels: int, num_classes: int, - multiclass_loss: nn.Module, - multilabel_loss: nn.Module | None = None, thr: float = 0.5, init_cfg: dict | None = None, **kwargs, @@ -52,13 +48,6 @@ def __init__( msg = "num_multiclass_head should be larger than 0" raise ValueError(msg) - self.multiclass_loss = multiclass_loss - self.multilabel_loss = None - self.is_ignored_label_loss = False - if num_multilabel_classes > 0 and multilabel_loss is not None: - self.multilabel_loss = multilabel_loss - self.is_ignored_label_loss = "valid_label_mask" in inspect.getfullargspec(self.multilabel_loss.forward).args - def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: """The process before the final classification head.""" if isinstance(feats, Sequence): @@ -72,78 +61,6 @@ def _get_head_idx_to_logits_range(self, idx: int) -> tuple[int, int]: self.head_idx_to_logits_range[str(idx)][1], ) - def loss(self, feats: tuple[torch.Tensor], labels: torch.Tensor, **kwargs) -> torch.Tensor: - """Calculate losses from the classification score. - - Args: - feats (tuple[Tensor]): The features extracted from the backbone. - Multiple stage inputs are acceptable but only the last stage - will be used to classify. The shape of every item should be - ``(num_samples, num_classes)``. - labels (torch.Tensor): The annotation data of - every samples. - **kwargs: Other keyword arguments to forward the loss module. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - cls_scores = self(feats) - - loss_score = torch.tensor(0.0, device=cls_scores.device) - - # Multiclass loss - num_effective_heads_in_batch = 0 # consider the label removal case - for i in range(self.num_multiclass_heads): - if i not in self.empty_multiclass_head_indices: - head_gt = labels[:, i] - logit_range = self._get_head_idx_to_logits_range(i) - head_logits = cls_scores[:, logit_range[0] : logit_range[1]] - valid_mask = head_gt >= 0 - - head_gt = head_gt[valid_mask] - if len(head_gt) > 0: - head_logits = head_logits[valid_mask, :] - loss_score += self.multiclass_loss(head_logits, head_gt) - num_effective_heads_in_batch += 1 - - if num_effective_heads_in_batch > 0: - loss_score /= num_effective_heads_in_batch - - # Multilabel loss - if self.num_multilabel_classes > 0: - head_gt = labels[:, self.num_multiclass_heads :] - head_logits = cls_scores[:, self.num_single_label_classes :] - valid_mask = head_gt > 0 - head_gt = head_gt[valid_mask] - if len(head_gt) > 0 and self.multilabel_loss is not None: - head_logits = head_logits[valid_mask] - imgs_info = kwargs.pop("imgs_info", None) - if imgs_info is not None and self.is_ignored_label_loss: - valid_label_mask = self.get_valid_label_mask(imgs_info).to(head_logits.device) - valid_label_mask = valid_label_mask[:, self.num_single_label_classes :] - valid_label_mask = valid_label_mask[valid_mask] - kwargs["valid_label_mask"] = valid_label_mask - loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs) - - return loss_score - - def get_valid_label_mask(self, img_metas: list[ImageInfo]) -> torch.Tensor: - """Get valid label mask using ignored_label. - - Args: - img_metas (list[ImageInfo]): The metadata of the input images. - - Returns: - torch.Tensor: The valid label mask. - """ - valid_label_mask = [] - for meta in img_metas: - mask = torch.Tensor([1 for _ in range(self.num_classes)]) - if meta.ignored_labels: - mask[meta.ignored_labels] = 0 - valid_label_mask.append(mask) - return torch.stack(valid_label_mask, dim=0) - def predict( self, feats: tuple[torch.Tensor], @@ -217,8 +134,6 @@ class HierarchicalLinearClsHead(HierarchicalClsHead): due to the label removing in_channels (int): Number of channels in the input feature map. num_classes (int): Number of total classes. - multiclass_loss (dict | None): Config of multi-class loss. - multilabel_loss (dict | None): Config of multi-label loss. thr (float | None): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. """ @@ -232,8 +147,6 @@ def __init__( empty_multiclass_head_indices: list[int], in_channels: int, num_classes: int, - multiclass_loss: nn.Module, - multilabel_loss: nn.Module | None = None, thr: float = 0.5, init_cfg: dict | None = None, **kwargs, @@ -246,8 +159,6 @@ def __init__( empty_multiclass_head_indices=empty_multiclass_head_indices, in_channels=in_channels, num_classes=num_classes, - multiclass_loss=multiclass_loss, - multilabel_loss=multilabel_loss, thr=thr, init_cfg=init_cfg, **kwargs, @@ -278,8 +189,6 @@ class HierarchicalNonLinearClsHead(HierarchicalClsHead): due to the label removing in_channels (int): Number of channels in the input feature map. num_classes (int): Number of total classes. - multiclass_loss (dict | None): Config of multi-class loss. - multilabel_loss (dict | None): Config of multi-label loss. thr (float | None): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. hid_cahnnels (int): Number of channels in the hidden feature map at the classifier. @@ -297,8 +206,6 @@ def __init__( empty_multiclass_head_indices: list[int], in_channels: int, num_classes: int, - multiclass_loss: nn.Module, - multilabel_loss: nn.Module | None = None, thr: float = 0.5, hid_channels: int = 1280, activation_callable: Callable[[], nn.Module] = nn.ReLU, @@ -314,8 +221,6 @@ def __init__( empty_multiclass_head_indices=empty_multiclass_head_indices, in_channels=in_channels, num_classes=num_classes, - multiclass_loss=multiclass_loss, - multilabel_loss=multilabel_loss, thr=thr, init_cfg=init_cfg, **kwargs, @@ -414,8 +319,6 @@ class HierarchicalCBAMClsHead(HierarchicalClsHead): due to the label removing in_channels (int): Number of channels in the input feature map. num_classes (int): Number of total classes. - multiclass_loss (nn.Module): Config of multi-class loss. - multilabel_loss (nn.Module | None, optional): Config of multi-label loss. thr (float, optional): Predictions with scores under the thresholds are considered as negative. Defaults to 0.5. init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. @@ -431,8 +334,6 @@ def __init__( empty_multiclass_head_indices: list[int], in_channels: int, num_classes: int, - multiclass_loss: nn.Module, - multilabel_loss: nn.Module | None = None, thr: float = 0.5, init_cfg: dict | None = None, step_size: int | tuple[int, int] = 7, @@ -446,8 +347,6 @@ def __init__( empty_multiclass_head_indices=empty_multiclass_head_indices, in_channels=in_channels, num_classes=num_classes, - multiclass_loss=multiclass_loss, - multilabel_loss=multilabel_loss, thr=thr, init_cfg=init_cfg, **kwargs, diff --git a/src/otx/algo/classification/heads/linear_head.py b/src/otx/algo/classification/heads/linear_head.py index 26397a43989..055f0a2d1e4 100644 --- a/src/otx/algo/classification/heads/linear_head.py +++ b/src/otx/algo/classification/heads/linear_head.py @@ -28,9 +28,6 @@ class LinearClsHead(BaseModule): num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. - loss (dict): Config of classification loss. Defaults to - ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. - topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. cal_acc (bool): Whether to calculate accuracy during training. If you use batch augmentations like Mixup and CutMix during training, it is pointless to calculate accuracy. @@ -43,8 +40,6 @@ def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, - topk: int | tuple = (1,), init_cfg: dict = {"type": "Normal", "layer": "Linear", "std": 0.01}, # noqa: B006 **kwargs, ): @@ -53,9 +48,6 @@ def __init__( self.init_cfg = copy.deepcopy(init_cfg) - self.topk = topk - self.loss_module = loss - self.in_channels = in_channels self.num_classes = num_classes @@ -72,28 +64,6 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: # The final classification head. return self.fc(feats) - ######################################################## - # Copy from mmpretrain.models.heads.cls_head.ClsHead - ######################################################## - - def loss(self, feats: tuple[torch.Tensor] | torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: - """Calculate losses from the classification score. - - Args: - feats (tuple[Tensor]): The features extracted from the backbone. - Multiple stage inputs are acceptable but only the last stage - will be used to classify. The shape of every item should be - ``(num_samples, num_classes)``. - **kwargs: Other keyword arguments to forward the loss module. - - Returns: - torch.Tensor: loss components - """ - cls_score = self(feats) - - loss = self.loss_module(cls_score, labels) - return loss.sum() / cls_score.size(0) - def predict( self, feats: tuple[torch.Tensor], diff --git a/src/otx/algo/classification/heads/multilabel_cls_head.py b/src/otx/algo/classification/heads/multilabel_cls_head.py index 2df5523d988..cc51dcd2d8d 100644 --- a/src/otx/algo/classification/heads/multilabel_cls_head.py +++ b/src/otx/algo/classification/heads/multilabel_cls_head.py @@ -11,7 +11,6 @@ from __future__ import annotations -import inspect from typing import Callable, Sequence import torch @@ -20,7 +19,6 @@ from otx.algo.modules.base_module import BaseModule from otx.algo.utils.weight_init import constant_init, normal_init -from otx.core.data.entity.base import ImageInfo class AnglularLinear(nn.Module): @@ -55,26 +53,17 @@ class MultiLabelClsHead(BaseModule): backbone network and predicts the class labels. Args: - BaseModule (class): The base module class. - - Attributes: - scale (float): The scaling factor for the classification score. - - Methods: - loss(feats, labels, **kwargs): Calculate losses from the classification score. - get_valid_label_mask(img_metas): Get valid label mask using ignored_label. - predict(feats, labels): Inference without augmentation. + num_classes (int): Number of categories. + in_channels (int): Number of channels in the input feature map. + normalized (bool): Normalize input features and weights. + init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. """ def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, normalized: bool = False, - scale: float = 1.0, - thr: float | None = None, - topk: int | None = None, init_cfg: dict | None = None, **kwargs, ): @@ -83,58 +72,6 @@ def __init__( self.num_classes = num_classes self.in_channels = in_channels self.normalized = normalized - self.scale = scale - self.loss_module = loss - self.is_ignored_label_loss = "valid_label_mask" in inspect.getfullargspec(self.loss_module.forward).args - - if thr is None and topk is None: - thr = 0.5 - - self.thr = thr - self.topk = topk - - def loss(self, feats: tuple[torch.Tensor], labels: torch.Tensor, **kwargs) -> torch.Tensor: - """Calculate losses from the classification score. - - Args: - feats (tuple[Tensor]): The features extracted from the backbone. - Multiple stage inputs are acceptable but only the last stage - will be used to classify. The shape of every item should be - ``(num_samples, num_classes)``. - labels (torch.Tensor): The annotation data of - every samples. - **kwargs: Other keyword arguments to forward the loss module. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - cls_score = self(feats) * self.scale - imgs_info = kwargs.pop("imgs_info", None) - if imgs_info is not None and self.is_ignored_label_loss: - kwargs["valid_label_mask"] = self.get_valid_label_mask(imgs_info).to(cls_score.device) - loss = self.loss_module(cls_score, labels, avg_factor=cls_score.size(0), **kwargs) - return loss / self.scale - - def get_valid_label_mask(self, img_metas: list[ImageInfo]) -> torch.Tensor: - """Get valid label mask using ignored_label. - - Args: - img_metas (list[ImageInfo]): The metadata of the input images. - - Returns: - torch.Tensor: The valid label mask. - """ - valid_label_mask = [] - for meta in img_metas: - mask = torch.Tensor([1 for _ in range(self.num_classes)]) - if meta.ignored_labels: - mask[meta.ignored_labels] = 0 - valid_label_mask.append(mask) - return torch.stack(valid_label_mask, dim=0) - - # ------------------------------------------------------------------------ # - # Copy from mmpretrain.models.heads.MultiLabelClsHead - # ------------------------------------------------------------------------ # def predict(self, feats: tuple[torch.Tensor], **kwargs) -> torch.Tensor: """Inference without augmentation. @@ -179,8 +116,7 @@ class MultiLabelLinearClsHead(MultiLabelClsHead): num_classes (int): Number of categories. in_channels (int): Number of channels in the input feature map. normalized (bool): Normalize input features and weights. - scale (float): positive scale parameter. - loss (dict): Config of classification loss. + init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. """ fc: nn.Module @@ -189,22 +125,14 @@ def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, normalized: bool = False, - scale: float = 1.0, - thr: float | None = None, - topk: int | None = None, init_cfg: dict | None = None, **kwargs, ): super().__init__( num_classes=num_classes, in_channels=in_channels, - loss=loss, normalized=normalized, - scale=scale, - thr=thr, - topk=topk, init_cfg=init_cfg, **kwargs, ) @@ -243,35 +171,26 @@ class MultiLabelNonLinearClsHead(MultiLabelClsHead): hid_channels (int): Number of channels in the hidden feature map. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to nn.ReLU. - scale (float): Positive scale parameter. - loss (dict): Config of classification loss. dropout (bool): Whether use the dropout or not. normalized (bool): Normalize input features and weights in the last linar layer. + init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None. """ def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, hid_channels: int = 1280, activation_callable: Callable[..., nn.Module] = nn.ReLU, - scale: float = 1.0, dropout: bool = False, normalized: bool = False, - thr: float | None = None, - topk: int | None = None, init_cfg: dict | None = None, **kwargs, ): super().__init__( num_classes=num_classes, in_channels=in_channels, - loss=loss, normalized=normalized, - scale=scale, - thr=thr, - topk=topk, init_cfg=init_cfg, **kwargs, ) diff --git a/src/otx/algo/classification/heads/semi_sl_head.py b/src/otx/algo/classification/heads/semi_sl_head.py index 7d8dccdafdf..6f233846dfc 100644 --- a/src/otx/algo/classification/heads/semi_sl_head.py +++ b/src/otx/algo/classification/heads/semi_sl_head.py @@ -42,7 +42,6 @@ class OTXSemiSLClsHead(nn.Module): def __init__( self, num_classes: int, - unlabeled_coef: float = 1.0, use_dynamic_threshold: bool = True, min_threshold: float = 0.5, ): @@ -50,14 +49,12 @@ def __init__( Args: num_classes (int): The number of classes. - unlabeled_coef (float, optional): The coefficient for the unlabeled loss. Defaults to 1.0. use_dynamic_threshold (bool, optional): Whether to use a dynamic threshold for pseudo-label selection. Defaults to True. min_threshold (float, optional): The minimum threshold for pseudo-label selection. Defaults to 0.5. """ self.num_classes = num_classes - self.unlabeled_coef = unlabeled_coef self.use_dynamic_threshold = use_dynamic_threshold self.min_threshold = ( min_threshold if self.use_dynamic_threshold else 0.95 @@ -65,37 +62,6 @@ def __init__( self.num_pseudo_label = 0 self.classwise_acc = torch.ones((self.num_classes,)) * self.min_threshold - def loss( - self, - feats: dict[str, torch.Tensor] | tuple[torch.Tensor] | torch.Tensor, - labels: dict[str, torch.Tensor] | torch.Tensor, - **kwargs, - ) -> torch.Tensor: - """Computes the loss function in which unlabeled data is considered. - - Args: - feats (dict[str, Tensor] | Tensor): Input features. - labels (dict[str, Tensor] | Tensor): Target features. - **kwargs: Additional keyword arguments. - - Returns: - Tensor: The computed loss. - """ - logits, labels, pseudo_label, mask = self.get_logits(feats, labels) - logits_x, logits_u_s = logits - num_samples = len(logits_x) - - # compute supervised loss - labeled_loss = self.loss_module(logits_x, labels).sum() / num_samples - - unlabeled_loss = torch.tensor(0.0) - if len(logits_u_s) > 0 and self.num_pseudo_label > 0 and mask is not None: - # compute unsupervised loss - unlabeled_loss = (self.loss_module(logits_u_s, pseudo_label) * mask).sum() / mask.sum().item() - unlabeled_loss.masked_fill_(torch.isnan(unlabeled_loss), 0.0) - - return labeled_loss + self.unlabeled_coef * unlabeled_loss - def get_logits( self, feats: dict[str, torch.Tensor] | tuple[torch.Tensor] | torch.Tensor, @@ -159,23 +125,20 @@ def get_logits( return logits, labels, label_u, mask -class OTXSemiSLLinearClsHead(OTXSemiSLClsHead, LinearClsHead): +class SemiSLLinearClsHead(OTXSemiSLClsHead, LinearClsHead): """LinearClsHead for OTXSemiSLClsHead.""" def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, - unlabeled_coef: float = 1, use_dynamic_threshold: bool = True, min_threshold: float = 0.5, ): - LinearClsHead.__init__(self, num_classes=num_classes, in_channels=in_channels, loss=loss) + LinearClsHead.__init__(self, num_classes=num_classes, in_channels=in_channels) OTXSemiSLClsHead.__init__( self, num_classes=num_classes, - unlabeled_coef=unlabeled_coef, use_dynamic_threshold=use_dynamic_threshold, min_threshold=min_threshold, ) @@ -188,14 +151,11 @@ def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, hid_channels: int = 1280, - unlabeled_coef: float = 1, use_dynamic_threshold: bool = True, min_threshold: float = 0.5, ): self.num_classes = num_classes - self.loss_module = loss self.in_channels = in_channels self.hid_channels = hid_channels @@ -210,7 +170,6 @@ def __init__( OTXSemiSLClsHead.__init__( self, num_classes=num_classes, - unlabeled_coef=unlabeled_coef, use_dynamic_threshold=use_dynamic_threshold, min_threshold=min_threshold, ) @@ -230,15 +189,13 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor: return self.classifier(feats) -class OTXSemiSLVisionTransformerClsHead(OTXSemiSLClsHead, VisionTransformerClsHead): +class SemiSLVisionTransformerClsHead(OTXSemiSLClsHead, VisionTransformerClsHead): """VisionTransformerClsHead for OTXSemiSLClsHead.""" def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, - unlabeled_coef: float = 1, use_dynamic_threshold: bool = True, min_threshold: float = 0.5, hidden_dim: int | None = None, @@ -249,7 +206,6 @@ def __init__( self, num_classes=num_classes, in_channels=in_channels, - loss=loss, hidden_dim=hidden_dim, init_cfg=init_cfg, **kwargs, @@ -257,7 +213,6 @@ def __init__( OTXSemiSLClsHead.__init__( self, num_classes=num_classes, - unlabeled_coef=unlabeled_coef, use_dynamic_threshold=use_dynamic_threshold, min_threshold=min_threshold, ) diff --git a/src/otx/algo/classification/heads/vision_transformer_head.py b/src/otx/algo/classification/heads/vision_transformer_head.py index a4b9950b260..f8ec717c367 100644 --- a/src/otx/algo/classification/heads/vision_transformer_head.py +++ b/src/otx/algo/classification/heads/vision_transformer_head.py @@ -34,15 +34,11 @@ def __init__( self, num_classes: int, in_channels: int, - loss: nn.Module, - topk: int | tuple = (1,), hidden_dim: int | None = None, init_cfg: dict = {"type": "Constant", "layer": "Linear", "val": 0}, # noqa: B006 **kwargs, ): super().__init__(init_cfg=init_cfg, **kwargs) - self.topk = topk - self.loss_module = loss self.in_channels = in_channels self.num_classes = num_classes @@ -98,24 +94,6 @@ def forward(self, feats: tuple[list[torch.Tensor]]) -> torch.Tensor: # The final classification head. return self.layers.head(pre_logits) - def loss(self, feats: tuple[torch.Tensor] | torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: - """Calculate losses from the classification score. - - Args: - feats (tuple[Tensor]): The features extracted from the backbone. - Multiple stage inputs are acceptable but only the last stage - will be used to classify. The shape of every item should be - ``(num_samples, num_classes)``. - **kwargs: Other keyword arguments to forward the loss module. - - Returns: - torch.Tensor: loss components - """ - cls_score = self(feats) - - loss = self.loss_module(cls_score, labels) - return loss.sum() / cls_score.size(0) - def predict( self, feats: tuple[torch.Tensor], diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index 40f92c594a4..657dbd1b48c 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -14,12 +14,12 @@ from torch import Tensor, nn from otx.algo.classification.backbones import OTXMobileNetV3 -from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier +from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, LinearClsHead, MultiLabelNonLinearClsHead, - OTXSemiSLLinearClsHead, + SemiSLLinearClsHead, ) from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore from otx.algo.classification.necks.gap import GlobalAveragePooling @@ -34,8 +34,6 @@ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, ) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable @@ -107,17 +105,16 @@ def _create_model(self) -> nn.Module: def _build_model(self, num_classes: int) -> nn.Module: backbone = OTXMobileNetV3(mode=self.mode, input_size=self.input_size) neck = GlobalAveragePooling(dim=2) - loss = nn.CrossEntropyLoss(reduction="none") in_channels = 960 if self.mode == "large" else 576 if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( backbone=backbone, neck=neck, - head=OTXSemiSLLinearClsHead( + head=SemiSLLinearClsHead( num_classes=num_classes, in_channels=in_channels, - loss=loss, ), + loss=nn.CrossEntropyLoss(reduction="none"), ) return ImageClassifier( @@ -126,9 +123,8 @@ def _build_model(self, num_classes: int) -> nn.Module: head=LinearClsHead( num_classes=num_classes, in_channels=in_channels, - topk=(1, 5) if num_classes >= 5 else (1,), - loss=loss, ), + loss=nn.CrossEntropyLoss(), ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -203,10 +199,10 @@ def _build_model(self, num_classes: int) -> nn.Module: in_channels=960, hid_channels=1280, normalized=True, - scale=7.0, activation_callable=nn.PReLU(), - loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), ), + loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + loss_scale=7.0, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -248,22 +244,6 @@ def _customize_outputs( labels=logits.argmax(-1, keepdim=True).unbind(0), ) - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, *self.input_size), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") @@ -335,16 +315,15 @@ def _build_model(self, head_config: dict) -> nn.Module: copied_head_config = copy(head_config) copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32)) - return ImageClassifier( + return HLabelClassifier( backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size), neck=nn.Identity(), head=HierarchicalCBAMClsHead( in_channels=960, - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **copied_head_config, ), - optimize_gap=False, + multiclass_loss=nn.CrossEntropyLoss(), + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -410,22 +389,6 @@ def _convert_pred_entity_to_compute_metric( "target": torch.stack(inputs.labels), } - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, *self.input_size), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, - ) - def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") diff --git a/src/otx/algo/classification/timm_model.py b/src/otx/algo/classification/timm_model.py index c31be7cc474..10d435e1834 100644 --- a/src/otx/algo/classification/timm_model.py +++ b/src/otx/algo/classification/timm_model.py @@ -11,12 +11,12 @@ from torch import nn from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType -from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier +from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, LinearClsHead, MultiLabelLinearClsHead, - OTXSemiSLLinearClsHead, + SemiSLLinearClsHead, ) from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore from otx.algo.classification.necks.gap import GlobalAveragePooling @@ -92,16 +92,15 @@ def _create_model(self) -> nn.Module: def _build_model(self, num_classes: int) -> nn.Module: backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) neck = GlobalAveragePooling(dim=2) - loss = nn.CrossEntropyLoss(reduction="none") if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( backbone=backbone, neck=neck, - head=OTXSemiSLLinearClsHead( + head=SemiSLLinearClsHead( num_classes=num_classes, in_channels=backbone.num_features, - loss=loss, ), + loss=nn.CrossEntropyLoss(reduction="none"), ) return ImageClassifier( @@ -110,9 +109,8 @@ def _build_model(self, num_classes: int) -> nn.Module: head=LinearClsHead( num_classes=num_classes, in_channels=backbone.num_features, - topk=(1, 5) if num_classes >= 5 else (1,), - loss=loss, ), + loss=nn.CrossEntropyLoss(), ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -190,9 +188,9 @@ def _build_model(self, num_classes: int) -> nn.Module: num_classes=num_classes, in_channels=backbone.num_features, normalized=True, - scale=7.0, - loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), ), + loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + loss_scale=7.0, ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: @@ -268,16 +266,15 @@ def _create_model(self) -> nn.Module: def _build_model(self, head_config: dict) -> nn.Module: backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) - return ImageClassifier( + return HLabelClassifier( backbone=backbone, neck=nn.Identity(), head=HierarchicalCBAMClsHead( in_channels=backbone.num_features, - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), **head_config, ), - optimize_gap=False, + multiclass_loss=nn.CrossEntropyLoss(), + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index fd4a9e6f192..ff429a29ce7 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -5,16 +5,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from copy import deepcopy +from typing import TYPE_CHECKING, Literal import torch -from torch import Tensor, nn -from torchvision.models import get_model, get_model_weights - -from otx.algo.classification.heads import HierarchicalCBAMClsHead, MultiLabelLinearClsHead, OTXSemiSLLinearClsHead +from torch import nn + +from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType +from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier +from otx.algo.classification.heads import ( + HierarchicalCBAMClsHead, + LinearClsHead, + MultiLabelLinearClsHead, + SemiSLLinearClsHead, +) from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore -from otx.algo.explain.explain_algo import ReciproCAM, feature_vector_fn -from otx.core.data.entity.base import OTXBatchLossEntity +from otx.algo.classification.necks.gap import GlobalAveragePooling +from otx.algo.classification.utils import get_classification_layers from otx.core.data.entity.classification import ( HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, @@ -23,421 +30,164 @@ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, ) -from otx.core.exporter.base import OTXModelExporter -from otx.core.exporter.native import OTXNativeModelExporter -from otx.core.metrics import MetricInput from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable -from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel +from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable +from otx.core.model.classification import ( + OTXHlabelClsModel, + OTXMulticlassClsModel, + OTXMultilabelClsModel, +) from otx.core.schedulers import LRSchedulerListCallable -from otx.core.types.export import TaskLevelExportParameters from otx.core.types.label import HLabelInfo, LabelInfoTypes -from otx.core.types.task import OTXTaskType, OTXTrainType +from otx.core.types.task import OTXTrainType if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable -CLASSIFICATION_BATCH_DATA_ENTITY = ( - MulticlassClsBatchDataEntity | MultilabelClsBatchDataEntity | HlabelClsBatchDataEntity -) -CLASSIFICATION_BATCH_PRED_ENTITY = ( - MulticlassClsBatchPredEntity | MultilabelClsBatchPredEntity | HlabelClsBatchPredEntity -) + from otx.core.metrics import MetricCallable -TVModelType = Literal[ - "alexnet", - "convnext_base", - "convnext_large", - "convnext_small", - "convnext_tiny", - "efficientnet_b0", - "efficientnet_b1", - "efficientnet_b2", - "efficientnet_b3", - "efficientnet_b4", - "efficientnet_b5", - "efficientnet_b6", - "efficientnet_b7", - "efficientnet_v2_l", - "efficientnet_v2_m", - "efficientnet_v2_s", - "googlenet", - "mobilenet_v3_large", - "mobilenet_v3_small", - "regnet_x_16gf", - "regnet_x_1_6gf", - "regnet_x_32gf", - "regnet_x_3_2gf", - "regnet_x_400mf", - "regnet_x_800mf", - "regnet_x_8gf", - "regnet_y_128gf", - "regnet_y_16gf", - "regnet_y_1_6gf", - "regnet_y_32gf", - "regnet_y_3_2gf", - "regnet_y_400mf", - "regnet_y_800mf", - "regnet_y_8gf", - "resnet101", - "resnet152", - "resnet18", - "resnet34", - "resnet50", - "resnext101_32x8d", - "resnext101_64x4d", - "resnext50_32x4d", - "swin_b", - "swin_s", - "swin_t", - "swin_v2_b", - "swin_v2_s", - "swin_v2_t", - "vgg11", - "vgg11_bn", - "vgg13", - "vgg13_bn", - "vgg16", - "vgg16_bn", - "vgg19", - "vgg19_bn", - "wide_resnet101_2", - "wide_resnet50_2", -] - - -class TVClassificationModel(nn.Module): - """TorchVision Model with Loss Computation. - - This class represents a TorchVision model with loss computation for classification tasks. - It takes a backbone model, number of classes, and an optional loss function as input. +class TVModelForMulticlassCls(OTXMulticlassClsModel): + """Torchvision model for multiclass classification. Args: - backbone (TVModelType): The backbone model to use for feature extraction. - num_classes (int): The number of classes for the classification task. - loss (Callable | None, optional): The loss function to use. - freeze_backbone (bool, optional): Whether to freeze the backbone model. Defaults to False. - task (Literal[OTXTaskType.MULTI_CLASS_CLS, OTXTaskType.MULTI_LABEL_CLS, OTXTaskType.H_LABEL_CLS], optional): - The type of classification task. - train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The type of training. - head_config (dict | None, optional): The configuration for the head module. - - Methods: - forward(images: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - Performs forward pass of the model. - + label_info (LabelInfoTypes): Information about the labels. + backbone (TVModelType): Backbone model for feature extraction. + pretrained (bool, optional): Whether to use pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): Optimizer for model training. Defaults to DefaultOptimizerCallable. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Learning rate scheduler. + Defaults to DefaultSchedulerCallable. + metric (MetricCallable, optional): Metric for model evaluation. Defaults to MultiClassClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): Type of training. + Defaults to OTXTrainType.SUPERVISED. + input_size (tuple[int, int], optional): Input size of the images. Defaults to (224, 224). + + Attributes: + backbone (TVModelType): Backbone model for feature extraction. + pretrained (bool): Whether to use pretrained weights. + classification_layers (nn.ModuleDict): Classification layers for class-incremental learning. """ def __init__( self, + label_info: LabelInfoTypes, backbone: TVModelType, - num_classes: int, - loss: nn.Module, - freeze_backbone: bool = False, - task: Literal[ - OTXTaskType.MULTI_CLASS_CLS, - OTXTaskType.MULTI_LABEL_CLS, - OTXTaskType.H_LABEL_CLS, - ] = OTXTaskType.MULTI_CLASS_CLS, + pretrained: bool = True, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MultiClassClsMetricCallable, + torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, - head_config: dict | None = None, + input_size: tuple[int, int] = (224, 224), ) -> None: - super().__init__() - self.num_classes = num_classes - self.task = task - self.train_type = train_type - self.head_config = head_config if head_config else {} - - net = get_model(name=backbone, weights=get_model_weights(backbone)) - - self.backbone = nn.Sequential(*list(net.children())[:-1]) - self.use_layer_norm_2d = False - - if freeze_backbone: - for param in self.backbone.parameters(): - param.requires_grad = False - - self.softmax = nn.Softmax(dim=-1) - self.loss_module = loss - self.neck: nn.Module | None = None - self.head = self._get_head(net) - - avgpool_index = 0 - for i, layer in enumerate(self.backbone.children()): - if isinstance(layer, nn.AdaptiveAvgPool2d): - avgpool_index = i - self.feature_extractor = nn.Sequential(*list(self.backbone.children())[:avgpool_index]) - self.avgpool = nn.Sequential( - *list(self.backbone.children())[avgpool_index:], - ) # Avgpool and Dropout (if the model has it) - - self.explainer = ReciproCAM( - self._head_forward_fn, - num_classes=num_classes, - optimize_gap=True, - ) - - def _get_head(self, net: nn.Module) -> nn.Module: - """Returns the head of the model.""" - last_layer = list(net.children())[-1] - layers = [] - classifier_len = len(list(last_layer.children())) - if classifier_len >= 1: - feature_channel = list(last_layer.children())[-1].in_features - layers = list(last_layer.children())[:-1] - self.use_layer_norm_2d = layers[0].__class__.__name__ == "LayerNorm2d" - else: - feature_channel = last_layer.in_features - if self.task == OTXTaskType.MULTI_CLASS_CLS: - if self.train_type == OTXTrainType.SEMI_SUPERVISED: - self.neck = nn.Sequential(*layers) if layers else None - return OTXSemiSLLinearClsHead( - num_classes=self.num_classes, - in_channels=feature_channel, - loss=self.loss_module, - ) - if classifier_len >= 1: - return nn.Sequential(*layers, nn.Linear(feature_channel, self.num_classes)) - return nn.Linear(feature_channel, self.num_classes) - if self.task == OTXTaskType.MULTI_LABEL_CLS: - self.neck = nn.Sequential(*layers) if layers else None - return MultiLabelLinearClsHead( - num_classes=self.num_classes, - in_channels=feature_channel, - scale=7.0, - normalized=True, - loss=self.loss_module, - ) - if self.task == OTXTaskType.H_LABEL_CLS: - self.neck = nn.Sequential(*layers, nn.Identity()) if layers else None - return HierarchicalCBAMClsHead( - in_channels=feature_channel, - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=self.loss_module, - **self.head_config, - step_size=1, - ) + self.backbone = backbone + self.pretrained = pretrained - msg = f"Task type {self.task} is not supported." - raise NotImplementedError(msg) + super().__init__( + label_info=label_info, + input_size=input_size, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + train_type=train_type, + ) - def forward( - self, - images: torch.Tensor | dict[str, torch.Tensor], - labels: torch.Tensor | dict[str, torch.Tensor] | None = None, - mode: str = "tensor", - **kwargs, - ) -> dict[str, tuple | torch.Tensor] | tuple | torch.Tensor: - """Performs forward pass of the model. - - Args: - images (torch.Tensor | dict[str, torch.Tensor]): The input images. - if Semi-SL task with multi-augmentation, it will comes as a dict. - labels (torch.Tensor | dict[str, torch.Tensor] | None, optional): The ground truth labels. - mode (str, optional): The mode of the forward pass. Defaults to "tensor". - - Returns: - torch.Tensor: The output logits or loss, depending on the training mode. - """ - if mode == "tensor": - return self.extract_feat(images, stage="head") - if mode == "loss": - feats = self.extract_feat(images, stage="neck") - return self.loss(feats, labels) - if mode == "predict": - feats = self.extract_feat(images, stage="neck") - return self.predict(feats, **kwargs) - if mode == "explain": - return self._forward_explain(images) - logits = self.extract_feat(images, stage="head") - return self.softmax(logits) - - def extract_feat( - self, - inputs: dict[str, torch.Tensor] | torch.Tensor, - stage: str = "neck", - ) -> dict[str, tuple | torch.Tensor] | tuple | torch.Tensor: - """Extract features from the input tensor with shape (N, C, ...). + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_model_dict = self._build_model(num_classes=5).state_dict() + incremental_model_dict = self._build_model(num_classes=6).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) - Args: - inputs (dict[str, torch.Tensor] | torch.Tensor): A batch of inputs. The shape of it should be - ``(num_samples, num_channels, *img_shape)``. - stage (str): Which stage to output the feature. Choose from: + model = self._build_model(num_classes=self.num_classes) + model.init_weights() + return model - - "backbone": The output of backbone network. Returns a tuple - including multiple stages features. - - "neck": The output of neck module. Returns a tuple including - multiple stages features. + def _build_model(self, num_classes: int) -> nn.Module: + backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained) + neck = GlobalAveragePooling(dim=2) - Defaults to "neck". + if self.train_type == OTXTrainType.SEMI_SUPERVISED: + return SemiSLClassifier( + backbone=backbone, + neck=neck, + head=SemiSLLinearClsHead( + num_classes=num_classes, + in_channels=backbone.in_features, + ), + loss=nn.CrossEntropyLoss(reduction="none"), + ) - Returns: - dict[str, tuple | torch.Tensor] | tuple | torch.Tensor: The output of specified stage. - The output depends on detailed implementation. In general, the - output of backbone and neck is a tuple. - """ - if isinstance(inputs, dict): - return self._extract_feat_with_unlabeled(inputs, stage) + return ImageClassifier( + backbone=backbone, + neck=neck, + head=LinearClsHead( + num_classes=num_classes, + in_channels=backbone.in_features, + ), + loss=nn.CrossEntropyLoss(), + ) - x = self._flatten_outputs(self.backbone(inputs)) + def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") - if stage == "backbone": - return x + return MulticlassClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) - if self.neck is not None: - x = self.neck(x) + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") - if stage == "neck": - return x + return self.model(images=image, mode="tensor") - return self.head(x) - def _extract_feat_with_unlabeled( - self, - images: dict[str, torch.Tensor], - stage: str = "neck", - ) -> dict[str, torch.Tensor]: - if "labeled" not in images or "weak_transforms" not in images or "strong_transforms" not in images: - msg = "The input dictionary should contain 'labeled', 'weak_transforms', and 'strong_transforms' keys." - raise ValueError(msg) - - labeled_inputs = images["labeled"] - unlabeled_weak_inputs = images["weak_transforms"] - unlabeled_strong_inputs = images["strong_transforms"] - - x = {} - x["labeled"] = self.extract_feat(labeled_inputs, stage) - # For weak augmentation inputs, use no_grad to use as a pseudo-label. - with torch.no_grad(): - x["unlabeled_weak"] = self.extract_feat(unlabeled_weak_inputs, stage) - x["unlabeled_strong"] = self.extract_feat(unlabeled_strong_inputs, stage) - return x - - def loss( - self, - inputs: dict[str, torch.Tensor] | tuple | torch.Tensor, - labels: dict[str, torch.Tensor] | torch.Tensor, - ) -> torch.Tensor: - """Calculates the loss of the model. - - Args: - inputs (dict[str, torch.Tensor] | tuple | torch.Tensor): The outputs of the model backbone. - labels (dict[str, torch.Tensor] | torch.Tensor): The ground truth labels. - - Returns: - torch.Tensor: The computed loss. - """ - if hasattr(self.head, "loss"): - return self.head.loss(inputs, labels) - logits = self.head(inputs) - return self.loss_module(logits, labels).sum() / logits.size(0) - - def predict(self, inputs: torch.Tensor, **kwargs) -> list[torch.Tensor]: - """Predict results from a batch of inputs. - - Args: - inputs (torch.Tensor): The input tensor with shape - (N, C, ...) in general. - **kwargs: Other keyword arguments accepted by the ``predict`` - method of :attr:`head`. - """ - if hasattr(self.head, "predict"): - return self.head.predict(inputs, **kwargs) - return self.softmax(self.head(inputs)) - - @torch.no_grad() - def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]: - backbone_feat = self.feature_extractor(images) - - saliency_map = self.explainer.func(backbone_feat) - feature_vector = feature_vector_fn(backbone_feat) - - x = self._flatten_outputs(self.avgpool(backbone_feat)) - if self.neck is not None: - x = self.neck(x) - logits = self.head(x) - pred_results = self.head.predict(x) if hasattr(self.head, "predict") else self.softmax(logits) - - outputs = { - "logits": logits, - "feature_vector": feature_vector, - "saliency_map": saliency_map, - } - - if not torch.jit.is_tracing(): - if isinstance(pred_results, dict): - outputs["scores"] = pred_results["scores"] - outputs["preds"] = pred_results["labels"] - else: - outputs["scores"] = pred_results.unbind(0) - outputs["preds"] = logits.argmax(-1, keepdim=True).unbind(0) - - return outputs - - @torch.no_grad() - def _head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: - """Performs model's neck and head forward.""" - x = self._flatten_outputs(self.avgpool(x)) - if self.neck is not None: - x = self.neck(x) - return self.head(x) - - def _flatten_outputs(self, x: torch.Tensor) -> torch.Tensor: - # Flatten the output - if len(x.shape) == 4 and not self.use_layer_norm_2d: # If feats is a 4D tensor: (b, c, h, w) - x = x.view(x.size(0), -1) # Flatten the output of the backbone: (b, f) - return x - - -class OTXTVModel(OTXModel): - """OTXTVModel is that represents a TorchVision model for classification. +class TVModelForMultilabelCls(OTXMultilabelClsModel): + """Torchvision model for multilabel classification. Args: - backbone (TVModelType): The backbone architecture of the model. - label_info (LabelInfoTypes): The number of classes for classification. - optimizer (OptimizerCallable, optional): The optimizer to use for training. - Defaults to DefaultOptimizerCallable. - scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler to use. + label_info (LabelInfoTypes): Information about the labels. + backbone (TVModelType): Backbone model for feature extraction. + pretrained (bool, optional): Whether to use pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): Optimizer for model training. Defaults to DefaultOptimizerCallable. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): Learning rate scheduler. Defaults to DefaultSchedulerCallable. - metric (MetricCallable, optional): The metric to use for evaluation. Defaults to MultiClassClsMetricCallable. + metric (MetricCallable, optional): Metric for model evaluation. Defaults to MultiLabelClsMetricCallable. torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. - freeze_backbone (bool, optional): Whether to freeze the backbone model. Defaults to False. - task (Literal[OTXTaskType.MULTI_CLASS_CLS, OTXTaskType.MULTI_LABEL_CLS, OTXTaskType.H_LABEL_CLS], optional): - The type of classification task. - train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The type of training. - input_size (tuple[int, int], optional): - Model input size in the order of height and width. Defaults to (224, 224) - """ + input_size (tuple[int, int], optional): Input size of the images. Defaults to (224, 224). - model: TVClassificationModel + Attributes: + backbone (TVModelType): Backbone model for feature extraction. + pretrained (bool): Whether to use pretrained weights. + input_size (tuple[int, int]): Input size of the images. + """ def __init__( self, - backbone: TVModelType, label_info: LabelInfoTypes, + backbone: TVModelType, + pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MultiLabelClsMetricCallable, torch_compile: bool = False, - freeze_backbone: bool = False, - task: Literal[ - OTXTaskType.MULTI_CLASS_CLS, - OTXTaskType.MULTI_LABEL_CLS, - OTXTaskType.H_LABEL_CLS, - ] = OTXTaskType.MULTI_CLASS_CLS, - train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, input_size: tuple[int, int] = (224, 224), ) -> None: self.backbone = backbone - self.freeze_backbone = freeze_backbone - self.task = task - - # TODO(@harimkang): Need to make it configurable. - if task == OTXTaskType.MULTI_CLASS_CLS: - metric = MultiClassClsMetricCallable - elif task == OTXTaskType.MULTI_LABEL_CLS: - metric = MultiLabelClsMetricCallable - elif task == OTXTaskType.H_LABEL_CLS: - metric = HLabelClsMetricCallble + self.pretrained = pretrained super().__init__( label_info=label_info, @@ -446,224 +196,151 @@ def __init__( metric=metric, torch_compile=torch_compile, input_size=input_size, - train_type=train_type, ) self.input_size: tuple[int, int] def _create_model(self) -> nn.Module: - if self.task == OTXTaskType.MULTI_CLASS_CLS: - loss = nn.CrossEntropyLoss(reduction="none") - else: - loss = AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum") - head_config = {} - if self.task == OTXTaskType.H_LABEL_CLS: - if not isinstance(self.label_info, HLabelInfo): - msg = "LabelInfo should be HLabelInfo for hierarchical classification." - raise ValueError(msg) - head_config = self.label_info.as_head_config_dict() - - return TVClassificationModel( - backbone=self.backbone, - num_classes=self.num_classes, - loss=loss, - freeze_backbone=self.freeze_backbone, - task=self.task, - train_type=self.train_type, - head_config=head_config, + # Get classification_layers for class-incr learning + sample_model_dict = self._build_model(num_classes=5).state_dict() + incremental_model_dict = self._build_model(num_classes=6).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", ) - def _customize_inputs(self, inputs: CLASSIFICATION_BATCH_DATA_ENTITY) -> dict[str, Any]: - if self.training: - mode = "loss" - elif self.explain_mode: - mode = "explain" - else: - mode = "predict" - - if isinstance(inputs, dict): - # When used with an unlabeled dataset, it comes in as a dict. - images = {key: inputs[key].stacked_images for key in inputs} - labels = {key: torch.cat(inputs[key].labels, dim=0) for key in inputs} - imgs_info = {key: inputs[key].imgs_info for key in inputs} - return { - "images": images, - "labels": labels, - "imgs_info": imgs_info, - "mode": mode, - } - - labels = ( - torch.cat(inputs.labels, dim=0) if self.task == OTXTaskType.MULTI_CLASS_CLS else torch.stack(inputs.labels) + model = self._build_model(num_classes=self.num_classes) + model.init_weights() + return model + + def _build_model(self, num_classes: int) -> nn.Module: + backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained) + return ImageClassifier( + backbone=backbone, + neck=GlobalAveragePooling(dim=2), + head=MultiLabelLinearClsHead( + num_classes=num_classes, + in_channels=backbone.in_features, + normalized=True, + ), + loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + loss_scale=7.0, ) - return { - "images": inputs.stacked_images, - "labels": labels, - "imgs_info": inputs.imgs_info, - "mode": mode, - } - - def _customize_outputs( - self, - outputs: Any, # noqa: ANN401 - inputs: CLASSIFICATION_BATCH_DATA_ENTITY, - ) -> CLASSIFICATION_BATCH_PRED_ENTITY | OTXBatchLossEntity: - if self.training: - return OTXBatchLossEntity(loss=outputs) - - # To list, batch-wise - if self.task == OTXTaskType.H_LABEL_CLS: - # To list, batch-wise - if isinstance(outputs, dict): - scores = outputs["scores"] - labels = outputs["labels"] - else: - scores = outputs - labels = outputs.argmax(-1, keepdim=True).unbind(0) - else: - logits = outputs if isinstance(outputs, torch.Tensor) else outputs["logits"] - scores = torch.unbind(logits, 0) - labels = logits.argmax(-1, keepdim=True).unbind(0) - - entity_kwargs = { - "batch_size": inputs.batch_size, - "images": inputs.images, - "imgs_info": inputs.imgs_info, - "scores": scores, - "labels": labels, - } - - if self.task == OTXTaskType.MULTI_CLASS_CLS: - return MulticlassClsBatchPredEntity(**entity_kwargs) - if self.task == OTXTaskType.MULTI_LABEL_CLS: - return MultilabelClsBatchPredEntity(**entity_kwargs) - if self.task == OTXTaskType.H_LABEL_CLS: - return HlabelClsBatchPredEntity(**entity_kwargs) - msg = f"Task type {self.task} is not supported." - raise NotImplementedError(msg) - - @property - def _export_parameters(self) -> TaskLevelExportParameters: - """Defines parameters required to export a particular model implementation.""" - export_params = { - "model_type": "Classification", - "task_type": "classification", - "multilabel": self.task == OTXTaskType.MULTI_LABEL_CLS, - "hierarchical": self.task == OTXTaskType.H_LABEL_CLS, - } - - return super()._export_parameters.wrap(**export_params) - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter( - task_level_export_parameters=self._export_parameters, - input_size=(1, 3, *self.input_size), - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - resize_mode="standard", - pad_value=0, - swap_rgb=False, - via_onnx=False, - onnx_export_configuration=None, - output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + + def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") + + return MultilabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], ) - def training_step(self, batch: CLASSIFICATION_BATCH_DATA_ENTITY, batch_idx: int) -> Tensor: - """Performs a single training step on a batch of data. + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") - Args: - batch (MulticlassClsBatchDataEntity): The input batch of data. - batch_idx (int): The index of the current batch. + return self.model(images=image, mode="tensor") - Returns: - Tensor: The computed loss for the training step. - """ - loss = super().training_step(batch, batch_idx) - # Collect metrics related to Semi-SL Training. - if self.train_type == OTXTrainType.SEMI_SUPERVISED: - self.log( - "train/unlabeled_coef", - self.model.head.unlabeled_coef, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - self.log( - "train/num_pseudo_label", - self.model.head.num_pseudo_label, - on_step=True, - on_epoch=False, - prog_bar=True, - ) - return loss - def forward_explain(self, inputs: CLASSIFICATION_BATCH_DATA_ENTITY) -> CLASSIFICATION_BATCH_PRED_ENTITY: +class TVModelForHLabelCls(OTXHlabelClsModel): + """TVModelForHLabelCls class represents a Torchvision model for hierarchical label classification. + + Args: + label_info (HLabelInfo): Information about the hierarchical labels. + backbone (TVModelType): The type of Torchvision backbone model. + pretrained (bool, optional): Whether to use pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable. Defaults to DefaultOptimizerCallable. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + Defaults to DefaultSchedulerCallable. + metric (MetricCallable, optional): The metric callable. Defaults to HLabelClsMetricCallble. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + input_size (tuple[int, int], optional): The input size of the images. Defaults to (224, 224). + + Attributes: + backbone (TVModelType): The type of Torchvision backbone model. + pretrained (bool): Whether to use pretrained weights. + classification_layers (nn.Module): The classification layers for class-incremental learning. + """ + + label_info: HLabelInfo + + def __init__( + self, + label_info: HLabelInfo, + backbone: TVModelType, + pretrained: bool = True, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = HLabelClsMetricCallble, + torch_compile: bool = False, + input_size: tuple[int, int] = (224, 224), + ) -> None: + self.backbone = backbone + self.pretrained = pretrained + + super().__init__( + label_info=label_info, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + input_size=input_size, + ) + + def _create_model(self) -> nn.Module: + # Get classification_layers for class-incr learning + sample_config = deepcopy(self.label_info.as_head_config_dict()) + sample_config["num_classes"] = 5 + sample_model_dict = self._build_model(head_config=sample_config).state_dict() + sample_config["num_classes"] = 6 + incremental_model_dict = self._build_model(head_config=sample_config).state_dict() + self.classification_layers = get_classification_layers( + sample_model_dict, + incremental_model_dict, + prefix="model.", + ) + + model = self._build_model(head_config=self.label_info.as_head_config_dict()) + model.init_weights() + return model + + def _build_model(self, head_config: dict) -> nn.Module: + backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained) + return HLabelClassifier( + backbone=backbone, + neck=nn.Identity(), + head=HierarchicalCBAMClsHead( + in_channels=backbone.in_features, + **head_config, + ), + multiclass_loss=nn.CrossEntropyLoss(), + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + ) + + def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPredEntity: """Model forward explain function.""" outputs = self.model(images=inputs.stacked_images, mode="explain") - entity_kwargs = { - "batch_size": inputs.batch_size, - "images": inputs.images, - "imgs_info": inputs.imgs_info, - "labels": outputs["preds"], - "scores": outputs["scores"], - "saliency_map": outputs["saliency_map"], - "feature_vector": outputs["feature_vector"], - } - - if self.task == OTXTaskType.MULTI_CLASS_CLS: - return MulticlassClsBatchPredEntity(**entity_kwargs) - if self.task == OTXTaskType.MULTI_LABEL_CLS: - return MultilabelClsBatchPredEntity(**entity_kwargs) - if self.task == OTXTaskType.H_LABEL_CLS: - return HlabelClsBatchPredEntity(**entity_kwargs) - msg = f"Task type {self.task} is not supported." - raise NotImplementedError(msg) - - def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + + return HlabelClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) + + def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]: """Model forward function used for the model tracing during model exportation.""" if self.explain_mode: return self.model(images=image, mode="explain") return self.model(images=image, mode="tensor") - - def _convert_pred_entity_to_compute_metric( - self, - preds: CLASSIFICATION_BATCH_PRED_ENTITY, - inputs: CLASSIFICATION_BATCH_DATA_ENTITY, - ) -> MetricInput: - if self.task == OTXTaskType.MULTI_CLASS_CLS: - pred = torch.tensor(preds.labels) - target = torch.tensor(inputs.labels) - elif self.task == OTXTaskType.MULTI_LABEL_CLS: - pred = torch.stack(preds.scores) - target = torch.stack(inputs.labels) - elif self.task == OTXTaskType.H_LABEL_CLS: - hlabel_info: HLabelInfo = self.label_info # type: ignore[assignment] - _labels = torch.stack(preds.labels) if isinstance(preds.labels, list) else preds.labels - _scores = torch.stack(preds.scores) if isinstance(preds.scores, list) else preds.scores - target = torch.stack(inputs.labels) - if hlabel_info.num_multilabel_classes > 0: - preds_multiclass = _labels[:, : hlabel_info.num_multiclass_heads] - preds_multilabel = _scores[:, hlabel_info.num_multiclass_heads :] - pred = torch.cat([preds_multiclass, preds_multilabel], dim=1) - else: - pred = _labels - return { - "preds": pred, - "target": target, - } - - def get_dummy_input(self, batch_size: int = 1) -> CLASSIFICATION_BATCH_DATA_ENTITY: - """Returns a dummy input for classification model.""" - images = [torch.rand(3, *self.input_size) for _ in range(batch_size)] - labels = [torch.LongTensor([0])] * batch_size - - if self.task == OTXTaskType.MULTI_CLASS_CLS: - return MulticlassClsBatchDataEntity(batch_size, images, [], labels=labels) - if self.task == OTXTaskType.MULTI_LABEL_CLS: - return MultilabelClsBatchDataEntity(batch_size, images, [], labels=labels) - if self.task == OTXTaskType.H_LABEL_CLS: - return HlabelClsBatchDataEntity(batch_size, images, [], labels=labels) - msg = f"Task type {self.task} is not supported." - raise NotImplementedError(msg) diff --git a/src/otx/algo/classification/utils/ignored_labels.py b/src/otx/algo/classification/utils/ignored_labels.py new file mode 100644 index 00000000000..fc19e9bd363 --- /dev/null +++ b/src/otx/algo/classification/utils/ignored_labels.py @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Util functions related to ignored_labels.""" + +from __future__ import annotations + +import torch + +from otx.core.data.entity.base import ImageInfo + + +def get_valid_label_mask(img_metas: list[ImageInfo], num_classes: int) -> torch.Tensor: + """Get valid label mask using ignored_label. + + Args: + img_metas (list[ImageInfo]): The metadata of the input images. + + Returns: + torch.Tensor: The valid label mask. + """ + valid_label_mask = [] + for meta in img_metas: + mask = torch.Tensor([1 for _ in range(num_classes)]) + if meta.ignored_labels: + mask[meta.ignored_labels] = 0 + valid_label_mask.append(mask) + return torch.stack(valid_label_mask, dim=0) diff --git a/src/otx/algo/classification/vit.py b/src/otx/algo/classification/vit.py index 3e3b5c0ff70..ffa7cebead9 100644 --- a/src/otx/algo/classification/vit.py +++ b/src/otx/algo/classification/vit.py @@ -16,11 +16,11 @@ from torch.hub import download_url_to_file from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer -from otx.algo.classification.classifier import ImageClassifier, SemiSLClassifier +from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, MultiLabelLinearClsHead, - OTXSemiSLVisionTransformerClsHead, + SemiSLVisionTransformerClsHead, VisionTransformerClsHead, ) from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore @@ -284,11 +284,11 @@ def _build_model(self, num_classes: int) -> nn.Module: return SemiSLClassifier( backbone=vit_backbone, neck=None, - head=OTXSemiSLVisionTransformerClsHead( + head=SemiSLVisionTransformerClsHead( num_classes=num_classes, in_channels=vit_backbone.embed_dim, - loss=nn.CrossEntropyLoss(reduction="none"), ), + loss=nn.CrossEntropyLoss(reduction="none"), init_cfg=init_cfg, ) @@ -298,37 +298,8 @@ def _build_model(self, num_classes: int) -> nn.Module: head=VisionTransformerClsHead( num_classes=num_classes, in_channels=vit_backbone.embed_dim, - topk=(1, 5) if num_classes >= 5 else (1,), - loss=nn.CrossEntropyLoss(reduction="none"), - ), - init_cfg=init_cfg, - ) - - -class VisionTransformerForMulticlassClsSemiSL(VisionTransformerForMulticlassCls): - """VisionTransformer model for multiclass classification with semi-supervised learning. - - This class extends the `VisionTransformerForMulticlassCls` class and adds support for semi-supervised learning. - It overrides the `_build_model` and `_customize_inputs` methods to incorporate the semi-supervised learning. - - Args: - VisionTransformerForMulticlassCls (class): The base class for VisionTransformer multiclass classification. - """ - - def _build_model(self, num_classes: int) -> nn.Module: - init_cfg = [ - {"std": 0.2, "layer": "Linear", "type": "TruncNormal"}, - {"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"}, - ] - vit_backbone = VisionTransformer(arch=self.arch, img_size=self.input_size) - return SemiSLClassifier( - backbone=vit_backbone, - neck=None, - head=OTXSemiSLVisionTransformerClsHead( - num_classes=num_classes, - in_channels=vit_backbone.embed_dim, - loss=nn.CrossEntropyLoss(reduction="none"), ), + loss=nn.CrossEntropyLoss(), init_cfg=init_cfg, ) @@ -412,8 +383,8 @@ def _build_model(self, num_classes: int) -> nn.Module: head=MultiLabelLinearClsHead( num_classes=num_classes, in_channels=vit_backbone.embed_dim, - loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), ), + loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), init_cfg=init_cfg, ) @@ -497,16 +468,15 @@ def _build_model(self, head_config: dict) -> nn.Module: {"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"}, ] vit_backbone = VisionTransformer(arch=self.arch, img_size=self.input_size, lora=self.lora) - return ImageClassifier( + return HLabelClassifier( backbone=vit_backbone, neck=None, head=HierarchicalCBAMClsHead( in_channels=vit_backbone.embed_dim, - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), step_size=1, **head_config, ), - optimize_gap=False, + multiclass_loss=nn.CrossEntropyLoss(), + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), init_cfg=init_cfg, ) diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index bed73e975e4..0a992e6cde3 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -128,10 +128,10 @@ def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) -> loss = super().training_step(batch, batch_idx) # Collect metrics related to Semi-SL Training. if self.train_type == OTXTrainType.SEMI_SUPERVISED: - if hasattr(self.model.head, "unlabeled_coef"): + if hasattr(self.model, "unlabeled_coef"): self.log( "train/unlabeled_coef", - self.model.head.unlabeled_coef, + self.model.unlabeled_coef, on_step=True, on_epoch=False, prog_bar=True, diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 0d9f3750889..4954105002c 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -45,20 +45,6 @@ from otx.core.metrics import MetricCallable -LITMODULE_PER_TASK = { - OTXTaskType.MULTI_CLASS_CLS: "otx.core.model.module.classification.OTXMulticlassClsLitModule", - OTXTaskType.MULTI_LABEL_CLS: "otx.core.model.module.classification.OTXMultilabelClsLitModule", - OTXTaskType.H_LABEL_CLS: "otx.core.model.module.classification.OTXHlabelClsLitModule", - OTXTaskType.DETECTION: "otx.core.model.module.detection.OTXDetectionLitModule", - OTXTaskType.ROTATED_DETECTION: "otx.core.model.module.rotated_detection.OTXRotatedDetLitModule", - OTXTaskType.INSTANCE_SEGMENTATION: "otx.core.model.module.instance_segmentation.OTXInstanceSegLitModule", - OTXTaskType.SEMANTIC_SEGMENTATION: "otx.core.model.module.segmentation.OTXSegmentationLitModule", - OTXTaskType.ACTION_CLASSIFICATION: "otx.core.model.module.action_classification.OTXActionClsLitModule", - OTXTaskType.VISUAL_PROMPTING: "otx.core.model.module.visual_prompting.OTXVisualPromptingLitModule", - OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: "otx.core.model.module.visual_prompting.OTXZeroShotVisualPromptingLitModule", # noqa: E501 -} - - @contextmanager def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallable | None) -> Iterator[OTXModel]: """Override `OTXModel.metric_callable` to change the evaluation metric. diff --git a/src/otx/recipe/classification/h_label_cls/tv_efficientnet_b3.yaml b/src/otx/recipe/classification/h_label_cls/tv_efficientnet_b3.yaml index 2e8c7ef0223..2078c98b43b 100644 --- a/src/otx/recipe/classification/h_label_cls/tv_efficientnet_b3.yaml +++ b/src/otx/recipe/classification/h_label_cls/tv_efficientnet_b3.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForHLabelCls init_args: backbone: efficientnet_b3 label_info: 1000 - task: H_LABEL_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/h_label_cls/tv_efficientnet_v2_l.yaml b/src/otx/recipe/classification/h_label_cls/tv_efficientnet_v2_l.yaml index 8f119c9db7c..0f2d7b60a6a 100644 --- a/src/otx/recipe/classification/h_label_cls/tv_efficientnet_v2_l.yaml +++ b/src/otx/recipe/classification/h_label_cls/tv_efficientnet_v2_l.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForHLabelCls init_args: backbone: efficientnet_v2_l label_info: 1000 - task: H_LABEL_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/h_label_cls/tv_mobilenet_v3_small.yaml b/src/otx/recipe/classification/h_label_cls/tv_mobilenet_v3_small.yaml index d5e52d5e69e..faab071ff5d 100644 --- a/src/otx/recipe/classification/h_label_cls/tv_mobilenet_v3_small.yaml +++ b/src/otx/recipe/classification/h_label_cls/tv_mobilenet_v3_small.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForHLabelCls init_args: backbone: mobilenet_v3_small label_info: 1000 - task: H_LABEL_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_b3_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_b3_semisl.yaml index a2b3c1d73d5..f847d078c2c 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_b3_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_b3_semisl.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMulticlassCls init_args: backbone: efficientnet_b3 label_info: 1000 - task: MULTI_CLASS_CLS train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_v2_l_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_v2_l_semisl.yaml index ffe7e25a99b..1fa2abcd93a 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_v2_l_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/tv_efficientnet_v2_l_semisl.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMulticlassCls init_args: backbone: efficientnet_v2_l label_info: 1000 - task: MULTI_CLASS_CLS train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/tv_mobilenet_v3_small_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/tv_mobilenet_v3_small_semisl.yaml index 824881433e4..5fd2567ef9f 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/tv_mobilenet_v3_small_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/tv_mobilenet_v3_small_semisl.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMulticlassCls init_args: backbone: mobilenet_v3_small label_info: 1000 - task: MULTI_CLASS_CLS train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_b3.yaml b/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_b3.yaml index 593a4100570..f06b3b36e32 100644 --- a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_b3.yaml +++ b/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_b3.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMulticlassCls init_args: backbone: efficientnet_b3 label_info: 1000 - task: MULTI_CLASS_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_v2_l.yaml b/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_v2_l.yaml index 47da6e209b7..c72714e9433 100644 --- a/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_v2_l.yaml +++ b/src/otx/recipe/classification/multi_class_cls/tv_efficientnet_v2_l.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMulticlassCls init_args: backbone: efficientnet_v2_l label_info: 1000 - task: MULTI_CLASS_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small.yaml b/src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small.yaml index a212f18f51c..4c6975c241a 100644 --- a/src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small.yaml +++ b/src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMulticlassCls init_args: backbone: mobilenet_v3_small label_info: 1000 - task: MULTI_CLASS_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_b3.yaml b/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_b3.yaml index eb212d61153..9579f8e5e57 100644 --- a/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_b3.yaml +++ b/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_b3.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMultilabelCls init_args: backbone: efficientnet_b3 label_info: 1000 - task: MULTI_LABEL_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_v2_l.yaml b/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_v2_l.yaml index 6e429e32a35..3003b26eb48 100644 --- a/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_v2_l.yaml +++ b/src/otx/recipe/classification/multi_label_cls/tv_efficientnet_v2_l.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMultilabelCls init_args: backbone: efficientnet_v2_l label_info: 1000 - task: MULTI_LABEL_CLS optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_label_cls/tv_mobilenet_v3_small.yaml b/src/otx/recipe/classification/multi_label_cls/tv_mobilenet_v3_small.yaml index 073f58d8b0a..492e835ef62 100644 --- a/src/otx/recipe/classification/multi_label_cls/tv_mobilenet_v3_small.yaml +++ b/src/otx/recipe/classification/multi_label_cls/tv_mobilenet_v3_small.yaml @@ -1,9 +1,8 @@ model: - class_path: otx.algo.classification.torchvision_model.OTXTVModel + class_path: otx.algo.classification.torchvision_model.TVModelForMultilabelCls init_args: backbone: mobilenet_v3_small label_info: 1000 - task: MULTI_LABEL_CLS optimizer: class_path: torch.optim.SGD diff --git a/tests/unit/algo/classification/classifier/test_base_classifier.py b/tests/unit/algo/classification/classifier/test_base_classifier.py new file mode 100644 index 00000000000..f27c5c40e27 --- /dev/null +++ b/tests/unit/algo/classification/classifier/test_base_classifier.py @@ -0,0 +1,63 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from otx.algo.classification.backbones import OTXEfficientNet +from otx.algo.classification.classifier import ImageClassifier +from otx.algo.classification.heads import LinearClsHead, MultiLabelLinearClsHead +from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore +from otx.algo.classification.necks.gap import GlobalAveragePooling +from torch import nn + + +class TestImageClassifier: + @pytest.fixture( + params=[ + (LinearClsHead, nn.CrossEntropyLoss, "fxt_multiclass_cls_batch_data_entity"), + (MultiLabelLinearClsHead, AsymmetricAngularLossWithIgnore, "fxt_multilabel_cls_batch_data_entity"), + ], + ids=["multiclass", "multilabel"], + ) + def fxt_model_and_inputs(self, request): + head_cls, loss_cls, input_fxt_name = request.param + backbone = OTXEfficientNet(version="b0") + neck = GlobalAveragePooling(dim=2) + head = head_cls(num_classes=3, in_channels=backbone.num_features) + loss = loss_cls() + fxt_input = request.getfixturevalue(input_fxt_name) + fxt_label = ( + torch.stack(fxt_input.labels) + if isinstance(head, MultiLabelLinearClsHead) + else torch.cat(fxt_input.labels, dim=0) + ) + return ImageClassifier( + backbone=backbone, + neck=neck, + head=head, + loss=loss, + ), fxt_input.stacked_images, fxt_label + + def test_forward(self, fxt_model_and_inputs): + model, images, labels = fxt_model_and_inputs + + output = model(images, labels, mode="tensor") + assert isinstance(output, torch.Tensor) + assert output.shape == (2, 3) + + output = model(images, labels, mode="loss") + assert isinstance(output, torch.Tensor) + + output = model(images, labels, mode="predict") + assert isinstance(output, torch.Tensor) + + output = model(images, labels, mode="explain") + assert isinstance(output, dict) + assert "logits" in output + assert "scores" in output + assert "preds" in output + assert "saliency_map" in output + assert "feature_vector" in output + + with pytest.raises(RuntimeError): + model(images, labels, mode="invalid_mode") diff --git a/tests/unit/algo/classification/classifier/test_semi_sl_classifier.py b/tests/unit/algo/classification/classifier/test_semi_sl_classifier.py new file mode 100644 index 00000000000..8a9a4e3cde9 --- /dev/null +++ b/tests/unit/algo/classification/classifier/test_semi_sl_classifier.py @@ -0,0 +1,45 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from otx.algo.classification.backbones import OTXEfficientNet +from otx.algo.classification.classifier import SemiSLClassifier +from otx.algo.classification.heads import SemiSLLinearClsHead +from otx.algo.classification.necks.gap import GlobalAveragePooling + + +class TestSemiSLClassifier: + @pytest.fixture() + def fxt_semi_sl_classifier(self): + backbone = OTXEfficientNet(version="b0") + neck = GlobalAveragePooling(dim=2) + head = SemiSLLinearClsHead( + num_classes=2, + in_channels=backbone.num_features, + use_dynamic_threshold=True, + min_threshold=0.5, + ) + loss = torch.nn.CrossEntropyLoss() + return SemiSLClassifier(backbone, neck, head, loss) + + @pytest.fixture() + def fxt_inputs(self): + return { + "labeled": torch.randn(16, 3, 224, 224), + "weak_transforms": torch.randn(16, 3, 224, 224), + "strong_transforms": torch.randn(16, 3, 224, 224), + } + + def test_extract_feat(self, fxt_semi_sl_classifier, fxt_inputs): + output = fxt_semi_sl_classifier.extract_feat(fxt_inputs) + assert isinstance(output, dict) + assert "labeled" in output + assert "unlabeled_weak" in output + assert "unlabeled_strong" in output + + def test_loss(self, fxt_semi_sl_classifier, fxt_inputs): + labels = torch.randint(0, 2, (16,)) + loss = fxt_semi_sl_classifier.loss(fxt_inputs, labels) + assert isinstance(loss, torch.Tensor) + assert loss.item() >= 0.0 diff --git a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py index a32f9bb14d4..3eb1f1464da 100644 --- a/tests/unit/algo/classification/heads/test_hlabel_cls_head.py +++ b/tests/unit/algo/classification/heads/test_hlabel_cls_head.py @@ -36,22 +36,6 @@ def fxt_data_sample() -> dict: } -@pytest.fixture() -def fxt_data_sample_with_ignored_labels() -> dict: - return { - "labels": torch.ones((18, 6), dtype=torch.long), - "imgs_info": [ - ImageInfo( - img_idx=i, - ori_shape=(24, 24, 3), - img_shape=(24, 24, 3), - ignored_labels=[3], - ) - for i in range(18) - ], - } - - class TestHierarchicalLinearClsHead: @pytest.fixture() def fxt_head_attrs(self, fxt_hlabel_cifar) -> dict[str, Any]: @@ -79,21 +63,6 @@ def fxt_hlabel_cbam_head(self, fxt_head_attrs) -> nn.Module: def fxt_hlabel_head(self, request) -> nn.Module: return request.getfixturevalue(request.param) - def test_loss( - self, - fxt_hlabel_head, - fxt_data_sample, - fxt_data_sample_with_ignored_labels, - ) -> None: - dummy_input = (torch.ones((18, 24)), torch.ones((18, 24))) - result_without_ignored_labels = fxt_hlabel_head.loss(dummy_input, **fxt_data_sample) - - result_with_ignored_labels = fxt_hlabel_head.loss( - dummy_input, - **fxt_data_sample_with_ignored_labels, - ) - assert result_with_ignored_labels <= result_without_ignored_labels - def test_predict( self, fxt_hlabel_head, diff --git a/tests/unit/algo/classification/heads/test_multilabel_cls_head.py b/tests/unit/algo/classification/heads/test_multilabel_cls_head.py index 68e9ce3c3ba..14d781cd288 100644 --- a/tests/unit/algo/classification/heads/test_multilabel_cls_head.py +++ b/tests/unit/algo/classification/heads/test_multilabel_cls_head.py @@ -7,7 +7,6 @@ import pytest import torch from otx.algo.classification.heads import MultiLabelLinearClsHead, MultiLabelNonLinearClsHead -from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore from otx.core.data.entity.base import ImageInfo from torch import nn @@ -17,7 +16,6 @@ def fxt_linear_head() -> None: return MultiLabelLinearClsHead( num_classes=3, in_channels=5, - loss=AsymmetricAngularLossWithIgnore(), ) @@ -28,7 +26,6 @@ def fxt_non_linear_head() -> None: in_channels=5, hid_channels=10, activation_callable=nn.PReLU(), - loss=AsymmetricAngularLossWithIgnore(), ) @@ -71,21 +68,6 @@ class TestMultiLabelClsHead: def fxt_multilabel_head(self, request) -> nn.Module: return request.getfixturevalue(request.param) - def test_loss( - self, - fxt_multilabel_head, - fxt_data_sample, - fxt_data_sample_with_ignore_labels, - ) -> None: - dummy_input = (torch.ones((2, 5)),) - result_without_ignored_labels = fxt_multilabel_head.loss(dummy_input, **fxt_data_sample) - - result_with_ignored_labels = fxt_multilabel_head.loss( - dummy_input, - **fxt_data_sample_with_ignore_labels, - ) - assert result_with_ignored_labels <= result_without_ignored_labels - def test_predict( self, fxt_multilabel_head, diff --git a/tests/unit/algo/classification/heads/test_semi_sl_head.py b/tests/unit/algo/classification/heads/test_semi_sl_head.py index 0c36a67b09a..04f272fa020 100644 --- a/tests/unit/algo/classification/heads/test_semi_sl_head.py +++ b/tests/unit/algo/classification/heads/test_semi_sl_head.py @@ -3,40 +3,35 @@ import pytest import torch -from otx.algo.classification.heads import OTXSemiSLLinearClsHead -from torch import nn +from otx.algo.classification.heads import SemiSLLinearClsHead class TestSemiSLClsHead: @pytest.fixture() def fxt_semi_sl_head(self): """Semi-SL for Classification Head Settings.""" - return OTXSemiSLLinearClsHead( + return SemiSLLinearClsHead( num_classes=10, in_channels=10, - loss=nn.CrossEntropyLoss(reduction="none"), ) def test_build_type_error(self): """Verifies that SemiSLClsHead parameters check with TypeError.""" with pytest.raises(TypeError): - OTXSemiSLLinearClsHead( + SemiSLLinearClsHead( num_classes=[1], in_channels=10, - loss=nn.CrossEntropyLoss(reduction="none"), ) with pytest.raises(TypeError): - OTXSemiSLLinearClsHead( + SemiSLLinearClsHead( num_classes=10, in_channels=[1], - loss=nn.CrossEntropyLoss(reduction="none"), ) def test_head_initialize(self, fxt_semi_sl_head): """Verifies that SemiSLClsHead parameters check with ValueError.""" assert fxt_semi_sl_head.num_classes == 10 - assert fxt_semi_sl_head.unlabeled_coef == 1.0 assert fxt_semi_sl_head.use_dynamic_threshold assert fxt_semi_sl_head.min_threshold == 0.5 assert fxt_semi_sl_head.num_pseudo_label == 0 @@ -57,16 +52,6 @@ def fxt_head_inputs(self): }, } - def test_loss(self, fxt_semi_sl_head, fxt_head_inputs): - """Verifies that SemiSLClsHead forward function works.""" - loss = fxt_semi_sl_head.loss(**fxt_head_inputs) - # Check that the loss is always non-negative - assert loss >= 0 - - # Check that the loss is proportional to the size of the input - size = sum(v.numel() for v in fxt_head_inputs["feats"].values()) - assert loss <= size - def test_classwise_acc(self, fxt_semi_sl_head, fxt_head_inputs): """Verifies that SemiSLClsHead classwise_acc function works.""" unlabeled_batch_size = fxt_head_inputs["feats"]["unlabeled_weak"].shape[0] diff --git a/tests/unit/algo/classification/test_torchvision_model.py b/tests/unit/algo/classification/test_torchvision_model.py index 7c8dec308bf..14270fbf594 100644 --- a/tests/unit/algo/classification/test_torchvision_model.py +++ b/tests/unit/algo/classification/test_torchvision_model.py @@ -3,17 +3,22 @@ import pytest import torch -from otx.algo.classification.heads import OTXSemiSLLinearClsHead -from otx.algo.classification.torchvision_model import OTXTVModel, TVClassificationModel +from otx.algo.classification.classifier import ImageClassifier +from otx.algo.classification.heads import SemiSLLinearClsHead +from otx.algo.classification.torchvision_model import ( + TVModelForHLabelCls, + TVModelForMulticlassCls, + TVModelForMultilabelCls, +) from otx.core.data.entity.base import OTXBatchLossEntity, OTXBatchPredEntity from otx.core.data.entity.classification import MulticlassClsBatchPredEntity from otx.core.types.export import TaskLevelExportParameters -from otx.core.types.task import OTXTaskType, OTXTrainType +from otx.core.types.task import OTXTaskType @pytest.fixture() def fxt_tv_model(): - return OTXTVModel(backbone="mobilenet_v3_small", label_info=10) + return TVModelForMulticlassCls(backbone="mobilenet_v3_small", label_info=10) @pytest.fixture() @@ -25,33 +30,33 @@ def fxt_tv_model_and_data_entity( fxt_hlabel_multilabel_info, ): if request.param == OTXTaskType.MULTI_CLASS_CLS: - return OTXTVModel(backbone="mobilenet_v3_small", label_info=10), fxt_multiclass_cls_batch_data_entity + return TVModelForMulticlassCls( + backbone="mobilenet_v3_small", + label_info=10, + ), fxt_multiclass_cls_batch_data_entity if request.param == OTXTaskType.MULTI_LABEL_CLS: - return OTXTVModel( + return TVModelForMultilabelCls( backbone="mobilenet_v3_small", label_info=10, - task=OTXTaskType.MULTI_LABEL_CLS, ), fxt_multilabel_cls_batch_data_entity if request.param == OTXTaskType.H_LABEL_CLS: - return OTXTVModel( + return TVModelForHLabelCls( backbone="mobilenet_v3_small", label_info=fxt_hlabel_multilabel_info, - task=OTXTaskType.H_LABEL_CLS, ), fxt_hlabel_cls_batch_data_entity return None class TestOTXTVModel: def test_create_model(self, fxt_tv_model): - assert isinstance(fxt_tv_model.model, TVClassificationModel) + assert isinstance(fxt_tv_model.model, ImageClassifier) - semi_sl_model = OTXTVModel( + semi_sl_model = TVModelForMulticlassCls( backbone="mobilenet_v3_small", label_info=10, - train_type=OTXTrainType.SEMI_SUPERVISED, - task=OTXTaskType.MULTI_CLASS_CLS, + train_type="SEMI_SUPERVISED", ) - assert isinstance(semi_sl_model.model.head, OTXSemiSLLinearClsHead) + assert isinstance(semi_sl_model.model.head, SemiSLLinearClsHead) @pytest.mark.parametrize( "fxt_tv_model_and_data_entity", @@ -88,7 +93,7 @@ def test_export_parameters(self, fxt_tv_model): assert export_parameters.task_type == "classification" @pytest.mark.parametrize("explain_mode", [True, False]) - def test_predict_step(self, fxt_tv_model: OTXTVModel, fxt_multiclass_cls_batch_data_entity, explain_mode): + def test_predict_step(self, fxt_tv_model, fxt_multiclass_cls_batch_data_entity, explain_mode): fxt_tv_model.eval() fxt_tv_model.explain_mode = explain_mode outputs = fxt_tv_model.predict_step(batch=fxt_multiclass_cls_batch_data_entity, batch_idx=0) @@ -99,8 +104,3 @@ def test_predict_step(self, fxt_tv_model: OTXTVModel, fxt_multiclass_cls_batch_d assert outputs.feature_vector.ndim == 2 assert outputs.saliency_map.ndim == 4 assert outputs.saliency_map.shape[-2:] != torch.Size([1, 1]) - - def test_freeze_backbone(self): - freezed_model = OTXTVModel(backbone="resnet50", label_info=10, freeze_backbone=True) - for param in freezed_model.model.backbone.parameters(): - assert not param.requires_grad diff --git a/tests/unit/algo/classification/utils/test_ignored_labels.py b/tests/unit/algo/classification/utils/test_ignored_labels.py new file mode 100644 index 00000000000..c41b3413358 --- /dev/null +++ b/tests/unit/algo/classification/utils/test_ignored_labels.py @@ -0,0 +1,26 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from otx.algo.classification.utils.ignored_labels import get_valid_label_mask +from otx.core.data.entity.base import ImageInfo + + +def test_get_valid_label_mask(): + img_metas = [ + ImageInfo(ignored_labels=[2, 4], img_idx=0, img_shape=(32, 32), ori_shape=(32, 32)), + ImageInfo(ignored_labels=[1, 3, 5], img_idx=1, img_shape=(32, 32), ori_shape=(32, 32)), + ImageInfo(ignored_labels=[0], img_idx=2, img_shape=(32, 32), ori_shape=(32, 32)), + ] + num_classes = 6 + + expected_mask = torch.tensor( + [ + [1.0, 1.0, 0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + ) + + mask = get_valid_label_mask(img_metas, num_classes) + assert torch.equal(mask, expected_mask) diff --git a/tests/unit/engine/test_engine.py b/tests/unit/engine/test_engine.py index e1695e2d9a2..db52a3f871e 100644 --- a/tests/unit/engine/test_engine.py +++ b/tests/unit/engine/test_engine.py @@ -6,7 +6,7 @@ import pytest from otx.algo.classification.efficientnet import EfficientNetForMulticlassCls -from otx.algo.classification.torchvision_model import OTXTVModel +from otx.algo.classification.torchvision_model import TVModelForMulticlassCls from otx.core.model.base import OTXModel, OVModel from otx.core.types.export import OTXExportFormatType from otx.core.types.label import NullLabelInfo @@ -73,7 +73,7 @@ def test_model_init(self, tmp_path, mock_datamodule): assert engine._model.label_info.num_classes == 4321 def test_model_setter(self, fxt_engine, mocker) -> None: - assert isinstance(fxt_engine.model, OTXTVModel) + assert isinstance(fxt_engine.model, TVModelForMulticlassCls) fxt_engine.model = "efficientnet_b0" assert isinstance(fxt_engine.model, EfficientNetForMulticlassCls)