diff --git a/src/otx/algo/__init__.py b/src/otx/algo/__init__.py index 29312f92f25..f17f0f80b47 100644 --- a/src/otx/algo/__init__.py +++ b/src/otx/algo/__init__.py @@ -8,6 +8,7 @@ action_classification, classification, detection, + instance_segmentation, plugins, segmentation, strategies, @@ -23,4 +24,5 @@ "strategies", "accelerators", "plugins", + "instance_segmentation", ] diff --git a/src/otx/algo/detection/deployment.py b/src/otx/algo/detection/deployment.py new file mode 100644 index 00000000000..f1d8cac9701 --- /dev/null +++ b/src/otx/algo/detection/deployment.py @@ -0,0 +1,19 @@ +"""Functions for mmdeploy adapters.""" +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import importlib + + +def is_mmdeploy_enabled() -> bool: + """Checks if the 'mmdeploy' Python module is installed and available for use. + + Returns: + bool: True if 'mmdeploy' is installed, False otherwise. + + Example: + >>> is_mmdeploy_enabled() + True + """ + return importlib.util.find_spec("mmdeploy") is not None diff --git a/src/otx/algo/detection/heads/anchor_head.py b/src/otx/algo/detection/heads/anchor_head.py index 1f56adb373c..28b7a6a0ef8 100644 --- a/src/otx/algo/detection/heads/anchor_head.py +++ b/src/otx/algo/detection/heads/anchor_head.py @@ -49,7 +49,7 @@ class AnchorHead(BaseDenseHead): def __init__( self, num_classes: int, - in_channels: tuple[int, ...], + in_channels: tuple[int, ...] | int, anchor_generator: dict, bbox_coder: dict, loss_cls: dict, diff --git a/src/otx/algo/detection/heads/base_head.py b/src/otx/algo/detection/heads/base_head.py index e2bad2827e5..6569a630d7d 100644 --- a/src/otx/algo/detection/heads/base_head.py +++ b/src/otx/algo/detection/heads/base_head.py @@ -10,12 +10,11 @@ from typing import TYPE_CHECKING import torch -from mmcv.ops import batched_nms -from mmengine.model import constant_init +from mmengine.model import BaseModule, constant_init from mmengine.structures import InstanceData -from torch import Tensor, nn +from torch import Tensor -from otx.algo.detection.ops.nms import multiclass_nms +from otx.algo.detection.ops.nms import batched_nms, multiclass_nms from otx.algo.detection.utils.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances if TYPE_CHECKING: @@ -24,7 +23,7 @@ # This class and its supporting functions below lightly adapted from the mmdet BaseDenseHead available at: # https://github.com/open-mmlab/mmdetection/blob/fe3f809a0a514189baf889aa358c498d51ee36cd/mmdet/models/dense_heads/base_dense_head.py -class BaseDenseHead(nn.Module): +class BaseDenseHead(BaseModule): """Base class for DenseHeads. 1. The ``init_weights`` method is used to initialize densehead's diff --git a/src/otx/algo/detection/heads/class_incremental_mixin.py b/src/otx/algo/detection/heads/class_incremental_mixin.py index cbe0c9a82d7..aca12b689b7 100644 --- a/src/otx/algo/detection/heads/class_incremental_mixin.py +++ b/src/otx/algo/detection/heads/class_incremental_mixin.py @@ -8,12 +8,13 @@ from typing import TYPE_CHECKING import torch -from mmdet.models.utils.misc import images_to_levels, multi_apply from mmdet.registry import MODELS from torch import Tensor +from otx.algo.detection.utils.utils import images_to_levels, multi_apply + if TYPE_CHECKING: - from mmdet.utils import InstanceList, OptInstanceList + from mmengine.structures import InstanceData @MODELS.register_module() @@ -24,9 +25,9 @@ def get_atss_targets( self, anchor_list: list, valid_flag_list: list[list[Tensor]], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, unmap_outputs: bool = True, ) -> tuple: """Get targets for ATSS head. diff --git a/src/otx/algo/detection/heads/custom_anchor_generator.py b/src/otx/algo/detection/heads/custom_anchor_generator.py index bff9aa7d59b..c8f9d93de51 100644 --- a/src/otx/algo/detection/heads/custom_anchor_generator.py +++ b/src/otx/algo/detection/heads/custom_anchor_generator.py @@ -9,12 +9,13 @@ import numpy as np import torch -from mmdet.registry import TASK_UTILS +from mmengine.registry import TASK_UTILS from torch.nn.modules.utils import _pair # This class and its supporting functions below lightly adapted from the mmdet AnchorGenerator available at: # https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/prior_generators/anchor_generator.py +@TASK_UTILS.register_module() class AnchorGenerator: """Standard anchor generator for 2D anchor-based detectors. diff --git a/src/otx/algo/detection/heads/custom_ssd_head.py b/src/otx/algo/detection/heads/custom_ssd_head.py index ef167ab3890..9fc01723743 100644 --- a/src/otx/algo/detection/heads/custom_ssd_head.py +++ b/src/otx/algo/detection/heads/custom_ssd_head.py @@ -61,7 +61,7 @@ def __init__( init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict], train_cfg: ConfigDict | dict, num_classes: int = 80, - in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256), + in_channels: tuple[int, ...] | int = (512, 1024, 512, 256, 256, 256), stacked_convs: int = 0, feat_channels: int = 256, use_depthwise: bool = False, @@ -274,6 +274,9 @@ def _init_layers(self) -> None: self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() + if isinstance(self.in_channels, int): + self.in_channels = (self.in_channels,) + for in_channel, num_base_priors in zip(self.in_channels, self.num_base_priors): if self.use_depthwise: activation_layer = nn.ReLU(inplace=True) diff --git a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py index e7571670a60..69c1fca3b92 100644 --- a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py +++ b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py @@ -6,11 +6,15 @@ import numpy as np import torch +from mmengine.registry import TASK_UTILS from torch import Tensor +from otx.algo.detection.deployment import is_mmdeploy_enabled + # This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at: # https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py +@TASK_UTILS.register_module() class DeltaXYWHBBoxCoder: """Delta XYWH BBox coder. @@ -236,21 +240,6 @@ def delta2bbox( References: .. [1] https://arxiv.org/abs/1311.2524 - - Example: - >>> rois = torch.Tensor([[ 0., 0., 1., 1.], - >>> [ 0., 0., 1., 1.], - >>> [ 0., 0., 1., 1.], - >>> [ 5., 5., 5., 5.]]) - >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], - >>> [ 1., 1., 1., 1.], - >>> [ 0., 0., 2., -1.], - >>> [ 0.7, -1.9, -0.5, 0.3]]) - >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) - tensor([[0.0000, 0.0000, 1.0000, 1.0000], - [0.1409, 0.1409, 2.8591, 2.8591], - [0.0000, 0.3161, 4.1945, 0.6839], - [5.0000, 5.0000, 5.0000, 5.0000]]) """ num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4 if num_bboxes == 0: @@ -426,3 +415,57 @@ def clip_bboxes( x2 = torch.clamp(x2, 0, max_shape[1]) y2 = torch.clamp(y2, 0, max_shape[0]) return x1, y1, x2, y2 + + +if is_mmdeploy_enabled(): + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.DeltaXYWHBBoxCoder.decode", + backend="default", + ) + def deltaxywhbboxcoder__decode( + self: DeltaXYWHBBoxCoder, + bboxes: Tensor, + pred_bboxes: Tensor, + max_shape: Tensor | None = None, + wh_ratio_clip: float = 16 / 1000, + ) -> Tensor: + """Rewrite `decode` of `DeltaXYWHBBoxCoder` for default backend. + + Rewrite this func to call `delta2bbox` directly. + + Args: + bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) + pred_bboxes (Tensor): Encoded offsets with respect to each roi. + Has shape (B, N, num_classes * 4) or (B, N, 4) or + (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H + when rois is a grid of anchors.Offset encoding follows [1]_. + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. + wh_ratio_clip (float, optional): The allowed ratio between + width and height. + + Returns: + torch.Tensor: Decoded boxes. + """ + if pred_bboxes.size(0) != bboxes.size(0): + msg = "The batch size of pred_bboxes and bboxes should be equal." + raise ValueError(msg) + if pred_bboxes.ndim == 3 and pred_bboxes.size(1) != bboxes.size(1): + msg = "The number of bboxes should be equal." + raise ValueError(msg) + return delta2bbox_export( + bboxes, + pred_bboxes, + self.means, + self.stds, + max_shape, + wh_ratio_clip, + self.clip_border, + self.add_ctr_clamp, + self.ctr_clamp, + ) diff --git a/src/otx/algo/detection/heads/iou2d_calculator.py b/src/otx/algo/detection/heads/iou2d_calculator.py index bad8a5ea094..214492b38eb 100644 --- a/src/otx/algo/detection/heads/iou2d_calculator.py +++ b/src/otx/algo/detection/heads/iou2d_calculator.py @@ -5,12 +5,14 @@ from __future__ import annotations import torch +from mmengine.registry import TASK_UTILS from otx.algo.detection.utils.bbox_overlaps import bbox_overlaps # This class and its supporting functions below lightly adapted from the mmdet BboxOverlaps2D available at: # https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/assigners/iou2d_calculator.py +@TASK_UTILS.register_module() class BboxOverlaps2D: """2D Overlaps (e.g. IoUs, GIoUs) Calculator.""" diff --git a/src/otx/algo/detection/heads/max_iou_assigner.py b/src/otx/algo/detection/heads/max_iou_assigner.py index f95f44585a1..d4e534c395c 100644 --- a/src/otx/algo/detection/heads/max_iou_assigner.py +++ b/src/otx/algo/detection/heads/max_iou_assigner.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Callable import torch +from mmengine.registry import TASK_UTILS from torch import Tensor from otx.algo.detection.heads.iou2d_calculator import BboxOverlaps2D @@ -19,6 +20,7 @@ # This class and its supporting functions below lightly adapted from the mmdet MaxIoUAssigner available at: # https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/assigners/max_iou_assigner.py +@TASK_UTILS.register_module() class MaxIoUAssigner: """Assign a corresponding gt bbox or background to each bbox. diff --git a/src/otx/algo/detection/losses/__init__.py b/src/otx/algo/detection/losses/__init__.py index e5b9e173df0..e6bd6456246 100644 --- a/src/otx/algo/detection/losses/__init__.py +++ b/src/otx/algo/detection/losses/__init__.py @@ -2,6 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 # """Custom OTX Losses for Object Detection.""" +from .accuracy import accuracy +from .cross_entropy_loss import CrossEntropyLoss +from .cross_focal_loss import CrossSigmoidFocalLoss +from .smooth_l1_loss import L1Loss - -__all__ = ["CrossSigmoidFocalLoss, OrdinaryFocalLoss"] +__all__ = [ + "CrossEntropyLoss", + "CrossSigmoidFocalLoss", + "accuracy", + "L1Loss", +] diff --git a/src/otx/algo/detection/losses/accuracy.py b/src/otx/algo/detection/losses/accuracy.py new file mode 100644 index 00000000000..1b991c85e5f --- /dev/null +++ b/src/otx/algo/detection/losses/accuracy.py @@ -0,0 +1,73 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet Accuracy.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from torch import Tensor + + +def accuracy( + pred: Tensor, + target: Tensor, + topk: int | tuple[int] = 1, + thresh: float | None = None, +) -> list[Tensor] | Tensor: + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class) + target (torch.Tensor): The target of each prediction, shape (N, ) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + if not isinstance(topk, (int, tuple)): + msg = f"topk must be int or tuple of int, got {type(topk)}" + raise TypeError(msg) + if isinstance(topk, int): + topk = (topk,) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.0) for i in range(len(topk))] + return accu[0] if return_single else accu + if pred.ndim != 2 or target.ndim != 1: + msg = "Input tensors must have 2 dims for pred and 1 dim for target" + raise ValueError(msg) + if pred.size(0) != target.size(0): + msg = "Input tensors must have the same size along the 0th dim" + raise ValueError(msg) + if maxk > pred.size(1): + msg = f"maxk {maxk} exceeds pred dimension {pred.size(1)}" + raise ValueError(msg) + pred_value, pred_label = pred.topk(maxk, dim=1) + pred_label = pred_label.t() # transpose to shape (maxk, N) + correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / pred.size(0))) + return res[0] if return_single else res diff --git a/src/otx/algo/detection/losses/cross_entropy_loss.py b/src/otx/algo/detection/losses/cross_entropy_loss.py index 81c3be1a1b1..76a6757d57d 100644 --- a/src/otx/algo/detection/losses/cross_entropy_loss.py +++ b/src/otx/algo/detection/losses/cross_entropy_loss.py @@ -5,6 +5,7 @@ from __future__ import annotations import torch +from mmengine.registry import MODELS from torch import nn from otx.algo.detection.losses.weighted_loss import weight_reduce_loss @@ -181,6 +182,7 @@ def mask_cross_entropy( )[None] +@MODELS.register_module() class CrossEntropyLoss(nn.Module): """Base Cross Entropy Loss implementation from mmdet.""" diff --git a/src/otx/algo/detection/losses/cross_focal_loss.py b/src/otx/algo/detection/losses/cross_focal_loss.py index abb0668109a..ca47788adcf 100644 --- a/src/otx/algo/detection/losses/cross_focal_loss.py +++ b/src/otx/algo/detection/losses/cross_focal_loss.py @@ -6,12 +6,13 @@ from __future__ import annotations import torch -import torch.nn.functional as F # noqa: N812 -from mmdet.models.losses.focal_loss import py_sigmoid_focal_loss, sigmoid_focal_loss -from mmdet.registry import MODELS +import torch.nn.functional +from mmengine.registry import MODELS from torch import Tensor, nn from torch.cuda.amp import custom_fwd +from otx.algo.detection.losses.focal_loss import py_sigmoid_focal_loss, sigmoid_focal_loss + def cross_sigmoid_focal_loss( inputs: Tensor, @@ -39,7 +40,7 @@ def cross_sigmoid_focal_loss( calculate_loss_func = sigmoid_focal_loss else: inputs_size = inputs.size(1) - targets = F.one_hot(targets, num_classes=inputs_size + 1) + targets = torch.nn.functional.one_hot(targets, num_classes=inputs_size + 1) targets = targets[:, :inputs_size] calculate_loss_func = py_sigmoid_focal_loss @@ -109,41 +110,3 @@ def forward( avg_factor=avg_factor, valid_label_mask=valid_label_mask, ) - - -@MODELS.register_module() -class OrdinaryFocalLoss(nn.Module): - """Focal loss without balancing.""" - - def __init__(self, gamma: float = 1.5, **kwargs): - super().__init__() - if gamma < 0: - msg = f"{gamma} is not valid number for gamma." - raise ValueError(msg) - self.gamma = gamma - - def forward( - self, - inputs: Tensor, - targets: Tensor, - label_weights: Tensor | None = None, - avg_factor: float | None = None, - reduction: str = "mean", - **kwargs, - ) -> Tensor: - """Forward function for focal loss.""" - if targets.numel() == 0: - return 0.0 * inputs.sum() - - cross_entropy_value = F.cross_entropy(inputs, targets, reduction="none") - p = torch.exp(-cross_entropy_value) - loss = (1 - p) ** self.gamma * cross_entropy_value - if label_weights is not None: - loss = loss * label_weights - if avg_factor is None: - avg_factor = targets.shape[0] - if reduction == "sum": - return loss.sum() - if reduction == "mean": - return loss.sum() / avg_factor - return loss diff --git a/src/otx/algo/detection/losses/focal_loss.py b/src/otx/algo/detection/losses/focal_loss.py new file mode 100644 index 00000000000..075c62fdbb1 --- /dev/null +++ b/src/otx/algo/detection/losses/focal_loss.py @@ -0,0 +1,125 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ +"""MMDet Focal Loss.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional + +# TODO(Eugene): replace mmcv.sigmoid_focal_loss with torchvision +# https://github.com/openvinotoolkit/training_extensions/pull/3281 +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss + +from otx.algo.detection.losses.weighted_loss import weight_reduce_loss + +if TYPE_CHECKING: + from torch import Tensor + + +# This method is only for debugging +def py_sigmoid_focal_loss( + pred: Tensor, + target: Tensor, + weight: None | Tensor = None, + gamma: float = 2.0, + alpha: float = 0.25, + reduction: str = "mean", + avg_factor: int | None = None, +) -> torch.Tensor: + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + # Actually, pt here denotes (1 - pt) in the Focal Loss paper + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + # Thus it's pt.pow(gamma) rather than (1 - pt).pow(gamma) + focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma) + loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target, reduction="none") * focal_weight + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + if weight.numel() != loss.numel(): + msg = "The number of elements in weight should be equal to the number of elements in loss." + raise ValueError(msg) + weight = weight.view(loss.size(0), -1) + if weight.ndim != loss.ndim: + msg = "The number of dimensions in weight should be equal to the number of dimensions in loss." + raise ValueError(msg) + return weight_reduce_loss(loss, weight, reduction, avg_factor) + + +def sigmoid_focal_loss( + pred: Tensor, + target: Tensor, + weight: None | Tensor = None, + gamma: float = 2.0, + alpha: float = 0.25, + reduction: str = "mean", + avg_factor: int | None = None, +) -> torch.Tensor: + r"""A wrapper of cuda version `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, alpha, None, "none") + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + if weight.numel() != loss.numel(): + msg = "The number of elements in weight should be equal to the number of elements in loss." + raise ValueError(msg) + weight = weight.view(loss.size(0), -1) + if weight.ndim != loss.ndim: + msg = "The number of dimensions in weight should be equal to the number of dimensions in loss." + raise ValueError(msg) + return weight_reduce_loss(loss, weight, reduction, avg_factor) diff --git a/src/otx/algo/detection/losses/smooth_l1_loss.py b/src/otx/algo/detection/losses/smooth_l1_loss.py new file mode 100644 index 00000000000..5322a238d66 --- /dev/null +++ b/src/otx/algo/detection/losses/smooth_l1_loss.py @@ -0,0 +1,84 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""Smooth L1 Loss.""" +from __future__ import annotations + +import torch +from mmengine.registry import MODELS +from torch import Tensor, nn + +from otx.algo.detection.losses.weighted_loss import weighted_loss + + +@weighted_loss +def l1_loss(pred: Tensor, target: Tensor) -> Tensor: + """L1 loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + + Returns: + Tensor: Calculated loss + """ + if target.numel() == 0: + return pred.sum() * 0 + + if pred.size() != target.size(): + msg = f"pred and target should be in the same size, but got {pred.size()} and {target.size()}" + raise ValueError(msg) + return torch.abs(pred - target) + + +@MODELS.register_module() +class L1Loss(nn.Module): + """L1 loss. + + Args: + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of loss. + """ + + def __init__(self, reduction: str = "mean", loss_weight: float = 1.0) -> None: + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward( + self, + pred: Tensor, + target: Tensor, + weight: Tensor | None = None, + avg_factor: int | None = None, + reduction_override: str | None = None, + ) -> Tensor: + """Forward function. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + + Returns: + Tensor: Calculated loss + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() + if reduction_override not in (None, "none", "mean", "sum"): + msg = f"Unsupported reduction method {reduction_override}" + raise NotImplementedError(msg) + reduction = reduction_override if reduction_override else self.reduction + return self.loss_weight * l1_loss(pred, target, weight, reduction=reduction, avg_factor=avg_factor) diff --git a/src/otx/algo/detection/utils/utils.py b/src/otx/algo/detection/utils/utils.py index 5a869cde8e6..34d4b4ca271 100644 --- a/src/otx/algo/detection/utils/utils.py +++ b/src/otx/algo/detection/utils/utils.py @@ -5,14 +5,12 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Callable +from typing import Callable import torch +from mmengine.structures import InstanceData from torch import Tensor -if TYPE_CHECKING: - from mmengine.structures import InstanceData - # Methods below come from mmdet.utils and slightly modified. # https://github.com/open-mmlab/mmdetection/blob/3.x/mmdet/models/utils/misc.py @@ -210,3 +208,77 @@ def unpack_gt_instances(batch_data_samples: list[InstanceData]) -> tuple: batch_gt_instances_ignore.append(None) return batch_gt_instances, batch_gt_instances_ignore, batch_img_metas + + +def empty_instances( + batch_img_metas: list[dict], + device: torch.device, + task_type: str, + instance_results: list[InstanceData] | None = None, + mask_thr_binary: int | float = 0, + num_classes: int = 80, + score_per_cls: bool = False, +) -> list[InstanceData]: + """Handle predicted instances when RoI is empty. + + Note: If ``instance_results`` is not None, it will be modified + in place internally, and then return ``instance_results`` + + Args: + batch_img_metas (list[dict]): List of image information. + device (torch.device): Device of tensor. + task_type (str): Expected returned task type. it currently + supports bbox and mask. + instance_results (list[:obj:`InstanceData`]): List of instance + results. + mask_thr_binary (int, float): mask binarization threshold. + Defaults to 0. + box_type (str or type): The empty box type. Defaults to `hbox`. + use_box_type (bool): Whether to warp boxes with the box type. + Defaults to False. + num_classes (int): num_classes of bbox_head. Defaults to 80. + score_per_cls (bool): Whether to generate classwise score for + the empty instance. ``score_per_cls`` will be True when the model + needs to produce raw results without nms. Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + """ + if task_type not in ("bbox", "mask"): + msg = f"Only support bbox and mask, but got {task_type}" + raise ValueError(msg) + + if instance_results is not None and len(instance_results) != len(batch_img_metas): + msg = "The length of instance_results should be the same as batch_img_metas" + raise ValueError(msg) + + results_list = [] + for img_id in range(len(batch_img_metas)): + if instance_results is not None: + results = instance_results[img_id] + if not isinstance(results, InstanceData): + msg = f"instance_results should be InstanceData, but got {type(results)}" + raise TypeError(msg) + else: + results = InstanceData() + + if task_type == "bbox": + bboxes = torch.zeros(0, 4, device=device) + results.bboxes = bboxes + score_shape = (0, num_classes + 1) if score_per_cls else (0,) + results.scores = torch.zeros(score_shape, device=device) + results.labels = torch.zeros((0,), device=device, dtype=torch.long) + else: + img_h, img_w = batch_img_metas[img_id]["ori_shape"][:2] + # the type of `im_mask` will be torch.bool or torch.uint8, + # where uint8 if for visualization and debugging. + im_mask = torch.zeros( + 0, + img_h, + img_w, + device=device, + dtype=torch.bool if mask_thr_binary >= 0 else torch.uint8, + ) + results.masks = im_mask + results_list.append(results) + return results_list diff --git a/src/otx/algo/instance_segmentation/__init__.py b/src/otx/algo/instance_segmentation/__init__.py index 5fd14c6c6f6..b22e4e26ca9 100644 --- a/src/otx/algo/instance_segmentation/__init__.py +++ b/src/otx/algo/instance_segmentation/__init__.py @@ -3,6 +3,6 @@ # """Module for OTX instance segmentation models.""" -from . import heads +from . import heads, mmdet -__all__ = ["heads"] +__all__ = ["heads", "mmdet"] diff --git a/src/otx/algo/instance_segmentation/heads/__init__.py b/src/otx/algo/instance_segmentation/heads/__init__.py index 17066a70e1f..7d68c6dcb03 100644 --- a/src/otx/algo/instance_segmentation/heads/__init__.py +++ b/src/otx/algo/instance_segmentation/heads/__init__.py @@ -3,7 +3,6 @@ # """Custom head architecture for OTX instance segmentation models.""" -from .custom_roi_head import CustomConvFCBBoxHead, CustomRoIHead from .custom_rtmdet_ins_head import CustomRTMDetInsSepBNHead -__all__ = ["CustomRoIHead", "CustomConvFCBBoxHead", "CustomRTMDetInsSepBNHead"] +__all__ = ["CustomRTMDetInsSepBNHead"] diff --git a/src/otx/algo/instance_segmentation/heads/custom_roi_head.py b/src/otx/algo/instance_segmentation/heads/custom_roi_head.py deleted file mode 100644 index a5699c6737f..00000000000 --- a/src/otx/algo/instance_segmentation/heads/custom_roi_head.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# -"""Custom ROI head for OTX template.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch -from mmdet.models.losses import accuracy -from mmdet.models.roi_heads.bbox_heads.convfc_bbox_head import Shared2FCBBoxHead -from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead -from mmdet.models.utils import multi_apply, unpack_gt_instances -from mmdet.registry import MODELS -from mmdet.structures.bbox import bbox2roi -from torch import Tensor - -from otx.algo.detection.heads.class_incremental_mixin import ( - ClassIncrementalMixin, -) -from otx.algo.detection.losses.cross_focal_loss import ( - CrossSigmoidFocalLoss, -) - -if TYPE_CHECKING: - from mmdet.models.task_modules.samplers import SamplingResult - from mmdet.structures import DetDataSample - from mmdet.utils import InstanceList - from mmengine.config import ConfigDict - - -@MODELS.register_module() -class CustomRoIHead(StandardRoIHead): - """CustomRoIHead class for OTX.""" - - def loss(self, x: tuple[Tensor], rpn_results_list: InstanceList, batch_data_samples: list[DetDataSample]) -> dict: - """Perform forward propagation and loss calculation of the detection roi on the features. - - Args: - x (tuple[Tensor]): list of multi-level img features. - rpn_results_list (list[:obj:`InstanceData`]): list of region - proposals. - batch_data_samples (list[:obj:`DetDataSample`]): The batch - data samples. It usually includes information such - as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. - - Returns: - dict[str, Tensor]: A dictionary of loss components - """ - outputs = unpack_gt_instances(batch_data_samples) - batch_gt_instances, batch_gt_instances_ignore, batch_img_metas = outputs - - # assign gts and sample proposals - num_imgs = len(batch_data_samples) - sampling_results = [] - for i in range(num_imgs): - # rename rpn_results.bboxes to rpn_results.priors - rpn_results = rpn_results_list[i] - rpn_results.priors = rpn_results.pop("bboxes") - - assign_result = self.bbox_assigner.assign(rpn_results, batch_gt_instances[i], batch_gt_instances_ignore[i]) - sampling_result = self.bbox_sampler.sample( - assign_result, - rpn_results, - batch_gt_instances[i], - feats=[lvl_feat[i][None] for lvl_feat in x], - ) - sampling_results.append(sampling_result) - - losses = {} - # bbox head loss - if self.with_bbox: - bbox_results = self.bbox_loss(x, sampling_results, batch_img_metas) - losses.update(bbox_results["loss_bbox"]) - - # mask head forward and loss - if self.with_mask: - mask_results = self.mask_loss(x, sampling_results, bbox_results["bbox_feats"], batch_gt_instances) - losses.update(mask_results["loss_mask"]) - - return losses - - def bbox_loss(self, x: tuple[Tensor], sampling_results: list[SamplingResult], batch_img_metas: list[dict]) -> dict: - """Perform forward propagation and loss calculation of the bbox head on the features of the upstream network. - - Args: - x (tuple[Tensor]): list of multi-level img features. - sampling_results (list["obj:`SamplingResult`]): Sampling results. - batch_img_metas (list[Dict]): Meta information of each image, e.g., image size, scaling factor, etc. - - Returns: - dict[str, Tensor]: Usually returns a dictionary with keys: - - - `cls_score` (Tensor): Classification scores. - - `bbox_pred` (Tensor): Box energies / deltas. - - `bbox_feats` (Tensor): Extract bbox RoI features. - - `loss_bbox` (dict): A dictionary of bbox loss components. - """ - rois = bbox2roi([res.bboxes for res in sampling_results]) - bbox_results = self._bbox_forward(x, rois) - - bbox_loss_and_target = self.bbox_head.loss_and_target( - cls_score=bbox_results["cls_score"], - bbox_pred=bbox_results["bbox_pred"], - rois=rois, - sampling_results=sampling_results, - rcnn_train_cfg=self.train_cfg, - batch_img_metas=batch_img_metas, - ) - bbox_results.update(loss_bbox=bbox_loss_and_target["loss_bbox"]) - - return bbox_results - - -@MODELS.register_module() -class CustomConvFCBBoxHead(Shared2FCBBoxHead, ClassIncrementalMixin): - """CustomConvFCBBoxHead class for OTX.""" - - def loss_and_target( - self, - cls_score: Tensor, - bbox_pred: Tensor, - rois: Tensor, - sampling_results: list[SamplingResult], - rcnn_train_cfg: ConfigDict, - batch_img_metas: list[dict], - concat: bool = True, - reduction_override: str | None = None, - ) -> dict: - """Calculate the loss based on the features extracted by the bbox head. - - Args: - cls_score (Tensor): Classification prediction - results of all class, has shape - (batch_size * num_proposals_single_image, num_classes) - bbox_pred (Tensor): Regression prediction results, - has shape - (batch_size * num_proposals_single_image, 4), the last - dimension 4 represents [tl_x, tl_y, br_x, br_y]. - rois (Tensor): RoIs with the shape - (batch_size * num_proposals_single_image, 5) where the first - column indicates batch id of each RoI. - sampling_results (list[obj:SamplingResult]): Assign results of - all images in a batch after sampling. - rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. - batch_img_metas (list[Dict]): Meta information of each image, e.g., image size, scaling factor, etc. - concat (bool): Whether to concatenate the results of all - the images in a single batch. Defaults to True. - reduction_override (str, optional): The reduction - method used to override the original reduction - method of the loss. Options are "none", - "mean" and "sum". Defaults to None, - - Returns: - dict: A dictionary of loss and targets components. - The targets are only used for cascade rcnn. - """ - cls_reg_targets = self.get_targets( - sampling_results, - rcnn_train_cfg, - concat=concat, - batch_img_metas=batch_img_metas, - ) - losses = self.loss( - cls_score, - bbox_pred, - rois, - *cls_reg_targets, - reduction_override=reduction_override, # type: ignore[misc] - ) - - # cls_reg_targets is only for cascade rcnn - return {"loss_bbox": losses, "bbox_targets": cls_reg_targets} - - def get_targets( - self, - sampling_results: list[SamplingResult], - rcnn_train_cfg: ConfigDict, - batch_img_metas: list[dict], - concat: bool = True, - ) -> tuple: - """Calculate the ground truth for all samples in a batch according to the sampling_results. - - Almost the same as the implementation in bbox_head, we passed - additional parameters pos_inds_list and neg_inds_list to - `_get_targets_single` function. - - Args: - sampling_results (list[obj:SamplingResult]): Assign results of - all images in a batch after sampling. - rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. - batch_img_metas (list[Dict]): Meta information of each image, e.g., image size, scaling factor, etc. - concat (bool): Whether to concatenate the results of all - the images in a single batch. - - Returns: - tuple[Tensor]: Ground truth for proposals in a single image. - Containing the following list of Tensors: - - - labels (list[Tensor],Tensor): Gt_labels for all - proposals in a batch, each tensor in list has - shape (num_proposals,) when `concat=False`, otherwise - just a single tensor has shape (num_all_proposals,). - - label_weights (list[Tensor]): Labels_weights for - all proposals in a batch, each tensor in list has - shape (num_proposals,) when `concat=False`, otherwise - just a single tensor has shape (num_all_proposals,). - - bbox_targets (list[Tensor],Tensor): Regression target - for all proposals in a batch, each tensor in list - has shape (num_proposals, 4) when `concat=False`, - otherwise just a single tensor has shape - (num_all_proposals, 4), the last dimension 4 represents - [tl_x, tl_y, br_x, br_y]. - - bbox_weights (list[tensor],Tensor): Regression weights for - all proposals in a batch, each tensor in list has shape - (num_proposals, 4) when `concat=False`, otherwise just a - single tensor has shape (num_all_proposals, 4). - """ - pos_priors_list = [res.pos_priors for res in sampling_results] - neg_priors_list = [res.neg_priors for res in sampling_results] - pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] - pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] - labels, label_weights, bbox_targets, bbox_weights = multi_apply( - self._get_targets_single, - pos_priors_list, - neg_priors_list, - pos_gt_bboxes_list, - pos_gt_labels_list, - cfg=rcnn_train_cfg, - ) - - valid_label_mask = self.get_valid_label_mask(img_metas=batch_img_metas, all_labels=labels, use_bg=True) - valid_label_mask = [i.to(labels[0].device) for i in valid_label_mask] - - if concat: - labels = torch.cat(labels, 0) - label_weights = torch.cat(label_weights, 0) - bbox_targets = torch.cat(bbox_targets, 0) - bbox_weights = torch.cat(bbox_weights, 0) - valid_label_mask = torch.cat(valid_label_mask, 0) - return labels, label_weights, bbox_targets, bbox_weights, valid_label_mask - - def loss( - self, - cls_score: Tensor, - bbox_pred: Tensor, - rois: Tensor, - labels: Tensor, - label_weights: Tensor, - bbox_targets: Tensor, - bbox_weights: Tensor, - valid_label_mask: Tensor | None = None, - reduction_override: str | None = None, - ) -> dict: - """Loss function for CustomConvFCBBoxHead.""" - losses = {} - if cls_score is not None and cls_score.numel() > 0: - avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.0) - - if isinstance(self.loss_cls, CrossSigmoidFocalLoss): - losses["loss_cls"] = self.loss_cls( - cls_score, - labels, - label_weights, - avg_factor=avg_factor, - reduction_override=reduction_override, - valid_label_mask=valid_label_mask, - ) - else: - losses["loss_cls"] = self.loss_cls( - cls_score, - labels, - label_weights, - avg_factor=avg_factor, - reduction_override=reduction_override, - ) - losses["acc"] = accuracy(cls_score, labels) - if bbox_pred is not None: - bg_class_ind = self.num_classes - # 0~self.num_classes-1 are FG, self.num_classes is BG - pos_inds = (labels >= 0) & (labels < bg_class_ind) - # do not perform bounding box regression for BG anymore. - if pos_inds.any(): - if self.reg_decoded_bbox: - # When the regression loss (e.g. `IouLoss`, - # `GIouLoss`, `DIouLoss`) is applied directly on - # the decoded bounding boxes, it decodes the - # already encoded coordinates to absolute format. - bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred) - if self.reg_class_agnostic: - pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds.type(torch.bool)] - else: - pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)[ - pos_inds.type(torch.bool), - labels[pos_inds.type(torch.bool)], - ] - losses["loss_bbox"] = self.loss_bbox( - pos_bbox_pred, - bbox_targets[pos_inds.type(torch.bool)], - bbox_weights[pos_inds.type(torch.bool)], - avg_factor=bbox_targets.size(0), - reduction_override=reduction_override, - ) - else: - losses["loss_bbox"] = bbox_pred[pos_inds].sum() - return losses diff --git a/src/otx/algo/instance_segmentation/heads/custom_rtmdet_ins_head.py b/src/otx/algo/instance_segmentation/heads/custom_rtmdet_ins_head.py index 580d1a6b931..d80ed84aba5 100644 --- a/src/otx/algo/instance_segmentation/heads/custom_rtmdet_ins_head.py +++ b/src/otx/algo/instance_segmentation/heads/custom_rtmdet_ins_head.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING import torch -import torch.nn.functional as F # noqa: N812 +import torch.nn.functional from mmcv.ops import RoIAlign, batched_nms from mmdeploy.codebase.mmdet import get_post_processing_params from mmdeploy.codebase.mmdet.models.dense_heads.rtmdet_ins_head import _parse_dynamic_params @@ -73,7 +73,7 @@ def mask_postprocess( mask_logits = mask_logits.sigmoid() for inds in chunks: masks[:, inds] = ( - F.interpolate( + torch.nn.functional.interpolate( mask_logits[:, inds], size=[ img_h, @@ -171,7 +171,11 @@ def _bbox_mask_post_process( # process masks mask_logits = self._mask_predict_by_feat_single(mask_feat, results.kernels, results.priors) - mask_logits = F.interpolate(mask_logits.unsqueeze(0), scale_factor=stride, mode="bilinear") + mask_logits = torch.nn.functional.interpolate( + mask_logits.unsqueeze(0), + scale_factor=stride, + mode="bilinear", + ) if rescale: ori_h, ori_w = img_meta["ori_shape"][:2] @@ -296,7 +300,7 @@ def _custom_nms_with_mask_static( priors = priors[batch_inds, inds, :] mask_logits = _custom_mask_predict_by_feat_single(self, mask_feats, kernels, priors) stride = self.prior_generator.strides[0][0] - mask_logits = F.interpolate(mask_logits, scale_factor=stride, mode="bilinear") + mask_logits = torch.nn.functional.interpolate(mask_logits, scale_factor=stride, mode="bilinear") masks = mask_logits.sigmoid() batch_index = ( diff --git a/src/otx/algo/instance_segmentation/maskrcnn.py b/src/otx/algo/instance_segmentation/maskrcnn.py index a518e83bbd9..f618f1c1737 100644 --- a/src/otx/algo/instance_segmentation/maskrcnn.py +++ b/src/otx/algo/instance_segmentation/maskrcnn.py @@ -8,6 +8,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Literal +from otx.algo.instance_segmentation.mmdet.models.detectors import MaskRCNN from otx.algo.utils.mmconfig import read_mmconfig from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.config.data import TileConfig @@ -18,15 +19,19 @@ from otx.core.model.instance_segmentation import MMDetInstanceSegCompatibleModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import LabelInfoTypes +from otx.core.utils.build import modify_num_classes +from otx.core.utils.config import convert_conf_to_mmconfig_dict from otx.core.utils.utils import get_mean_std_from_data_processing if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable + from omegaconf import DictConfig + from torch.nn.modules import Module from otx.core.metrics import MetricCallable -class MaskRCNN(MMDetInstanceSegCompatibleModel): +class MMDetMaskRCNN(MMDetInstanceSegCompatibleModel): """MaskRCNN Model.""" def __init__( @@ -53,6 +58,54 @@ def __init__( self.image_size = (1, 3, 1024, 1024) self.tile_image_size = (1, 3, 512, 512) + def get_classification_layers(self, config: DictConfig, prefix: str = "") -> dict[str, dict[str, int]]: + """Return classification layer names by comparing two different number of classes models. + + Args: + config (DictConfig): Config for building model. + model_registry (Registry): Registry for building model. + prefix (str): Prefix of model param name. + Normally it is "model." since OTXModel set it's nn.Module model as self.model + + Return: + dict[str, dict[str, int]] + A dictionary contain classification layer's name and information. + Stride means dimension of each classes, normally stride is 1, but sometimes it can be 4 + if the layer is related bbox regression for object detection. + Extra classes is default class except class from data. + Normally it is related with background classes. + """ + sample_config = deepcopy(config) + modify_num_classes(sample_config, 5) + sample_model_dict = MaskRCNN( + **convert_conf_to_mmconfig_dict(sample_config, to="list"), + ).state_dict() + + modify_num_classes(sample_config, 6) + incremental_model_dict = MaskRCNN( + **convert_conf_to_mmconfig_dict(sample_config, to="list"), + ).state_dict() + + classification_layers = {} + for key in sample_model_dict: + if sample_model_dict[key].shape != incremental_model_dict[key].shape: + sample_model_dim = sample_model_dict[key].shape[0] + incremental_model_dim = incremental_model_dict[key].shape[0] + stride = incremental_model_dim - sample_model_dim + num_extra_classes = 6 * sample_model_dim - 5 * incremental_model_dim + classification_layers[prefix + key] = {"stride": stride, "num_extra_classes": num_extra_classes} + return classification_layers + + def _create_model(self) -> Module: + from mmengine.runner import load_checkpoint + + config = deepcopy(self.config) + self.classification_layers = self.get_classification_layers(config, "model.") + detector = MaskRCNN(**convert_conf_to_mmconfig_dict(config, to="list")) + if self.load_from is not None: + load_checkpoint(detector, self.load_from, map_location="cpu") + return detector + @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" @@ -108,6 +161,52 @@ def __init__( self.image_size = (1, 3, 1344, 1344) self.tile_image_size = (1, 3, 512, 512) + def get_classification_layers(self, config: DictConfig, prefix: str = "") -> dict[str, dict[str, int]]: + """Return classification layer names by comparing two different number of classes models. + + Args: + config (DictConfig): Config for building model. + model_registry (Registry): Registry for building model. + prefix (str): Prefix of model param name. + Normally it is "model." since OTXModel set it's nn.Module model as self.model + + Return: + dict[str, dict[str, int]] + A dictionary contain classification layer's name and information. + Stride means dimension of each classes, normally stride is 1, but sometimes it can be 4 + if the layer is related bbox regression for object detection. + Extra classes is default class except class from data. + Normally it is related with background classes. + """ + sample_config = deepcopy(config) + modify_num_classes(sample_config, 5) + sample_model_dict = MaskRCNN(**convert_conf_to_mmconfig_dict(sample_config, to="list")).state_dict() + + modify_num_classes(sample_config, 6) + incremental_model_dict = MaskRCNN( + **convert_conf_to_mmconfig_dict(sample_config, to="list"), + ).state_dict() + + classification_layers = {} + for key in sample_model_dict: + if sample_model_dict[key].shape != incremental_model_dict[key].shape: + sample_model_dim = sample_model_dict[key].shape[0] + incremental_model_dim = incremental_model_dict[key].shape[0] + stride = incremental_model_dim - sample_model_dim + num_extra_classes = 6 * sample_model_dim - 5 * incremental_model_dim + classification_layers[prefix + key] = {"stride": stride, "num_extra_classes": num_extra_classes} + return classification_layers + + def _create_model(self) -> Module: + from mmengine.runner import load_checkpoint + + config = deepcopy(self.config) + self.classification_layers = self.get_classification_layers(config, "model.") + detector = MaskRCNN(**convert_conf_to_mmconfig_dict(config, to="list")) + if self.load_from is not None: + load_checkpoint(detector, self.load_from, map_location="cpu") + return detector + @property def _exporter(self) -> OTXModelExporter: """Creates OTXModelExporter object that can export the model.""" diff --git a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml index 13ce4962ebf..ad28fcbae36 100644 --- a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml +++ b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_efficientnetb2b.yaml @@ -14,6 +14,7 @@ data_preprocessor: - 1.0 - 1.0 type: MaskRCNN +_scope_: mmengine backbone: type: efficientnet_b2b out_indices: diff --git a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_r50.yaml b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_r50.yaml index aee1178ba7c..c37f124f0f7 100644 --- a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_r50.yaml +++ b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_r50.yaml @@ -1,5 +1,8 @@ load_from: https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_fpn_mstrain-poly_3x_coco/mask_rcnn_r50_fpn_mstrain-poly_3x_coco_20210524_201154-21b550bb.pth +type: "MaskRCNN" +_scope_: mmengine backbone: + type: "ResNet" depth: 50 frozen_stages: 1 init_cfg: @@ -15,9 +18,8 @@ backbone: - 1 - 2 - 3 - style: "pytorch" - type: "ResNet" data_preprocessor: + type: "DetDataPreprocessor" bgr_to_rgb: false mean: - 123.675 @@ -29,9 +31,9 @@ data_preprocessor: - 58.395 - 57.12 - 57.375 - type: "DetDataPreprocessor" non_blocking: true neck: + type: "FPN" in_channels: - 256 - 512 @@ -39,10 +41,12 @@ neck: - 2048 num_outs: 5 out_channels: 256 - type: "FPN" roi_head: + type: "CustomRoIHead" bbox_head: + type: "CustomConvFCBBoxHead" bbox_coder: + type: "DeltaXYWHBBoxCoder" target_means: - 0.0 - 0.0 @@ -53,7 +57,6 @@ roi_head: - 0.1 - 0.2 - 0.2 - type: "DeltaXYWHBBoxCoder" fc_out_channels: 1024 in_channels: 256 loss_bbox: @@ -66,8 +69,8 @@ roi_head: num_classes: 5 reg_class_agnostic: false roi_feat_size: 7 - type: "CustomConvFCBBoxHead" bbox_roi_extractor: + type: "SingleRoIExtractor" featmap_strides: - 4 - 8 @@ -78,8 +81,8 @@ roi_head: output_size: 7 sampling_ratio: 0 type: "RoIAlign" - type: "SingleRoIExtractor" mask_head: + type: "FCNMaskHead" conv_out_channels: 256 in_channels: 256 loss_mask: @@ -88,8 +91,8 @@ roi_head: use_mask: true num_classes: 5 num_convs: 4 - type: "FCNMaskHead" mask_roi_extractor: + type: "SingleRoIExtractor" featmap_strides: - 4 - 8 @@ -100,10 +103,10 @@ roi_head: output_size: 14 sampling_ratio: 0 type: "RoIAlign" - type: "SingleRoIExtractor" - type: "CustomRoIHead" rpn_head: + type: "RPNHead" anchor_generator: + type: "AnchorGenerator" ratios: - 0.5 - 1.0 @@ -116,8 +119,8 @@ rpn_head: - 16 - 32 - 64 - type: "AnchorGenerator" bbox_coder: + type: "DeltaXYWHBBoxCoder" target_means: - 0.0 - 0.0 @@ -128,7 +131,6 @@ rpn_head: - 1.0 - 1.0 - 1.0 - type: "DeltaXYWHBBoxCoder" feat_channels: 256 in_channels: 256 loss_bbox: @@ -138,7 +140,6 @@ rpn_head: loss_weight: 1.0 type: "CrossEntropyLoss" use_sigmoid: true - type: "RPNHead" test_cfg: rcnn: mask_thr_binary: 0.5 @@ -157,38 +158,38 @@ test_cfg: train_cfg: rcnn: assigner: + type: "MaxIoUAssigner" ignore_iof_thr: -1 match_low_quality: true min_pos_iou: 0.5 neg_iou_thr: 0.5 pos_iou_thr: 0.5 - type: "MaxIoUAssigner" debug: false mask_size: 28 pos_weight: -1 sampler: + type: "RandomSampler" add_gt_as_proposals: true neg_pos_ub: -1 num: 512 pos_fraction: 0.25 - type: "RandomSampler" rpn: allowed_border: -1 assigner: + type: "MaxIoUAssigner" ignore_iof_thr: -1 match_low_quality: true min_pos_iou: 0.3 neg_iou_thr: 0.3 pos_iou_thr: 0.7 - type: "MaxIoUAssigner" debug: false pos_weight: -1 sampler: + type: "RandomSampler" add_gt_as_proposals: false neg_pos_ub: -1 num: 256 pos_fraction: 0.5 - type: "RandomSampler" rpn_proposal: max_per_img: 1000 min_bbox_size: 0 @@ -196,4 +197,3 @@ train_cfg: iou_threshold: 0.7 type: "nms" nms_pre: 2000 -type: "MaskRCNN" diff --git a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_swint.yaml b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_swint.yaml index caed1cb7af4..5072f1d2a2e 100644 --- a/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_swint.yaml +++ b/src/otx/algo/instance_segmentation/mmconfigs/maskrcnn_swint.yaml @@ -1,4 +1,6 @@ load_from: https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco/mask_rcnn_swin-t-p4-w7_fpn_fp16_ms-crop-3x_coco_20210908_165006-90a4008c.pth +type: MaskRCNN +_scope_: mmengine backbone: attn_drop_rate: 0.0 convert_weights: true @@ -209,4 +211,3 @@ train_cfg: iou_threshold: 0.7 type: nms nms_pre: 2000 -type: MaskRCNN diff --git a/src/otx/algo/instance_segmentation/mmdet/__init__.py b/src/otx/algo/instance_segmentation/mmdet/__init__.py new file mode 100644 index 00000000000..3bccd4f0295 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/__init__.py @@ -0,0 +1,5 @@ +"""MMDet models for instance segmentation.""" + +from .models.detectors import MaskRCNN + +__all__ = ["MaskRCNN"] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/__init__.py new file mode 100644 index 00000000000..fb557f2ad5e --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +""""MMDet model files.""" +from .backbones import ResNet +from .dense_heads import RPNHead +from .detectors import MaskRCNN +from .samplers import RandomSampler + +__all__ = [ + "ResNet", + "RPNHead", + "MaskRCNN", + "RandomSampler", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/backbones/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/backbones/__init__.py new file mode 100644 index 00000000000..d9321a4b167 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/backbones/__init__.py @@ -0,0 +1,14 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet backbones.""" +from .resnet import ResNet +from .swin import SwinTransformer + +__all__ = [ + "ResNet", + "SwinTransformer", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/backbones/resnet.py b/src/otx/algo/instance_segmentation/mmdet/models/backbones/resnet.py new file mode 100644 index 00000000000..03a9e455549 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/backbones/resnet.py @@ -0,0 +1,341 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet ResNet.""" +from __future__ import annotations + +import warnings +from typing import ClassVar + +import torch +import torch.utils.checkpoint as cp +from mmengine.model import BaseModule +from mmengine.registry import MODELS +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from otx.algo.instance_segmentation.mmdet.models.layers import ResLayer +from otx.algo.modules.conv import build_conv_layer +from otx.algo.modules.norm import build_norm_layer + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet.""" + + expansion = 4 + + def __init__( + self, + inplanes: int, + planes: int, + norm_cfg: dict, + stride: int = 1, + dilation: int = 1, + downsample: nn.Module | None = None, + with_cp: bool = False, + conv_cfg: dict | None = None, + init_cfg: dict | None = None, + ): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super().__init__(init_cfg) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.conv1_stride = 1 + self.conv2_stride = stride + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer(norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer(conv_cfg, inplanes, planes, kernel_size=1, stride=self.conv1_stride, bias=False) + self.add_module(self.norm1_name, norm1) + + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False, + ) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer(conv_cfg, planes, planes * self.expansion, kernel_size=1, bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + @property + def norm1(self) -> nn.Module: + """nn.Module: normalization layer after the first convolution layer.""" + return getattr(self, self.norm1_name) + + @property + def norm2(self) -> nn.Module: + """nn.Module: normalization layer after the second convolution layer.""" + return getattr(self, self.norm2_name) + + @property + def norm3(self) -> nn.Module: + """nn.Module: normalization layer after the third convolution layer.""" + return getattr(self, self.norm3_name) + + def forward(self, x: torch.Tensor) -> nn.Module: + """Forward function.""" + + def _inner_forward(x: torch.Tensor) -> nn.Module: + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + out = cp.checkpoint(_inner_forward, x) if self.with_cp and x.requires_grad else _inner_forward(x) + + return self.relu(out) + + +@MODELS.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + arch_settings: ClassVar = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + } + + def __init__( + self, + depth: int, + in_channels: int = 3, + stem_channels: int | None = None, + base_channels: int = 64, + num_stages: int = 4, + strides: tuple[int, int, int, int] = (1, 2, 2, 2), + dilations: tuple[int, int, int, int] = (1, 1, 1, 1), + out_indices: tuple[int, int, int, int] = (0, 1, 2, 3), + avg_down: bool = False, + frozen_stages: int = -1, + conv_cfg: dict | None = None, + norm_cfg: dict | None = None, + norm_eval: bool = True, + with_cp: bool = False, + zero_init_residual: bool = True, + pretrained: str | bool | None = None, + init_cfg: list[dict] | dict | None = None, + ): + super().__init__(init_cfg) + self.zero_init_residual = zero_init_residual + if depth not in self.arch_settings: + msg = f"invalid depth {depth} for resnet" + raise KeyError(msg) + + block_init_cfg = None + self.init_cfg: list[dict] | dict | None = None + if init_cfg and pretrained: + msg = "init_cfg and pretrained cannot be specified at the same time" + raise ValueError(msg) + if isinstance(pretrained, str): + warnings.warn("DeprecationWarning: pretrained is deprecated, please use init_cfg instead", stacklevel=2) + self.init_cfg = {"type": "Pretrained", "checkpoint": pretrained} + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + {"type": "Kaiming", "layer": "Conv2d"}, + {"type": "Constant", "val": 1, "layer": ["BatchNorm", "GroupNorm"]}, + ] + if self.zero_init_residual: + block_init_cfg = {"type": "Constant", "val": 0, "override": {"name": "norm3"}} + else: + msg = "pretrained must be a str or None" + raise TypeError(msg) + + if norm_cfg is None: + norm_cfg = {"type": "BN", "requires_grad": True} + + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + if num_stages > 4 or num_stages < 1: + msg = "num_stages must be in [1, 4]" + raise ValueError(msg) + self.strides = strides + self.dilations = dilations + if len(strides) != len(dilations) != num_stages: + msg = "The length of strides, dilations and out_indices should be the same as num_stages" + raise ValueError(msg) + self.out_indices = out_indices + if max(out_indices) >= num_stages: + msg = "max(out_indices) should be smaller than num_stages" + raise ValueError(msg) + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + init_cfg=block_init_cfg, + ) + self.inplanes = planes * self.block.expansion + layer_name = f"layer{i + 1}" + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2 ** (len(self.stage_blocks) - 1) + + def make_res_layer(self, **kwargs) -> ResLayer: + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self) -> nn.Module: + """nn.Module: the normalization layer named "norm1".""" + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels: int, stem_channels: int) -> None: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False, + ) + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self) -> None: + if self.frozen_stages >= 0: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f"layer{i}") + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor) -> tuple: + """Forward function.""" + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode: bool = True) -> None: + """Convert the model into training mode while keep normalization layer freezed.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/src/otx/algo/instance_segmentation/mmdet/models/backbones/swin.py b/src/otx/algo/instance_segmentation/mmdet/models/backbones/swin.py new file mode 100644 index 00000000000..ffd9d75ce34 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/backbones/swin.py @@ -0,0 +1,829 @@ +"""MMDet SwinTransformer.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +# Copyright (c) OpenMMLab. All rights reserved. + +from __future__ import annotations + +import warnings +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn.functional +import torch.utils.checkpoint as cp +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import constant_init, trunc_normal_, trunc_normal_init +from mmengine.registry import MODELS +from mmengine.runner.checkpoint import CheckpointLoader +from mmengine.utils import to_2tuple +from timm.models.layers import DropPath +from torch import nn + +from otx.algo.instance_segmentation.mmdet.models.layers import PatchEmbed, PatchMerging +from otx.algo.modules.norm import build_norm_layer +from otx.algo.modules.transformer import FFN + +# ruff: noqa: PLR0913 + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + window_size: tuple[int, int], + qkv_bias: bool = True, + qk_scale: float | None = None, + attn_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + init_cfg: None = None, + ): + super().__init__() + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + self.init_cfg = init_cfg + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads), + ) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + wh, ww = self.window_size + rel_index_coords = self.double_step_seq(2 * ww - 1, wh, 1, ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer("relative_position_index", rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self) -> None: + """Initialize the weights.""" + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: + """Swin Transformer layer computation. + + Args: + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, Wh*Ww, Wh*Ww), value between (-inf, 0]. + """ + batch_size, num_pred, channels = x.shape + qkv = ( + self.qkv(x) + .reshape(batch_size, num_pred, 3, self.num_heads, channels // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(batch_size // nw, nw, self.num_heads, num_pred, num_pred) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, num_pred, num_pred) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(batch_size, num_pred, channels) + x = self.proj(x) + return self.proj_drop(x) + + @staticmethod + def double_step_seq(step1: int, len1: int, step2: int, len2: int) -> torch.Tensor: + """Generate double step sequence.""" + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ShiftWindowMSA(BaseModule): + """Shifted Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + window_size: int, + shift_size: int = 0, + qkv_bias: bool = True, + qk_scale: float | None = None, + attn_drop_rate: float = 0, + proj_drop_rate: float = 0, + dropout_layer: dict | None = None, + init_cfg: None = None, + ): + super().__init__(init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + if self.shift_size < 0 or self.shift_size >= self.window_size: + msg = "shift_size must be in [0, window_size)" + raise ValueError(msg) + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + init_cfg=None, + ) + + dropout_layer = {"type": "DropPath", "drop_prob": 0.0} if dropout_layer is None else dropout_layer + _dropout_layer = deepcopy(dropout_layer) + dropout_type = _dropout_layer.pop("type") + if dropout_type != "DropPath": + msg = "Only support `DropPath` dropout layer." + raise ValueError(msg) + self.drop = DropPath(**_dropout_layer) + + def forward(self, query: torch.Tensor, hw_shape: tuple[int, int]) -> torch.Tensor: + """Forward function.""" + b, length, c = query.shape + h, w = hw_shape + if h * w != length: + msg = "The length of query should be equal to H*W." + raise ValueError(msg) + query = query.view(b, h, w, c) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - w % self.window_size) % self.window_size + pad_b = (self.window_size - h % self.window_size) % self.window_size + query = torch.nn.functional.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + h_pad, w_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll(query, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, h_pad, w_pad, 1), device=query.device) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for _h in h_slices: + for _w in w_slices: + img_mask[:, _h, _w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, c) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, h_pad, w_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :h, :w, :].contiguous() + + x = x.view(b, h * w, c) + + return self.drop(x) + + def window_reverse(self, windows: torch.Tensor, h: int, w: int) -> torch.Tensor: + """Reverse the window partition process. + + Args: + windows: (num_windows*B, window_size, window_size, C) + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + batch_size = int(windows.shape[0] / (h * w / window_size / window_size)) + x = windows.view(batch_size, h // window_size, w // window_size, window_size, window_size, -1) + return x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, h, w, -1) + + def window_partition(self, x: torch.Tensor) -> torch.Tensor: + """Split x into multi windows. + + Args: + x: (B, H, W, C) + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + batch_size, h, w, c = x.shape + window_size = self.window_size + x = x.view(batch_size, h // window_size, window_size, w // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + return windows.view(-1, window_size, window_size, c) + + +class SwinBlock(BaseModule): + """Basic Swin Transformer block. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + window_size: int = 7, + shift: bool = False, + qkv_bias: bool = True, + qk_scale: float | None = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + act_cfg: dict | None = None, + norm_cfg: dict | None = None, + with_cp: bool = False, + init_cfg: None = None, + ): + super().__init__() + + self.init_cfg = init_cfg + self.with_cp = with_cp + + act_cfg = act_cfg if act_cfg is not None else {"type": "GELU"} + norm_cfg = norm_cfg if norm_cfg is not None else {"type": "LN"} + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer={"type": "DropPath", "drop_prob": drop_path_rate}, + init_cfg=None, + ) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer={"type": "DropPath", "drop_prob": drop_path_rate}, + act_cfg=act_cfg, + add_identity=True, + init_cfg=None, + ) + + def forward(self, x: torch.Tensor, hw_shape: torch.Tensor) -> torch.Tensor: + """Forward function.""" + + def _inner_forward(x: torch.Tensor) -> torch.Tensor: + """Inner forward function.""" + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + return self.ffn(x, identity=identity) + + return cp.checkpoint(_inner_forward, x) if self.with_cp and x.requires_grad else _inner_forward(x) + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. + downsample (BaseModule | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + feedforward_channels: int, + depth: int, + window_size: int = 7, + qkv_bias: bool = True, + qk_scale: float | None = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: list[float] | float = 0.0, + downsample: BaseModule | None = None, + act_cfg: dict | None = None, + norm_cfg: dict | None = None, + with_cp: bool = False, + init_cfg: None = None, + ): + super().__init__(init_cfg=init_cfg) + + act_cfg = act_cfg if act_cfg is not None else {"type": "GELU"} + norm_cfg = norm_cfg if norm_cfg is not None else {"type": "LN"} + + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + if len(drop_path_rates) != depth: + msg = "The length of drop_path_rate should be equal to depth." + raise ValueError(msg) + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=i % 2 != 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rates[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None, + ) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x: torch.Tensor, hw_shape: tuple[int, int]) -> torch.Tensor: + """Forward function.""" + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + return x, hw_shape, x, hw_shape + + +@MODELS.register_module() +class SwinTransformer(BaseModule): + """Swin Transformer. + + A PyTorch implement of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/abs/2103.14030 + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None. + convert_weights (bool): The flag indicates whether the + pre-trained model is from the original repo. We may need + to convert some keys to make it compatible. + Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + Default: -1 (-1 means not freezing any parameters). + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__( + self, + pretrain_img_size: int = 224, + in_channels: int = 3, + embed_dims: int = 96, + patch_size: int = 4, + window_size: int = 7, + mlp_ratio: int = 4, + depths: tuple[int, ...] = (2, 2, 6, 2), + num_heads: tuple[int, ...] = (3, 6, 12, 24), + strides: tuple[int, ...] = (4, 2, 2, 2), + out_indices: tuple[int, ...] = (0, 1, 2, 3), + qkv_bias: bool = True, + qk_scale: float | None = None, + patch_norm: bool = True, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + act_cfg: dict | None = None, + norm_cfg: dict | None = None, + with_cp: bool = False, + pretrained: str | None = None, + convert_weights: bool = False, + frozen_stages: int = -1, + init_cfg: dict | None = None, + ): + act_cfg = act_cfg if act_cfg is not None else {"type": "GELU"} + norm_cfg = norm_cfg if norm_cfg is not None else {"type": "LN"} + self.convert_weights = convert_weights + self.frozen_stages = frozen_stages + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + + if len(pretrain_img_size) != 2: + msg = f"The size of image should have length 1 or 2, but got {len(pretrain_img_size)}" + raise ValueError(msg) + + if init_cfg and pretrained: + msg = "init_cfg and pretrained cannot be set simultaneously" + raise ValueError(msg) + + init_cfg = {} if init_cfg is None else init_cfg + if isinstance(pretrained, str): + warnings.warn("DeprecationWarning: pretrained is deprecated, please use init_cfg instead", stacklevel=2) + self.init_cfg = {"type": "Pretrained", "checkpoint": pretrained} + elif pretrained is None: + self.init_cfg = init_cfg + else: + msg = "pretrained must be a str or None" + raise TypeError(msg) + + super().__init__(init_cfg=init_cfg) + + num_layers = len(depths) + self.out_indices = out_indices + + if strides[0] != patch_size: + msg = "Use non-overlapping patch embed." + raise ValueError(msg) + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type="Conv2d", + kernel_size=patch_size, + stride=strides[0], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # set stochastic depth decay rule + total_depth = sum(depths) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] + + self.stages = ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * in_channels, + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None, + ) + self.stages.append(stage) + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f"norm{i}" + self.add_module(layer_name, layer) + + def train(self, mode: bool = True) -> None: + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + self._freeze_stages() + + def _freeze_stages(self) -> None: + """Freeze stages when training.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f"norm{i-1}") + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self) -> None: + """Initialize the weights.""" + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=0.02, bias=0.0) + elif isinstance(m, nn.LayerNorm): + constant_init(m, 1.0) + else: + if "checkpoint" not in self.init_cfg: + msg = "The checkpoint is not in the init_cfg." + raise ValueError(msg) + ckpt = CheckpointLoader.load_checkpoint(self.init_cfg["checkpoint"], map_location="cpu") + if "state_dict" in ckpt: + _state_dict = ckpt["state_dict"] + elif "model" in ckpt: + _state_dict = ckpt["model"] + else: + _state_dict = ckpt + if self.convert_weights: + # supported loading weight from original repo, + _state_dict = swin_converter(_state_dict) + + state_dict = {} + for k, v in _state_dict.items(): + if k.startswith("backbone."): + state_dict[k[9:]] = v + + # strip prefix of state_dict + if next(iter(state_dict.keys())).startswith("module."): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get("absolute_pos_embed") is not None: + absolute_pos_embed = state_dict["absolute_pos_embed"] + n1, length, c1 = absolute_pos_embed.size() + n2, c2, h, w = self.absolute_pos_embed.size() + if n1 != n2 or c1 != c2 or h * w != length: + warnings.warn("Error in loading absolute_pos_embed, pass", stacklevel=2) + else: + state_dict["absolute_pos_embed"] = ( + absolute_pos_embed.view(n2, h, w, c2).permute(0, 3, 1, 2).contiguous() + ) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [k for k in state_dict if "relative_position_bias_table" in k] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = self.state_dict()[table_key] + l1, n_h1 = table_pretrained.size() + l2, n_h2 = table_current.size() + if n_h1 != n_h2: + warnings.warn(f"Error in loading {table_key}, pass", stacklevel=2) + elif l1 != l2: + s1 = int(l1**0.5) + s2 = int(l2**0.5) + table_pretrained_resized = torch.nn.functional.interpolate( + table_pretrained.permute(1, 0).reshape(1, n_h1, s1, s1), + size=(s2, s2), + mode="bicubic", + ) + state_dict[table_key] = table_pretrained_resized.view(n_h2, l2).permute(1, 0).contiguous() + + # load state_dict + self.load_state_dict(state_dict, False) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """Forward function.""" + x, hw_shape = self.patch_embed(x) + x = self.drop_after_pos(x) + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return outs + + +def swin_converter(ckpt: dict) -> OrderedDict: + """Convert the key of pre-trained model from original repo.""" + new_ckpt = OrderedDict() + + def correct_unfold_reduction_order(x: torch.Tensor) -> torch.Tensor: + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + return x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel) + + def correct_unfold_norm_order(x: torch.Tensor) -> torch.Tensor: + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + return x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + + for k, v in ckpt.items(): + if k.startswith("head"): + continue + if k.startswith("layers"): + new_v = v + if "attn." in k: + new_k = k.replace("attn.", "attn.w_msa.") + elif "mlp." in k: + if "mlp.fc1." in k: + new_k = k.replace("mlp.fc1.", "ffn.layers.0.0.") + elif "mlp.fc2." in k: + new_k = k.replace("mlp.fc2.", "ffn.layers.1.") + else: + new_k = k.replace("mlp.", "ffn.") + elif "downsample" in k: + new_k = k + if "reduction." in k: + new_v = correct_unfold_reduction_order(v) + elif "norm." in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace("layers", "stages", 1) + elif k.startswith("patch_embed"): + new_v = v + new_k = k.replace("proj", "projection") if "proj" in k else k + else: + new_v = v + new_k = k + + new_ckpt["backbone." + new_k] = new_v + + return new_ckpt diff --git a/src/otx/algo/instance_segmentation/mmdet/models/base_roi_head.py b/src/otx/algo/instance_segmentation/mmdet/models/base_roi_head.py new file mode 100644 index 00000000000..89079b2ba41 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/base_roi_head.py @@ -0,0 +1,135 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet BaseRoIHead.""" +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING + +from mmengine.model import BaseModule + +if TYPE_CHECKING: + from mmdet.structures import DetDataSample + from mmengine import ConfigDict + from mmengine.structures import InstanceData + from torch import Tensor + + +class BaseRoIHead(BaseModule, metaclass=ABCMeta): + """Base class for RoIHeads.""" + + def __init__( + self, + train_cfg: ConfigDict | dict, + test_cfg: ConfigDict | dict, + bbox_roi_extractor: ConfigDict | dict | list[ConfigDict | dict] | None = None, + bbox_head: ConfigDict | dict | list[ConfigDict | dict] | None = None, + mask_roi_extractor: ConfigDict | dict | list[ConfigDict | dict] | None = None, + mask_head: ConfigDict | dict | list[ConfigDict | dict] | None = None, + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + ) -> None: + super().__init__(init_cfg=init_cfg) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if bbox_head is not None: + self.init_bbox_head(bbox_roi_extractor, bbox_head) + + if mask_head is not None: + self.init_mask_head(mask_roi_extractor, mask_head) + + self.init_assigner_sampler() + + @property + def with_bbox(self) -> bool: + """bool: whether the RoI head contains a `bbox_head`.""" + return hasattr(self, "bbox_head") and self.bbox_head is not None + + @property + def with_mask(self) -> bool: + """bool: whether the RoI head contains a `mask_head`.""" + return hasattr(self, "mask_head") and self.mask_head is not None + + @property + def with_shared_head(self) -> bool: + """bool: whether the RoI head contains a `shared_head`.""" + return hasattr(self, "shared_head") and self.shared_head is not None + + @abstractmethod + def init_bbox_head(self, *args, **kwargs) -> None: + """Initialize ``bbox_head``.""" + + @abstractmethod + def init_mask_head(self, *args, **kwargs) -> None: + """Initialize ``mask_head``.""" + + @abstractmethod + def init_assigner_sampler(self, *args, **kwargs) -> None: + """Initialize assigner and sampler.""" + + @abstractmethod + def loss( + self, + x: tuple[Tensor], + rpn_results_list: list[InstanceData], + batch_data_samples: list[DetDataSample], + ) -> dict: + """Perform forward propagation and loss calculation of the roi head on the features of the upstream network.""" + + def predict( + self, + x: tuple[Tensor], + rpn_results_list: list[InstanceData], + batch_data_samples: list[DetDataSample], + rescale: bool = False, + ) -> list[InstanceData]: + """Forward the roi head and predict detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from upstream network. Each + has shape (N, C, H, W). + rpn_results_list (list[:obj:`InstanceData`]): list of region + proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results to + the original image. Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + if not self.with_bbox: + msg = "Bbox head must be implemented." + raise NotImplementedError(msg) + batch_img_metas = [data_samples.metainfo for data_samples in batch_data_samples] + + # If it has the mask branch, the bbox branch does not need + # to be scaled to the original image scale, because the mask + # branch will scale both bbox and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.predict_bbox( + x, + batch_img_metas, + rpn_results_list, + rcnn_test_cfg=self.test_cfg, + rescale=bbox_rescale, + ) + + if self.with_mask: + results_list = self.predict_mask(x, batch_img_metas, results_list, rescale=rescale) + + return results_list diff --git a/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/__init__.py new file mode 100644 index 00000000000..a03583d0360 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/__init__.py @@ -0,0 +1,16 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet BBoxHead.""" + +from .bbox_head import BBoxHead +from .convfc_bbox_head import ConvFCBBoxHead, Shared2FCBBoxHead + +__all__ = [ + "BBoxHead", + "ConvFCBBoxHead", + "Shared2FCBBoxHead", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/bbox_head.py b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/bbox_head.py new file mode 100644 index 00000000000..fb0aac90339 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/bbox_head.py @@ -0,0 +1,442 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet BBox Head.""" +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional +from mmengine.model import BaseModule +from mmengine.registry import MODELS, TASK_UTILS +from mmengine.structures import InstanceData +from torch import Tensor, nn +from torch.nn.modules.utils import _pair + +from otx.algo.detection.deployment import is_mmdeploy_enabled +from otx.algo.detection.utils.utils import empty_instances +from otx.algo.instance_segmentation.mmdet.models.layers import multiclass_nms_torch +from otx.algo.instance_segmentation.mmdet.structures.bbox import scale_boxes + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +class BBoxHead(BaseModule): + """Simplest RoI head, with only two fc layers for classification and regression respectively.""" + + def __init__( + self, + in_channels: int, + roi_feat_size: int, + num_classes: int, + bbox_coder: dict, + loss_cls: dict, + loss_bbox: dict, + with_avg_pool: bool = False, + with_cls: bool = True, + with_reg: bool = True, + predict_box_type: str = "hbox", + reg_class_agnostic: bool = False, + reg_decoded_bbox: bool = False, + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + ) -> None: + super().__init__(init_cfg=init_cfg) + if not with_cls and not with_reg: + msg = "with_cls and with_reg cannot be both False" + raise ValueError(msg) + self.with_avg_pool = with_avg_pool + self.with_cls = with_cls + self.with_reg = with_reg + self.roi_feat_size = _pair(roi_feat_size) + self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] + self.in_channels = in_channels + self.num_classes = num_classes + self.predict_box_type = predict_box_type + self.reg_class_agnostic = reg_class_agnostic + self.reg_decoded_bbox = reg_decoded_bbox + + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + + in_channels = self.in_channels + if self.with_avg_pool: + self.avg_pool = nn.AvgPool2d(self.roi_feat_size) + else: + in_channels *= self.roi_feat_area + if self.with_cls: + # need to add background class + cls_channels = num_classes + 1 + self.fc_cls = nn.Linear(in_features=in_channels, out_features=cls_channels) + if self.with_reg: + box_dim = self.bbox_coder.encode_size + out_dim_reg = box_dim if reg_class_agnostic else box_dim * num_classes + self.fc_reg = nn.Linear(in_features=in_channels, out_features=out_dim_reg) + self.debug_imgs = None + if init_cfg is None: + self.init_cfg = [] + if self.with_cls: + self.init_cfg += [ + { + "type": "Normal", + "std": 0.01, + "override": {"name": "fc_cls"}, + }, + ] + if self.with_reg: + self.init_cfg += [ + { + "type": "Normal", + "std": 0.001, + "override": {"name": "fc_reg"}, + }, + ] + + @property + def custom_cls_channels(self) -> bool: + """Get custom_cls_channels from loss_cls.""" + return getattr(self.loss_cls, "custom_cls_channels", False) + + def _get_targets_single( + self, + pos_priors: Tensor, + neg_priors: Tensor, + pos_gt_bboxes: Tensor, + pos_gt_labels: Tensor, + cfg: ConfigDict, + ) -> tuple: + """Calculate the ground truth for proposals in the single image according to the sampling results. + + Args: + pos_priors (Tensor): Contains all the positive boxes, + has shape (num_pos, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + neg_priors (Tensor): Contains all the negative boxes, + has shape (num_neg, 4), the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + pos_gt_bboxes (Tensor): Contains gt_boxes for + all positive samples, has shape (num_pos, 4), + the last dimension 4 + represents [tl_x, tl_y, br_x, br_y]. + pos_gt_labels (Tensor): Contains gt_labels for + all positive samples, has shape (num_pos, ). + cfg (obj:`ConfigDict`): `train_cfg` of R-CNN. + + Returns: + Tuple[Tensor]: Ground truth for proposals + in a single image. Containing the following Tensors: + + - labels(Tensor): Gt_labels for all proposals, has + shape (num_proposals,). + - label_weights(Tensor): Labels_weights for all + proposals, has shape (num_proposals,). + - bbox_targets(Tensor):Regression target for all + proposals, has shape (num_proposals, 4), the + last dimension 4 represents [tl_x, tl_y, br_x, br_y]. + - bbox_weights(Tensor):Regression weights for all + proposals, has shape (num_proposals, 4). + """ + num_pos = pos_priors.size(0) + num_neg = neg_priors.size(0) + num_samples = num_pos + num_neg + + # original implementation uses new_zeros since BG are set to be 0 + # now use empty & fill because BG cat_id = num_classes, + # FG cat_id = [0, num_classes-1] + labels = pos_priors.new_full((num_samples,), self.num_classes, dtype=torch.long) + reg_dim = pos_gt_bboxes.size(-1) if self.reg_decoded_bbox else self.bbox_coder.encode_size + label_weights = pos_priors.new_zeros(num_samples) + bbox_targets = pos_priors.new_zeros(num_samples, reg_dim) + bbox_weights = pos_priors.new_zeros(num_samples, reg_dim) + if num_pos > 0: + labels[:num_pos] = pos_gt_labels + pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight + label_weights[:num_pos] = pos_weight + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode(pos_priors, pos_gt_bboxes) + else: + # When the regression loss (e.g. `IouLoss`, `GIouLoss`) + # is applied directly on the decoded bounding boxes, both + # the predicted boxes and regression targets should be with + # absolute coordinate format. + pos_bbox_targets = pos_gt_bboxes + bbox_targets[:num_pos, :] = pos_bbox_targets + bbox_weights[:num_pos, :] = 1 + if num_neg > 0: + label_weights[-num_neg:] = 1.0 + + return labels, label_weights, bbox_targets, bbox_weights + + def predict_by_feat( + self, + rois: tuple[Tensor], + cls_scores: tuple[Tensor], + bbox_preds: tuple[Tensor], + batch_img_metas: list[dict], + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + ) -> list[InstanceData]: + """Transform a batch of output features extracted from the head into bbox results. + + Args: + rois (tuple[Tensor]): Tuple of boxes to be transformed. + Each has shape (num_boxes, 5). last dimension 5 arrange as + (batch_index, x1, y1, x2, y2). + cls_scores (tuple[Tensor]): Tuple of box scores, each has shape + (num_boxes, num_classes + 1). + bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each + has shape (num_boxes, num_classes * 4). + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Instance segmentation + results of each image after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if len(cls_scores) != len(bbox_preds): + msg = "The length of cls_scores and bbox_preds should be the same." + raise ValueError(msg) + result_list = [] + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + results = self._predict_by_feat_single( + roi=rois[img_id], + cls_score=cls_scores[img_id], + bbox_pred=bbox_preds[img_id], + img_meta=img_meta, + rescale=rescale, + rcnn_test_cfg=rcnn_test_cfg, + ) + result_list.append(results) + + return result_list + + def _predict_by_feat_single( + self, + roi: Tensor, + cls_score: Tensor, + bbox_pred: Tensor, + img_meta: dict, + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + ) -> InstanceData: + """Transform a single image's features extracted from the head into bbox results. + + Args: + roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5). + last dimension 5 arrange as (batch_index, x1, y1, x2, y2). + cls_score (Tensor): Box scores, has shape + (num_boxes, num_classes + 1). + bbox_pred (Tensor): Box energies / deltas. + has shape (num_boxes, num_classes * 4). + img_meta (dict): image information. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None + + Returns: + :obj:`InstanceData`: Detection results of each image\ + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + results = InstanceData() + if roi.shape[0] == 0: + return empty_instances( + [img_meta], + roi.device, + task_type="bbox", + instance_results=[results], + num_classes=self.num_classes, + score_per_cls=rcnn_test_cfg is None, + )[0] + + scores = torch.nn.functional.softmax(cls_score, dim=-1) if cls_score is not None else None + + img_shape = img_meta["img_shape"] + num_rois = roi.size(0) + # bbox_pred would be None in some detector when with_reg is False, + # e.g. Grid R-CNN. + num_classes = 1 if self.reg_class_agnostic else self.num_classes + roi = roi.repeat_interleave(num_classes, dim=0) + bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size) + bboxes = self.bbox_coder.decode(roi[..., 1:], bbox_pred, max_shape=img_shape) + + if rescale and bboxes.size(0) > 0: + if img_meta.get("scale_factor") is None: + msg = "scale_factor must be specified in img_meta" + raise ValueError(msg) + scale_factor = [1 / s for s in img_meta["scale_factor"]] + bboxes = scale_boxes(bboxes, scale_factor) # type: ignore [arg-type] + + # Get the inside tensor when `bboxes` is a box type + box_dim = bboxes.size(-1) + bboxes = bboxes.view(num_rois, -1) + + det_bboxes, det_labels = multiclass_nms_torch( # type: ignore [misc] + bboxes, + scores, + rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms, + rcnn_test_cfg.max_per_img, + box_dim=box_dim, + ) + results.bboxes = det_bboxes[:, :-1] + results.scores = det_bboxes[:, -1] + results.labels = det_labels + return results + + +if is_mmdeploy_enabled(): + from mmdeploy.codebase.mmdet.deploy import get_post_processing_params + from mmdeploy.core import FUNCTION_REWRITER, mark + + from otx.algo.detection.ops.nms import multiclass_nms + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.bbox_heads.bbox_head.BBoxHead.forward", + ) + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.custom_roi_head.CustomConvFCBBoxHead.forward", + ) + def bbox_head__forward(self: BBoxHead, x: Tensor) -> tuple[Tensor]: + """Rewrite `forward` for default backend. + + This function uses the specific `forward` function for the BBoxHead + or ConvFCBBoxHead after adding marks. + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the original class. + x (Tensor): Input image tensor. + + Returns: + tuple(Tensor, Tensor): The (cls_score, bbox_pred). The cls_score + has shape (N, num_det, num_cls) and the bbox_pred has shape + (N, num_det, 4). + """ + ctx = FUNCTION_REWRITER.get_context() + + @mark("bbox_head_forward", inputs=["bbox_feats"], outputs=["cls_score", "bbox_pred"]) + def __forward(self: BBoxHead, x: Tensor) -> tuple[Tensor]: + return ctx.origin_func(self, x) + + return __forward(self, x) + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.bbox_heads.bbox_head.BBoxHead.predict_by_feat", + ) + def bbox_head__predict_by_feat( + self: BBoxHead, + rois: Tensor, + cls_scores: tuple[Tensor], + bbox_preds: tuple[Tensor], + batch_img_metas: list[dict], + rcnn_test_cfg: dict, + rescale: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Rewrite `predict_by_feat` of `BBoxHead` for default backend. + + Transform network output for a batch into bbox predictions. Support + `reg_class_agnostic == False` case. + + Args: + rois (tuple[Tensor]): Tuple of boxes to be transformed. + Each has shape (num_boxes, 5). last dimension 5 arrange as + (batch_index, x1, y1, x2, y2). + cls_scores (tuple[Tensor]): Tuple of box scores, each has shape + (num_boxes, num_classes + 1). + bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each + has shape (num_boxes, num_classes * 4). + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + - dets (Tensor): Classification bboxes and scores, has a shape + (num_instance, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + """ + warnings.warn(f"rescale: {rescale} is not supported in ONNX export. Ignored.", stacklevel=2) + ctx = FUNCTION_REWRITER.get_context() + if rois.ndim != 3: + msg = "Only support export two stage model to ONNX with batch dimension." + raise ValueError(msg) + + img_shape = batch_img_metas[0]["img_shape"] + if self.custom_cls_channels: + scores = self.loss_cls.get_activation(cls_scores) + else: + scores = torch.nn.functional.softmax(cls_scores, dim=-1) if cls_scores is not None else None + + if bbox_preds is not None: + # num_classes = 1 if self.reg_class_agnostic else self.num_classes + # if num_classes > 1: + # rois = rois.repeat_interleave(num_classes, dim=1) + bboxes = self.bbox_coder.decode(rois[..., 1:], bbox_preds, max_shape=img_shape) + else: + bboxes = rois[..., 1:].clone() + if img_shape is not None: + max_shape = bboxes.new_tensor(img_shape)[..., :2] + min_xy = bboxes.new_tensor(0) + max_xy = torch.cat([max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2) + bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) + bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) + + batch_size = scores.shape[0] + device = scores.device + # ignore background class + scores = scores[..., : self.num_classes] + if not self.reg_class_agnostic: + # only keep boxes with the max scores + max_inds = scores.reshape(-1, self.num_classes).argmax(1, keepdim=True) + encode_size = self.bbox_coder.encode_size + bboxes = bboxes.reshape(-1, self.num_classes, encode_size) + dim0_inds = torch.arange(bboxes.shape[0], device=device).unsqueeze(-1) + bboxes = bboxes[dim0_inds, max_inds].reshape(batch_size, -1, encode_size) + # get nms params + post_params = get_post_processing_params(ctx.cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = rcnn_test_cfg["nms"].get("iou_threshold", post_params.iou_threshold) + score_threshold = rcnn_test_cfg.get("score_thr", post_params.score_threshold) + if torch.onnx.is_in_onnx_export(): + pre_top_k = post_params.pre_top_k + else: + # For two stage partition post processing + pre_top_k = -1 if post_params.pre_top_k >= bboxes.shape[1] else post_params.pre_top_k + keep_top_k = rcnn_test_cfg.get("max_per_img", post_params.keep_top_k) + return multiclass_nms( + bboxes, + scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + ) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/convfc_bbox_head.py b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/convfc_bbox_head.py new file mode 100644 index 00000000000..dd2f1397d97 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/bbox_heads/convfc_bbox_head.py @@ -0,0 +1,203 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet ConvFCBBoxHead.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from mmengine.registry import MODELS +from torch import Tensor, nn + +from .bbox_head import BBoxHead + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +@MODELS.register_module() +class ConvFCBBoxHead(BBoxHead): + r"""More general bbox head, with shared conv and fc layers and two optional separated branches. + + .. code-block:: none + + /-> cls convs -> cls fcs -> cls + shared convs -> shared fcs + \-> reg convs -> reg fcs -> reg + """ + + def __init__( + self, + num_shared_convs: int = 0, + num_shared_fcs: int = 0, + num_cls_convs: int = 0, + num_cls_fcs: int = 0, + num_reg_convs: int = 0, + num_reg_fcs: int = 0, + conv_out_channels: int = 256, + fc_out_channels: int = 1024, + conv_cfg: dict | ConfigDict | None = None, + norm_cfg: dict | ConfigDict | None = None, + init_cfg: dict | ConfigDict | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, init_cfg=init_cfg, **kwargs) # type: ignore [misc] + if num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs + num_reg_convs + num_reg_fcs <= 0: + msg = ( + "Pls specify at least one of num_shared_convs, num_shared_fcs, num_cls_convs, num_cls_fcs, " + "num_reg_convs, num_reg_fcs" + ) + raise ValueError(msg) + if (num_cls_convs > 0 or num_reg_convs > 0) and num_shared_fcs != 0: + msg = "Shared FC layers are mutually exclusive with cls/reg conv layers" + raise ValueError(msg) + if (not self.with_cls) and (num_cls_convs != 0 or num_cls_fcs != 0): + msg = "num_cls_convs and num_cls_fcs should be zero if without classification" + raise ValueError(msg) + if (not self.with_reg) and (num_reg_convs != 0 or num_reg_fcs != 0): + msg = "num_reg_convs and num_reg_fcs should be zero if without regression" + raise ValueError(msg) + self.num_shared_convs = num_shared_convs + self.num_shared_fcs = num_shared_fcs + self.num_cls_convs = num_cls_convs + self.num_cls_fcs = num_cls_fcs + self.num_reg_convs = num_reg_convs + self.num_reg_fcs = num_reg_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + # add shared convs and fcs + self.shared_convs, self.shared_fcs, last_layer_dim = self._add_conv_fc_branch( + self.num_shared_convs, + self.num_shared_fcs, + self.in_channels, + True, + ) + self.shared_out_channels = last_layer_dim + + # add cls specific branch + self.cls_convs, self.cls_fcs, self.cls_last_dim = self._add_conv_fc_branch( + self.num_cls_convs, + self.num_cls_fcs, + self.shared_out_channels, + ) + + # add reg specific branch + self.reg_convs, self.reg_fcs, self.reg_last_dim = self._add_conv_fc_branch( + self.num_reg_convs, + self.num_reg_fcs, + self.shared_out_channels, + ) + + self.relu = nn.ReLU(inplace=True) + # reconstruct fc_cls and fc_reg since input channels are changed + if self.with_cls: + cls_channels = self.num_classes + 1 + self.fc_cls = nn.Linear(in_features=self.cls_last_dim, out_features=cls_channels) + if self.with_reg: + box_dim = self.bbox_coder.encode_size + out_dim_reg = box_dim if self.reg_class_agnostic else box_dim * self.num_classes + self.fc_reg = nn.Linear(in_features=self.reg_last_dim, out_features=out_dim_reg) + + if init_cfg is None: + # when init_cfg is None, + # It has been set to + # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))], + # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))] + # after `super(ConvFCBBoxHead, self).__init__()` + # we only need to append additional configuration + # for `shared_fcs`, `cls_fcs` and `reg_fcs` + self.init_cfg += [ + { + "type": "Xavier", + "distribution": "uniform", + "override": [ + {"name": "shared_fcs"}, + {"name": "cls_fcs"}, + {"name": "reg_fcs"}, + ], + }, + ] + + def _add_conv_fc_branch( + self, + num_branch_convs: int, + num_branch_fcs: int, + in_channels: int, + is_shared: bool = False, + ) -> tuple: + """Add shared or separable branch. + + convs -> avg pool (optional) -> fcs + """ + last_layer_dim = in_channels + # add branch specific conv layers + branch_convs = nn.ModuleList() + + # add branch specific fc layers + branch_fcs = nn.ModuleList() + if num_branch_fcs > 0: + # for shared branch, only consider self.with_avg_pool + # for separated branches, also consider self.num_shared_fcs + if (is_shared or self.num_shared_fcs == 0) and not self.with_avg_pool: + last_layer_dim *= self.roi_feat_area + for i in range(num_branch_fcs): + fc_in_channels = last_layer_dim if i == 0 else self.fc_out_channels + branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + return branch_convs, branch_fcs, last_layer_dim + + def forward(self, x: Tensor) -> tuple: + """Forward features from the upstream network. + + Args: + x (Tensor): Features from the upstream network, each is a 4D-tensor. + + Returns: + tuple: A tuple of classification scores and bbox prediction. + + - cls_score (Tensor): Classification scores for all \ + scale levels, each is a 4D-tensor, the channels number \ + is num_base_priors * num_classes. + - bbox_pred (Tensor): Box energies / deltas for all \ + scale levels, each is a 4D-tensor, the channels number \ + is num_base_priors * 4. + """ + # shared part + + if self.num_shared_fcs > 0: + x = x.flatten(1) + + for fc in self.shared_fcs: + x = self.relu(fc(x)) + # separate branches + x_cls = x + x_reg = x + + cls_score = self.fc_cls(x_cls) if self.with_cls else None + bbox_pred = self.fc_reg(x_reg) if self.with_reg else None + return cls_score, bbox_pred + + +@MODELS.register_module() +class Shared2FCBBoxHead(ConvFCBBoxHead): + """Shared 2 FC BBox Head.""" + + def __init__(self, fc_out_channels: int = 1024, *args, **kwargs) -> None: + super().__init__( # type: ignore [misc] + num_shared_convs=0, + num_shared_fcs=2, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + fc_out_channels=fc_out_channels, + *args, # noqa: B026 + **kwargs, + ) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/custom_roi_head.py b/src/otx/algo/instance_segmentation/mmdet/models/custom_roi_head.py new file mode 100644 index 00000000000..51a74ef73db --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/custom_roi_head.py @@ -0,0 +1,740 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDetection StandardRoIHead.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from mmengine.registry import MODELS, TASK_UTILS +from torch import Tensor + +from otx.algo.detection.deployment import is_mmdeploy_enabled +from otx.algo.detection.heads.class_incremental_mixin import ( + ClassIncrementalMixin, +) +from otx.algo.detection.losses import CrossSigmoidFocalLoss, accuracy +from otx.algo.detection.utils.structures import SamplingResult +from otx.algo.detection.utils.utils import empty_instances, multi_apply, unpack_gt_instances +from otx.algo.instance_segmentation.mmdet.models.bbox_heads.convfc_bbox_head import Shared2FCBBoxHead +from otx.algo.instance_segmentation.mmdet.models.mask_heads.fcn_mask_head import FCNMaskHead +from otx.algo.instance_segmentation.mmdet.structures.bbox import bbox2roi + +from .base_roi_head import BaseRoIHead +from .roi_extractors import SingleRoIExtractor + +if TYPE_CHECKING: + from mmdet.structures.det_data_sample import DetDataSample + from mmengine.config import ConfigDict + from mmengine.structures import InstanceData + + +@MODELS.register_module() +class StandardRoIHead(BaseRoIHead): + """Simplest base roi head including one bbox head and one mask head.""" + + def init_assigner_sampler(self) -> None: + """Initialize assigner and sampler.""" + self.bbox_assigner = TASK_UTILS.build(self.train_cfg["assigner"]) + self.bbox_sampler = TASK_UTILS.build(self.train_cfg["sampler"], default_args={"context": self}) + + def init_bbox_head(self, bbox_roi_extractor: ConfigDict | dict, bbox_head: ConfigDict | dict) -> None: + """Initialize box head and box roi extractor. + + Args: + bbox_roi_extractor (dict or ConfigDict): Config of box + roi extractor. + bbox_head (dict or ConfigDict): Config of box in box head. + """ + if bbox_roi_extractor["type"] != SingleRoIExtractor.__name__: + msg = f"bbox_roi_extractor should be SingleRoIExtractor, but got {bbox_roi_extractor['type']}" + raise ValueError(msg) + + if bbox_head["type"] != CustomConvFCBBoxHead.__name__: + msg = f"bbox_head should be CustomConvFCBBoxHead, but got {bbox_head['type']}" + raise ValueError(msg) + + bbox_roi_extractor.pop("type") + bbox_head.pop("type") + + self.bbox_roi_extractor = SingleRoIExtractor(**bbox_roi_extractor) + self.bbox_head = CustomConvFCBBoxHead(**bbox_head) + + def init_mask_head(self, mask_roi_extractor: ConfigDict | dict, mask_head: ConfigDict | dict) -> None: + """Initialize mask head and mask roi extractor. + + Args: + mask_roi_extractor (dict or ConfigDict): Config of mask roi + extractor. + mask_head (dict or ConfigDict): Config of mask in mask head. + """ + if mask_roi_extractor["type"] != SingleRoIExtractor.__name__: + msg = f"mask_roi_extractor should be SingleRoIExtractor, but got {mask_roi_extractor['type']}" + raise ValueError(msg) + mask_roi_extractor.pop("type") + self.mask_roi_extractor = SingleRoIExtractor(**mask_roi_extractor) + + if mask_head["type"] != FCNMaskHead.__name__: + msg = f"mask_head should be FCNMaskHead, but got {mask_head['type']}" + raise ValueError(msg) + + mask_head.pop("type") + self.mask_head = FCNMaskHead(**mask_head) + + def forward( + self, + x: tuple[Tensor], + rpn_results_list: list[InstanceData], + batch_data_samples: list[DetDataSample] | None = None, + ) -> tuple: + """Network forward process. Usually includes backbone, neck and head forward without any post-processing. + + Args: + x (List[Tensor]): Multi-level features that may have different + resolutions. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple: A tuple of features from ``bbox_head`` and ``mask_head`` + forward. + """ + raise NotImplementedError + + def _bbox_forward(self, x: tuple[Tensor], rois: Tensor) -> dict: + """Box head forward function used in both training and testing. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + bbox_feats = self.bbox_roi_extractor(x[: self.bbox_roi_extractor.num_inputs], rois) + if self.with_shared_head: + bbox_feats = self.shared_head(bbox_feats) + cls_score, bbox_pred = self.bbox_head(bbox_feats) + + return {"cls_score": cls_score, "bbox_pred": bbox_pred, "bbox_feats": bbox_feats} + + def mask_loss( + self, + x: tuple[Tensor], + sampling_results: list[SamplingResult], + bbox_feats: Tensor, + batch_gt_instances: list[InstanceData], + ) -> dict: + """Perform forward propagation and loss calculation of the mask head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + bbox_feats (Tensor): Extract bbox RoI features. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `mask_feats` (Tensor): Extract mask RoI features. + - `mask_targets` (Tensor): Mask target of each positive\ + proposals in the image. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward(x, pos_rois) + + mask_loss_and_target = self.mask_head.loss_and_target( + mask_preds=mask_results["mask_preds"], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg, + ) + + mask_results.update(loss_mask=mask_loss_and_target["loss_mask"]) + return mask_results + + def _mask_forward( + self, + x: tuple[Tensor], + rois: Tensor | None = None, + pos_inds: Tensor | None = None, + bbox_feats: Tensor | None = None, + ) -> dict: + """Mask head forward function used in both training and testing. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + pos_inds (Tensor, optional): Indices of positive samples. + Defaults to None. + bbox_feats (Tensor): Extract bbox RoI features. Defaults to None. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `mask_feats` (Tensor): Extract mask RoI features. + """ + if not ((rois is not None) ^ (pos_inds is not None and bbox_feats is not None)): + msg = "rois is None xor (pos_inds is not None and bbox_feats is not None)" + raise ValueError(msg) + if rois is not None: + mask_feats = self.mask_roi_extractor(x[: self.mask_roi_extractor.num_inputs], rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + else: + if bbox_feats is None: + msg = "bbox_feats should not be None when rois is None" + raise ValueError(msg) + mask_feats = bbox_feats[pos_inds] + + mask_preds = self.mask_head(mask_feats) + return {"mask_preds": mask_preds, "mask_feats": mask_feats} + + def predict_bbox( + self, + x: tuple[Tensor], + batch_img_metas: list[dict], + rpn_results_list: list[InstanceData], + rcnn_test_cfg: ConfigDict | dict, + rescale: bool = False, + ) -> list[InstanceData]: + """Forward the bbox head and predict detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + proposals = [res.bboxes for res in rpn_results_list] + rois = bbox2roi(proposals) + + if rois.shape[0] == 0: + return empty_instances( + batch_img_metas, + rois.device, + task_type="bbox", + num_classes=self.bbox_head.num_classes, + score_per_cls=rcnn_test_cfg is None, + ) + + bbox_results = self._bbox_forward(x, rois) + + # split batch bbox prediction back to each image + cls_scores = bbox_results["cls_score"] + bbox_preds = bbox_results["bbox_pred"] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = rois.split(num_proposals_per_img, 0) + cls_scores = cls_scores.split(num_proposals_per_img, 0) + + # some detector with_reg is False, bbox_preds will be None + if bbox_preds is not None: + # the bbox prediction of some detectors like SABL is not Tensor + if isinstance(bbox_preds, torch.Tensor): + bbox_preds = bbox_preds.split(num_proposals_per_img, 0) + else: + bbox_preds = self.bbox_head.bbox_pred_split(bbox_preds, num_proposals_per_img) + else: + bbox_preds = (None,) * len(proposals) + + return self.bbox_head.predict_by_feat( + rois=rois, + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=rcnn_test_cfg, + rescale=rescale, + ) + + def predict_mask( + self, + x: tuple[Tensor], + batch_img_metas: list[dict], + results_list: list[InstanceData], + rescale: bool = False, + ) -> list[InstanceData]: + """Forward the mask head and predict detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + # don't need to consider aug_test. + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + return empty_instances( + batch_img_metas, + mask_rois.device, + task_type="mask", + instance_results=results_list, + mask_thr_binary=self.test_cfg["mask_thr_binary"], + ) + + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results["mask_preds"] + # split batch mask prediction back to each image + num_mask_rois_per_img = [len(res) for res in results_list] + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + + return self.mask_head.predict_by_feat( + mask_preds=mask_preds, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale, + ) + + +@MODELS.register_module() +class CustomRoIHead(StandardRoIHead): + """CustomRoIHead class for OTX.""" + + def loss( + self, + x: tuple[Tensor], + rpn_results_list: list[InstanceData], + batch_data_samples: list[DetDataSample], + ) -> dict: + """Perform forward propagation and loss calculation of the detection roi on the features. + + Args: + x (tuple[Tensor]): list of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): list of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas = outputs + + # assign gts and sample proposals + num_imgs = len(batch_data_samples) + sampling_results = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop("bboxes") + + assign_result = self.bbox_assigner.assign(rpn_results, batch_gt_instances[i], batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x], + ) + sampling_results.append(sampling_result) + + losses = {} + # bbox head loss + if self.with_bbox: + bbox_results = self.bbox_loss(x, sampling_results, batch_img_metas) + losses.update(bbox_results["loss_bbox"]) + + # mask head forward and loss + if self.with_mask: + mask_results = self.mask_loss(x, sampling_results, bbox_results["bbox_feats"], batch_gt_instances) + losses.update(mask_results["loss_mask"]) + + return losses + + def bbox_loss(self, x: tuple[Tensor], sampling_results: list[SamplingResult], batch_img_metas: list[dict]) -> dict: + """Perform forward propagation and loss calculation of the bbox head on the features of the upstream network. + + Args: + x (tuple[Tensor]): list of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + batch_img_metas (list[Dict]): Meta information of each image, e.g., image size, scaling factor, etc. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + rois = bbox2roi([res.bboxes for res in sampling_results]) + bbox_results = self._bbox_forward(x, rois) + + bbox_loss_and_target = self.bbox_head.loss_and_target( + cls_score=bbox_results["cls_score"], + bbox_pred=bbox_results["bbox_pred"], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg, + batch_img_metas=batch_img_metas, + ) + bbox_results.update(loss_bbox=bbox_loss_and_target["loss_bbox"]) + + return bbox_results + + +@MODELS.register_module() +class CustomConvFCBBoxHead(Shared2FCBBoxHead, ClassIncrementalMixin): + """CustomConvFCBBoxHead class for OTX.""" + + def loss_and_target( + self, + cls_score: Tensor, + bbox_pred: Tensor, + rois: Tensor, + sampling_results: list[SamplingResult], + rcnn_train_cfg: ConfigDict, + batch_img_metas: list[dict], + concat: bool = True, + reduction_override: str | None = None, + ) -> dict: + """Calculate the loss based on the features extracted by the bbox head. + + Args: + cls_score (Tensor): Classification prediction + results of all class, has shape + (batch_size * num_proposals_single_image, num_classes) + bbox_pred (Tensor): Regression prediction results, + has shape + (batch_size * num_proposals_single_image, 4), the last + dimension 4 represents [tl_x, tl_y, br_x, br_y]. + rois (Tensor): RoIs with the shape + (batch_size * num_proposals_single_image, 5) where the first + column indicates batch id of each RoI. + sampling_results (list[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + batch_img_metas (list[Dict]): Meta information of each image, e.g., image size, scaling factor, etc. + concat (bool): Whether to concatenate the results of all + the images in a single batch. Defaults to True. + reduction_override (str, optional): The reduction + method used to override the original reduction + method of the loss. Options are "none", + "mean" and "sum". Defaults to None, + + Returns: + dict: A dictionary of loss and targets components. + The targets are only used for cascade rcnn. + """ + cls_reg_targets = self.get_targets( + sampling_results, + rcnn_train_cfg, + concat=concat, + batch_img_metas=batch_img_metas, + ) + losses = self.loss( + cls_score, + bbox_pred, + rois, + *cls_reg_targets, + reduction_override=reduction_override, # type: ignore[misc] + ) + + # cls_reg_targets is only for cascade rcnn + return {"loss_bbox": losses, "bbox_targets": cls_reg_targets} + + def get_targets( + self, + sampling_results: list[SamplingResult], + rcnn_train_cfg: ConfigDict, + batch_img_metas: list[dict], + concat: bool = True, + ) -> tuple: + """Calculate the ground truth for all samples in a batch according to the sampling_results. + + Almost the same as the implementation in bbox_head, we passed + additional parameters pos_inds_list and neg_inds_list to + `_get_targets_single` function. + + Args: + sampling_results (list[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + batch_img_metas (list[Dict]): Meta information of each image, e.g., image size, scaling factor, etc. + concat (bool): Whether to concatenate the results of all + the images in a single batch. + + Returns: + tuple[Tensor]: Ground truth for proposals in a single image. + Containing the following list of Tensors: + + - labels (list[Tensor],Tensor): Gt_labels for all + proposals in a batch, each tensor in list has + shape (num_proposals,) when `concat=False`, otherwise + just a single tensor has shape (num_all_proposals,). + - label_weights (list[Tensor]): Labels_weights for + all proposals in a batch, each tensor in list has + shape (num_proposals,) when `concat=False`, otherwise + just a single tensor has shape (num_all_proposals,). + - bbox_targets (list[Tensor],Tensor): Regression target + for all proposals in a batch, each tensor in list + has shape (num_proposals, 4) when `concat=False`, + otherwise just a single tensor has shape + (num_all_proposals, 4), the last dimension 4 represents + [tl_x, tl_y, br_x, br_y]. + - bbox_weights (list[tensor],Tensor): Regression weights for + all proposals in a batch, each tensor in list has shape + (num_proposals, 4) when `concat=False`, otherwise just a + single tensor has shape (num_all_proposals, 4). + """ + pos_priors_list = [res.pos_priors for res in sampling_results] + neg_priors_list = [res.neg_priors for res in sampling_results] + pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] + pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] + labels, label_weights, bbox_targets, bbox_weights = multi_apply( + self._get_targets_single, + pos_priors_list, + neg_priors_list, + pos_gt_bboxes_list, + pos_gt_labels_list, + cfg=rcnn_train_cfg, + ) + + valid_label_mask = self.get_valid_label_mask(img_metas=batch_img_metas, all_labels=labels, use_bg=True) + valid_label_mask = [i.to(labels[0].device) for i in valid_label_mask] + + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + bbox_targets = torch.cat(bbox_targets, 0) + bbox_weights = torch.cat(bbox_weights, 0) + valid_label_mask = torch.cat(valid_label_mask, 0) + return labels, label_weights, bbox_targets, bbox_weights, valid_label_mask + + def loss( + self, + cls_score: Tensor, + bbox_pred: Tensor, + rois: Tensor, + labels: Tensor, + label_weights: Tensor, + bbox_targets: Tensor, + bbox_weights: Tensor, + valid_label_mask: Tensor | None = None, + reduction_override: str | None = None, + ) -> dict: + """Loss function for CustomConvFCBBoxHead.""" + losses = {} + if cls_score is not None and cls_score.numel() > 0: + avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.0) + + if isinstance(self.loss_cls, CrossSigmoidFocalLoss): + losses["loss_cls"] = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override, + valid_label_mask=valid_label_mask, + ) + else: + losses["loss_cls"] = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override, + ) + losses["acc"] = accuracy(cls_score, labels) + if bbox_pred is not None: + bg_class_ind = self.num_classes + # 0~self.num_classes-1 are FG, self.num_classes is BG + pos_inds = (labels >= 0) & (labels < bg_class_ind) + # do not perform bounding box regression for BG anymore. + if pos_inds.any(): + if self.reg_decoded_bbox: + # When the regression loss (e.g. `IouLoss`, + # `GIouLoss`, `DIouLoss`) is applied directly on + # the decoded bounding boxes, it decodes the + # already encoded coordinates to absolute format. + bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred) + if self.reg_class_agnostic: + pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds.type(torch.bool)] + else: + pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)[ + pos_inds.type(torch.bool), + labels[pos_inds.type(torch.bool)], + ] + losses["loss_bbox"] = self.loss_bbox( + pos_bbox_pred, + bbox_targets[pos_inds.type(torch.bool)], + bbox_weights[pos_inds.type(torch.bool)], + avg_factor=bbox_targets.size(0), + reduction_override=reduction_override, + ) + else: + losses["loss_bbox"] = bbox_pred[pos_inds].sum() + return losses + + +if is_mmdeploy_enabled(): + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.custom_roi_head.StandardRoIHead.predict_bbox", + ) + def standard_roi_head__predict_bbox( + self: StandardRoIHead, + x: tuple[Tensor], + batch_img_metas: list[dict], + rpn_results_list: list[Tensor], + rcnn_test_cfg: ConfigDict | dict, + rescale: bool = False, + ) -> list[Tensor]: + """Rewrite `predict_bbox` of `StandardRoIHead` for default backend. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[Tensor]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[Tensor]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - dets (Tensor): Classification bboxes and scores, has a shape + (num_instance, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + """ + rois = rpn_results_list[0] + rois_dims = int(rois.shape[-1]) + batch_index = ( + torch.arange(rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(rois.size(0), rois.size(1), 1) + ) + rois = torch.cat([batch_index, rois[..., : rois_dims - 1]], dim=-1) + batch_size = rois.shape[0] + num_proposals_per_img = rois.shape[1] + + # Eliminate the batch dimension + rois = rois.view(-1, rois_dims) + bbox_results = self._bbox_forward(x, rois) + cls_scores = bbox_results["cls_score"] + bbox_preds = bbox_results["bbox_pred"] + + # Recover the batch dimension + rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) + cls_scores = cls_scores.reshape(batch_size, num_proposals_per_img, cls_scores.size(-1)) + + bbox_preds = bbox_preds.reshape(batch_size, num_proposals_per_img, bbox_preds.size(-1)) + return self.bbox_head.predict_by_feat( + rois=rois, + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=rcnn_test_cfg, + rescale=rescale, + ) + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.custom_roi_head.StandardRoIHead.predict_mask", + ) + def standard_roi_head__predict_mask( + self: StandardRoIHead, + x: tuple[Tensor], + batch_img_metas: list[dict], + results_list: list[Tensor], + rescale: bool = False, + ) -> tuple[Tensor, Tensor, Tensor]: + """Forward the mask head and predict detection results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[Tensor]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + dets, det_labels = results_list + batch_size = dets.size(0) + det_bboxes = dets[..., :4] + # expand might lead to static shape, use broadcast instead + batch_index = torch.arange(det_bboxes.size(0), device=det_bboxes.device).float().view( + -1, + 1, + 1, + ) + det_bboxes.new_zeros((det_bboxes.size(0), det_bboxes.size(1))).unsqueeze(-1) + mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) + mask_rois = mask_rois.view(-1, 5) + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results["mask_preds"] + num_det = det_bboxes.shape[1] + segm_results: Tensor = self.mask_head.predict_by_feat( + mask_preds, + results_list, + batch_img_metas, + self.test_cfg, + rescale=rescale, + ) + segm_results = segm_results.reshape(batch_size, num_det, segm_results.shape[-2], segm_results.shape[-1]) + return dets, det_labels, segm_results diff --git a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/__init__.py new file mode 100644 index 00000000000..539d0b20d3d --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/__init__.py @@ -0,0 +1,11 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +""""MMDet RPNHead.""" + +from .rpn_head import RPNHead + +__all__ = ["RPNHead"] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rpn_head.py b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rpn_head.py new file mode 100644 index 00000000000..56f7a2d7b46 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rpn_head.py @@ -0,0 +1,473 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet RPNHead.""" +from __future__ import annotations + +import copy +import warnings +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional +from mmengine.registry import MODELS +from mmengine.structures import InstanceData +from torch import Tensor, nn + +from otx.algo.detection.deployment import is_mmdeploy_enabled +from otx.algo.detection.heads.anchor_head import AnchorHead +from otx.algo.detection.ops.nms import batched_nms +from otx.algo.instance_segmentation.mmdet.structures.bbox import ( + empty_box_as, + get_box_wh, +) +from otx.algo.modules.conv_module import ConvModule + +# ruff: noqa: PLW2901 + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +@MODELS.register_module() +class RPNHead(AnchorHead): + """Implementation of RPN head. + + Args: + in_channels (int): Number of channels in the input feature map. + num_classes (int): Number of categories excluding the background + category. Defaults to 1. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \ + list[dict]): Initialization config dict. + num_convs (int): Number of convolution layers in the head. + Defaults to 1. + """ + + def __init__( + self, + in_channels: int, + num_classes: int = 1, + init_cfg: dict | None = None, + num_convs: int = 1, + **kwargs, + ) -> None: + self.num_convs = num_convs + if init_cfg is None: + init_cfg = {"type": "Normal", "layer": "Conv2d", "std": 0.01} + + if num_classes != 1: + msg = "num_classes must be 1 for RPNHead" + raise ValueError(msg) + super().__init__(num_classes=num_classes, in_channels=in_channels, init_cfg=init_cfg, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + if self.num_convs > 1: + rpn_convs = [] + for i in range(self.num_convs): + in_channels = self.in_channels if i == 0 else self.feat_channels + # use ``inplace=False`` to avoid error: one of the variables + # needed for gradient computation has been modified by an + # inplace operation. + rpn_convs.append(ConvModule(in_channels, self.feat_channels, 3, padding=1, inplace=False)) + self.rpn_conv = nn.Sequential(*rpn_convs) + else: + self.rpn_conv = nn.Conv2d(self.in_channels, self.feat_channels, 3, padding=1) + self.rpn_cls = nn.Conv2d(self.feat_channels, self.num_base_priors * self.cls_out_channels, 1) + reg_dim = self.bbox_coder.encode_size + self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * reg_dim, 1) + + def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + cls_score (Tensor): Cls scores for a single scale level \ + the channels number is num_base_priors * num_classes. + bbox_pred (Tensor): Box energies / deltas for a single scale \ + level, the channels number is num_base_priors * 4. + """ + x = self.rpn_conv(x) + x = torch.nn.functional.relu(x) + rpn_cls_score = self.rpn_cls(x) + rpn_bbox_pred = self.rpn_reg(x) + return rpn_cls_score, rpn_bbox_pred + + def loss_by_feat( + self, + cls_scores: list[Tensor], + bbox_preds: list[Tensor], + batch_gt_instances: list[InstanceData], + batch_img_metas: list[dict], + batch_gt_instances_ignore: list[InstanceData] | None = None, + ) -> dict: + """Calculate the loss based on the features extracted by the detection head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + batch_gt_instances (list[obj:InstanceData]): Batch of gt_instance. + It usually includes ``bboxes`` and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[obj:InstanceData], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + losses = super().loss_by_feat( + cls_scores, + bbox_preds, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, + ) + return {"loss_rpn_cls": losses["loss_cls"], "loss_rpn_bbox": losses["loss_bbox"]} + + def _predict_by_feat_single( + self, + cls_score_list: list[Tensor], + bbox_pred_list: list[Tensor], + score_factor_list: list[Tensor], + mlvl_priors: list[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True, + ) -> InstanceData: + """Transform a single image's features extracted from the head into bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Be compatible with + BaseDenseHead. Not used in RPNHead. + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (ConfigDict, optional): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta["img_shape"] + nms_pre = cfg.get("nms_pre", -1) + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] + mlvl_scores = [] + level_ids = [] + for level_idx, (cls_score, bbox_pred, priors) in enumerate(zip(cls_score_list, bbox_pred_list, mlvl_priors)): + if cls_score.size()[-2:] != bbox_pred.size()[-2:]: + msg = "cls_score and bbox_pred should have the same size" + raise RuntimeError(msg) + + reg_dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, reg_dim) + cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels) + scores = cls_score.sigmoid() if self.use_sigmoid_cls else cls_score.softmax(-1)[:, :-1] + + scores = torch.squeeze(scores) + if 0 < nms_pre < scores.shape[0]: + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) + ranked_scores, rank_inds = scores.sort(descending=True) + topk_inds = rank_inds[:nms_pre] + scores = ranked_scores[:nms_pre] + bbox_pred = bbox_pred[topk_inds, :] + priors = priors[topk_inds] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + + # use level id to implement the separate level nms + level_ids.append(scores.new_full((scores.size(0),), level_idx, dtype=torch.long)) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = torch.cat(mlvl_valid_priors) + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.scores = torch.cat(mlvl_scores) + results.level_ids = torch.cat(level_ids) + + return self._bbox_post_process(results=results, cfg=cfg, rescale=rescale, img_meta=img_meta) + + def _bbox_post_process( + self, + results: InstanceData, + cfg: ConfigDict, + img_meta: dict, + rescale: bool = False, + with_nms: bool = True, + ) -> InstanceData: + """Bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + cfg (ConfigDict): Test / postprocessing configuration. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Default to True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if not with_nms: + msg = "`with_nms` must be True in RPNHead" + raise RuntimeError(msg) + + if rescale: + msg = "Rescale is not implemented in RPNHead" + raise NotImplementedError + + # filter small size bboxes + if cfg.get("min_bbox_size", -1) >= 0: + w, h = get_box_wh(results.bboxes) + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + results = results[valid_mask] + + if results.bboxes.numel() > 0: + bboxes = results.bboxes + det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, results.level_ids, cfg.nms) + results = results[keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + results = results[: cfg.max_per_img] + + # in visualization + results.labels = results.scores.new_zeros(len(results), dtype=torch.long) + del results.level_ids + else: + # To avoid some potential error + results_ = InstanceData() + results_.bboxes = empty_box_as(results.bboxes) + results_.scores = results.scores.new_zeros(0) + results_.labels = results.scores.new_zeros(0) + results = results_ + return results + + +if is_mmdeploy_enabled(): + from mmdeploy.codebase.mmdet.deploy import gather_topk, get_post_processing_params, pad_with_value_if_necessary + from mmdeploy.core import FUNCTION_REWRITER + from mmdeploy.utils import is_dynamic_shape + + from otx.algo.detection.ops.nms import multiclass_nms + + @FUNCTION_REWRITER.register_rewriter( + func_name="otx.algo.instance_segmentation.mmdet.models.dense_heads.rpn_head.RPNHead.predict_by_feat", + ) + def rpn_head__predict_by_feat( + self: RPNHead, + cls_scores: list[Tensor], + bbox_preds: list[Tensor], + batch_img_metas: list[dict], + score_factors: list[Tensor] | None = None, + cfg: ConfigDict | None = None, + rescale: bool = False, + with_nms: bool = True, + **kwargs, + ) -> tuple: + """Rewrite `predict_by_feat` of `RPNHead` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx (ContextCaller): The context with additional information. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + If with_nms == True: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + Else: + tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, + batch_mlvl_scores, batch_mlvl_centerness + """ + warnings.warn(f"score_factors: {score_factors} is not used in RPNHead", stacklevel=2) + warnings.warn(f"rescale: {rescale} is not used in RPNHead", stacklevel=2) + warnings.warn(f"kwargs: {kwargs} is not used in RPNHead", stacklevel=2) + ctx = FUNCTION_REWRITER.get_context() + img_metas = batch_img_metas + if len(cls_scores) != len(bbox_preds): + msg = "cls_scores and bbox_preds should have the same length" + raise ValueError(msg) + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_anchors = self.anchor_generator.grid_anchors(featmap_sizes, device=device) + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + if len(mlvl_cls_scores) != len(mlvl_bbox_preds) != len(mlvl_anchors): + msg = "mlvl_cls_scores, mlvl_bbox_preds and mlvl_anchors should have the same length" + raise ValueError(msg) + + cfg = self.test_cfg if cfg is None else cfg + if cfg is None: + warnings.warn("cfg is None, use default cfg", stacklevel=2) + cfg = { + "max_per_img": 1000, + "min_bbox_size": 0, + "nms": {"iou_threshold": 0.7, "type": "nms"}, + "nms_pre": 1000, + } + batch_size = mlvl_cls_scores[0].shape[0] + pre_topk = cfg.get("nms_pre", -1) + + # loop over features, decode boxes + mlvl_valid_bboxes = [] + mlvl_scores = [] + mlvl_valid_anchors = [] + for cls_score, bbox_pred, anchors in zip( + mlvl_cls_scores, + mlvl_bbox_preds, + mlvl_anchors, + ): + if cls_score.size()[-2:] != bbox_pred.size()[-2:]: + msg = "cls_score and bbox_pred should have the same size" + raise ValueError(msg) + cls_score = cls_score.permute(0, 2, 3, 1) + if self.use_sigmoid_cls: + cls_score = cls_score.reshape(batch_size, -1) + scores = cls_score.sigmoid() + else: + cls_score = cls_score.reshape(batch_size, -1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = cls_score.softmax(-1)[..., 0] + scores = scores.reshape(batch_size, -1, 1) + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, dim) + + # use static anchor if input shape is static + if not is_dynamic_flag: + anchors = anchors.data + + anchors = anchors.unsqueeze(0) + + # topk in tensorrt does not support shape 0: + _, topk_inds = scores.squeeze(2).topk(pre_topk) + bbox_pred, scores = gather_topk( + bbox_pred, + scores, + inds=topk_inds, + batch_size=batch_size, + is_batched=True, + ) + anchors = gather_topk(anchors, inds=topk_inds, batch_size=batch_size, is_batched=False) + mlvl_valid_bboxes.append(bbox_pred) + mlvl_scores.append(scores) + mlvl_valid_anchors.append(anchors) + + batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1) + batch_mlvl_bboxes = self.bbox_coder.decode( + batch_mlvl_anchors, + batch_mlvl_bboxes, + max_shape=img_metas[0]["img_shape"], + ) + # ignore background class + if not self.use_sigmoid_cls: + batch_mlvl_scores = batch_mlvl_scores[..., : self.num_classes] + if not with_nms: + return batch_mlvl_bboxes, batch_mlvl_scores + + post_params = get_post_processing_params(deploy_cfg) + iou_threshold = cfg["nms"].get("iou_threshold", post_params.iou_threshold) + score_threshold = cfg.get("score_thr", post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get("max_per_img", post_params.keep_top_k) + # only one class in rpn + max_output_boxes_per_class = keep_top_k + return multiclass_nms( + batch_mlvl_bboxes, + batch_mlvl_scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + ) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/detectors/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/detectors/__init__.py new file mode 100644 index 00000000000..c37e7929270 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/detectors/__init__.py @@ -0,0 +1,15 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet Detectors.""" + +from .mask_rcnn import MaskRCNN +from .two_stage import TwoStageDetector + +__all__ = [ + "MaskRCNN", + "TwoStageDetector", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/detectors/base.py b/src/otx/algo/instance_segmentation/mmdet/models/detectors/base.py new file mode 100644 index 00000000000..686b89730f0 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/detectors/base.py @@ -0,0 +1,140 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet BaseDetector.""" +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, TypeAlias + +import torch +from mmdet.structures.det_data_sample import DetDataSample +from mmengine.model import BaseModel +from torch import Tensor + +ForwardResults: TypeAlias = dict[str, torch.Tensor] | list[DetDataSample] | tuple[torch.Tensor] | torch.Tensor + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + from mmengine.structures import InstanceData + + +class BaseDetector(BaseModel, metaclass=ABCMeta): + """Base class for detectors. + + Args: + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. Defaults to None. + """ + + def __init__( + self, + data_preprocessor: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + ): + super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + @property + def with_neck(self) -> bool: + """bool: whether the detector has a neck.""" + return hasattr(self, "neck") and self.neck is not None + + @property + def with_bbox(self) -> bool: + """bool: whether the detector has a bbox head.""" + return (hasattr(self, "roi_head") and self.roi_head.with_bbox) or ( + hasattr(self, "bbox_head") and self.bbox_head is not None + ) + + def forward(self, inputs: torch.Tensor, data_samples: list[DetDataSample], mode: str = "tensor") -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DetDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle either back propagation or + parameter update, which are supposed to be done in :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == "loss": + return self.loss(inputs, data_samples) + if mode == "predict": + return self.predict(inputs, data_samples) + msg = f"Invalid mode {mode}. Only supports loss and predict mode." + raise RuntimeError(msg) + + @abstractmethod + def loss(self, batch_inputs: Tensor, batch_data_samples: list[DetDataSample]) -> dict | tuple: + """Calculate losses from a batch of inputs and data samples.""" + + @abstractmethod + def predict(self, batch_inputs: Tensor, batch_data_samples: list[DetDataSample]) -> list[DetDataSample]: + """Predict results from a batch of inputs and data samples with post-processing.""" + + @abstractmethod + def _forward(self, batch_inputs: Tensor, batch_data_samples: list[DetDataSample]) -> tuple: + """Network forward process. + + Usually includes backbone, neck and head forward without any post- + processing. + """ + + @abstractmethod + def extract_feat(self, batch_inputs: Tensor) -> tuple: + """Extract features from images.""" + + def add_pred_to_datasample( + self, + data_samples: list[DetDataSample], + results_list: list[InstanceData], + ) -> list[DetDataSample]: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances'. And the ``pred_instances`` usually + contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + for data_sample, pred_instances in zip(data_samples, results_list): + data_sample.pred_instances = pred_instances + return data_samples diff --git a/src/otx/algo/instance_segmentation/mmdet/models/detectors/mask_rcnn.py b/src/otx/algo/instance_segmentation/mmdet/models/detectors/mask_rcnn.py new file mode 100644 index 00000000000..63fd07dd5c5 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/detectors/mask_rcnn.py @@ -0,0 +1,44 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ +"""MMDet MaskRCNN.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from mmengine.registry import MODELS + +from .two_stage import TwoStageDetector + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +@MODELS.register_module() +class MaskRCNN(TwoStageDetector): + """Implementation of `Mask R-CNN `.""" + + def __init__( + self, + backbone: ConfigDict, + rpn_head: ConfigDict, + roi_head: ConfigDict, + train_cfg: ConfigDict, + test_cfg: ConfigDict, + neck: ConfigDict | dict | None = None, + data_preprocessor: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + **kwargs, + ) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + data_preprocessor=data_preprocessor, + ) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/detectors/two_stage.py b/src/otx/algo/instance_segmentation/mmdet/models/detectors/two_stage.py new file mode 100644 index 00000000000..13620deae0c --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/detectors/two_stage.py @@ -0,0 +1,353 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet TwoStageDetector.""" +from __future__ import annotations + +import copy +import warnings +from typing import TYPE_CHECKING, Callable + +import torch +from mmengine.registry import MODELS +from torch import Tensor + +from otx.algo.detection.backbones.pytorchcv_backbones import _build_pytorchcv_model +from otx.algo.detection.deployment import is_mmdeploy_enabled +from otx.algo.instance_segmentation.mmdet.models.custom_roi_head import CustomRoIHead +from otx.algo.instance_segmentation.mmdet.models.dense_heads import RPNHead +from otx.algo.instance_segmentation.mmdet.models.necks import FPN + +from .base import BaseDetector + +if TYPE_CHECKING: + from mmdet.structures.det_data_sample import DetDataSample + from mmengine.config import ConfigDict + + from otx.algo.instance_segmentation.mmdet.models.detectors.base import ForwardResults + + +class TwoStageDetector(BaseDetector): + """Base class for two-stage detectors. + + Two-stage detectors typically consisting of a region proposal network and a + task-specific regression head. + """ + + def __init__( + self, + backbone: ConfigDict | dict, + neck: ConfigDict | dict, + rpn_head: ConfigDict | dict, + roi_head: ConfigDict | dict, + train_cfg: ConfigDict | dict, + test_cfg: ConfigDict | dict, + data_preprocessor: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + **kwargs, + ) -> None: + super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) + try: + self.backbone = MODELS.build(backbone) + except KeyError: + self.backbone = _build_pytorchcv_model(**backbone) + + if neck["type"] != FPN.__name__: + msg = f"neck type must be {FPN.__name__}, but got {neck['type']}" + raise ValueError(msg) + # pop out type for FPN + neck.pop("type") + self.neck = FPN(**neck) + + rpn_train_cfg = train_cfg["rpn"] + rpn_head_ = rpn_head.copy() + rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg["rpn"]) + rpn_head_num_classes = rpn_head_.get("num_classes", None) + if rpn_head_num_classes is None: + rpn_head_.update(num_classes=1) + elif rpn_head_num_classes != 1: + warnings.warn( + "The `num_classes` should be 1 in RPN, but get " + f"{rpn_head_num_classes}, please set " + "rpn_head.num_classes = 1 in your config file.", + stacklevel=2, + ) + rpn_head_.update(num_classes=1) + if rpn_head_["type"] != RPNHead.__name__: + msg = f"rpn_head type must be {RPNHead.__name__}, but got {rpn_head_['type']}" + raise ValueError(msg) + # pop out type for RPNHead + rpn_head_.pop("type") + self.rpn_head = RPNHead(**rpn_head_) + + # update train and test cfg here for now + rcnn_train_cfg = train_cfg["rcnn"] + roi_head.update(train_cfg=rcnn_train_cfg) + roi_head.update(test_cfg=test_cfg["rcnn"]) + if roi_head["type"] != CustomRoIHead.__name__: + msg = f"roi_head type must be {CustomRoIHead.__name__}, but got {roi_head['type']}" + raise ValueError(msg) + # pop out type for RoIHead + roi_head.pop("type") + self.roi_head = CustomRoIHead(**roi_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def _load_from_state_dict( + self, + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: list[str] | str, + unexpected_keys: list[str] | str, + error_msgs: list[str] | str, + ) -> None: + """Exchange bbox_head key to rpn_head key when loading single-stage weights into two-stage model.""" + bbox_head_prefix = prefix + ".bbox_head" if prefix else "bbox_head" + bbox_head_keys = [k for k in state_dict if k.startswith(bbox_head_prefix)] + rpn_head_prefix = prefix + ".rpn_head" if prefix else "rpn_head" + rpn_head_keys = [k for k in state_dict if k.startswith(rpn_head_prefix)] + if len(bbox_head_keys) != 0 and len(rpn_head_keys) == 0: + for bbox_head_key in bbox_head_keys: + rpn_head_key = rpn_head_prefix + bbox_head_key[len(bbox_head_prefix) :] + state_dict[rpn_head_key] = state_dict.pop(bbox_head_key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @property + def with_rpn(self) -> bool: + """bool: whether the detector has RPN.""" + return hasattr(self, "rpn_head") and self.rpn_head is not None + + def extract_feat(self, batch_inputs: Tensor) -> tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). + + Returns: + tuple[Tensor]: Multi-level features that may have + different resolutions. + """ + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x + + def _forward(self, batch_inputs: Tensor, batch_data_samples: list[DetDataSample]) -> tuple: + """Network forward process. Usually includes backbone, neck and head forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): Each item contains + the meta information of each image and corresponding + annotations. + + Returns: + tuple: A tuple of features from ``rpn_head`` and ``roi_head`` + forward. + """ + results = () + x = self.extract_feat(batch_inputs) + + if self.with_rpn: + rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False) + else: + if batch_data_samples[0].get("proposals", None) is None: + msg = "No 'proposals' in data samples." + raise ValueError(msg) + rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples] + roi_outs = self.roi_head.forward(x, rpn_results_list, batch_data_samples) + return (*results, roi_outs) + + def loss(self, batch_inputs: Tensor, batch_data_samples: list[DetDataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + x = self.extract_feat(batch_inputs) + + losses = {} + + # RPN forward and loss + if self.with_rpn: + proposal_cfg = self.train_cfg.get("rpn_proposal", self.test_cfg["rpn"]) + rpn_data_samples = copy.deepcopy(batch_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( + x, + rpn_data_samples, + proposal_cfg=proposal_cfg, + ) + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in list(keys): + if "loss" in key and "rpn" not in key: + rpn_losses[f"rpn_{key}"] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + if batch_data_samples[0].get("proposals", None) is None: + msg = "No 'proposals' in data samples." + raise ValueError(msg) + # use pre-defined proposals in InstanceData for the second stage + # to extract ROI features. + rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples] + + roi_losses = self.roi_head.loss(x, rpn_results_list, batch_data_samples) + losses.update(roi_losses) + + return losses + + def predict( + self, + batch_inputs: Tensor, + batch_data_samples: list[DetDataSample], + rescale: bool = True, + ) -> list[DetDataSample]: + """Predict results from a batch of inputs and data samples with post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Return the detection results of the + input images. The returns value is DetDataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + if not self.with_bbox: + msg = "Bbox head is not implemented." + raise NotImplementedError(msg) + x = self.extract_feat(batch_inputs) + + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get("proposals", None) is None: + rpn_results_list = self.rpn_head.predict(x, batch_data_samples, rescale=False) + else: + rpn_results_list = [data_sample.proposals for data_sample in batch_data_samples] + + results_list = self.roi_head.predict(x, rpn_results_list, batch_data_samples, rescale=rescale) + + return self.add_pred_to_datasample(batch_data_samples, results_list) + + +if is_mmdeploy_enabled(): + from mmdeploy.core import FUNCTION_REWRITER, mark + from mmdeploy.utils import is_dynamic_shape + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.detectors.two_stage.TwoStageDetector.extract_feat", + ) + def two_stage_detector__extract_feat(self: TwoStageDetector, img: Tensor) -> list[Tensor]: + """Rewrite `extract_feat` for default backend. + + This function uses the specific `extract_feat` function for the two + stage detector after adding marks. + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the original class. + img (Tensor | List[Tensor]): Input image tensor(s). + + Returns: + list[Tensor]: Each item with shape (N, C, H, W) corresponds one + level of backbone and neck features. + """ + ctx = FUNCTION_REWRITER.get_context() + + @mark("extract_feat", inputs="img", outputs="feat") + def __extract_feat_impl(self: TwoStageDetector, img: Tensor) -> Callable: + return ctx.origin_func(self, img) + + return __extract_feat_impl(self, img) + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.detectors.two_stage.TwoStageDetector.forward", + ) + def two_stage_detector__forward( + self: TwoStageDetector, + batch_inputs: torch.Tensor, + data_samples: list[DetDataSample], + mode: str = "tensor", + **kwargs, + ) -> ForwardResults: + """Rewrite `forward` for default backend. + + Support configured dynamic/static shape for model input and return + detection result as Tensor instead of numpy array. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + mode (str): export mode, not used. + + Returns: + tuple[Tensor]: Detection results of the + input images. + - dets (Tensor): Classification bboxes and scores. + Has a shape (num_instances, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + """ + warnings.warn(f"{mode}, {kwargs} not used", stacklevel=2) + ctx = FUNCTION_REWRITER.get_context() + deploy_cfg = ctx.cfg + + # get origin input shape as tensor to support onnx dynamic shape + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + img_shape = torch._shape_as_tensor(batch_inputs)[2:] # noqa: SLF001 + if not is_dynamic_flag: + img_shape = [int(val) for val in img_shape] + + # set the metainfo + # note that we can not use `set_metainfo`, deepcopy would crash the + # onnx trace. + for data_sample in data_samples: + data_sample.set_field(name="img_shape", value=img_shape, field_type="metainfo") + + x = self.extract_feat(batch_inputs) + + if data_samples[0].get("proposals", None) is None: + rpn_results_list = self.rpn_head.predict(x, data_samples, rescale=False) + else: + rpn_results_list = [data_sample.proposals for data_sample in data_samples] + + return self.roi_head.predict(x, rpn_results_list, data_samples, rescale=False) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/layers/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/layers/__init__.py new file mode 100644 index 00000000000..bee16f3fbc2 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/layers/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet Layers.""" + +from .bbox_nms import multiclass_nms_torch +from .res_layer import ResLayer +from .transformer import PatchEmbed, PatchMerging + +__all__ = [ + "multiclass_nms_torch", + "ResLayer", + "PatchEmbed", + "PatchMerging", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/layers/bbox_nms.py b/src/otx/algo/instance_segmentation/mmdet/models/layers/bbox_nms.py new file mode 100644 index 00000000000..229fe1dfb7a --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/layers/bbox_nms.py @@ -0,0 +1,111 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet NMS.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import Tensor + +from otx.algo.detection.ops.nms import batched_nms + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +def multiclass_nms_torch( + multi_bboxes: Tensor, + multi_scores: Tensor, + score_thr: float, + nms_cfg: ConfigDict | dict, + max_num: int = -1, + score_factors: Tensor | None = None, + return_inds: bool = False, + box_dim: int = 4, +) -> tuple[Tensor, Tensor, Tensor] | tuple[Tensor, Tensor]: + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_cfg (Union[:obj:`ConfigDict`, dict]): a dict that contains + the arguments of nms operations. + max_num (int, optional): if there are more than max_num bboxes after + NMS, only top max_num will be kept. Default to -1. + score_factors (Tensor, optional): The factors multiplied to scores + before applying NMS. Default to None. + return_inds (bool, optional): Whether return the indices of kept + bboxes. Default to False. + box_dim (int): The dimension of boxes. Defaults to 4. + + Returns: + Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: + (dets, labels, indices (optional)), tensors of shape (k, 5), + (k), and (k). Dets are boxes with scores. Labels are 0-based. + """ + num_classes = multi_scores.size(1) - 1 + # exclude background category + if multi_bboxes.shape[1] > box_dim: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, box_dim) + else: + bboxes = multi_bboxes[:, None].expand(multi_scores.size(0), num_classes, box_dim) + + scores = multi_scores[:, :-1] + + labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) + labels = labels.view(1, -1).expand_as(scores) + + bboxes = bboxes.reshape(-1, box_dim) + scores = scores.reshape(-1) + labels = labels.reshape(-1) + + if not torch.onnx.is_in_onnx_export(): + # NonZero not supported in TensorRT + # remove low scoring boxes + valid_mask = scores > score_thr + # multiply score_factor after threshold to preserve more bboxes, improve + # mAP by 1% for YOLOv3 + if score_factors is not None: + # expand the shape to match original shape of score + score_factors = score_factors.view(-1, 1).expand(multi_scores.size(0), num_classes) + score_factors = score_factors.reshape(-1) + scores = scores * score_factors + + if not torch.onnx.is_in_onnx_export(): + # NonZero not supported in TensorRT + inds = valid_mask.nonzero(as_tuple=False).squeeze(1) + bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] + else: + # TensorRT NMS plugin has invalid output filled with -1 + # add dummy data to make detection output correct. + bboxes = torch.cat([bboxes, bboxes.new_zeros(1, box_dim)], dim=0) + scores = torch.cat([scores, scores.new_zeros(1)], dim=0) + labels = torch.cat([labels, labels.new_zeros(1)], dim=0) + + if bboxes.numel() == 0: + if torch.onnx.is_in_onnx_export(): + msg = "[ONNX Error] Can not record NMS as it has not been executed this time" + raise RuntimeError(msg) + dets = torch.cat([bboxes, scores[:, None]], -1) + if return_inds: + return dets, labels, inds + return dets, labels + + dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) + + if max_num > 0: + dets = dets[:max_num] + keep = keep[:max_num] + + if return_inds: + return dets, labels[keep], inds[keep] + return dets, labels[keep] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/layers/res_layer.py b/src/otx/algo/instance_segmentation/mmdet/models/layers/res_layer.py new file mode 100644 index 00000000000..91590c6ae86 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/layers/res_layer.py @@ -0,0 +1,106 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet ResLayer.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from mmengine.model import BaseModule, Sequential +from torch import nn + +from otx.algo.modules.conv import build_conv_layer +from otx.algo.modules.norm import build_norm_layer + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Defaults to 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Defaults to False + conv_cfg (dict): dictionary to construct and config conv layer. + Defaults to None + norm_cfg (dict): dictionary to construct and config norm layer. + Defaults to dict(type='BN') + downsample_first (bool): Downsample at the first block or last block. + False for Hourglass, True for ResNet. Defaults to True + """ + + def __init__( + self, + block: BaseModule, + inplanes: int, + planes: int, + num_blocks: int, + norm_cfg: dict, + stride: int = 1, + avg_down: bool = False, + conv_cfg: ConfigDict | dict | None = None, + downsample_first: bool = True, + **kwargs, + ) -> None: + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False, + ), + ) + downsample.extend( + [ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False, + ), + build_norm_layer(norm_cfg, planes * block.expansion)[1], + ], + ) + downsample = nn.Sequential(*downsample) + + layers = [] + if downsample_first: + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs, + ), + ) + inplanes = planes * block.expansion + layers.extend( + [ + block(inplanes=inplanes, planes=planes, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs) + for _ in range(1, num_blocks) + ], + ) + + super().__init__(*layers) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/layers/transformer.py b/src/otx/algo/instance_segmentation/mmdet/models/layers/transformer.py new file mode 100644 index 00000000000..2cfc9886036 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/layers/transformer.py @@ -0,0 +1,324 @@ +"""MMDet Transformer layers.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +# Copyright (c) OpenMMLab. All rights reserved. + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Sequence + +import torch +import torch.nn.functional +from mmengine.model import BaseModule +from mmengine.utils import to_2tuple +from torch import nn + +from otx.algo.modules.conv import build_conv_layer +from otx.algo.modules.norm import build_norm_layer + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed). + + so that input can get fully covered by filter you specified. It support two modes "same" and "corner". + The "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around input. + The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + """ + + def __init__( + self, + kernel_size: int | tuple = 1, + stride: int | tuple | None = 1, + dilation: int | tuple = 1, + padding: str = "corner", + ) -> None: + super().__init__() + + if padding not in ("same", "corner"): + msg = f"padding mode only support 'same' and 'corner', but got {padding}" + raise ValueError(msg) + + self.padding = to_2tuple(padding) + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.dilation = to_2tuple(dilation) + + def get_pad_shape(self, input_shape: tuple[int, int]) -> tuple[int, int]: + """Get the padding shape.""" + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h: int = math.ceil(input_h / stride_h) + output_w: int = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function for AdaptivePadding.""" + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == "corner": + x = torch.nn.functional.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == "same": + x = torch.nn.functional.pad( + x, + [ + pad_w // 2, + pad_w - pad_w // 2, + pad_h // 2, + pad_h - pad_h // 2, + ], + ) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d. + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmengine.ConfigDict`, optional): The Config for + initialization. Default: None. + """ + + def __init__( + self, + in_channels: int = 3, + embed_dims: int = 768, + conv_type: str = "Conv2d", + kernel_size: int = 16, + stride: int = 16, + padding: int | tuple | str = "corner", + dilation: int = 1, + bias: bool = True, + norm_cfg: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | None = None, + ) -> None: + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.adap_padding: nn.Module | None + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + ) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + {"type": conv_type}, + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]: + """Forward function for PatchEmbed. + + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple = 2, + stride: int | tuple | None = None, + padding: int | tuple | str = "corner", + dilation: int | tuple = 1, + bias: bool = False, + norm_cfg: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | None = None, + ) -> None: + super().__init__(init_cfg=init_cfg) + norm_cfg = norm_cfg if norm_cfg is not None else {"type": "LN"} + self.in_channels = in_channels + self.out_channels = out_channels + stride = stride if stride else kernel_size + + _kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.adap_padding: nn.Module | None + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=_kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + ) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold(kernel_size=_kernel_size, dilation=dilation, padding=padding, stride=stride) + + sample_dim = _kernel_size[0] * _kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x: torch.Tensor, input_size: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, tuple[int, int]]: + """Forward function for PatchMerging. + + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + batch_size, length, channels = x.shape + if not isinstance(input_size, Sequence): + msg = f"Expect input_size is `Sequence` but get {input_size}" + raise TypeError(msg) + + h, w = input_size + if h * w != length: + msg = "input feature has wrong size" + raise ValueError(msg) + + x = x.view(batch_size, h, w, channels).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + h, w = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = ( + h + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1 + ) // self.sampler.stride[0] + 1 + out_w = ( + w + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1 + ) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size diff --git a/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/__init__.py new file mode 100644 index 00000000000..a8b2842517c --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/__init__.py @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet MaskHeads.""" +from .fcn_mask_head import FCNMaskHead + +__all__ = [ + "FCNMaskHead", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/fcn_mask_head.py b/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/fcn_mask_head.py new file mode 100644 index 00000000000..201977abc24 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/mask_heads/fcn_mask_head.py @@ -0,0 +1,546 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet FCNMaskHead.""" +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import numpy as np +import torch +import torch.nn.functional +from mmengine.model import BaseModule, ModuleList +from mmengine.registry import MODELS +from torch import Tensor, nn +from torch.nn.modules.utils import _pair + +from otx.algo.detection.deployment import is_mmdeploy_enabled +from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss +from otx.algo.detection.utils.structures import SamplingResult +from otx.algo.detection.utils.utils import empty_instances +from otx.algo.instance_segmentation.mmdet.structures.mask import mask_target +from otx.algo.modules.conv import build_conv_layer +from otx.algo.modules.conv_module import ConvModule + +BYTES_PER_FLOAT = 4 +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + from mmengine.structures import InstanceData + + +@MODELS.register_module() +class FCNMaskHead(BaseModule): + """FCNMaskHead.""" + + def __init__( + self, + num_convs: int = 4, + roi_feat_size: int = 14, + in_channels: int = 256, + conv_kernel_size: int = 3, + conv_out_channels: int = 256, + num_classes: int = 80, + class_agnostic: int = False, + conv_cfg: ConfigDict | dict | None = None, + norm_cfg: ConfigDict | dict | None = None, + loss_mask: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + ) -> None: + if init_cfg is not None: + msg = "To prevent abnormal initialization behavior, init_cfg is not allowed to be set" + raise ValueError(msg) + + super().__init__(init_cfg=init_cfg) + self.num_convs = num_convs + # WARN: roi_feat_size is reserved and not used + self.roi_feat_size = _pair(roi_feat_size) + self.in_channels = in_channels + self.conv_kernel_size = conv_kernel_size + self.conv_out_channels = conv_out_channels + self.num_classes = num_classes + self.class_agnostic = class_agnostic + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.predictor_cfg = {"type": "Conv"} + self.loss_mask = MODELS.build(loss_mask) if loss_mask else CrossEntropyLoss(use_mask=True, loss_weight=1.0) + + self.convs = ModuleList() + for i in range(self.num_convs): + in_channels = self.in_channels if i == 0 else self.conv_out_channels + padding = (self.conv_kernel_size - 1) // 2 + self.convs.append( + ConvModule( + in_channels, + self.conv_out_channels, + self.conv_kernel_size, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + ), + ) + upsample_in_channels = self.conv_out_channels if self.num_convs > 0 else in_channels + + _scale_factor = 2 + upsample_cfg = { + "in_channels": upsample_in_channels, + "out_channels": self.conv_out_channels, + "kernel_size": _scale_factor, + "stride": _scale_factor, + } + self.upsample = nn.ConvTranspose2d(**upsample_cfg) + out_channels = 1 if self.class_agnostic else self.num_classes + logits_in_channel = self.conv_out_channels + self.conv_logits = build_conv_layer(self.predictor_cfg, logits_in_channel, out_channels, 1) + self.relu = nn.ReLU(inplace=True) + self.debug_imgs = None + + def init_weights(self) -> None: + """Initialize the weights.""" + super().init_weights() + for m in [self.upsample, self.conv_logits]: + if m is None: + continue + if hasattr(m, "weight") and hasattr(m, "bias"): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Forward features from the upstream network. + + Args: + x (Tensor): Extract mask RoI features. + + Returns: + Tensor: Predicted foreground masks. + """ + for conv in self.convs: + x = conv(x) + if self.upsample is not None: + x = self.upsample(x) + x = self.relu(x) + return self.conv_logits(x) + + def get_targets( + self, + sampling_results: list[SamplingResult], + batch_gt_instances: list[InstanceData], + rcnn_train_cfg: ConfigDict, + ) -> Tensor: + """Calculate the ground truth for all samples in a batch according to the sampling_results. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + Tensor: Mask target of each positive proposals in the image. + """ + pos_proposals = [res.pos_priors for res in sampling_results] + pos_assigned_gt_inds = [res.pos_assigned_gt_inds for res in sampling_results] + gt_masks = [res.masks for res in batch_gt_instances] + return mask_target(pos_proposals, pos_assigned_gt_inds, gt_masks, rcnn_train_cfg) + + def loss_and_target( + self, + mask_preds: Tensor, + sampling_results: list[SamplingResult], + batch_gt_instances: list[InstanceData], + rcnn_train_cfg: ConfigDict, + ) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (num_pos, num_classes, h, w). + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + dict: A dictionary of loss and targets components. + """ + mask_targets = self.get_targets( + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=rcnn_train_cfg, + ) + + pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) + + loss = {} + if mask_preds.size(0) == 0: + loss_mask = mask_preds.sum() + elif self.class_agnostic: + loss_mask = self.loss_mask(mask_preds, mask_targets, torch.zeros_like(pos_labels)) + else: + loss_mask = self.loss_mask(mask_preds, mask_targets, pos_labels) + loss["loss_mask"] = loss_mask + return {"loss_mask": loss, "mask_targets": mask_targets} + + def predict_by_feat( + self, + mask_preds: tuple[Tensor], + results_list: list[InstanceData], + batch_img_metas: list[dict], + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False, + ) -> list[InstanceData]: + """Transform a batch of output features extracted from the head into mask results. + + Args: + mask_preds (tuple[Tensor]): Tuple of predicted foreground masks, + each has shape (n, num_classes, h, w). + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + if len(mask_preds) != len(results_list) != len(batch_img_metas): + msg = "The number of inputs should be the same." + raise ValueError(msg) + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + results = results_list[img_id] + bboxes = results.bboxes + if bboxes.shape[0] == 0: + results_list[img_id] = empty_instances( + [img_meta], + bboxes.device, + task_type="mask", + instance_results=[results], + mask_thr_binary=rcnn_test_cfg.mask_thr_binary, + )[0] + else: + im_mask = self._predict_by_feat_single( + mask_preds=mask_preds[img_id], + bboxes=bboxes, + labels=results.labels, + img_meta=img_meta, + rcnn_test_cfg=rcnn_test_cfg, + rescale=rescale, + activate_map=activate_map, + ) + results.masks = im_mask + return results_list + + def _predict_by_feat_single( + self, + mask_preds: Tensor, + bboxes: Tensor, + labels: Tensor, + img_meta: dict, + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False, + ) -> Tensor: + """Get segmentation masks from mask_preds and bboxes. + + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (n, num_classes, h, w). + bboxes (Tensor): Predicted bboxes, has shape (n, 4) + labels (Tensor): Labels of bboxes, has shape (n, ) + img_meta (dict): image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + + Returns: + Tensor: Encoded masks, has shape (n, img_w, img_h) + """ + scale_factor = bboxes.new_tensor(img_meta["scale_factor"]).repeat((1, 2)) + img_h, img_w = img_meta["ori_shape"][:2] + device = bboxes.device + + mask_preds = mask_preds.sigmoid() if not activate_map else bboxes.new_tensor(mask_preds) + + if rescale: # in-placed rescale the bboxes + bboxes /= scale_factor + else: + w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] + img_h = np.round(img_h * h_scale.item()).astype(np.int32) + img_w = np.round(img_w * w_scale.item()).astype(np.int32) + + num_preds = len(mask_preds) + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == "cpu": + # CPU is most efficient when they are pasted one by one with + # skip_empty=True, so that it performs minimal number of + # operations. + num_chunks = num_preds + else: + # GPU benefits from parallelism for larger chunks, + # but may have memory issue + # the types of img_w and img_h are np.int32, + # when the image resolution is large, + # the calculation of num_chunks will overflow. + # so we need to change the types of img_w and img_h to int. + # See https://github.com/open-mmlab/mmdetection/pull/5191 + num_chunks = int(np.ceil(num_preds * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) + if num_chunks > num_preds: + msg = "Default GPU_MEM_LIMIT is too small; try increasing it" + raise ValueError(msg) + chunks = torch.chunk(torch.arange(num_preds, device=device), num_chunks) + + threshold = rcnn_test_cfg.mask_thr_binary + im_mask = torch.zeros( + num_preds, + img_h, + img_w, + device=device, + dtype=torch.bool if threshold >= 0 else torch.uint8, + ) + + if not self.class_agnostic: + mask_preds = mask_preds[range(num_preds), labels][:, None] + + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + mask_preds[inds], + bboxes[inds], + img_h, + img_w, + skip_empty=device.type == "cpu", + ) + + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) + im_mask[(inds, *spatial_inds)] = masks_chunk + return im_mask + + +def _do_paste_mask(masks: Tensor, boxes: Tensor, img_h: int, img_w: int, skip_empty: bool = True) -> tuple: + """Paste instance masks according to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + box_values, _ = boxes.min(dim=0) + x0_int, y0_int = torch.clamp(box_values.floor()[:2] - 1, min=0).to(torch.int32) + x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + num_preds = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + # IsInf op is not supported with ONNX<=1.7.0 + if not torch.onnx.is_in_onnx_export(): + if torch.isinf(img_x).any(): + inds = torch.where(torch.isinf(img_x)) + img_x[inds] = 0 + if torch.isinf(img_y).any(): + inds = torch.where(torch.isinf(img_y)) + img_y[inds] = 0 + + gx = img_x[:, None, :].expand(num_preds, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(num_preds, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = torch.nn.functional.grid_sample(masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + return img_masks[:, 0], () + + +if is_mmdeploy_enabled(): + from mmdeploy.codebase.mmdet.deploy import get_post_processing_params + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.mask_heads.fcn_mask_head.FCNMaskHead.predict_by_feat", + ) + def fcn_mask_head__predict_by_feat( + self: FCNMaskHead, + mask_preds: Tensor, + results_list: list[Tensor], + batch_img_metas: list[dict], + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False, + ) -> Tensor: + """Transform a batch of output features extracted from the head into mask results. + + Args: + mask_preds (tuple[Tensor]): Tuple of predicted foreground masks, + each has shape (n, num_classes, h, w). + results_list (list[Tensor]): Detection results of + each image. + batch_img_metas (list[dict]): List of image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + + Returns: + list[Tensor]: Detection results of each image + after the post process. Each item usually contains following keys. + + - dets (Tensor): Classification scores, has a shape + (num_instance, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + warnings.warn(f"rescale: {rescale} is not supported in deploy mode", stacklevel=2) + warnings.warn(f"activate_map: {activate_map} is not supported in deploy mode", stacklevel=2) + + ctx = FUNCTION_REWRITER.get_context() + ori_shape = batch_img_metas[0]["img_shape"] + dets, det_labels = results_list + dets = dets.view(-1, 5) + det_labels = det_labels.view(-1) + mask_preds = mask_preds.sigmoid() + bboxes = dets[:, :4] + labels = det_labels + threshold = rcnn_test_cfg.mask_thr_binary + if not self.class_agnostic: + box_inds = torch.arange(mask_preds.shape[0], device=bboxes.device) + mask_pred = mask_preds[box_inds, labels][:, None] + + # grid sample is not supported by most engine + # so we add a flag to disable it. + mmdet_params = get_post_processing_params(ctx.cfg) + export_postprocess_mask = mmdet_params.get("export_postprocess_mask", False) + if not export_postprocess_mask: + return mask_pred + + masks, _ = _do_paste_mask_ops(mask_pred, bboxes, ori_shape[0], ori_shape[1], skip_empty=False) + if threshold >= 0: + masks = (masks >= threshold).to(dtype=torch.bool) + return masks + + def _do_paste_mask_ops( + masks: Tensor, + boxes: Tensor, + img_h: int, + img_w: int, + skip_empty: bool = True, + ) -> Tensor: + """Paste instance masks according to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + box_values, _ = boxes.min(dim=0) + x0_int, y0_int = torch.clamp(box_values.floor()[:2] - 1, min=0).to(dtype=torch.int32) + x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + num_preds = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + gx = img_x[:, None, :].expand(num_preds, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(num_preds, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = torch.nn.functional.grid_sample(masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + return img_masks[:, 0], () diff --git a/src/otx/algo/instance_segmentation/mmdet/models/necks/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/necks/__init__.py new file mode 100644 index 00000000000..136bfe9b16b --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/necks/__init__.py @@ -0,0 +1,12 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMdetection necks.""" +from .fpn import FPN + +__all__ = [ + "FPN", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/necks/fpn.py b/src/otx/algo/instance_segmentation/mmdet/models/necks/fpn.py new file mode 100644 index 00000000000..f37f8d71563 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/necks/fpn.py @@ -0,0 +1,179 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet Feature Pyramid Network.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch.nn.functional +from mmengine.model import BaseModule +from mmengine.registry import MODELS +from torch import Tensor, nn + +from otx.algo.modules.conv_module import ConvModule + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +@MODELS.register_module() +class FPN(BaseModule): + r"""Feature Pyramid Network. + + This is an implementation of paper `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Defaults to 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Defaults to -1, which means the + last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Defaults to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Defaults to False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Defaults to False. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Defaults to None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Defaults to None. + act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + activation layer in ConvModule. Defaults to None. + upsample_cfg (:obj:`ConfigDict` or dict, optional): Config dict + for interpolate layer. Defaults to dict(mode='nearest'). + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict]): Initialization config dict. + """ + + def __init__( + self, + in_channels: list[int], + out_channels: int, + num_outs: int, + start_level: int = 0, + end_level: int = -1, + relu_before_extra_convs: bool = False, + no_norm_on_lateral: bool = False, + conv_cfg: ConfigDict | dict | None = None, + norm_cfg: ConfigDict | dict | None = None, + act_cfg: ConfigDict | dict | None = None, + upsample_cfg: dict | None = None, + init_cfg: dict | None = None, + ) -> None: + init_cfg = {"type": "Xavier", "layer": "Conv2d", "distribution": "uniform"} if init_cfg is None else init_cfg + super().__init__(init_cfg=init_cfg) + if not isinstance(in_channels, list): + msg = f"in_channels must be a list, but got {type(in_channels)}" + raise AssertionError(msg) # noqa: TRY004 + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = {"mode": "nearest"} if upsample_cfg is None else upsample_cfg + + if end_level in (-1, self.num_ins - 1): + self.backbone_end_level = self.num_ins + if num_outs < self.num_ins - start_level: + msg = "num_outs should not be less than the number of output levels" + raise ValueError(msg) + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + if end_level >= self.num_ins: + msg = "end_level must be less than len(in_channels)" + raise ValueError(msg) + if num_outs != end_level - start_level + 1: + msg = "num_outs must be equal to end_level - start_level + 1" + raise ValueError(msg) + self.start_level = start_level + self.end_level = end_level + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False, + ) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False, + ) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + def forward(self, inputs: tuple[Tensor]) -> tuple: + """Forward function. + + Args: + inputs (tuple[Tensor]): Features from the upstream network, each + is a 4D-tensor. + + Returns: + tuple: Feature maps, each is a 4D-tensor. + """ + if len(inputs) != len(self.in_channels): + msg = f"len(inputs) is not equal to len(in_channels): {len(inputs)} != {len(self.in_channels)}" + raise ValueError(msg) + + # build laterals + laterals = [lateral_conv(inputs[i + self.start_level]) for i, lateral_conv in enumerate(self.lateral_convs)] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if "scale_factor" in self.upsample_cfg: + # fix runtime error of "+=" inplace operation in PyTorch 1.10 + laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate(laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + torch.nn.functional.interpolate( + laterals[i], + size=prev_shape, + **self.upsample_cfg, + ) + + # build outputs + # part 1: from original levels + outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + for _ in range(self.num_outs - used_backbone_levels): + outs.append(torch.nn.functional.max_pool2d(outs[-1], 1, stride=2)) + return tuple(outs) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/__init__.py new file mode 100644 index 00000000000..071196b7984 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/__init__.py @@ -0,0 +1,11 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet RoI Extractors.""" + +from .single_level_roi_extractor import SingleRoIExtractor + +__all__ = ["SingleRoIExtractor"] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/base_roi_extractor.py b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/base_roi_extractor.py new file mode 100644 index 00000000000..dd1bd94658f --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/base_roi_extractor.py @@ -0,0 +1,112 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet RoI Extractors.""" +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING + +import torch + +# TODO(Eugene): replace mmcv.sigmoid_focal_loss with torchvision +# https://github.com/openvinotoolkit/training_extensions/pull/3281 +from mmcv.ops import RoIAlign +from mmengine.model import BaseModule +from torch import Tensor, nn + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +class BaseRoIExtractor(BaseModule, metaclass=ABCMeta): + """Base class for RoI extractor. + + Args: + roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and + arguments. + out_channels (int): Output channels of RoI layers. + featmap_strides (list[int]): Strides of input feature maps. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__( + self, + roi_layer: ConfigDict | dict, + out_channels: int, + featmap_strides: list[int], + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + ) -> None: + super().__init__(init_cfg=init_cfg) + self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) + self.out_channels = out_channels + self.featmap_strides = featmap_strides + + @property + def num_inputs(self) -> int: + """int: Number of input feature maps.""" + return len(self.featmap_strides) + + def build_roi_layers(self, layer_cfg: ConfigDict | dict, featmap_strides: list[int]) -> nn.ModuleList: + """Build RoI operator to extract feature from each level feature map. + + Args: + layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and + config RoI layer operation. Options are modules under + ``mmcv/ops`` such as ``RoIAlign``. + featmap_strides (list[int]): The stride of input feature map w.r.t + to the original image size, which would be used to scale RoI + coordinate (original image coordinate system) to feature + coordinate system. + + Returns: + :obj:`nn.ModuleList`: The RoI extractor modules for each level + feature map. + """ + cfg = layer_cfg.copy() + layer_type = cfg.pop("type") + if layer_type != RoIAlign.__name__: + msg = f"Unsupported RoI layer type {layer_type}" + raise ValueError(msg) + return nn.ModuleList([RoIAlign(spatial_scale=1 / s, **cfg) for s in featmap_strides]) + + def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor: + """Scale RoI coordinates by scale factor. + + Args: + rois (Tensor): RoI (Region of Interest), shape (n, 5) + scale_factor (float): Scale factor that RoI will be multiplied by. + + Returns: + Tensor: Scaled RoI. + """ + cx = (rois[:, 1] + rois[:, 3]) * 0.5 + cy = (rois[:, 2] + rois[:, 4]) * 0.5 + w = rois[:, 3] - rois[:, 1] + h = rois[:, 4] - rois[:, 2] + new_w = w * scale_factor + new_h = h * scale_factor + x1 = cx - new_w * 0.5 + x2 = cx + new_w * 0.5 + y1 = cy - new_h * 0.5 + y2 = cy + new_h * 0.5 + return torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1) + + @abstractmethod + def forward(self, feats: tuple[Tensor], rois: Tensor, roi_scale_factor: float | None = None) -> Tensor: + """Extractor ROI feats. + + Args: + feats (Tuple[Tensor]): Multi-scale features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + roi_scale_factor (Optional[float]): RoI scale factor. + Defaults to None. + + Returns: + Tensor: RoI feature. + """ diff --git a/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/single_level_roi_extractor.py b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/single_level_roi_extractor.py new file mode 100644 index 00000000000..1f6e5bbcca1 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/roi_extractors/single_level_roi_extractor.py @@ -0,0 +1,215 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMdet Single RoI Extractor.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from mmengine.registry import MODELS +from torch import Tensor + +from otx.algo.detection.deployment import is_mmdeploy_enabled + +from .base_roi_extractor import BaseRoIExtractor + +if TYPE_CHECKING: + from mmengine.config import ConfigDict + + +# ruff: noqa: ARG004 + + +@MODELS.register_module() +class SingleRoIExtractor(BaseRoIExtractor): + """Extract RoI features from a single level feature map. + + If there are multiple input feature levels, each RoI is mapped to a level + according to its scale. The mapping rule is proposed in + `FPN `_. + + Args: + roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and + arguments. + out_channels (int): Output channels of RoI layers. + featmap_strides (List[int]): Strides of input feature maps. + finest_scale (int): Scale threshold of mapping to level 0. + Defaults to 56. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__( + self, + roi_layer: ConfigDict | dict, + out_channels: int, + featmap_strides: list[int], + finest_scale: int = 56, + init_cfg: ConfigDict | dict | list[ConfigDict | dict] | None = None, + ) -> None: + super().__init__( + roi_layer=roi_layer, + out_channels=out_channels, + featmap_strides=featmap_strides, + init_cfg=init_cfg, + ) + self.finest_scale = finest_scale + + def map_roi_levels(self, rois: Tensor, num_levels: int) -> Tensor: + """Map rois to corresponding feature levels by scales. + + - scale < finest_scale * 2: level 0 + - finest_scale * 2 <= scale < finest_scale * 4: level 1 + - finest_scale * 4 <= scale < finest_scale * 8: level 2 + - scale >= finest_scale * 8: level 3 + + Args: + rois (Tensor): Input RoIs, shape (k, 5). + num_levels (int): Total level number. + + Returns: + Tensor: Level index (0-based) of each RoI, shape (k, ) + """ + scale = torch.sqrt((rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])) + target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6)) + return target_lvls.clamp(min=0, max=num_levels - 1).long() + + def forward(self, feats: tuple[Tensor], rois: Tensor, roi_scale_factor: float | None = None) -> Tensor: + """Extractor ROI feats. + + Args: + feats (Tuple[Tensor]): Multi-scale features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + roi_scale_factor (Optional[float]): RoI scale factor. + Defaults to None. + + Returns: + Tensor: RoI feature. + """ + # convert fp32 to fp16 when amp is on + rois = rois.type_as(feats[0]) + out_size = self.roi_layers[0].output_size + num_levels = len(feats) + roi_feats = feats[0].new_zeros(rois.size(0), self.out_channels, *out_size) + + if num_levels == 1: + if len(rois) == 0: + return roi_feats + return self.roi_layers[0](feats[0], rois) + + target_lvls = self.map_roi_levels(rois, num_levels) + + if roi_scale_factor is not None: + rois = self.roi_rescale(rois, roi_scale_factor) + + for i in range(num_levels): + mask = target_lvls == i + inds = mask.nonzero(as_tuple=False).squeeze(1) + if inds.numel() > 0: + rois_ = rois[inds] + roi_feats_t = self.roi_layers[i](feats[i], rois_) + roi_feats[inds] = roi_feats_t + else: + # Sometimes some pyramid levels will not be used for RoI + # feature extraction and this will cause an incomplete + # computation graph in one GPU, which is different from those + # in other GPUs and will cause a hanging error. + # Therefore, we add it to ensure each feature pyramid is + # included in the computation graph to avoid runtime bugs. + roi_feats += sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + feats[i].sum() * 0.0 + return roi_feats + + +if is_mmdeploy_enabled(): + from mmdeploy.core.rewriters import FUNCTION_REWRITER + from torch import Graph + from torch.autograd import Function + + class SingleRoIExtractorOpenVINO(Function): + """This class adds support for ExperimentalDetectronROIFeatureExtractor when exporting to OpenVINO. + + The `forward` method returns the original output, which is calculated in + advance and added to the SingleRoIExtractorOpenVINO class. In addition, the + list of arguments is changed here to be more suitable for + ExperimentalDetectronROIFeatureExtractor. + """ + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def forward( + g: Graph, + output_size: int, + featmap_strides: int, + sample_num: int, + rois: torch.Value, + *feats: tuple[torch.Value], + ) -> Tensor: + """Run forward.""" + return SingleRoIExtractorOpenVINO.origin_output + + @staticmethod + def symbolic( + g: Graph, + output_size: int, + featmap_strides: list[int], + sample_num: int, + rois: torch.Value, + *feats: tuple[torch.Value], + ) -> Graph: + """Symbolic function for creating onnx op.""" + from torch.onnx.symbolic_opset10 import _slice + + rois = _slice(g, rois, axes=[1], starts=[1], ends=[5]) + domain = "org.openvinotoolkit" + op_name = "ExperimentalDetectronROIFeatureExtractor" + return g.op( + f"{domain}::{op_name}", + rois, + *feats, + output_size_i=output_size, + pyramid_scales_i=featmap_strides, + sampling_ratio_i=sample_num, + image_id_i=0, + distribute_rois_between_levels_i=1, + preserve_rois_order_i=0, + aligned_i=1, + outputs=1, + ) + + @FUNCTION_REWRITER.register_rewriter( + "otx.algo.instance_segmentation.mmdet.models.roi_extractors." + "single_level_roi_extractor.SingleRoIExtractor.forward", + backend="openvino", + ) + def single_roi_extractor__forward__openvino( + self: SingleRoIExtractor, + feats: tuple[Tensor], + rois: Tensor, + roi_scale_factor: float | None = None, + ) -> Tensor: + """Replaces SingleRoIExtractor with SingleRoIExtractorOpenVINO when exporting to OpenVINO. + + This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO. + """ + ctx = FUNCTION_REWRITER.get_context() + + # Adding original output to SingleRoIExtractorOpenVINO. + state = torch._C._get_tracing_state() # noqa: SLF001 + origin_output = ctx.origin_func(self, feats, rois, roi_scale_factor) + SingleRoIExtractorOpenVINO.origin_output = origin_output + torch._C._set_tracing_state(state) # noqa: SLF001 + + output_size = self.roi_layers[0].output_size[0] + featmap_strides = self.featmap_strides + sample_num = self.roi_layers[0].sampling_ratio + + args = (output_size, featmap_strides, sample_num, rois, *feats) + return SingleRoIExtractorOpenVINO.apply(*args) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/samplers/__init__.py b/src/otx/algo/instance_segmentation/mmdet/models/samplers/__init__.py new file mode 100644 index 00000000000..f0a6102ed11 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/samplers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet samplers.""" + +from .random_sampler import RandomSampler + +__all__ = [ + "RandomSampler", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/models/samplers/random_sampler.py b/src/otx/algo/instance_segmentation/mmdet/models/samplers/random_sampler.py new file mode 100644 index 00000000000..0f4a4c41607 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/samplers/random_sampler.py @@ -0,0 +1,171 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMdet Random sampler.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from mmengine.registry import TASK_UTILS +from torch import Tensor + +from otx.algo.detection.utils.structures import AssignResult, SamplingResult + +if TYPE_CHECKING: + from mmengine.structures import InstanceData + from numpy import ndarray + + +@TASK_UTILS.register_module() +class RandomSampler: + """Random sampler. + + Args: + num (int): Number of samples + pos_fraction (float): Fraction of positive samples + neg_pos_up (int): Upper bound number of negative and + positive samples. Defaults to -1. + add_gt_as_proposals (bool): Whether to add ground truth + boxes as proposals. Defaults to True. + """ + + def __init__(self, num: int, pos_fraction: float, neg_pos_ub: int = -1, add_gt_as_proposals: bool = True, **kwargs): + from otx.algo.instance_segmentation.mmdet.models.utils.util_random import ensure_rng + + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_sampler = self + self.neg_sampler = self + self.rng = ensure_rng(kwargs.get("rng", None)) + + def random_choice(self, gallery: Tensor | ndarray | list, num: int) -> Tensor | ndarray: + """Random select some elements from the gallery. + + If `gallery` is a Tensor, the returned indices will be a Tensor; + If `gallery` is a ndarray or list, the returned indices will be a + ndarray. + + Args: + gallery (Tensor | ndarray | list): indices pool. + num (int): expected sample num. + + Returns: + Tensor or ndarray: sampled indices. + """ + if len(gallery) < num: + msg = f"Cannot sample {num} elements from a set of size {len(gallery)}" + raise ValueError(msg) + + is_tensor = isinstance(gallery, torch.Tensor) + device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + _gallery: Tensor = torch.tensor(gallery, dtype=torch.long, device=device) if not is_tensor else gallery + perm = torch.randperm(_gallery.numel())[:num].to(device=_gallery.device) + rand_inds = _gallery[perm] + if not is_tensor: + rand_inds = rand_inds.cpu().numpy() + return rand_inds + + def _sample_pos(self, assign_result: AssignResult, num_expected: int, **kwargs: dict) -> Tensor | ndarray: + """Randomly sample some positive samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + return self.random_choice(pos_inds, num_expected) + + def _sample_neg(self, assign_result: AssignResult, num_expected: int, **kwargs: dict) -> Tensor | ndarray: + """Randomly sample some negative samples. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + num_expected (int): The number of expected positive samples + + Returns: + Tensor or ndarray: sampled indices. + """ + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + return self.random_choice(neg_inds, num_expected) + + def sample( + self, + assign_result: AssignResult, + pred_instances: InstanceData, + gt_instances: InstanceData, + **kwargs, + ) -> SamplingResult: + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Assigning results. + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``priors``, and the priors can + be anchors or points, or the bboxes predicted by the + previous stage, has shape (n, 4). The bboxes predicted by + the current model or stage will be named ``bboxes``, + ``labels``, and ``scores``, the same as the ``InstanceData`` + in other places. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``bboxes``, with shape (k, 4), + and ``labels``, with shape (k, ). + + Returns: + :obj:`SamplingResult`: Sampling result. + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + gt_labels = gt_instances.labels + if len(priors.shape) < 2: + priors = priors[None, :] + + gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + priors = torch.cat([gt_bboxes, priors], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = priors.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=priors, **kwargs) # noqa: SLF001 + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=priors, **kwargs) # noqa: SLF001 + neg_inds = neg_inds.unique() + + return SamplingResult( + pos_inds=pos_inds, + neg_inds=neg_inds, + priors=priors, + gt_bboxes=gt_bboxes, + assign_result=assign_result, + gt_flags=gt_flags, + ) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/utils/util_random.py b/src/otx/algo/instance_segmentation/mmdet/models/utils/util_random.py new file mode 100644 index 00000000000..b76d452764a --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/models/utils/util_random.py @@ -0,0 +1,37 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ +"""MMDet Utility functions for random number generation.""" +from __future__ import annotations + +import numpy as np + + +def ensure_rng(rng: int | np.random.RandomState | None = None) -> np.random.RandomState: + """Coerces input into a random number generator. + + If the input is None, then a global random state is returned. + + If the input is a numeric value, then that is used as a seed to construct a + random state. Otherwise the input is returned as-is. + + Adapted from [1]_. + + Args: + rng (int | numpy.random.RandomState | None): + if None, then defaults to the global rng. Otherwise this can be an + integer or a RandomState class + Returns: + (numpy.random.RandomState) : rng - + a numpy random number generator + + References: + .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501 + """ + if rng is None: + return np.random.mtrand._rand # noqa: SLF001 + if isinstance(rng, int): + return np.random.RandomState(rng) + return rng diff --git a/src/otx/algo/instance_segmentation/mmdet/structures/bbox/__init__.py b/src/otx/algo/instance_segmentation/mmdet/structures/bbox/__init__.py new file mode 100644 index 00000000000..987dacfc808 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/structures/bbox/__init__.py @@ -0,0 +1,20 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet bbox structures.""" +from .transforms import ( + bbox2roi, + empty_box_as, + get_box_wh, + scale_boxes, +) + +__all__ = [ + "bbox2roi", + "empty_box_as", + "get_box_wh", + "scale_boxes", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/structures/bbox/transforms.py b/src/otx/algo/instance_segmentation/mmdet/structures/bbox/transforms.py new file mode 100644 index 00000000000..73a35038b72 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/structures/bbox/transforms.py @@ -0,0 +1,79 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet bbox transforms.""" +from __future__ import annotations + +import torch +from torch import Tensor + + +def bbox2roi(bbox_list: list[Tensor]) -> Tensor: + """Convert a list of bboxes to roi format. + + Args: + bbox_list (List[Union[Tensor, :obj:`BaseBoxes`]): a list of bboxes + corresponding to a batch of images. + + Returns: + Tensor: shape (n, box_dim + 1), where ``box_dim`` depends on the + different box types. For example, If the box type in ``bbox_list`` + is HorizontalBoxes, the output shape is (n, 5). Each row of data + indicates [batch_ind, x1, y1, x2, y2]. + """ + rois_list = [] + for img_id, bboxes in enumerate(bbox_list): + img_inds = bboxes.new_full((bboxes.size(0), 1), img_id) + rois = torch.cat([img_inds, bboxes], dim=-1) + rois_list.append(rois) + return torch.cat(rois_list, 0) + + +def scale_boxes(boxes: Tensor, scale_factor: tuple[float, float]) -> Tensor: + """Scale boxes with type of tensor or box type. + + Args: + boxes (Tensor or :obj:`BaseBoxes`): boxes need to be scaled. Its type + can be a tensor or a box type. + scale_factor (Tuple[float, float]): factors for scaling boxes. + The length should be 2. + + Returns: + Union[Tensor, :obj:`BaseBoxes`]: Scaled boxes. + """ + # Tensor boxes will be treated as horizontal boxes + repeat_num = int(boxes.size(-1) / 2) + scale_factor = boxes.new_tensor(scale_factor).repeat((1, repeat_num)) + return boxes * scale_factor + + +def get_box_wh(boxes: Tensor) -> tuple[Tensor, Tensor]: + """Get the width and height of boxes with type of tensor or box type. + + Args: + boxes (Tensor or :obj:`BaseBoxes`): boxes with type of tensor + or box type. + + Returns: + Tuple[Tensor, Tensor]: the width and height of boxes. + """ + # Tensor boxes will be treated as horizontal boxes by defaults + w = boxes[:, 2] - boxes[:, 0] + h = boxes[:, 3] - boxes[:, 1] + return w, h + + +def empty_box_as(boxes: Tensor) -> Tensor: + """Generate empty box according to input ``boxes` type and device. + + Args: + boxes (Tensor or :obj:`BaseBoxes`): boxes with type of tensor + or box type. + + Returns: + Union[Tensor, BaseBoxes]: Generated empty box. + """ + return boxes.new_zeros(0, 4) diff --git a/src/otx/algo/instance_segmentation/mmdet/structures/mask/__init__.py b/src/otx/algo/instance_segmentation/mmdet/structures/mask/__init__.py new file mode 100644 index 00000000000..2f8a8286252 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/structures/mask/__init__.py @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +"""MMDet mask structures.""" +from .mask_target import mask_target + +__all__ = [ + "mask_target", +] diff --git a/src/otx/algo/instance_segmentation/mmdet/structures/mask/mask_target.py b/src/otx/algo/instance_segmentation/mmdet/structures/mask/mask_target.py new file mode 100644 index 00000000000..f7ec5ff2b93 --- /dev/null +++ b/src/otx/algo/instance_segmentation/mmdet/structures/mask/mask_target.py @@ -0,0 +1,89 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ +"""MMDet Mask Structure.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import torch +from torch.nn.modules.utils import _pair + +if TYPE_CHECKING: + from mmdet.structures.mask import BitmapMasks + + +def mask_target( + pos_proposals_list: list[torch.Tensor], + pos_assigned_gt_inds_list: list[torch.Tensor], + gt_masks_list: list[BitmapMasks], + cfg: dict, +) -> torch.Tensor: + """Compute mask target for positive proposals in multiple images. + + Args: + pos_proposals_list (list[Tensor]): Positive proposals in multiple + images, each has shape (num_pos, 4). + pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each + positive proposals, each has shape (num_pos,). + gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of + each image. + cfg (dict): Config dict that specifies the mask size. + + Returns: + Tensor: Mask target of each image, has shape (num_pos, w, h). + """ + cfg_list = [cfg for _ in range(len(pos_proposals_list))] + mask_targets = map(mask_target_single, pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, cfg_list) + _mask_targets = list(mask_targets) + if len(_mask_targets) > 0: + _mask_targets = torch.cat(_mask_targets) + return _mask_targets + + +def mask_target_single( + pos_proposals: torch.Tensor, + pos_assigned_gt_inds: torch.Tensor, + gt_masks: BitmapMasks, + cfg: dict, +) -> torch.Tensor | np.ndarray: + """Compute mask target for each positive proposal in the image. + + Args: + pos_proposals (Tensor): Positive proposals. + pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals. + gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap + or Polygon. + cfg (dict): Config dict that indicate the mask size. + + Returns: + Tensor: Mask target of each positive proposals in the image. + """ + device = pos_proposals.device + mask_size = _pair(cfg["mask_size"]) + binarize = not cfg.get("soft_mask_target", False) + num_pos = pos_proposals.size(0) + if num_pos > 0: + proposals_np = pos_proposals.cpu().numpy() + maxh, maxw = gt_masks.height, gt_masks.width + proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw) + proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh) + pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() + + mask_targets = gt_masks.crop_and_resize( + proposals_np, + mask_size, + device=device, + inds=pos_assigned_gt_inds, + binarize=binarize, + ).to_ndarray() + + mask_targets = torch.from_numpy(mask_targets).float().to(device) + else: + mask_targets = pos_proposals.new_zeros((0, *mask_size)) + + return mask_targets diff --git a/src/otx/algo/modules/activation.py b/src/otx/algo/modules/activation.py index 23aacd0b77a..7d507650776 100644 --- a/src/otx/algo/modules/activation.py +++ b/src/otx/algo/modules/activation.py @@ -4,6 +4,8 @@ """This implementation replaces the functionality of mmcv.cnn.bricks.activation.build_activation_layer.""" +import copy + import torch from torch import nn @@ -59,7 +61,8 @@ def build_activation_layer(cfg: dict) -> nn.Module: Returns: nn.Module: Created activation layer. """ - activation_type = cfg.pop("type", None) + _cfg = copy.deepcopy(cfg) + activation_type = _cfg.pop("type", None) if activation_type is None: msg = "The cfg dict must contain the key 'type'" raise KeyError(msg) @@ -67,4 +70,4 @@ def build_activation_layer(cfg: dict) -> nn.Module: msg = f"Cannot find {activation_type} in {ACTIVATION_DICT.keys()}" raise KeyError(msg) - return ACTIVATION_DICT[activation_type](**cfg) + return ACTIVATION_DICT[activation_type](**_cfg) diff --git a/src/otx/algo/modules/conv_module.py b/src/otx/algo/modules/conv_module.py index b56648c1391..cb1afc65060 100644 --- a/src/otx/algo/modules/conv_module.py +++ b/src/otx/algo/modules/conv_module.py @@ -122,7 +122,7 @@ class ConvModule(nn.Module): def __init__( self, - in_channels: int, + in_channels: int | tuple[int, ...], out_channels: int, kernel_size: int | tuple[int, int], stride: int | tuple[int, int] = 1, diff --git a/src/otx/algo/modules/transformer.py b/src/otx/algo/modules/transformer.py new file mode 100644 index 00000000000..1d904261e2e --- /dev/null +++ b/src/otx/algo/modules/transformer.py @@ -0,0 +1,109 @@ +"""MMCV Transformer modules.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# This class and its supporting functions are adapted from the mmdet. +# Please refer to https://github.com/open-mmlab/mmdetection/ + +# Copyright (c) OpenMMLab. All rights reserved. + +from __future__ import annotations + +import copy + +import torch +import torch.nn.functional +from mmengine.model import BaseModule, Sequential +from timm.models.layers import DropPath +from torch import nn + +from otx.algo.modules.activation import build_activation_layer + + +class FFN(BaseModule): + """Implements feed-forward networks (FFNs) with identity connection. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + layer_scale_init_value (float): Initial value of scale factor in + LayerScale. Default: 1.0 + """ + + def __init__( + self, + embed_dims: int = 256, + feedforward_channels: int = 1024, + num_fcs: int = 2, + act_cfg: dict | None = None, + ffn_drop: float = 0.0, + dropout_layer: dict | None = None, + add_identity: bool = True, + init_cfg: dict | None = None, + ): + super().__init__(init_cfg) + if num_fcs < 2: + msg = "The number of fully-connected layers in FFNs should be at least 2." + raise ValueError(msg) + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + + if act_cfg is None: + act_cfg = {"type": "ReLU", "inplace": True} + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + nn.Linear(in_channels, feedforward_channels), + build_activation_layer(act_cfg), + nn.Dropout(ffn_drop), + ), + ) + in_channels = feedforward_channels + layers.append(nn.Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + + if dropout_layer: + _dropout_layer = copy.deepcopy(dropout_layer) + dropout_type = _dropout_layer.pop("type") + if dropout_type != "DropPath": + msg = f"Unsupported dropout type {dropout_type}" + raise NotImplementedError(msg) + self.dropout_layer = DropPath(**_dropout_layer) + else: + self.dropout_layer = torch.nn.Identity() + + self.add_identity = add_identity + self.gamma2 = nn.Identity() + + def forward(self, x: torch.Tensor, identity: torch.Tensor | None = None) -> torch.Tensor: + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + out = self.gamma2(out) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) diff --git a/src/otx/core/data/transform_libs/mmdet.py b/src/otx/core/data/transform_libs/mmdet.py index cb83f8a0a86..ba32e28e066 100644 --- a/src/otx/core/data/transform_libs/mmdet.py +++ b/src/otx/core/data/transform_libs/mmdet.py @@ -99,14 +99,13 @@ def _generate_gt_masks( gt_masks (BitmapMasks or PolygonMasks): The generated ground truth masks. """ if len(otx_data_entity.masks): - gt_masks = BitmapMasks(otx_data_entity.masks.numpy(), height, width) - else: - gt_masks = PolygonMasks( - [[np.array(polygon.points)] for polygon in otx_data_entity.polygons], - height, - width, - ) - return gt_masks + return BitmapMasks(otx_data_entity.masks.numpy(), height, width) + + return PolygonMasks( + [[np.array(polygon.points)] for polygon in otx_data_entity.polygons], + height, + width, + ) @TRANSFORMS.register_module(force=True) diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index ade0c51068e..86ef004e330 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -417,6 +417,7 @@ def _customize_inputs(self, entity: InstanceSegBatchDataEntity) -> dict[str, Any ): # NOTE: ground-truth masks are resized in training, but not in inference height, width = img_info.img_shape if self.training else img_info.ori_shape + mmdet_masks: BitmapMasks | PolygonMasks if len(masks): mmdet_masks = BitmapMasks(masks.data.cpu().numpy(), height, width) else: diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml index 656adae1091..fcd1f7a4254 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b.yaml @@ -1,5 +1,5 @@ model: - class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNN + class_path: otx.algo.instance_segmentation.maskrcnn.MMDetMaskRCNN init_args: label_info: 80 variant: efficientnetb2b diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml index 90501de99dc..6e7f039db98 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_efficientnetb2b_tile.yaml @@ -1,5 +1,5 @@ model: - class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNN + class_path: otx.algo.instance_segmentation.maskrcnn.MMDetMaskRCNN init_args: label_info: 80 variant: efficientnetb2b diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml index abd99edc284..c94f6eb3ebb 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml @@ -1,5 +1,5 @@ model: - class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNN + class_path: otx.algo.instance_segmentation.maskrcnn.MMDetMaskRCNN init_args: label_info: 80 variant: r50 diff --git a/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml b/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml index 468d295bd02..0916a4b070d 100644 --- a/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml +++ b/src/otx/recipe/instance_segmentation/maskrcnn_r50_tile.yaml @@ -1,5 +1,5 @@ model: - class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNN + class_path: otx.algo.instance_segmentation.maskrcnn.MMDetMaskRCNN init_args: label_info: 80 variant: r50 diff --git a/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml b/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml index 5d6e98b8233..2c4d0cfe29a 100644 --- a/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml +++ b/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml @@ -1,5 +1,5 @@ model: - class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNN + class_path: otx.algo.instance_segmentation.maskrcnn.MMDetMaskRCNN init_args: label_info: 80 variant: efficientnetb2b diff --git a/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml b/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml index 0c3d6fa4e34..1b613d560df 100644 --- a/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml +++ b/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml @@ -1,5 +1,5 @@ model: - class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNN + class_path: otx.algo.instance_segmentation.maskrcnn.MMDetMaskRCNN init_args: label_info: 80 variant: r50 diff --git a/tests/unit/algo/instance_segmentation/heads/test_custom_roi_head.py b/tests/unit/algo/instance_segmentation/heads/test_custom_roi_head.py index 9ed6683250a..2ee6df317e2 100644 --- a/tests/unit/algo/instance_segmentation/heads/test_custom_roi_head.py +++ b/tests/unit/algo/instance_segmentation/heads/test_custom_roi_head.py @@ -10,8 +10,8 @@ import torch from mmdet.structures import DetDataSample from mmengine.structures import InstanceData -from otx.algo.instance_segmentation.heads.custom_roi_head import CustomRoIHead -from otx.algo.instance_segmentation.maskrcnn import MaskRCNN +from otx.algo.instance_segmentation.maskrcnn import MMDetMaskRCNN +from otx.algo.instance_segmentation.mmdet.models.custom_roi_head import CustomRoIHead @pytest.fixture() @@ -68,7 +68,7 @@ def test_ignore_label( fxt_data_sample_with_ignored_label, fxt_instance_list, ) -> None: - maskrcnn = MaskRCNN(3, "r50") + maskrcnn = MMDetMaskRCNN(3, "r50") input_tensors = [ torch.randn([4, 256, 144, 256]), torch.randn([4, 256, 72, 128]), diff --git a/tests/unit/algo/instance_segmentation/test_mmdet_decouple.py b/tests/unit/algo/instance_segmentation/test_mmdet_decouple.py new file mode 100644 index 00000000000..f084d4f971e --- /dev/null +++ b/tests/unit/algo/instance_segmentation/test_mmdet_decouple.py @@ -0,0 +1,41 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pathlib import Path + +from otx.core.model.utils.mmdet import create_model +from otx.core.types.task import OTXTaskType +from otx.engine import Engine +from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK + + +class TestDecoupleMMDetInstanceSeg: + def test_maskrcnn(self, tmp_path: Path) -> None: + tmp_path_train = tmp_path / OTXTaskType.INSTANCE_SEGMENTATION + engine = Engine.from_config( + config_path=DEFAULT_CONFIG_PER_TASK[OTXTaskType.INSTANCE_SEGMENTATION], + data_root="tests/assets/car_tree_bug", + work_dir=tmp_path_train, + device="cpu", + ) + + new_model, _ = create_model(engine.model.config, engine.model.load_from) + engine.model.model = new_model + + train_metric = engine.train(max_epochs=1) + assert len(train_metric) > 0 + + test_metric = engine.test() + assert len(test_metric) > 0 + + predict_result = engine.predict() + assert len(predict_result) > 0 + + # Export IR Model + exported_model_path: Path | dict[str, Path] = engine.export() + if isinstance(exported_model_path, Path): + assert exported_model_path.exists() + test_metric_from_ov_model = engine.test(checkpoint=exported_model_path, accelerator="cpu") + assert len(test_metric_from_ov_model) > 0 diff --git a/tests/unit/core/model/test_inst_segmentation.py b/tests/unit/core/model/test_inst_segmentation.py index 39bf006d271..317dcaeb8d2 100644 --- a/tests/unit/core/model/test_inst_segmentation.py +++ b/tests/unit/core/model/test_inst_segmentation.py @@ -6,7 +6,7 @@ import pytest import torch from otx.algo.explain.explain_algo import feature_vector_fn -from otx.algo.instance_segmentation.maskrcnn import MaskRCNN +from otx.algo.instance_segmentation.maskrcnn import MMDetMaskRCNN from otx.core.model.instance_segmentation import MMDetInstanceSegCompatibleModel from otx.core.types.export import TaskLevelExportParameters @@ -14,7 +14,7 @@ class TestOTXInstanceSegModel: @pytest.fixture() def otx_model(self) -> MMDetInstanceSegCompatibleModel: - return MaskRCNN(label_info=1, variant="efficientnetb2b") + return MMDetMaskRCNN(label_info=1, variant="efficientnetb2b") def test_create_model(self, otx_model) -> None: mmdet_model = otx_model._create_model()