Skip to content

Commit

Permalink
Decouple YOLOX : unit tests for model migration (#3398)
Browse files Browse the repository at this point in the history
- unit tests for #3382
  • Loading branch information
sungchul2 authored Apr 25, 2024
1 parent 73b85eb commit 4c18ab1
Show file tree
Hide file tree
Showing 16 changed files with 776 additions and 5 deletions.
137 changes: 137 additions & 0 deletions tests/unit/algo/detection/backbones/test_csp_darknet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""Test of CSPDarknet.
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_backbones/test_csp_darknet.py
"""

import pytest
import torch
from otx.algo.detection.backbones.csp_darknet import CSPDarknet
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm


def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state.
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_backbones/utils.py#L26C1-L32C16
"""
return all(not (isinstance(mod, _BatchNorm) and mod.training != train_state) for mod in modules)


def is_norm(modules):
"""Check if is one of the norms.
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_backbones/utils.py#L19-L23
"""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False


class TestCSPDarknet:
def test_init_with_large_frozen_stages(self) -> None:
"""Test __init__ with large frozen_stages."""
with pytest.raises(ValueError): # noqa: PT011
# frozen_stages must in range(-1, len(arch_setting) + 1)
CSPDarknet(frozen_stages=6)

def test_init_with_large_out_indices(self) -> None:
"""Test __init__ with large out_indices."""
with pytest.raises(AssertionError):
CSPDarknet(out_indices=[6])

def test_freeze_stages(self) -> None:
"""Test _freeze_stages."""
frozen_stages = 1
model = CSPDarknet(frozen_stages=frozen_stages)
model.train()

for mod in model.stem.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f"stage{i}")
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False

def test_train_with_norm_eval(self) -> None:
"""Test train with norm_eval=True."""
model = CSPDarknet(norm_eval=True)
model.train()

assert check_norm_state(model.modules(), False)

def test_forward(self) -> None:
# Test CSPDarknet-P5 forward with widen_factor=0.5
model = CSPDarknet(arch="P5", widen_factor=0.25, out_indices=range(5))
model.train()

imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 16, 32, 32))
assert feat[1].shape == torch.Size((1, 32, 16, 16))
assert feat[2].shape == torch.Size((1, 64, 8, 8))
assert feat[3].shape == torch.Size((1, 128, 4, 4))
assert feat[4].shape == torch.Size((1, 256, 2, 2))

# Test CSPDarknet-P6 forward with widen_factor=0.5
model = CSPDarknet(arch="P6", widen_factor=0.25, out_indices=range(6), spp_kernal_sizes=(3, 5, 7))
model.train()

imgs = torch.randn(1, 3, 128, 128)
feat = model(imgs)
assert feat[0].shape == torch.Size((1, 16, 64, 64))
assert feat[1].shape == torch.Size((1, 32, 32, 32))
assert feat[2].shape == torch.Size((1, 64, 16, 16))
assert feat[3].shape == torch.Size((1, 128, 8, 8))
assert feat[4].shape == torch.Size((1, 192, 4, 4))
assert feat[5].shape == torch.Size((1, 256, 2, 2))

# Test CSPDarknet forward with dict(type='ReLU')
model = CSPDarknet(widen_factor=0.125, act_cfg={"type": "ReLU"}, out_indices=range(5))
model.train()

imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))

# Test CSPDarknet with BatchNorm forward
model = CSPDarknet(widen_factor=0.125, out_indices=range(5))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()

imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))

# Test CSPDarknet with custom arch forward
arch_ovewrite = [[32, 56, 3, True, False], [56, 224, 2, True, False], [224, 512, 1, True, False]]
model = CSPDarknet(arch_ovewrite=arch_ovewrite, widen_factor=0.25, out_indices=(0, 1, 2, 3))
model.train()

imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size((1, 8, 16, 16))
assert feat[1].shape == torch.Size((1, 14, 8, 8))
assert feat[2].shape == torch.Size((1, 56, 4, 4))
assert feat[3].shape == torch.Size((1, 128, 2, 2))
87 changes: 87 additions & 0 deletions tests/unit/algo/detection/heads/test_point_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of point generators."""

import pytest
import torch
from otx.algo.detection.heads.point_generator import MlvlPointGenerator


class TestMlvlPointGenerator:
def test_mlvl_point_generator(self) -> None:
# Square strides
mlvl_points = MlvlPointGenerator(strides=[4, 10], offset=0)
mlvl_points_half_stride_generator = MlvlPointGenerator(strides=[4, 10], offset=0.5)
assert mlvl_points.num_levels == 2

# assert self.num_levels == len(featmap_sizes) # noqa: ERA001
with pytest.raises(AssertionError):
mlvl_points.grid_priors(featmap_sizes=[(2, 2)], device="cpu")
priors = mlvl_points.grid_priors(featmap_sizes=[(2, 2), (4, 8)], device="cpu")
priors_with_stride = mlvl_points.grid_priors(featmap_sizes=[(2, 2), (4, 8)], with_stride=True, device="cpu")
assert len(priors) == 2

# assert last dimension is (coord_x, coord_y, stride_w, stride_h).
assert priors_with_stride[0].size(1) == 4
assert priors_with_stride[0][0][2] == 4
assert priors_with_stride[0][0][3] == 4
assert priors_with_stride[1][0][2] == 10
assert priors_with_stride[1][0][3] == 10

stride_4_feat_2_2 = priors[0]
assert (stride_4_feat_2_2[1] - stride_4_feat_2_2[0]).sum() == 4
assert stride_4_feat_2_2.size(0) == 4
assert stride_4_feat_2_2.size(1) == 2

stride_10_feat_4_8 = priors[1]
assert (stride_10_feat_4_8[1] - stride_10_feat_4_8[0]).sum() == 10
assert stride_10_feat_4_8.size(0) == 4 * 8
assert stride_10_feat_4_8.size(1) == 2

# assert the offset of 0.5 * stride
priors_half_offset = mlvl_points_half_stride_generator.grid_priors(featmap_sizes=[(2, 2), (4, 8)], device="cpu")

assert (priors_half_offset[0][0] - priors[0][0]).sum() == 4 * 0.5 * 2
assert (priors_half_offset[1][0] - priors[1][0]).sum() == 10 * 0.5 * 2
if torch.cuda.is_available():
# Square strides
mlvl_points = MlvlPointGenerator(strides=[4, 10], offset=0)
mlvl_points_half_stride_generator = MlvlPointGenerator(strides=[4, 10], offset=0.5)
assert mlvl_points.num_levels == 2

# assert self.num_levels == len(featmap_sizes) # noqa: ERA001
with pytest.raises(AssertionError):
mlvl_points.grid_priors(featmap_sizes=[(2, 2)], device="cuda")
priors = mlvl_points.grid_priors(featmap_sizes=[(2, 2), (4, 8)], device="cuda")
priors_with_stride = mlvl_points.grid_priors(
featmap_sizes=[(2, 2), (4, 8)],
with_stride=True,
device="cuda",
)
assert len(priors) == 2

# assert last dimension is (coord_x, coord_y, stride_w, stride_h).
assert priors_with_stride[0].size(1) == 4
assert priors_with_stride[0][0][2] == 4
assert priors_with_stride[0][0][3] == 4
assert priors_with_stride[1][0][2] == 10
assert priors_with_stride[1][0][3] == 10

stride_4_feat_2_2 = priors[0]
assert (stride_4_feat_2_2[1] - stride_4_feat_2_2[0]).sum() == 4
assert stride_4_feat_2_2.size(0) == 4
assert stride_4_feat_2_2.size(1) == 2

stride_10_feat_4_8 = priors[1]
assert (stride_10_feat_4_8[1] - stride_10_feat_4_8[0]).sum() == 10
assert stride_10_feat_4_8.size(0) == 4 * 8
assert stride_10_feat_4_8.size(1) == 2

# assert the offset of 0.5 * stride
priors_half_offset = mlvl_points_half_stride_generator.grid_priors(
featmap_sizes=[(2, 2), (4, 8)],
device="cuda",
)

assert (priors_half_offset[0][0] - priors[0][0]).sum() == 4 * 0.5 * 2
assert (priors_half_offset[1][0] - priors[1][0]).sum() == 10 * 0.5 * 2
52 changes: 52 additions & 0 deletions tests/unit/algo/detection/heads/test_sim_ota_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""Test of SimOTAAssigner.
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_task_modules/test_assigners/test_simota_assigner.py
"""

