Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: The shape of the 2D attn_mask is torch.Size([399, 399]), but should be (355, 355). #255

Open
vyskocj opened this issue Apr 15, 2023 · 2 comments
Assignees

Comments

@vyskocj
Copy link

vyskocj commented Apr 15, 2023

Hello guys,

Thank you for a nice project. I would like to retrain MaskDINO on a custom dataset (I can have max batch size of 2 if I want to keep the settings as similar as possible to the original). My dataset contains bitmasks - I have already rewritten the dataset loader and fixed the following issues:

my fix: gt_masks = targets_per_image.gt_masks.tensor

File "/mnt/storage-plzen1/home/vyskocj/ADETR/detrex/projects/maskdino/maskdino.py", line 239, in prepare_targets
padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
AttributeError: 'BitMasks' object has no attribute 'shape'

my fix: topk = min(self.num_queries, enc_outputs_class_unselected.shape[1])

File "/auto/plzen1/home/vyskocj/ADETR/detrex/projects/maskdino/modeling/transformer_decoder/maskdino_decoder.py", line 364, in forward
topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1]
RuntimeError: selected index k out of range

But I do not know how to handle the following one:

File "/auto/plzen1/home/vyskocj/ADETR/detrex/projects/maskdino/modeling/transformer_decoder/dino_decoder.py", line 247, in forward
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 1029, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 5112, in multi_head_attention_forward
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([399, 399]), but should be (355, 355).

Do you know, what could cause this issue & how to fix it?

EDIT:
this is happening only when I have batch size per gpu = 2. After I decreased the image size and increased batch size per gpu = 8, this error disappears. But this is not a solution...

@HaoZhang534
Copy link
Contributor

This error happens in DN part. The number of queries does not match the shape of the attention mask. Can you give more details such as the hyperparameters related to DN and the number of decoder queries?

@vyskocj
Copy link
Author

vyskocj commented Apr 18, 2023

Thank you for your fast response. Yes, here is my config file:

model config (based on maskdino_r50.py)
import torch.nn as nn
from detrex.layers import PositionEmbeddingSine
from detrex.modeling.backbone import ResNet, BasicStem

from detectron2.config import LazyCall as L

from projects.maskdino.modeling.meta_arch.maskdino_head import MaskDINOHead
from projects.maskdino.modeling.pixel_decoder.maskdino_encoder import MaskDINOEncoder
from projects.maskdino.modeling.transformer_decoder.maskdino_decoder import MaskDINODecoder
from projects.maskdino.modeling.criterion import SetCriterion
from projects.maskdino.modeling.matcher import HungarianMatcher
from projects.maskdino.maskdino_my import MaskDINO
from detectron2.data import MetadataCatalog
from detectron2.layers import Conv2d, ShapeSpec, get_norm



dim=256
n_class=530
dn="seg"
dec_layers = 9
input_shape={'res2': ShapeSpec(channels=256, height=None, width=None, stride=4), 'res3': ShapeSpec(channels=512, height=None, width=None, stride=8), 'res4': ShapeSpec(channels=1024, height=None, width=None, stride=16), 'res5': ShapeSpec(channels=2048, height=None, width=None, stride=32)}
model = L(MaskDINO)(
    backbone=L(ResNet)(
        stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
        stages=L(ResNet.make_default_stages)(
            depth=50,
            stride_in_1x1=False,
            norm="FrozenBN",
        ),
        out_features=["res2", "res3", "res4", "res5"],
        freeze_at=1,
    ),
    sem_seg_head=L(MaskDINOHead)(
        input_shape=input_shape,
        num_classes=n_class,
        pixel_decoder=L(MaskDINOEncoder)(
            input_shape=input_shape,
            transformer_dropout=0.0,
            transformer_nheads=8,
            transformer_dim_feedforward=2048,
            transformer_enc_layers=6,
            conv_dim=dim,
            mask_dim=dim,
            norm = 'GN',
            transformer_in_features=['res3', 'res4', 'res5'],
            common_stride=4,
            num_feature_levels=3,
            total_num_feature_levels=4,
            feature_order='low2high',
        ),
        loss_weight= 1.0,
        ignore_value= -1,
        transformer_predictor=L(MaskDINODecoder)(
            in_channels=dim,
            mask_classification=True,
            num_classes="${..num_classes}",
            hidden_dim=dim,
            num_queries=300,
            nheads=8,
            dim_feedforward=2048,
            dec_layers=dec_layers,
            mask_dim=dim,
            enforce_input_project=False,
            two_stage=True,
            dn=dn,
            noise_scale=0.4,
            dn_num=100,
            initialize_box_type='bitmask',
            initial_pred=True,
            learn_tgt=False,
            total_num_feature_levels= "${..pixel_decoder.total_num_feature_levels}",
            dropout = 0.0,
            activation= 'relu',
            nhead= 8,
            dec_n_points= 4,
            return_intermediate_dec = True,
            query_dim= 4,
            dec_layer_share = False,
            semantic_ce_loss = False,
        ),
    ),
    criterion=L(SetCriterion)(
        num_classes="${..sem_seg_head.num_classes}",
        matcher=L(HungarianMatcher)(
            cost_class = 4.0,
            cost_mask = 5.0,
            cost_dice = 5.0,
            num_points = 12544,
            cost_box=5.0,
            cost_giou=2.0,
            panoptic_on="${..panoptic_on}",
        ),
        weight_dict=dict(),
        eos_coef=0.1,
        losses=['labels', 'masks', 'boxes'],
        num_points=12544,
        oversample_ratio=3.0,
        importance_sample_ratio=0.75,
        dn=dn,
        dn_losses=['labels', 'masks', 'boxes'],
        panoptic_on="${..panoptic_on}",
        semantic_ce_loss=False
    ),
    num_queries=300,
    object_mask_threshold=0.25,
    overlap_threshold=0.8,
    metadata=MetadataCatalog.get('custom_train'),
    size_divisibility=32,
    sem_seg_postprocess_before_inference=True,
    pixel_mean=[123.675, 116.28, 103.53],
    pixel_std=[58.395, 57.12, 57.375],
    # inference
    semantic_on=False,
    panoptic_on=False,
    instance_on=True,
    test_topk_per_image=100,
    pano_temp=0.06,
    focus_on_box = False,
    transform_eval = True,
)

