Skip to content

Commit

Permalink
MMDet MaskRCNN ResNet50/SwinTransformer Decouple (#3281)
Browse files Browse the repository at this point in the history
* migrate mmdet maskrcnn modules

* ignore mypy, ruff errors

* skip mypy error

* add maskrcnn

* add cross-entropy loss

* style changes

* mypy changes and style changes

* update style

* remove box structures

* modify resnet

* add annotation

* fix all mypy issues

* fix mypy issues

* style changes

* remove unused losses

* remove focal_loss_pb

* fix all rull and mypy issues

* remove duplicates

* remove mmdet mask structures

* remove duplicates

* style changes

* add new test

* test style change

* chagne device for unit test

* add deployment files

* remove deployment from inst-seg

* update deployment

* add mmdeploy maskrcnn opset

* replace mmcv.cnn module

* remove upsample building

* remove upsample building

* use batch_nms from otx

* add swintransformer

* add transformers

* add swin transformer

* update instance_segmentation/maskrcnn.py

* update nms

* change rotate detection recipe
  • Loading branch information
eugene123tw authored Apr 25, 2024
1 parent 4d80c49 commit fffa2eb
Show file tree
Hide file tree
Showing 73 changed files with 6,627 additions and 432 deletions.
2 changes: 2 additions & 0 deletions src/otx/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
action_classification,
classification,
detection,
instance_segmentation,
plugins,
segmentation,
strategies,
Expand All @@ -23,4 +24,5 @@
"strategies",
"accelerators",
"plugins",
"instance_segmentation",
]
19 changes: 19 additions & 0 deletions src/otx/algo/detection/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Functions for mmdeploy adapters."""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import importlib


def is_mmdeploy_enabled() -> bool:
"""Checks if the 'mmdeploy' Python module is installed and available for use.
Returns:
bool: True if 'mmdeploy' is installed, False otherwise.
Example:
>>> is_mmdeploy_enabled()
True
"""
return importlib.util.find_spec("mmdeploy") is not None
2 changes: 1 addition & 1 deletion src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AnchorHead(BaseDenseHead):
def __init__(
self,
num_classes: int,
in_channels: tuple[int, ...],
in_channels: tuple[int, ...] | int,
anchor_generator: dict,
bbox_coder: dict,
loss_cls: dict,
Expand Down
9 changes: 4 additions & 5 deletions src/otx/algo/detection/heads/base_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from typing import TYPE_CHECKING

import torch
from mmcv.ops import batched_nms
from mmengine.model import constant_init
from mmengine.model import BaseModule, constant_init
from mmengine.structures import InstanceData
from torch import Tensor, nn
from torch import Tensor

from otx.algo.detection.ops.nms import multiclass_nms
from otx.algo.detection.ops.nms import batched_nms, multiclass_nms
from otx.algo.detection.utils.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances

if TYPE_CHECKING:
Expand All @@ -24,7 +23,7 @@

# This class and its supporting functions below lightly adapted from the mmdet BaseDenseHead available at:
# https://github.com/open-mmlab/mmdetection/blob/fe3f809a0a514189baf889aa358c498d51ee36cd/mmdet/models/dense_heads/base_dense_head.py
class BaseDenseHead(nn.Module):
class BaseDenseHead(BaseModule):
"""Base class for DenseHeads.
1. The ``init_weights`` method is used to initialize densehead's
Expand Down
9 changes: 5 additions & 4 deletions src/otx/algo/detection/heads/class_incremental_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from typing import TYPE_CHECKING

import torch
from mmdet.models.utils.misc import images_to_levels, multi_apply
from mmdet.registry import MODELS
from torch import Tensor

from otx.algo.detection.utils.utils import images_to_levels, multi_apply

if TYPE_CHECKING:
from mmdet.utils import InstanceList, OptInstanceList
from mmengine.structures import InstanceData


@MODELS.register_module()
Expand All @@ -24,9 +25,9 @@ def get_atss_targets(
self,
anchor_list: list,
valid_flag_list: list[list[Tensor]],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
unmap_outputs: bool = True,
) -> tuple:
"""Get targets for ATSS head.
Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/detection/heads/custom_anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

import numpy as np
import torch
from mmdet.registry import TASK_UTILS
from mmengine.registry import TASK_UTILS
from torch.nn.modules.utils import _pair


# This class and its supporting functions below lightly adapted from the mmdet AnchorGenerator available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/prior_generators/anchor_generator.py
@TASK_UTILS.register_module()
class AnchorGenerator:
"""Standard anchor generator for 2D anchor-based detectors.
Expand Down
5 changes: 4 additions & 1 deletion src/otx/algo/detection/heads/custom_ssd_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict],
train_cfg: ConfigDict | dict,
num_classes: int = 80,
in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256),
in_channels: tuple[int, ...] | int = (512, 1024, 512, 256, 256, 256),
stacked_convs: int = 0,
feat_channels: int = 256,
use_depthwise: bool = False,
Expand Down Expand Up @@ -274,6 +274,9 @@ def _init_layers(self) -> None:
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()

if isinstance(self.in_channels, int):
self.in_channels = (self.in_channels,)

for in_channel, num_base_priors in zip(self.in_channels, self.num_base_priors):
if self.use_depthwise:
activation_layer = nn.ReLU(inplace=True)
Expand Down
73 changes: 58 additions & 15 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

import numpy as np
import torch
from mmengine.registry import TASK_UTILS
from torch import Tensor

from otx.algo.detection.deployment import is_mmdeploy_enabled