import torch
from mmengine.structures import InstanceData # TODO (sungchul): remove
from otx.algo.detection.heads.sim_ota_assigner import SimOTAAssigner


class TestSimOTAAssigner:
def test_assign(self) -> None:
assigner = SimOTAAssigner(center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0)
pred_instances = InstanceData(
bboxes=torch.Tensor([[23, 23, 43, 43], [4, 5, 6, 7]]),
scores=torch.FloatTensor([[0.2], [0.8]]),
priors=torch.Tensor([[30, 30, 8, 8], [4, 5, 6, 7]]),
)
gt_instances = InstanceData(bboxes=torch.Tensor([[23, 23, 43, 43]]), labels=torch.LongTensor([0]))
assign_result = assigner.assign(pred_instances=pred_instances, gt_instances=gt_instances)

expected_gt_inds = torch.LongTensor([1, 0])
assert torch.allclose(assign_result.gt_inds, expected_gt_inds)

def test_assign_with_no_valid_bboxes(self) -> None:
assigner = SimOTAAssigner(center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0)
pred_instances = InstanceData(
bboxes=torch.Tensor([[123, 123, 143, 143], [114, 151, 161, 171]]),
scores=torch.FloatTensor([[0.2], [0.8]]),
priors=torch.Tensor([[30, 30, 8, 8], [55, 55, 8, 8]]),
)
gt_instances = InstanceData(bboxes=torch.Tensor([[0, 0, 1, 1]]), labels=torch.LongTensor([0]))
assign_result = assigner.assign(pred_instances=pred_instances, gt_instances=gt_instances)

expected_gt_inds = torch.LongTensor([0, 0])
assert torch.allclose(assign_result.gt_inds, expected_gt_inds)

def test_assign_with_empty_gt(self) -> None:
assigner = SimOTAAssigner(center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0)
pred_instances = InstanceData(
bboxes=torch.Tensor([[[30, 40, 50, 60]], [[4, 5, 6, 7]]]),
scores=torch.FloatTensor([[0.2], [0.8]]),
priors=torch.Tensor([[0, 12, 23, 34], [4, 5, 6, 7]]),
)
gt_instances = InstanceData(bboxes=torch.empty(0, 4), labels=torch.empty(0))