# set aux loss weight dict
class_weight=4.0
mask_weight=5.0
dice_weight=5.0
box_weight=5.0
giou_weight=2.0
weight_dict = {"loss_ce": class_weight}
weight_dict.update({"loss_mask": mask_weight, "loss_dice": dice_weight})
weight_dict.update({"loss_bbox": box_weight, "loss_giou": giou_weight})
# two stage is the query selection scheme

interm_weight_dict = {}
interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})
weight_dict.update(interm_weight_dict)
# denoising training

if dn == "standard":
    weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"})
    dn_losses = ["labels", "boxes"]
elif dn == "seg":
    weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()})
    dn_losses = ["labels", "masks", "boxes"]
else:
    dn_losses = []
# if deep_supervision:

aux_weight_dict = {}
for i in range(dec_layers):
    aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
model.criterion.weight_dict=weight_dict
the configuration passed via arguments (based on maskdino_r50_coco_instance_seg_50ep.py)
# ... imports

train = get_config("common/train.py").train

# modify dataloader config
dataloader.train.num_workers = 16

# please notice that this is total batch size.
# surpose you're using 4 gpus for training and the batch size for
# each gpu is 16/4 = 4
dataloader.train.total_batch_size = min(16, batch_size_per_gpu * num_gpus)
train.accum_grad_batch_size = min(1, int(16 / batch_size_per_gpu))  # to have 16 batch size
train.wandb.enabled = True
train.wandb.params.name = "maskdino_r50_50ep"
train.seed = 666


# max training iterations
train.max_iter = int(50 * num_train_imgs / dataloader.train.total_batch_size)
# warmup lr scheduler
lr_multiplier = L(WarmupParamScheduler)(
    scheduler=L(MultiStepParamScheduler)(
        values=[1.0, 0.1],
        milestones=[327778, 355092],
    ),
    warmup_length=10 / train.max_iter,
    warmup_factor=1.0,
)

optimizer = get_config("common/optim.py").AdamW
# lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep

# initialize checkpoint to be loaded
train.init_checkpoint = "https://github.com/IDEA-Research/detrex-storage/releases/download/maskdino-v0.1.0/maskdino_r50_50ep_300q_hid2048_3sd1_instance_maskenhanced_mask46.3ap_box51.7ap.pth" # "detectron2://ImageNetPretrained/torchvision/R-50.pkl"

# run evaluation every X iters
train.eval_period = 2 * math.ceil(num_train_imgs / dataloader.train.total_batch_size)

# log training infomation every 20 iters
train.log_period = 36  # must be divisible by train.eval_period !!!

# save checkpoint every X iters
train.checkpointer.period = train.eval_period

# gradient clipping for training
train.clip_grad.enabled = True
train.clip_grad.params.max_norm = 0.01
train.clip_grad.params.norm_type = 2

# set training devices
train.device = "cuda"


# modify optimizer config
optimizer.lr = 1e-4
optimizer.betas = (0.9, 0.999)
optimizer.weight_decay = 1e-4
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1

# ... dataloader

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants