Skip to content

Commit

Permalink
[Enhance] Support with_velocity=False in CenterPoint and NuScenesData…
Browse files Browse the repository at this point in the history
…set (open-mmlab#1469)

* fixed some bugs of CenterPoint and NuScenesDataset when velocity is ignored

* add bbox_code_size attribute for DataBaseSampler & add with_velocity property for CenterPoint

* add time_dim attribute for

* update the config of the no velocity version centerpoint

* [fix] fix bugs in LoadPointsFromMultiSweeps to pass unit testing

Co-authored-by: Wenwei Zhang <[email protected]>
  • Loading branch information
YuanxianH and ZwwWayne authored Jun 27, 2022
1 parent 348db56 commit 23e858b
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
_base_ = [
'../_base_/datasets/nus-3d.py',
'../_base_/models/centerpoint_02pillar_second_secfpn_nus.py',
'../_base_/schedules/cyclic_20e.py', '../_base_/default_runtime.py'
]

# If point cloud range is changed, the models should also change their point
# cloud range accordingly
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
# For nuScenes we usually do 10-class detection
class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]

model = dict(
pts_voxel_layer=dict(point_cloud_range=point_cloud_range),
pts_voxel_encoder=dict(point_cloud_range=point_cloud_range),
pts_bbox_head=dict(
bbox_coder=dict(pc_range=point_cloud_range[:2], code_size=7),
common_heads=dict(
reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), _delete_=True)),
# model training and testing settings
train_cfg=dict(
pts=dict(
point_cloud_range=point_cloud_range,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
test_cfg=dict(pts=dict(pc_range=point_cloud_range[:2])))

dataset_type = 'NuScenesDataset'
data_root = 'data/nuscenes/'
file_client_args = dict(backend='disk')

db_sampler = dict(
data_root=data_root,
info_path=data_root + 'nuscenes_dbinfos_train.pkl',
rate=1.0,
bbox_code_size=7,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(
car=5,
truck=5,
bus=5,
trailer=5,
construction_vehicle=5,
traffic_cone=5,
barrier=5,
motorcycle=5,
bicycle=5,
pedestrian=5)),
classes=class_names,
sample_groups=dict(
car=2,
truck=3,
construction_vehicle=7,
bus=4,
trailer=6,
barrier=2,
motorcycle=6,
bicycle=6,
pedestrian=2,
traffic_cone=2),
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=[0, 1, 2, 3, 4],
file_client_args=file_client_args))

train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=5,
file_client_args=file_client_args),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=9,
use_dim=[0, 1, 2, 3, 4],
file_client_args=file_client_args,
pad_empty_sweeps=True,
remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.3925, 0.3925],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=5,
file_client_args=file_client_args),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=9,
use_dim=[0, 1, 2, 3, 4],
file_client_args=file_client_args,
pad_empty_sweeps=True,
remove_close=True),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=5,
file_client_args=file_client_args),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=9,
use_dim=[0, 1, 2, 3, 4],
file_client_args=file_client_args,
pad_empty_sweeps=True,
remove_close=True),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
]

train_load_interval = 1000

