-
Notifications
You must be signed in to change notification settings - Fork 446
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Decouple YOLOX : unit tests for model migration (#3398)
- unit tests for #3382
- Loading branch information
Showing
16 changed files
with
776 additions
and
5 deletions.
There are no files selected for viewing
137 changes: 137 additions & 0 deletions
137
tests/unit/algo/detection/backbones/test_csp_darknet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
"""Test of CSPDarknet. | ||
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_backbones/test_csp_darknet.py | ||
""" | ||
|
||
import pytest | ||
import torch | ||
from otx.algo.detection.backbones.csp_darknet import CSPDarknet | ||
from torch.nn.modules import GroupNorm | ||
from torch.nn.modules.batchnorm import _BatchNorm | ||
|
||
|
||
def check_norm_state(modules, train_state): | ||
"""Check if norm layer is in correct train state. | ||
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_backbones/utils.py#L26C1-L32C16 | ||
""" | ||
return all(not (isinstance(mod, _BatchNorm) and mod.training != train_state) for mod in modules) | ||
|
||
|
||
def is_norm(modules): | ||
"""Check if is one of the norms. | ||
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_backbones/utils.py#L19-L23 | ||
""" | ||
if isinstance(modules, (GroupNorm, _BatchNorm)): | ||
return True | ||
return False | ||
|
||
|
||
class TestCSPDarknet: | ||
def test_init_with_large_frozen_stages(self) -> None: | ||
"""Test __init__ with large frozen_stages.""" | ||
with pytest.raises(ValueError): # noqa: PT011 | ||
# frozen_stages must in range(-1, len(arch_setting) + 1) | ||
CSPDarknet(frozen_stages=6) | ||
|
||
def test_init_with_large_out_indices(self) -> None: | ||
"""Test __init__ with large out_indices.""" | ||
with pytest.raises(AssertionError): | ||
CSPDarknet(out_indices=[6]) | ||
|
||
def test_freeze_stages(self) -> None: | ||
"""Test _freeze_stages.""" | ||
frozen_stages = 1 | ||
model = CSPDarknet(frozen_stages=frozen_stages) | ||
model.train() | ||
|
||
for mod in model.stem.modules(): | ||
for param in mod.parameters(): | ||
assert param.requires_grad is False | ||
for i in range(1, frozen_stages + 1): | ||
layer = getattr(model, f"stage{i}") | ||
for mod in layer.modules(): | ||
if isinstance(mod, _BatchNorm): | ||
assert mod.training is False | ||
for param in layer.parameters(): | ||
assert param.requires_grad is False | ||
|
||
def test_train_with_norm_eval(self) -> None: | ||
"""Test train with norm_eval=True.""" | ||
model = CSPDarknet(norm_eval=True) | ||
model.train() | ||
|
||
assert check_norm_state(model.modules(), False) | ||
|
||
def test_forward(self) -> None: | ||
# Test CSPDarknet-P5 forward with widen_factor=0.5 | ||
model = CSPDarknet(arch="P5", widen_factor=0.25, out_indices=range(5)) | ||
model.train() | ||
|
||
imgs = torch.randn(1, 3, 64, 64) | ||
feat = model(imgs) | ||
assert len(feat) == 5 | ||
assert feat[0].shape == torch.Size((1, 16, 32, 32)) | ||
assert feat[1].shape == torch.Size((1, 32, 16, 16)) | ||
assert feat[2].shape == torch.Size((1, 64, 8, 8)) | ||
assert feat[3].shape == torch.Size((1, 128, 4, 4)) | ||
assert feat[4].shape == torch.Size((1, 256, 2, 2)) | ||
|
||
# Test CSPDarknet-P6 forward with widen_factor=0.5 | ||
model = CSPDarknet(arch="P6", widen_factor=0.25, out_indices=range(6), spp_kernal_sizes=(3, 5, 7)) | ||
model.train() | ||
|
||
imgs = torch.randn(1, 3, 128, 128) | ||
feat = model(imgs) | ||
assert feat[0].shape == torch.Size((1, 16, 64, 64)) | ||
assert feat[1].shape == torch.Size((1, 32, 32, 32)) | ||
assert feat[2].shape == torch.Size((1, 64, 16, 16)) | ||
assert feat[3].shape == torch.Size((1, 128, 8, 8)) | ||
assert feat[4].shape == torch.Size((1, 192, 4, 4)) | ||
assert feat[5].shape == torch.Size((1, 256, 2, 2)) | ||
|
||
# Test CSPDarknet forward with dict(type='ReLU') | ||
model = CSPDarknet(widen_factor=0.125, act_cfg={"type": "ReLU"}, out_indices=range(5)) | ||
model.train() | ||
|
||
imgs = torch.randn(1, 3, 64, 64) | ||
feat = model(imgs) | ||
assert len(feat) == 5 | ||
assert feat[0].shape == torch.Size((1, 8, 32, 32)) | ||
assert feat[1].shape == torch.Size((1, 16, 16, 16)) | ||
assert feat[2].shape == torch.Size((1, 32, 8, 8)) | ||
assert feat[3].shape == torch.Size((1, 64, 4, 4)) | ||
assert feat[4].shape == torch.Size((1, 128, 2, 2)) | ||
|
||
# Test CSPDarknet with BatchNorm forward | ||
model = CSPDarknet(widen_factor=0.125, out_indices=range(5)) | ||
for m in model.modules(): | ||
if is_norm(m): | ||
assert isinstance(m, _BatchNorm) | ||
model.train() | ||
|
||
imgs = torch.randn(1, 3, 64, 64) | ||
feat = model(imgs) | ||
assert len(feat) == 5 | ||
assert feat[0].shape == torch.Size((1, 8, 32, 32)) | ||
assert feat[1].shape == torch.Size((1, 16, 16, 16)) | ||
assert feat[2].shape == torch.Size((1, 32, 8, 8)) | ||
assert feat[3].shape == torch.Size((1, 64, 4, 4)) | ||
assert feat[4].shape == torch.Size((1, 128, 2, 2)) | ||
|
||
# Test CSPDarknet with custom arch forward | ||
arch_ovewrite = [[32, 56, 3, True, False], [56, 224, 2, True, False], [224, 512, 1, True, False]] | ||
model = CSPDarknet(arch_ovewrite=arch_ovewrite, widen_factor=0.25, out_indices=(0, 1, 2, 3)) | ||
model.train() | ||
|
||
imgs = torch.randn(1, 3, 32, 32) | ||
feat = model(imgs) | ||
assert len(feat) == 4 | ||
assert feat[0].shape == torch.Size((1, 8, 16, 16)) | ||
assert feat[1].shape == torch.Size((1, 14, 8, 8)) | ||
assert feat[2].shape == torch.Size((1, 56, 4, 4)) | ||
assert feat[3].shape == torch.Size((1, 128, 2, 2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Test of point generators.""" | ||
|
||
import pytest | ||
import torch | ||
from otx.algo.detection.heads.point_generator import MlvlPointGenerator | ||
|
||
|
||
class TestMlvlPointGenerator: | ||
def test_mlvl_point_generator(self) -> None: | ||
# Square strides | ||
mlvl_points = MlvlPointGenerator(strides=[4, 10], offset=0) | ||
mlvl_points_half_stride_generator = MlvlPointGenerator(strides=[4, 10], offset=0.5) | ||
assert mlvl_points.num_levels == 2 | ||
|
||
# assert self.num_levels == len(featmap_sizes) # noqa: ERA001 | ||
with pytest.raises(AssertionError): | ||
mlvl_points.grid_priors(featmap_sizes=[(2, 2)], device="cpu") | ||
priors = mlvl_points.grid_priors(featmap_sizes=[(2, 2), (4, 8)], device="cpu") | ||
priors_with_stride = mlvl_points.grid_priors(featmap_sizes=[(2, 2), (4, 8)], with_stride=True, device="cpu") | ||
assert len(priors) == 2 | ||
|
||
# assert last dimension is (coord_x, coord_y, stride_w, stride_h). | ||
assert priors_with_stride[0].size(1) == 4 | ||
assert priors_with_stride[0][0][2] == 4 | ||
assert priors_with_stride[0][0][3] == 4 | ||
assert priors_with_stride[1][0][2] == 10 | ||
assert priors_with_stride[1][0][3] == 10 | ||
|
||
stride_4_feat_2_2 = priors[0] | ||
assert (stride_4_feat_2_2[1] - stride_4_feat_2_2[0]).sum() == 4 | ||
assert stride_4_feat_2_2.size(0) == 4 | ||
assert stride_4_feat_2_2.size(1) == 2 | ||
|
||
stride_10_feat_4_8 = priors[1] | ||
assert (stride_10_feat_4_8[1] - stride_10_feat_4_8[0]).sum() == 10 | ||
assert stride_10_feat_4_8.size(0) == 4 * 8 | ||
assert stride_10_feat_4_8.size(1) == 2 | ||
|
||
# assert the offset of 0.5 * stride | ||
priors_half_offset = mlvl_points_half_stride_generator.grid_priors(featmap_sizes=[(2, 2), (4, 8)], device="cpu") | ||
|
||
assert (priors_half_offset[0][0] - priors[0][0]).sum() == 4 * 0.5 * 2 | ||
assert (priors_half_offset[1][0] - priors[1][0]).sum() == 10 * 0.5 * 2 | ||
if torch.cuda.is_available(): | ||
# Square strides | ||
mlvl_points = MlvlPointGenerator(strides=[4, 10], offset=0) | ||
mlvl_points_half_stride_generator = MlvlPointGenerator(strides=[4, 10], offset=0.5) | ||
assert mlvl_points.num_levels == 2 | ||
|
||
# assert self.num_levels == len(featmap_sizes) # noqa: ERA001 | ||
with pytest.raises(AssertionError): | ||
mlvl_points.grid_priors(featmap_sizes=[(2, 2)], device="cuda") | ||
priors = mlvl_points.grid_priors(featmap_sizes=[(2, 2), (4, 8)], device="cuda") | ||
priors_with_stride = mlvl_points.grid_priors( | ||
featmap_sizes=[(2, 2), (4, 8)], | ||
with_stride=True, | ||
device="cuda", | ||
) | ||
assert len(priors) == 2 | ||
|
||
# assert last dimension is (coord_x, coord_y, stride_w, stride_h). | ||
assert priors_with_stride[0].size(1) == 4 | ||
assert priors_with_stride[0][0][2] == 4 | ||
assert priors_with_stride[0][0][3] == 4 | ||
assert priors_with_stride[1][0][2] == 10 | ||
assert priors_with_stride[1][0][3] == 10 | ||
|
||
stride_4_feat_2_2 = priors[0] | ||
assert (stride_4_feat_2_2[1] - stride_4_feat_2_2[0]).sum() == 4 | ||
assert stride_4_feat_2_2.size(0) == 4 | ||
assert stride_4_feat_2_2.size(1) == 2 | ||
|
||
stride_10_feat_4_8 = priors[1] | ||
assert (stride_10_feat_4_8[1] - stride_10_feat_4_8[0]).sum() == 10 | ||
assert stride_10_feat_4_8.size(0) == 4 * 8 | ||
assert stride_10_feat_4_8.size(1) == 2 | ||
|
||
# assert the offset of 0.5 * stride | ||
priors_half_offset = mlvl_points_half_stride_generator.grid_priors( | ||
featmap_sizes=[(2, 2), (4, 8)], | ||
device="cuda", | ||
) | ||
|
||
assert (priors_half_offset[0][0] - priors[0][0]).sum() == 4 * 0.5 * 2 | ||
assert (priors_half_offset[1][0] - priors[1][0]).sum() == 10 * 0.5 * 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
"""Test of SimOTAAssigner. | ||
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_task_modules/test_assigners/test_simota_assigner.py | ||
""" | ||
|
||
import torch | ||
from mmengine.structures import InstanceData # TODO (sungchul): remove | ||
from otx.algo.detection.heads.sim_ota_assigner import SimOTAAssigner | ||
|
||
|
||
class TestSimOTAAssigner: | ||
def test_assign(self) -> None: | ||
assigner = SimOTAAssigner(center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0) | ||
pred_instances = InstanceData( | ||
bboxes=torch.Tensor([[23, 23, 43, 43], [4, 5, 6, 7]]), | ||
scores=torch.FloatTensor([[0.2], [0.8]]), | ||
priors=torch.Tensor([[30, 30, 8, 8], [4, 5, 6, 7]]), | ||
) | ||
gt_instances = InstanceData(bboxes=torch.Tensor([[23, 23, 43, 43]]), labels=torch.LongTensor([0])) | ||
assign_result = assigner.assign(pred_instances=pred_instances, gt_instances=gt_instances) | ||
|
||
expected_gt_inds = torch.LongTensor([1, 0]) | ||
assert torch.allclose(assign_result.gt_inds, expected_gt_inds) | ||
|
||
def test_assign_with_no_valid_bboxes(self) -> None: | ||
assigner = SimOTAAssigner(center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0) | ||
pred_instances = InstanceData( | ||
bboxes=torch.Tensor([[123, 123, 143, 143], [114, 151, 161, 171]]), | ||
scores=torch.FloatTensor([[0.2], [0.8]]), | ||
priors=torch.Tensor([[30, 30, 8, 8], [55, 55, 8, 8]]), | ||
) | ||
gt_instances = InstanceData(bboxes=torch.Tensor([[0, 0, 1, 1]]), labels=torch.LongTensor([0])) | ||
assign_result = assigner.assign(pred_instances=pred_instances, gt_instances=gt_instances) | ||
|
||
expected_gt_inds = torch.LongTensor([0, 0]) | ||
assert torch.allclose(assign_result.gt_inds, expected_gt_inds) | ||
|
||
def test_assign_with_empty_gt(self) -> None: | ||
assigner = SimOTAAssigner(center_radius=2.5, candidate_topk=1, iou_weight=3.0, cls_weight=1.0) | ||
pred_instances = InstanceData( | ||
bboxes=torch.Tensor([[[30, 40, 50, 60]], [[4, 5, 6, 7]]]), | ||
scores=torch.FloatTensor([[0.2], [0.8]]), | ||
priors=torch.Tensor([[0, 12, 23, 34], [4, 5, 6, 7]]), | ||
) | ||
gt_instances = InstanceData(bboxes=torch.empty(0, 4), labels=torch.empty(0)) | ||
|
||
assign_result = assigner.assign(pred_instances=pred_instances, gt_instances=gt_instances) | ||
expected_gt_inds = torch.LongTensor([0, 0]) | ||
assert torch.allclose(assign_result.gt_inds, expected_gt_inds) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
"""Test of YOLOXHead. | ||
Reference : https://github.com/open-mmlab/mmdetection/blob/v3.2.0/tests/test_models/test_dense_heads/test_yolox_head.py | ||
""" | ||
|
||
import torch | ||
from mmengine.config import Config | ||
from mmengine.structures import InstanceData | ||
from otx.algo.detection.heads.yolox_head import YOLOXHead | ||
from otx.algo.modules.conv_module import ConvModule | ||
from otx.algo.modules.depthwise_separable_conv_module import DepthwiseSeparableConvModule # TODO (sungchul): remove | ||
|
||
|
||
class TestYOLOXHead: | ||
def test_predict_by_feat(self): | ||
s = 256 | ||
img_metas = [ | ||
{ | ||
"img_shape": (s, s, 3), | ||
"scale_factor": (1.0, 1.0), | ||
}, | ||
] | ||
test_cfg = Config({"score_thr": 0.01, "nms": {"type": "nms", "iou_threshold": 0.65}}) | ||
head = YOLOXHead(num_classes=4, in_channels=1, stacked_convs=1, use_depthwise=False, test_cfg=test_cfg) | ||
feat = [torch.rand(1, 1, s // feat_size, s // feat_size) for feat_size in [4, 8, 16]] | ||
cls_scores, bbox_preds, objectnesses = head.forward(feat) | ||
head.predict_by_feat(cls_scores, bbox_preds, objectnesses, img_metas, cfg=test_cfg, rescale=True, with_nms=True) | ||
head.predict_by_feat( | ||
cls_scores, | ||
bbox_preds, | ||
objectnesses, | ||
img_metas, | ||
cfg=test_cfg, | ||
rescale=False, | ||
with_nms=False, | ||
) | ||
|
||
def test_loss_by_feat(self): | ||
s = 256 | ||
img_metas = [ | ||
{ | ||
"img_shape": (s, s, 3), | ||
"scale_factor": 1, | ||
}, | ||
] | ||
train_cfg = Config( | ||
{ | ||
"assigner": { | ||
"type": "SimOTAAssigner", | ||
"center_radius": 2.5, | ||
"candidate_topk": 10, | ||
"iou_weight": 3.0, | ||
"cls_weight": 1.0, | ||
}, | ||
}, | ||
) | ||
head = YOLOXHead(num_classes=4, in_channels=1, stacked_convs=1, use_depthwise=False, train_cfg=train_cfg) | ||
assert not head.use_l1 | ||
assert isinstance(head.multi_level_cls_convs[0][0], ConvModule) | ||
|
||
feat = [torch.rand(1, 1, s // feat_size, s // feat_size) for feat_size in [4, 8, 16]] | ||
cls_scores, bbox_preds, objectnesses = head.forward(feat) | ||
|
||
# Test that empty ground truth encourages the network to predict | ||
# background | ||
gt_instances = InstanceData(bboxes=torch.empty((0, 4)), labels=torch.LongTensor([])) | ||
|
||
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses, [gt_instances], img_metas) | ||
# When there is no truth, the cls loss should be nonzero but there | ||
# should be no box loss. | ||
empty_cls_loss = empty_gt_losses["loss_cls"].sum() | ||
empty_box_loss = empty_gt_losses["loss_bbox"].sum() | ||
empty_obj_loss = empty_gt_losses["loss_obj"].sum() | ||
assert empty_cls_loss.item() == 0, "there should be no cls loss when there are no true boxes" | ||
assert empty_box_loss.item() == 0, "there should be no box loss when there are no true boxes" | ||
assert empty_obj_loss.item() > 0, "objectness loss should be non-zero" | ||
|
||
# When truth is non-empty then both cls and box loss should be nonzero | ||
# for random inputs | ||
head = YOLOXHead(num_classes=4, in_channels=1, stacked_convs=1, use_depthwise=True, train_cfg=train_cfg) | ||
assert isinstance(head.multi_level_cls_convs[0][0], DepthwiseSeparableConvModule) | ||
head.use_l1 = True | ||
gt_instances = InstanceData( | ||
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), | ||
labels=torch.LongTensor([2]), | ||
) | ||
|
||
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses, [gt_instances], img_metas) | ||
onegt_cls_loss = one_gt_losses["loss_cls"].sum() | ||
onegt_box_loss = one_gt_losses["loss_bbox"].sum() | ||
onegt_obj_loss = one_gt_losses["loss_obj"].sum() | ||
onegt_l1_loss = one_gt_losses["loss_l1"].sum() | ||
assert onegt_cls_loss.item() > 0, "cls loss should be non-zero" | ||
assert onegt_box_loss.item() > 0, "box loss should be non-zero" | ||
assert onegt_obj_loss.item() > 0, "obj loss should be non-zero" | ||
assert onegt_l1_loss.item() > 0, "l1 loss should be non-zero" | ||
|
||
# Test groud truth out of bound | ||
gt_instances = InstanceData( | ||
bboxes=torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]]), | ||
labels=torch.LongTensor([2]), | ||
) | ||
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses, [gt_instances], img_metas) | ||
# When gt_bboxes out of bound, the assign results should be empty, | ||
# so the cls and bbox loss should be zero. | ||
empty_cls_loss = empty_gt_losses["loss_cls"].sum() | ||
empty_box_loss = empty_gt_losses["loss_bbox"].sum() | ||
empty_obj_loss = empty_gt_losses["loss_obj"].sum() | ||
assert empty_cls_loss.item() == 0, "there should be no cls loss when gt_bboxes out of bound" | ||
assert empty_box_loss.item() == 0, "there should be no box loss when gt_bboxes out of bound" | ||
assert empty_obj_loss.item() > 0, "objectness loss should be non-zero" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Test of custom layers of OTX Detection task.""" |
Oops, something went wrong.