Skip to content

Commit

Permalink
update visualization scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Abyssaledge committed Sep 1, 2023
1 parent fc41a05 commit 5cbc975
Show file tree
Hide file tree
Showing 3 changed files with 406 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tools/vis/show_bin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np
import os
from os import path as osp
import argparse
from ipdb import set_trace
from tqdm import tqdm

# from pipeline_vis import frame_visualization
from visualizer import Visualizer2D
from utils import get_obj_dict_from_bin_file, get_pc_from_time_stamp

parser = argparse.ArgumentParser()
# running configurations
parser.add_argument('--bin-path', type=str, default='')
parser.add_argument('--gt-bin-path', type=str, default='./data/waymo/waymo_format/gt.bin')
parser.add_argument('--save-folder', type=str, default='')
parser.add_argument('--suffix', type=str, default='')
parser.add_argument('--split', type=str, default='training')
parser.add_argument('--interval', type=int, default=198)
parser.add_argument('--no-gt', action='store_true')
# process
args = parser.parse_args()

def frame_visualization(pc, dets, gts, name='', save_path='./exp.png', figsize=(40, 40)):
visualizer = Visualizer2D(name=name, figsize=figsize)
visualizer.handler_pc(pc, s=0.1)
if gts is not None:
for _, bbox in enumerate(gts):
visualizer.handler_box(bbox, message='', color='black')
for det in dets:
# visualizer.handler_box(det, message='%.2f-' % det.s + str(int(det.type)), color='red', linestyle='dashed', text_color='green', fontsize='small')
visualizer.handler_box(det, message='%.2f' % det.s, color='red', linestyle='dashed', text_color='blue', fontsize='small', center_message=True)
visualizer.save(save_path)
visualizer.close()

if __name__ == '__main__':
bin_path = osp.abspath(args.bin_path)
if args.save_folder == '':
save_folder = osp.join(osp.dirname(bin_path), 'vis_folder')
elif '/' in args.save_folder:
save_folder = args.save_folder
else:
save_folder = osp.join(osp.dirname(bin_path), args.save_folder)

assert 'vis' in save_folder


os.makedirs(save_folder, exist_ok=True)

pred_dict = get_obj_dict_from_bin_file(bin_path, debug=False)
if args.no_gt:
gt_dict = None
else:
gt_dict = get_obj_dict_from_bin_file(args.gt_bin_path, debug=False)

if gt_dict is not None:
ts_list = sorted(list(gt_dict.keys()))
else:
ts_list = sorted(list(pred_dict.keys()))

# with open

for i, ts in tqdm(enumerate(ts_list)):
if i % args.interval != 0:
continue
try:
dets = pred_dict[ts]
except KeyError:
continue

if gt_dict is None:
gts = None
elif ts not in gt_dict:
gts = []
else:
gts = gt_dict[ts]
# set_trace()
pc = get_pc_from_time_stamp(ts, './data/waymo/kitti_format/idx2timestamp.pkl', split=args.split)
# set_trace()

if len(args.suffix) > 0:
suffix = '_' + args.suffix
else:
suffix = ''

frame_visualization(
pc, dets, gts,
save_path=osp.join(save_folder, str(ts) + suffix + '.png'),
figsize=(18, 18)
)
125 changes: 125 additions & 0 deletions tools/vis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import numpy as np
import pickle as pkl
from tqdm import tqdm
from collections import defaultdict

from waymo_open_dataset import label_pb2
from waymo_open_dataset.utils import box_utils
from waymo_open_dataset.protos import metrics_pb2

from ipdb import set_trace
from visualizer import BBox

import torch

def read_bin(file_path):
with open(file_path, 'rb') as f:
objects = metrics_pb2.Objects()
objects.ParseFromString(f.read())
return objects

def object2array(obj):
"""transform box dict in waymo_open_format to array
Args:
box_dict ([dict]): waymo_open_dataset formatted bbox
"""
box = obj.object.box
result = np.array([
box.center_x,
box.center_y,
box.center_z,
box.width,
box.length,
box.height,
box.heading,
obj.score,
float(obj.object.type),
])
return result

def object2BBox(obj):
box = obj.object.box
result = BBox(
box.center_x,
box.center_y,
box.center_z,
box.height,
box.width,
box.length,
box.heading,
obj.score,
float(obj.object.type),
)
return result

def object2mmdetformat(obj):
'''
According to https://github.com/waymo-research/waymo-open-dataset/blob/master/waymo_open_dataset/label.proto#L33
and the definition of LiDARInstance3DBoxes
'''
box = obj.object.box
heading = box.heading - 1.5 * np.pi

while heading < -np.pi:
heading += 2 * np.pi
while heading > np.pi:
heading -= 2 * np.pi

result = np.array(
[
box.center_x,
box.center_y,
box.center_z,
box.width,
box.length,
box.height,
heading,
obj.score,
float(obj.object.type),
]
)
return result

def get_obj_dict_from_bin_file(file_path, debug=False, to_mmdet=False, concat=False):
print(f'Reading {file_path} ...')
pred_data = read_bin(file_path)
objects = pred_data.objects
if debug:
objects = objects[:10]
obj_dict = defaultdict(list)
print('Collecting Bboxes ...')
for o in tqdm(objects):
seg_name = o.context_name
time_stamp = o.frame_timestamp_micros
if to_mmdet:
bbox_for_vis = object2mmdetformat(o)
else:
bbox_for_vis = object2BBox(o)
obj_dict[time_stamp].append(bbox_for_vis)

if concat:
for k in obj_dict:
obj_list = obj_dict[k]
obj_dict[k] = np.stack(obj_list, dim=0)

return obj_dict

idx2ts = None

def get_pc_from_time_stamp(timestamp, path, split='training'):
global idx2ts
if idx2ts is None:
with open(path, 'rb') as fr:
idx2ts = pkl.load(fr)
print('Read idx2ts')
ts2idx = {}
for idx, ts in idx2ts.items():
ts2idx[ts] = idx

curr_idx = ts2idx[timestamp]
pc_root = f'./data/waymo/kitti_format/{split}/velodyne'
pc_path = os.path.join(pc_root, curr_idx + '.bin')
pc = np.fromfile(pc_path, dtype=np.float32).reshape(-1, 6)
pc = pc[:, :3]
return pc
Loading

0 comments on commit 5cbc975

Please sign in to comment.