with_velocity = False
data = dict(
train=dict(
type='CBGSDataset',
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_train.pkl',
pipeline=train_pipeline,
load_interval=train_load_interval,
classes=class_names,
with_velocity=with_velocity,
test_mode=False,
use_valid_flag=True,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR')),
val=dict(
pipeline=test_pipeline,
classes=class_names,
with_velocity=with_velocity),
test=dict(
pipeline=test_pipeline,
classes=class_names,
with_velocity=with_velocity))
9 changes: 6 additions & 3 deletions mmdet3d/datasets/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _format_bbox(self, results, jsonfile_prefix=None):
print('Start to convert detection format...')
for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
annos = []
boxes = output_to_nusc_box(det)
boxes = output_to_nusc_box(det, self.with_velocity)
sample_token = self.data_infos[sample_id]['token']
boxes = lidar_nusc_box_to_global(self.data_infos[sample_id], boxes,
mapped_class_names,
Expand Down Expand Up @@ -573,7 +573,7 @@ def show(self, results, out_dir, show=False, pipeline=None):
file_name, show)


def output_to_nusc_box(detection):
def output_to_nusc_box(detection, with_velocity=True):
"""Convert the output to the box class in the nuScenes.
Args:
Expand All @@ -600,7 +600,10 @@ def output_to_nusc_box(detection):
box_list = []
for i in range(len(box3d)):
quat = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw[i])
velocity = (*box3d.tensor[i, 7:9], 0.0)
if with_velocity:
velocity = (*box3d.tensor[i, 7:9], 0.0)
else:
velocity = (0, 0, 0)
# velo_val = np.linalg.norm(box3d[i, 7:9])
# velo_ori = box3d[i, 6]
# velocity = (
Expand Down
11 changes: 11 additions & 0 deletions mmdet3d/datasets/pipelines/dbsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class DataBaseSampler(object):
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Default: None.
bbox_code_size (int, optional): The number of bbox dimensions.
Default: None.
points_loader(dict, optional): Config of points loader. Default:
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
"""
Expand All @@ -100,6 +102,7 @@ def __init__(self,
prepare,
sample_groups,
classes=None,
bbox_code_size=None,
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
Expand Down Expand Up @@ -143,13 +146,21 @@ def __init__(self,

self.db_infos = db_infos

self.bbox_code_size = bbox_code_size
if bbox_code_size is not None:
for k, info_cls in self.db_infos.items():
for info in info_cls:
info['box3d_lidar'] = info['box3d_lidar'][:self.
bbox_code_size]

# load sample groups
# TODO: more elegant way to load sample groups
self.sample_groups = []
for name, num in sample_groups.items():
self.sample_groups.append({name: int(num)})

self.group_db_infos = self.db_infos # just use db_infos

self.sample_classes = []
self.sample_max_nums = []
for group_info in self.sample_groups:
Expand Down
12 changes: 10 additions & 2 deletions mmdet3d/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class LoadPointsFromMultiSweeps(object):
Defaults to 5.
use_dim (list[int], optional): Which dimension to use.
Defaults to [0, 1, 2, 4].
time_dim (int, optional): Which dimension to represent the timestamps
of each points. Defaults to 4.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
Expand All @@ -125,18 +127,24 @@ def __init__(self,
sweeps_num=10,
load_dim=5,
use_dim=[0, 1, 2, 4],
time_dim=4,
file_client_args=dict(backend='disk'),
pad_empty_sweeps=False,
remove_close=False,
test_mode=False):
self.load_dim = load_dim
self.sweeps_num = sweeps_num
self.use_dim = use_dim
self.time_dim = time_dim
assert time_dim < load_dim, \
f'Expect the timestamp dimension < {load_dim}, got {time_dim}'
self.file_client_args = file_client_args.copy()
self.file_client = None
self.pad_empty_sweeps = pad_empty_sweeps
self.remove_close = remove_close
self.test_mode = test_mode
assert max(use_dim) < load_dim, \
f'Expect all used dimensions < {load_dim}, got {use_dim}'

def _load_points(self, pts_filename):
"""Private function to load point clouds data.
Expand Down Expand Up @@ -197,7 +205,7 @@ def __call__(self, results):
cloud arrays.
"""
points = results['points']
points.tensor[:, 4] = 0
points.tensor[:, self.time_dim] = 0
sweep_points_list = [points]
ts = results['timestamp']
if self.pad_empty_sweeps and len(results['sweeps']) == 0:
Expand All @@ -224,7 +232,7 @@ def __call__(self, results):
points_sweep[:, :3] = points_sweep[:, :3] @ sweep[
'sensor2lidar_rotation'].T
points_sweep[:, :3] += sweep['sensor2lidar_translation']
points_sweep[:, 4] = ts - sweep_ts
points_sweep[:, self.time_dim] = ts - sweep_ts
points_sweep = points.new_point(points_sweep)
sweep_points_list.append(points_sweep)

Expand Down
52 changes: 36 additions & 16 deletions mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def __init__(self,
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head))

self.with_velocity = 'vel' in common_heads.keys()

def forward_single(self, x):
"""Forward function for CenterPoint.
Expand Down Expand Up @@ -490,8 +492,12 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))

anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32)
if self.with_velocity:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32)
else:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)

ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
Expand Down Expand Up @@ -548,19 +554,27 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
vx, vy = task_boxes[idx][k][7:]
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0)
])
if self.with_velocity:
vx, vy = task_boxes[idx][k][7:]
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0),
vx.unsqueeze(0),
vy.unsqueeze(0)
])
else:
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0)
])

heatmaps.append(heatmap)
anno_boxes.append(anno_box)
Expand Down Expand Up @@ -594,11 +608,17 @@ def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
avg_factor=max(num_pos, 1))
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot'],
preds_dict[0]['vel']),
dim=1)
if self.with_velocity:
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot'],
preds_dict[0]['vel']),
dim=1)
else:
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot']),
dim=1)

# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
Expand Down
Loading

0 comments on commit 23e858b

Please sign in to comment.