Skip to content

Commit

Permalink
[Feature] Add FCAF3D on ScanNet, S3DIS and SUN RGB-D (open-mmlab#1547)
Browse files Browse the repository at this point in the history
* add fcaf3d

* fix lint; add docstrings

* add loss tests; fix config test

* fix test_axis_aligned_iou_loss

* revert depthinstance3dboxes.overlaps

* fix deadlock in reduce_mean

* fix comments

* fix lint

* raw mmdetection3d

* revert last commit

* fix comments

* fix comments
  • Loading branch information
filaPro authored Aug 1, 2022
1 parent a9ae58c commit fe8f9d4
Show file tree
Hide file tree
Showing 15 changed files with 1,344 additions and 11 deletions.
16 changes: 16 additions & 0 deletions configs/_base_/models/fcaf3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model = dict(
type='MinkSingleStage3DDetector',
voxel_size=.01,
backbone=dict(type='MinkResNet', in_channels=3, depth=34),
head=dict(
type='FCAF3DHead',
in_channels=(64, 128, 256, 512),
out_channels=128,
voxel_size=.01,
pts_prune_threshold=100000,
pts_assign_threshold=27,
pts_center_threshold=18,
n_classes=18,
n_reg_outs=6),
train_cfg=dict(),
test_cfg=dict(nms_pre=1000, iou_thr=.5, score_thr=.01))
104 changes: 104 additions & 0 deletions configs/fcaf3d/fcaf3d_8x2_s3dis-3d-5class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
_base_ = ['fcaf3d_8x2_scannet-3d-18class.py']
n_points = 100000

model = dict(head=dict(n_classes=5))

dataset_type = 'S3DISDataset'
data_root = './data/s3dis/'
class_names = ('table', 'chair', 'sofa', 'bookcase', 'board')
train_area = [1, 2, 3, 4, 6]
test_area = 5

train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='LoadAnnotations3D'),
dict(type='PointSample', num_points=n_points),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.087266, 0.087266],
scale_ratio_range=[.9, 1.1],
translation_std=[.1, .1, .1],
shift_height=False),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(type='PointSample', num_points=n_points),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=13,
dataset=dict(
type='ConcatDataset',
datasets=[
dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + f's3dis_infos_Area_{i}.pkl',
pipeline=train_pipeline,
filter_empty_gt=True,
classes=class_names,
box_type_3d='Depth') for i in train_area
],
separate_eval=False)),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + f's3dis_infos_Area_{test_area}.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + f's3dis_infos_Area_{test_area}.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'))
105 changes: 105 additions & 0 deletions configs/fcaf3d/fcaf3d_8x2_scannet-3d-18class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
_base_ = ['../_base_/models/fcaf3d.py', '../_base_/default_runtime.py']
n_points = 100000

dataset_type = 'ScanNetDataset'
data_root = './data/scannet/'
class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'garbagebin')
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='LoadAnnotations3D'),
dict(type='GlobalAlignment', rotation_axis=2),
dict(type='PointSample', num_points=n_points),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.087266, 0.087266],
scale_ratio_range=[.9, 1.1],
translation_std=[.1, .1, .1],
shift_height=False),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(type='PointSample', num_points=n_points),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=10,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_train.pkl',
pipeline=train_pipeline,
filter_empty_gt=True,
classes=class_names,
box_type_3d='Depth')),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'))

optimizer = dict(type='AdamW', lr=0.001, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2))
lr_config = dict(policy='step', warmup=None, step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)]
94 changes: 94 additions & 0 deletions configs/fcaf3d/fcaf3d_8x2_sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
_base_ = ['fcaf3d_8x2_scannet-3d-18class.py']
n_points = 100000

model = dict(
head=dict(
n_classes=10, n_reg_outs=8, bbox_loss=dict(type='RotatedIoU3DLoss')))

dataset_type = 'SUNRGBDDataset'
data_root = 'data/sunrgbd/'
class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
'night_stand', 'bookshelf', 'bathtub')
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='LoadAnnotations3D'),
dict(type='PointSample', num_points=n_points),
dict(type='RandomFlip3D', sync_2d=False, flip_ratio_bev_horizontal=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
translation_std=[.1, .1, .1],
shift_height=False),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(type='PointSample', num_points=n_points),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=3,
dataset=dict(
type=dataset_type,
modality=dict(use_camera=False, use_lidar=True),
data_root=data_root,
ann_file=data_root + 'sunrgbd_infos_train.pkl',
pipeline=train_pipeline,
filter_empty_gt=True,
classes=class_names,
box_type_3d='Depth')),
val=dict(
type=dataset_type,
modality=dict(use_camera=False, use_lidar=True),
data_root=data_root,
ann_file=data_root + 'sunrgbd_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'),
test=dict(
type=dataset_type,
modality=dict(use_camera=False, use_lidar=True),
data_root=data_root,
ann_file=data_root + 'sunrgbd_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
box_type_3d='Depth'))
4 changes: 2 additions & 2 deletions mmdet3d/models/backbones/mink_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, depth, in_channels, num_stages=4, pool=True):

for i, num_blocks in enumerate(stage_blocks):
setattr(
self, f'layer{i}',
self, f'layer{i + 1}',
self._make_layer(block, 64 * 2**i, stage_blocks[i], stride=2))

def init_weights(self):
Expand Down Expand Up @@ -111,6 +111,6 @@ def forward(self, x):
x = self.maxpool(x)
outs = []
for i in range(self.num_stages):
x = getattr(self, f'layer{i}')(x)
x = getattr(self, f'layer{i + 1}')(x)
outs.append(x)
return outs
3 changes: 2 additions & 1 deletion mmdet3d/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .base_conv_bbox_head import BaseConvBboxHead
from .base_mono3d_dense_head import BaseMono3DDenseHead
from .centerpoint_head import CenterHead
from .fcaf3d_head import FCAF3DHead
from .fcos_mono3d_head import FCOSMono3DHead
from .free_anchor3d_head import FreeAnchor3DHead
from .groupfree3d_head import GroupFree3DHead
Expand All @@ -21,5 +22,5 @@
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead',
'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead',
'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead',
'MonoFlexHead'
'MonoFlexHead', 'FCAF3DHead'
]
Loading

0 comments on commit fe8f9d4

Please sign in to comment.