From b06bc1ca5199504ea7721276300ce0b1d00b8fc5 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Fri, 25 Oct 2024 09:25:08 +0100 Subject: [PATCH] Reduce memory consumption in test_maskdino_heads.py --- .../heads/test_maskdino_heads.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/unit/algo/instance_segmentation/heads/test_maskdino_heads.py b/tests/unit/algo/instance_segmentation/heads/test_maskdino_heads.py index 297c517dc4..1a2ff4f668 100644 --- a/tests/unit/algo/instance_segmentation/heads/test_maskdino_heads.py +++ b/tests/unit/algo/instance_segmentation/heads/test_maskdino_heads.py @@ -6,26 +6,29 @@ import torch from otx.algo.instance_segmentation.heads import MaskDINODecoderHead, MaskDINOEncoderHead from otx.algo.instance_segmentation.losses import MaskDINOCriterion -from otx.algo.instance_segmentation.maskdino import MaskDINO +from otx.algo.instance_segmentation.utils.utils import ShapeSpec class TestMaskDINOTransformerHeads: @pytest.fixture() def fxt_shape_spec(self): - model = MaskDINO(label_info=3, model_name="resnet50") - _, specs = model._build_backbone() - return specs + return { + "res2": ShapeSpec(channels=16, stride=4), + "res3": ShapeSpec(channels=16, stride=8), + "res4": ShapeSpec(channels=16, stride=16), + "res5": ShapeSpec(channels=16, stride=32), + } def test_maskdino_encoder_decoder_head(self, fxt_shape_spec, num_classes=2): maskdino_encoder_head = MaskDINOEncoderHead("resnet50", fxt_shape_spec) maskdino_decoder_head = MaskDINODecoderHead("resnet50", num_classes=num_classes) criterion = MaskDINOCriterion(num_classes=2) - features = { - "res2": torch.randn(1, 256, 256, 256), - "res3": torch.randn(1, 512, 128, 128), - "res4": torch.randn(1, 1024, 64, 64), - "res5": torch.randn(1, 2048, 32, 32), - } + + fmap_size = (64, 64) + features = {} + for fmap_name, spec in fxt_shape_spec.items(): + features[fmap_name] = torch.rand((1, spec.channels, *fmap_size)) + fmap_size = tuple(x // 2 for x in fmap_size) targets = [ {