# This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py
@TASK_UTILS.register_module()
class DeltaXYWHBBoxCoder:
"""Delta XYWH BBox coder.
Expand Down Expand Up @@ -236,21 +240,6 @@ def delta2bbox(
References:
.. [1] https://arxiv.org/abs/1311.2524
Example:
>>> rois = torch.Tensor([[ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 0., 0., 1., 1.],
>>> [ 5., 5., 5., 5.]])
>>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
>>> [ 1., 1., 1., 1.],
>>> [ 0., 0., 2., -1.],
>>> [ 0.7, -1.9, -0.5, 0.3]])
>>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
tensor([[0.0000, 0.0000, 1.0000, 1.0000],
[0.1409, 0.1409, 2.8591, 2.8591],
[0.0000, 0.3161, 4.1945, 0.6839],
[5.0000, 5.0000, 5.0000, 5.0000]])
"""
num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
if num_bboxes == 0:
Expand Down Expand Up @@ -426,3 +415,57 @@ def clip_bboxes(
x2 = torch.clamp(x2, 0, max_shape[1])
y2 = torch.clamp(y2, 0, max_shape[0])
return x1, y1, x2, y2


if is_mmdeploy_enabled():
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
func_name="otx.algo.detection.heads.delta_xywh_bbox_coder.DeltaXYWHBBoxCoder.decode",
backend="default",
)
def deltaxywhbboxcoder__decode(
self: DeltaXYWHBBoxCoder,
bboxes: Tensor,
pred_bboxes: Tensor,
max_shape: Tensor | None = None,
wh_ratio_clip: float = 16 / 1000,
) -> Tensor:
"""Rewrite `decode` of `DeltaXYWHBBoxCoder` for default backend.
Rewrite this func to call `delta2bbox` directly.
Args:
bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
pred_bboxes (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
wh_ratio_clip (float, optional): The allowed ratio between
width and height.
Returns:
torch.Tensor: Decoded boxes.
"""
if pred_bboxes.size(0) != bboxes.size(0):
msg = "The batch size of pred_bboxes and bboxes should be equal."
raise ValueError(msg)
if pred_bboxes.ndim == 3 and pred_bboxes.size(1) != bboxes.size(1):
msg = "The number of bboxes should be equal."
raise ValueError(msg)
return delta2bbox_export(
bboxes,
pred_bboxes,
self.means,
self.stds,
max_shape,
wh_ratio_clip,
self.clip_border,
self.add_ctr_clamp,
self.ctr_clamp,
)
2 changes: 2 additions & 0 deletions src/otx/algo/detection/heads/iou2d_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from __future__ import annotations

import torch
from mmengine.registry import TASK_UTILS

from otx.algo.detection.utils.bbox_overlaps import bbox_overlaps


# This class and its supporting functions below lightly adapted from the mmdet BboxOverlaps2D available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/assigners/iou2d_calculator.py
@TASK_UTILS.register_module()
class BboxOverlaps2D:
"""2D Overlaps (e.g. IoUs, GIoUs) Calculator."""

Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/detection/heads/max_iou_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Callable

import torch
from mmengine.registry import TASK_UTILS
from torch import Tensor

from otx.algo.detection.heads.iou2d_calculator import BboxOverlaps2D
Expand All @@ -19,6 +20,7 @@

# This class and its supporting functions below lightly adapted from the mmdet MaxIoUAssigner available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/assigners/max_iou_assigner.py
@TASK_UTILS.register_module()
class MaxIoUAssigner:
"""Assign a corresponding gt bbox or background to each bbox.
Expand Down
12 changes: 10 additions & 2 deletions src/otx/algo/detection/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
#
"""Custom OTX Losses for Object Detection."""
from .accuracy import accuracy
from .cross_entropy_loss import CrossEntropyLoss
from .cross_focal_loss import CrossSigmoidFocalLoss
from .smooth_l1_loss import L1Loss


__all__ = ["CrossSigmoidFocalLoss, OrdinaryFocalLoss"]
__all__ = [
"CrossEntropyLoss",
"CrossSigmoidFocalLoss",
"accuracy",
"L1Loss",
]
73 changes: 73 additions & 0 deletions src/otx/algo/detection/losses/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# This class and its supporting functions are adapted from the mmdet.
# Please refer to https://github.com/open-mmlab/mmdetection/

"""MMDet Accuracy."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from torch import Tensor


def accuracy(
pred: Tensor,
target: Tensor,
topk: int | tuple[int] = 1,
thresh: float | None = None,
) -> list[Tensor] | Tensor:
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class)
target (torch.Tensor): The target of each prediction, shape (N, )
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
if not isinstance(topk, (int, tuple)):
msg = f"topk must be int or tuple of int, got {type(topk)}"
raise TypeError(msg)
if isinstance(topk, int):
topk = (topk,)
return_single = True
else:
return_single = False

maxk = max(topk)
if pred.size(0) == 0:
accu = [pred.new_tensor(0.0) for i in range(len(topk))]
return accu[0] if return_single else accu
if pred.ndim != 2 or target.ndim != 1:
msg = "Input tensors must have 2 dims for pred and 1 dim for target"
raise ValueError(msg)
if pred.size(0) != target.size(0):
msg = "Input tensors must have the same size along the 0th dim"
raise ValueError(msg)
if maxk > pred.size(1):
msg = f"maxk {maxk} exceeds pred dimension {pred.size(1)}"
raise ValueError(msg)
pred_value, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t() # transpose to shape (maxk, N)
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
2 changes: 2 additions & 0 deletions src/otx/algo/detection/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import torch
from mmengine.registry import MODELS
from torch import nn

from otx.algo.detection.losses.weighted_loss import weight_reduce_loss
Expand Down Expand Up @@ -181,6 +182,7 @@ def mask_cross_entropy(
)[None]


@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
"""Base Cross Entropy Loss implementation from mmdet."""

Expand Down
Loading

0 comments on commit fffa2eb

Please sign in to comment.