assign_result = assigner.assign(pred_instances=pred_instances, gt_instances=gt_instances)
expected_gt_inds = torch.LongTensor([0, 0])
assert torch.allclose(assign_result.gt_inds, expected_gt_inds)
114 changes: 114 additions & 0 deletions tests/unit/algo/detection/heads/test_yolox_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""Test of YOLOXHead.
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_dense_heads/test_yolox_head.py
"""

import torch
from mmengine.config import Config
from mmengine.structures import InstanceData
from otx.algo.detection.heads.yolox_head import YOLOXHead
from otx.algo.modules.conv_module import ConvModule
from otx.algo.modules.depthwise_separable_conv_module import DepthwiseSeparableConvModule # TODO (sungchul): remove


class TestYOLOXHead:
def test_predict_by_feat(self):
s = 256
img_metas = [
{
"img_shape": (s, s, 3),
"scale_factor": (1.0, 1.0),
},
]
test_cfg = Config({"score_thr": 0.01, "nms": {"type": "nms", "iou_threshold": 0.65}})
head = YOLOXHead(num_classes=4, in_channels=1, stacked_convs=1, use_depthwise=False, test_cfg=test_cfg)
feat = [torch.rand(1, 1, s // feat_size, s // feat_size) for feat_size in [4, 8, 16]]
cls_scores, bbox_preds, objectnesses = head.forward(feat)
head.predict_by_feat(cls_scores, bbox_preds, objectnesses, img_metas, cfg=test_cfg, rescale=True, with_nms=True)
head.predict_by_feat(
cls_scores,
bbox_preds,
objectnesses,
img_metas,
cfg=test_cfg,
rescale=False,
with_nms=False,
)

def test_loss_by_feat(self):
s = 256
img_metas = [
{
"img_shape": (s, s, 3),
"scale_factor": 1,
},
]
train_cfg = Config(
{
"assigner": {
"type": "SimOTAAssigner",
"center_radius": 2.5,
"candidate_topk": 10,
"iou_weight": 3.0,
"cls_weight": 1.0,
},
},
)
head = YOLOXHead(num_classes=4, in_channels=1, stacked_convs=1, use_depthwise=False, train_cfg=train_cfg)
assert not head.use_l1
assert isinstance(head.multi_level_cls_convs[0][0], ConvModule)

feat = [torch.rand(1, 1, s // feat_size, s // feat_size) for feat_size in [4, 8, 16]]
cls_scores, bbox_preds, objectnesses = head.forward(feat)

# Test that empty ground truth encourages the network to predict
# background
gt_instances = InstanceData(bboxes=torch.empty((0, 4)), labels=torch.LongTensor([]))

empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses, [gt_instances], img_metas)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
empty_cls_loss = empty_gt_losses["loss_cls"].sum()
empty_box_loss = empty_gt_losses["loss_bbox"].sum()
empty_obj_loss = empty_gt_losses["loss_obj"].sum()
assert empty_cls_loss.item() == 0, "there should be no cls loss when there are no true boxes"
assert empty_box_loss.item() == 0, "there should be no box loss when there are no true boxes"
assert empty_obj_loss.item() > 0, "objectness loss should be non-zero"

# When truth is non-empty then both cls and box loss should be nonzero
# for random inputs
head = YOLOXHead(num_classes=4, in_channels=1, stacked_convs=1, use_depthwise=True, train_cfg=train_cfg)
assert isinstance(head.multi_level_cls_convs[0][0], DepthwiseSeparableConvModule)
head.use_l1 = True
gt_instances = InstanceData(
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
labels=torch.LongTensor([2]),
)

one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses, [gt_instances], img_metas)
onegt_cls_loss = one_gt_losses["loss_cls"].sum()
onegt_box_loss = one_gt_losses["loss_bbox"].sum()
onegt_obj_loss = one_gt_losses["loss_obj"].sum()
onegt_l1_loss = one_gt_losses["loss_l1"].sum()
assert onegt_cls_loss.item() > 0, "cls loss should be non-zero"
assert onegt_box_loss.item() > 0, "box loss should be non-zero"
assert onegt_obj_loss.item() > 0, "obj loss should be non-zero"
assert onegt_l1_loss.item() > 0, "l1 loss should be non-zero"

# Test groud truth out of bound
gt_instances = InstanceData(
bboxes=torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]]),
labels=torch.LongTensor([2]),
)
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses, [gt_instances], img_metas)
# When gt_bboxes out of bound, the assign results should be empty,
# so the cls and bbox loss should be zero.
empty_cls_loss = empty_gt_losses["loss_cls"].sum()
empty_box_loss = empty_gt_losses["loss_bbox"].sum()
empty_obj_loss = empty_gt_losses["loss_obj"].sum()
assert empty_cls_loss.item() == 0, "there should be no cls loss when gt_bboxes out of bound"
assert empty_box_loss.item() == 0, "there should be no box loss when gt_bboxes out of bound"
assert empty_obj_loss.item() > 0, "objectness loss should be non-zero"
3 changes: 3 additions & 0 deletions tests/unit/algo/detection/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of custom layers of OTX Detection task."""
Loading

0 comments on commit 4c18ab1

Please sign in to comment.