Skip to content

Commit

Permalink
Refactor Classifier & Head (#3830)
Browse files Browse the repository at this point in the history
* Refactor classification

* Add Loss scale

* fix unit-tests

* fix unit-tests

* Fix Loss reduction settings

* Revert unnecessary change

* Add unit-tests

* Rename OTX*Head to *Head

* Fix unlabeled_coef

* Fix tv model loss scale

* Refactor torchvision models

* Fix in_features found way

* Refactor H-label side

* Remove regacy code

* Remove _exporter

* Fix wrong config

* Update docstring & Add unit-test

* Update docstring 2

* Remove hard Type assign
  • Loading branch information
harimkang authored Aug 17, 2024
1 parent fa530a7 commit 43f1fc9
Show file tree
Hide file tree
Showing 42 changed files with 862 additions and 1,172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ There are 58 different models available from torchvision, see `TVModelType <http

.. code-block:: shell
(otx) ...$ otx train --model otx.algo.classification.torchvision_model.OTXTVModel --backbone {backbone_name} ...
(otx) ...$ otx train --model otx.algo.classification.torchvision_model.TVModelForMulticlassCls --backbone {backbone_name} ...
************************
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/callbacks/unlabeled_loss_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def on_train_batch_start(
if trainer.model is None:
msg = "Model is not found in the trainer."
raise ValueError(msg)
trainer.model.model.head.unlabeled_coef = self.unlabeled_coef
trainer.model.model.unlabeled_coef = self.unlabeled_coef
self.current_step += 1
3 changes: 2 additions & 1 deletion src/otx/algo/classification/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .efficientnet import OTXEfficientNet
from .mobilenet_v3 import OTXMobileNetV3
from .timm import TimmBackbone
from .torchvision import TorchvisionBackbone
from .vision_transformer import VisionTransformer

__all__ = ["OTXEfficientNet", "TimmBackbone", "OTXMobileNetV3", "VisionTransformer"]
__all__ = ["OTXEfficientNet", "TimmBackbone", "OTXMobileNetV3", "VisionTransformer", "TorchvisionBackbone"]
108 changes: 108 additions & 0 deletions src/otx/algo/classification/backbones/torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Torchvison model's Backbone Class."""

from typing import Literal

import torch
from torch import nn
from torchvision.models import get_model, get_model_weights

TVModelType = Literal[
"alexnet",
"convnext_base",
"convnext_large",
"convnext_small",
"convnext_tiny",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_l",
"efficientnet_v2_m",
"efficientnet_v2_s",
"googlenet",
"mobilenet_v3_large",
"mobilenet_v3_small",
"regnet_x_16gf",
"regnet_x_1_6gf",
"regnet_x_32gf",
"regnet_x_3_2gf",
"regnet_x_400mf",
"regnet_x_800mf",
"regnet_x_8gf",
"regnet_y_128gf",
"regnet_y_16gf",
"regnet_y_1_6gf",
"regnet_y_32gf",
"regnet_y_3_2gf",
"regnet_y_400mf",
"regnet_y_800mf",
"regnet_y_8gf",
"resnet101",
"resnet152",
"resnet18",
"resnet34",
"resnet50",
"resnext101_32x8d",
"resnext101_64x4d",
"resnext50_32x4d",
"swin_b",
"swin_s",
"swin_t",
"swin_v2_b",
"swin_v2_s",
"swin_v2_t",
"vgg11",
"vgg11_bn",
"vgg13",
"vgg13_bn",
"vgg16",
"vgg16_bn",
"vgg19",
"vgg19_bn",
"wide_resnet101_2",
"wide_resnet50_2",
]


def get_in_features(sequential: nn.Sequential) -> int:
"""Get the in_features value from the first layer of an nn.Sequential object."""
for layer in sequential.children():
if isinstance(layer, nn.Linear):
return layer.in_features
if isinstance(layer, nn.Conv2d):
return layer.in_channels
# Add more conditions if needed for other layer types
msg = "No suitable layer found to extract in_features"
raise ValueError(msg)


class TorchvisionBackbone(nn.Module):
"""TorchvisionBackbone is a class that represents a backbone model from the torchvision library."""

def __init__(
self,
backbone: TVModelType,
pretrained: bool = False,
**kwargs,
):
super().__init__(**kwargs)

tv_model_cfg = {"name": backbone}
if pretrained:
tv_model_cfg["weights"] = get_model_weights(backbone)
net = get_model(**tv_model_cfg)
self.features = net.features

last_layer = list(net.children())[-1]
self.in_features = get_in_features(last_layer)

def forward(self, *args) -> torch.Tensor:
"""Forward pass of the model."""
return self.features(*args)
3 changes: 2 additions & 1 deletion src/otx/algo/classification/classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Head modules for OTX custom model."""

from .base_classifier import ImageClassifier
from .h_label_classifier import HLabelClassifier
from .semi_sl_classifier import SemiSLClassifier

__all__ = ["ImageClassifier", "SemiSLClassifier"]
__all__ = ["ImageClassifier", "SemiSLClassifier", "HLabelClassifier"]
78 changes: 28 additions & 50 deletions src/otx/algo/classification/classifier/base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from __future__ import annotations

import copy
import inspect
from typing import TYPE_CHECKING

import torch

from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils.ignored_labels import get_valid_label_mask
from otx.algo.explain.explain_algo import ReciproCAM
from otx.algo.modules.base_module import BaseModule

Expand All @@ -27,30 +30,13 @@ class ImageClassifier(BaseModule):
"""Image classifiers for supervised classification task.
Args:
backbone (dict): The backbone module. See
:mod:`mmpretrain.models.backbones`.
neck (dict, optional): The neck module to process features from
backbone. See :mod:`mmpretrain.models.necks`. Defaults to None.
head (dict, optional): The head module to do prediction and calculate
loss from processed features. See :mod:`mmpretrain.models.heads`.
backbone (nn.Module): The backbone module.
neck (nn.Module | None): The neck module to process features from backbone.
head (nn.Module): The head module to do prediction and calculate loss from processed features.
Notice that if the head is not set, almost all methods cannot be
used except :meth:`extract_feat`. Defaults to None.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
train_cfg (dict, optional): The training setting. The acceptable
fields are:
- augments (List[dict]): The batch augmentation methods to use.
More details can be found in
:mod:`mmpretrain.model.utils.augment`.
- probs (List[float], optional): The probability of every batch
augmentation methods. If None, choose evenly. Defaults to None.
Defaults to None.
data_preprocessor (dict, optional): The config for preprocessing input
data. If None or no specified type, it will use
"ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for
more details. Defaults to None.
loss (nn.Module): The loss module to calculate the loss.
loss_scale (float, optional): The scaling factor for the loss. Defaults to 1.0.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
Expand All @@ -60,16 +46,10 @@ def __init__(
backbone: nn.Module,
neck: nn.Module | None,
head: nn.Module,
pretrained: str | None = None,
optimize_gap: bool = True,
mean: list[float] | None = None,
std: list[float] | None = None,
to_rgb: bool = False,
loss: nn.Module,
loss_scale: float = 1.0,
init_cfg: dict | list[dict] | None = None,
):
if pretrained is not None:
init_cfg = {"type": "Pretrained", "checkpoint": pretrained}

super().__init__(init_cfg=init_cfg)

self._is_init = False
Expand All @@ -78,11 +58,14 @@ def __init__(
self.backbone = backbone
self.neck = neck
self.head = head
self.loss_module = loss
self.loss_scale = loss_scale
self.is_ignored_label_loss = "valid_label_mask" in inspect.getfullargspec(self.loss_module.forward).args

self.explainer = ReciproCAM(
self._head_forward_fn,
num_classes=head.num_classes,
optimize_gap=optimize_gap,
optimize_gap=isinstance(neck, GlobalAveragePooling),
)

def forward(
Expand All @@ -103,8 +86,7 @@ def forward(
torch.Tensor: The output logits or loss, depending on the training mode.
"""
if mode == "tensor":
feats = self.extract_feat(images)
return self.head(feats)
return self.extract_feat(images, stage="head")
if mode == "loss":
return self.loss(images, labels, **kwargs)
if mode == "predict":
Expand All @@ -115,7 +97,7 @@ def forward(
msg = f'Invalid mode "{mode}".'
raise RuntimeError(msg)

def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> tuple | torch.Tensor:
def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> torch.Tensor:
"""Extract features from the input tensor with shape (N, C, ...).
Args:
Expand All @@ -133,10 +115,8 @@ def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> tuple | tor
Defaults to "neck".
Returns:
tuple | Tensor: The output of specified stage.
The output depends on detailed implementation. In general, the
output of backbone and neck is a tuple and the output of
pre_logits is a tensor.
torch.Tensor: The output of specified stage.
In general, the output of pre_logits is a tensor.
"""
x = self.backbone(inputs)

Expand All @@ -151,7 +131,7 @@ def extract_feat(self, inputs: torch.Tensor, stage: str = "neck") -> tuple | tor

return self.head(x)

def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> dict:
def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
"""Calculate losses from a batch of inputs and data samples.
Args:
Expand All @@ -161,12 +141,15 @@ def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> dict:
every samples.
Returns:
dict[str, Tensor]: a dictionary of loss components
torch.Tensor: loss components
"""
feats = self.extract_feat(inputs)
return self.head.loss(feats, labels, **kwargs)
cls_score = self.extract_feat(inputs, stage="head") * self.loss_scale
imgs_info = kwargs.pop("imgs_info", None)
if imgs_info is not None and self.is_ignored_label_loss:
kwargs["valid_label_mask"] = get_valid_label_mask(imgs_info, self.head.num_classes).to(cls_score.device)
return self.loss_module(cls_score, labels, **kwargs) / self.loss_scale

def predict(self, inputs: torch.Tensor, **kwargs) -> list[torch.Tensor]:
def predict(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
"""Predict results from a batch of inputs.
Args:
Expand Down Expand Up @@ -206,13 +189,8 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis

logits = self.head(x)
pred_results = self.head._get_predictions(logits) # noqa: SLF001
# H-Label Classification Case
if isinstance(pred_results, dict):
scores = pred_results["scores"]
preds = pred_results["labels"]
else:
scores = pred_results.unbind(0)
preds = logits.argmax(-1, keepdim=True).unbind(0)
scores = pred_results.unbind(0)
preds = logits.argmax(-1, keepdim=True).unbind(0)

outputs = {
"logits": logits,
Expand Down
Loading

0 comments on commit 43f1fc9

Please sign in to comment.