Skip to content

Commit

Permalink
Remove left mmengine things in object detection (#3432)
Browse files Browse the repository at this point in the history
* Migrate load_checkpoint

* Migrate instance data
  • Loading branch information
jaegukhyun authored Apr 30, 2024
1 parent 8766676 commit 88db52e
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 89 deletions.
10 changes: 4 additions & 6 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import TYPE_CHECKING, Any

import torch
from mmengine.structures import InstanceData
from omegaconf import DictConfig
from torchvision import tv_tensors

Expand All @@ -23,6 +22,7 @@
from otx.algo.detection.losses.iou_loss import GIoULoss
from otx.algo.detection.necks.fpn import FPN
from otx.algo.detection.ssd import SingleStageDetector
from otx.algo.utils.mmengine_utils import InstanceData, load_checkpoint
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
Expand Down Expand Up @@ -67,8 +67,6 @@ def __init__(
self.tile_image_size = self.image_size

def _create_model(self) -> nn.Module:
from mmengine.runner import load_checkpoint

detector = self._build_model(num_classes=self.label_info.num_classes)
detector.init_weights()
self.classification_layers = self.get_classification_layers(prefix="model.")
Expand Down Expand Up @@ -116,15 +114,15 @@ def _customize_outputs(
for img_info, prediction in zip(inputs.imgs_info, predictions):
if not isinstance(prediction, InstanceData):
raise TypeError(prediction)
scores.append(prediction.scores)
scores.append(prediction.scores) # type: ignore[attr-defined]
bboxes.append(
tv_tensors.BoundingBoxes(
prediction.bboxes,
prediction.bboxes, # type: ignore[attr-defined]
format="XYXY",
canvas_size=img_info.ori_shape,
),
)
labels.append(prediction.labels)
labels.append(prediction.labels) # type: ignore[attr-defined]

if self.explain_mode:
if not isinstance(outputs, dict):
Expand Down
8 changes: 4 additions & 4 deletions src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING

import torch
from mmengine.structures import InstanceData
from torch import Tensor, nn

from otx.algo.detection.heads.anchor_generator import AnchorGenerator
Expand All @@ -19,6 +18,7 @@
from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from otx.algo.detection.heads.max_iou_assigner import MaxIoUAssigner
from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap
from otx.algo.utils.mmengine_utils import InstanceData

if TYPE_CHECKING:
from omegaconf import DictConfig
Expand Down Expand Up @@ -251,13 +251,13 @@ def _get_targets_single(
anchors = flat_anchors[inside_flags]

pred_instances = InstanceData(priors=anchors)
assign_result = self.assigner.assign(pred_instances, gt_instances, gt_instances_ignore)
assign_result = self.assigner.assign(pred_instances, gt_instances, gt_instances_ignore) # type: ignore[arg-type]
# No sampling is required except for RPN and
# Guided Anchoring algorithms
sampling_result = self.sampler.sample(assign_result, pred_instances, gt_instances)

num_valid_anchors = anchors.shape[0]
target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox else self.bbox_coder.encode_size
target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox else self.bbox_coder.encode_size # type: ignore[attr-defined]
bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim)
bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim)

Expand Down Expand Up @@ -352,7 +352,7 @@ def get_targets(
raise ValueError(msg)

if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
batch_gt_instances_ignore = [None] * num_imgs # type: ignore[list-item]

# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
Expand Down
15 changes: 8 additions & 7 deletions src/otx/algo/detection/heads/atss_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from otx.algo.detection.utils.structures import AssignResult

if TYPE_CHECKING:
from mmengine.structures import InstanceData
from omegaconf import DictConfig

from otx.algo.utils.mmengine_utils import InstanceData


def bbox_center_distance(bboxes: Tensor, priors: Tensor) -> Tensor:
"""Compute the center distance between bboxes and priors.
Expand Down Expand Up @@ -120,10 +121,10 @@ def assign(
Returns:
:obj:`AssignResult`: The assign result.
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
gt_labels = gt_instances.labels
gt_bboxes_ignore = gt_instances_ignore.bboxes if gt_instances_ignore is not None else None
gt_bboxes = gt_instances.bboxes # type: ignore[attr-defined]
priors = pred_instances.priors # type: ignore[attr-defined]
gt_labels = gt_instances.labels # type: ignore[attr-defined]
gt_bboxes_ignore = gt_instances_ignore.bboxes if gt_instances_ignore is not None else None # type: ignore[attr-defined]

inf = 100000000
priors = priors[:, :4]
Expand All @@ -145,8 +146,8 @@ def assign(

else:
# Dynamic cost ATSSAssigner in DDOD
cls_scores = pred_instances.scores
bbox_preds = pred_instances.bboxes
cls_scores = pred_instances.scores # type: ignore[attr-defined]
bbox_preds = pred_instances.bboxes # type: ignore[attr-defined]

# compute cls cost for bbox and GT
cls_cost = torch.sigmoid(cls_scores[:, gt_labels])
Expand Down
6 changes: 3 additions & 3 deletions src/otx/algo/detection/heads/atss_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import torch
from mmengine.structures import InstanceData
from torch import Tensor, nn

from otx.algo.detection.heads.anchor_head import AnchorHead
Expand All @@ -20,6 +19,7 @@
from otx.algo.detection.utils.utils import anchor_inside_flags, multi_apply, reduce_mean, unmap
from otx.algo.modules.conv_module import ConvModule
from otx.algo.utils.mmcv_utils import Scale
from otx.algo.utils.mmengine_utils import InstanceData

EPS = 1e-12

Expand Down Expand Up @@ -208,7 +208,7 @@ def loss_by_feat( # type: ignore[override]
bbox_preds: list[Tensor],
centernesses: list[Tensor],
batch_gt_instances: list[InstanceData],
batch_img_metas: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: list[InstanceData] | None = None,
) -> dict[str, Tensor]:
"""Compute losses of the head.
Expand Down Expand Up @@ -530,7 +530,7 @@ def _get_targets_single( # type: ignore[override]
pred_instances = InstanceData(priors=anchors)
assign_result = self.assigner.assign( # type: ignore[call-arg]
pred_instances,
num_level_anchors_inside,
num_level_anchors_inside, # type: ignore[arg-type]
gt_instances,
gt_instances_ignore,
)
Expand Down
54 changes: 24 additions & 30 deletions src/otx/algo/detection/heads/base_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from typing import TYPE_CHECKING

import torch
from mmengine.structures import InstanceData
from torch import Tensor

from otx.algo.detection.ops.nms import batched_nms, multiclass_nms
from otx.algo.detection.utils.utils import filter_scores_and_topk, select_single_mlvl, unpack_det_entity
from otx.algo.modules.base_module import BaseModule
from otx.algo.utils.mmengine_utils import InstanceData
from otx.core.data.entity.detection import DetBatchDataEntity

if TYPE_CHECKING:
Expand Down Expand Up @@ -405,25 +405,25 @@ def _bbox_post_process(
"""
if rescale:
scale_factor = [1 / s for s in img_meta["scale_factor"]]
results.bboxes = results.bboxes * results.bboxes.new_tensor(scale_factor).repeat(
(1, int(results.bboxes.size(-1) / 2)),
results.bboxes = results.bboxes * results.bboxes.new_tensor(scale_factor).repeat( # type: ignore[attr-defined]
(1, int(results.bboxes.size(-1) / 2)), # type: ignore[attr-defined]
)

if hasattr(results, "score_factors"):
score_factors = results.pop("score_factors")
results.scores = results.scores * score_factors
results.scores = results.scores * score_factors # type: ignore[attr-defined]

# filter small size bboxes
if cfg.get("min_bbox_size", -1) >= 0:
w = results.bboxes[:, 2] - results.bboxes[:, 0]
h = results.bboxes[:, 3] - results.bboxes[:, 1]
w = results.bboxes[:, 2] - results.bboxes[:, 0] # type: ignore[attr-defined]
h = results.bboxes[:, 3] - results.bboxes[:, 1] # type: ignore[attr-defined]
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
if not valid_mask.all():
results = results[valid_mask]

if with_nms and results.bboxes.numel() > 0:
bboxes = results.bboxes
det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, results.labels, cfg.nms)
if with_nms and results.bboxes.numel() > 0: # type: ignore[attr-defined]
bboxes = results.bboxes # type: ignore[attr-defined]
det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, results.labels, cfg.nms) # type: ignore[attr-defined]
results = results[keep_idxs]
# some nms would reweight the score, such as softnms
results.scores = det_bboxes[:, -1]
Expand All @@ -436,7 +436,7 @@ def export(
x: tuple[Tensor],
batch_img_metas: list[dict],
rescale: bool = False,
) -> list[InstanceData]:
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform forward propagation of the detection head and predict detection results.
Args:
Expand All @@ -449,8 +449,8 @@ def export(
Defaults to False.
Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
list[tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
Detection results of each image after the post process.
"""
outs = self(x)

Expand All @@ -465,7 +465,7 @@ def export_by_feat(
cfg: DictConfig | None = None,
rescale: bool = False,
with_nms: bool = True,
) -> list[InstanceData]:
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform a batch of output features extracted from the head into bbox results.
Note: When score_factors is not None, the cls_scores are
Expand Down Expand Up @@ -493,15 +493,13 @@ def export_by_feat(
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- scores (Tensor): Classification scores, has a shape
(num_instance, )
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
if batch_img_metas is None:
batch_img_metas = [{}]
Expand Down Expand Up @@ -543,7 +541,7 @@ def _export_by_feat_single(
cfg: DictConfig,
rescale: bool = False,
with_nms: bool = True,
) -> InstanceData:
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform a single image's features extracted from the head into bbox results.
Args:
Expand Down Expand Up @@ -571,16 +569,12 @@ def _export_by_feat_single(
Defaults to True.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
batch_size = cls_score_list[0].shape[0]
with_score_factors = score_factor_list[0] is not None
Expand Down
12 changes: 6 additions & 6 deletions src/otx/algo/detection/heads/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from abc import ABCMeta, abstractmethod

import torch
from mmengine.structures import InstanceData

from otx.algo.detection.utils.structures import AssignResult, SamplingResult
from otx.algo.utils.mmengine_utils import InstanceData


class BaseSampler(metaclass=ABCMeta):
Expand Down Expand Up @@ -72,9 +72,9 @@ def sample(
Returns:
:obj:`SamplingResult`: Sampling result.
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
gt_labels = gt_instances.labels
gt_bboxes = gt_instances.bboxes # type: ignore[attr-defined]
priors = pred_instances.priors # type: ignore[attr-defined]
gt_labels = gt_instances.labels # type: ignore[attr-defined]
if len(priors.shape) < 2:
priors = priors[None, :]

Expand Down Expand Up @@ -158,8 +158,8 @@ def sample(
Returns:
:obj:`SamplingResult`: sampler results
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
gt_bboxes = gt_instances.bboxes # type: ignore[attr-defined]
priors = pred_instances.priors # type: ignore[attr-defined]

pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
Expand Down
4 changes: 2 additions & 2 deletions src/otx/algo/detection/heads/class_incremental_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from otx.algo.detection.utils.utils import images_to_levels, multi_apply

if TYPE_CHECKING:
from mmengine.structures import InstanceData
from otx.algo.utils.mmengine_utils import InstanceData


class ClassIncrementalMixin:
Expand Down Expand Up @@ -54,7 +54,7 @@ def get_atss_targets(

# compute targets for each image
if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
batch_gt_instances_ignore = [None] * num_imgs # type: ignore[list-item]
(
all_anchors,
all_labels,
Expand Down
12 changes: 6 additions & 6 deletions src/otx/algo/detection/heads/max_iou_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.detection.utils.structures import AssignResult

if TYPE_CHECKING:
from mmengine.structures import InstanceData
from otx.algo.utils.mmengine_utils import InstanceData


# This class and its supporting functions below lightly adapted from the mmdet MaxIoUAssigner available at:
Expand Down Expand Up @@ -122,7 +122,7 @@ def assign(
:obj:`AssignResult`: The assign result.
Example:
>>> from mmengine.structures import InstanceData
>>> from otx.algo.utils.mmengine_utils import InstanceData
>>> self = MaxIoUAssigner(0.5, 0.5)
>>> pred_instances = InstanceData()
>>> pred_instances.priors = torch.Tensor([[0, 0, 10, 10],
Expand All @@ -134,10 +134,10 @@ def assign(
>>> expected_gt_inds = torch.LongTensor([1, 0])
>>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
gt_labels = gt_instances.labels
gt_bboxes_ignore = gt_instances_ignore.bboxes if gt_instances_ignore is not None else None
gt_bboxes = gt_instances.bboxes # type: ignore[attr-defined]
priors = pred_instances.priors # type: ignore[attr-defined]
gt_labels = gt_instances.labels # type: ignore[attr-defined]
gt_bboxes_ignore = gt_instances_ignore.bboxes if gt_instances_ignore is not None else None # type: ignore[attr-defined]

assign_on_cpu = (self.gpu_assign_thr > 0) and (gt_bboxes.shape[0] > self.gpu_assign_thr)
# compute overlap and assign gt on CPU when number of GT is large
Expand Down
Loading

0 comments on commit 88db52e

Please sign in to comment.