Skip to content

Commit

Permalink
[Fix] Fix PartA2 bug (open-mmlab#1538)
Browse files Browse the repository at this point in the history
* [Fix] fixed compatibility of MMCV (open-mmlab#1508)

* fixed GPG key error

* fixed compatibility of MMCV

* fix parta2 bug

Co-authored-by: Wenwei Zhang <[email protected]>
Co-authored-by: atinfinity <[email protected]>
  • Loading branch information
3 people authored Jun 8, 2022
1 parent 36a0d19 commit ed3c1eb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
pipeline=train_pipeline,
modality=input_modality,
classes=class_names,
box_type_3d='LiDAR',
test_mode=False)),
val=dict(
type=dataset_type,
Expand All @@ -100,6 +101,7 @@
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
box_type_3d='LiDAR',
test_mode=True),
test=dict(
type=dataset_type,
Expand All @@ -110,6 +112,7 @@
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
box_type_3d='LiDAR',
test_mode=True))

# Part-A2 uses a different learning rate from what SECOND uses.
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/detectors/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self,
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)

self.train_cfg = train_cfg
self.test_cfg = test_cfg
if neck is not None:
self.neck = build_neck(neck)

Expand Down
9 changes: 5 additions & 4 deletions mmdet3d/models/middle_encoders/sparse_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmcv.runner import BaseModule, auto_fp16

from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops.sparse_block import replace_feature
from ..builder import MIDDLE_ENCODERS


Expand Down Expand Up @@ -168,10 +169,11 @@ def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
:obj:`SparseConvTensor`: Upsampled feature.
"""
x = lateral_layer(x_lateral)
x.features = torch.cat((x_bottom.features, x.features), dim=1)
x = replace_feature(x, torch.cat((x_bottom.features, x.features),
dim=1))
x_merge = merge_layer(x)
x = self.reduce_channel(x, x_merge.features.shape[1])
x.features = x_merge.features + x.features
x = replace_feature(x, x_merge.features + x.features)
x = upsample_layer(x)
return x

Expand All @@ -191,8 +193,7 @@ def reduce_channel(x, out_channels):
n, in_channels = features.shape
assert (in_channels % out_channels
== 0) and (in_channels >= out_channels)

x.features = features.view(n, out_channels, -1).sum(dim=2)
x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
return x

def make_encoder_layers(self, make_block, norm_cfg, in_channels):
Expand Down
12 changes: 10 additions & 2 deletions mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
import numpy as np
import torch
from mmcv.cnn import ConvModule, normal_init
from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential

from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE

if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import (SparseConvTensor, SparseMaxPool3d,
SparseSequential)
else:
from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential

from mmcv.runner import BaseModule
from torch import nn as nn

Expand Down Expand Up @@ -252,7 +260,7 @@ def forward(self, seg_feats, part_feats):
sparse_idx[:, 2], sparse_idx[:, 3]]
seg_features = seg_feats[sparse_idx[:, 0], sparse_idx[:, 1],
sparse_idx[:, 2], sparse_idx[:, 3]]
coords = sparse_idx.int()
coords = sparse_idx.int().contiguous()
part_features = SparseConvTensor(part_features, coords, sparse_shape,
rcnn_batch_size)
seg_features = SparseConvTensor(seg_features, coords, sparse_shape,
Expand Down

0 comments on commit ed3c1eb

Please sign in to comment.