Skip to content

Commit

Permalink
Reduce memory consumption in test_maskdino_heads.py
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed Oct 25, 2024
1 parent 938b42b commit b06bc1c
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions tests/unit/algo/instance_segmentation/heads/test_maskdino_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,29 @@
import torch
from otx.algo.instance_segmentation.heads import MaskDINODecoderHead, MaskDINOEncoderHead
from otx.algo.instance_segmentation.losses import MaskDINOCriterion
from otx.algo.instance_segmentation.maskdino import MaskDINO
from otx.algo.instance_segmentation.utils.utils import ShapeSpec


class TestMaskDINOTransformerHeads:
@pytest.fixture()
def fxt_shape_spec(self):
model = MaskDINO(label_info=3, model_name="resnet50")
_, specs = model._build_backbone()
return specs
return {
"res2": ShapeSpec(channels=16, stride=4),
"res3": ShapeSpec(channels=16, stride=8),
"res4": ShapeSpec(channels=16, stride=16),
"res5": ShapeSpec(channels=16, stride=32),
}

def test_maskdino_encoder_decoder_head(self, fxt_shape_spec, num_classes=2):
maskdino_encoder_head = MaskDINOEncoderHead("resnet50", fxt_shape_spec)
maskdino_decoder_head = MaskDINODecoderHead("resnet50", num_classes=num_classes)
criterion = MaskDINOCriterion(num_classes=2)
features = {
"res2": torch.randn(1, 256, 256, 256),
"res3": torch.randn(1, 512, 128, 128),
"res4": torch.randn(1, 1024, 64, 64),
"res5": torch.randn(1, 2048, 32, 32),
}

fmap_size = (64, 64)
features = {}
for fmap_name, spec in fxt_shape_spec.items():
features[fmap_name] = torch.rand((1, spec.channels, *fmap_size))
fmap_size = tuple(x // 2 for x in fmap_size)

targets = [
{
Expand Down

0 comments on commit b06bc1c

Please sign in to comment.