Skip to content

Commit

Permalink
add visualization for seg palette
Browse files Browse the repository at this point in the history
  • Loading branch information
theairlab committed Feb 12, 2025
1 parent bca3099 commit 49967ba
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/visualization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
difficulty='easy',
trajectory_id = 'P000',
modality = ['image', 'depth', 'seg'],
camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left'])
camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left'],
show_seg_palette = True)
4 changes: 2 additions & 2 deletions tartanair/tartanair.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def list_envs():
check_init()
return lister.list_envs()

def visualize(env, difficulty, trajectory_id, modality, camera_name = None):
def visualize(env, difficulty, trajectory_id, modality, camera_name = None, show_seg_palette = False):
"""
Interactively visualizes a trajectory from the TartanAir dataset that is saved locally.
Expand All @@ -342,7 +342,7 @@ def visualize(env, difficulty, trajectory_id, modality, camera_name = None):
"""
global visualizer
check_init()
visualizer.visualize(env, difficulty, trajectory_id, modality, camera_name)
visualizer.visualize(env, difficulty, trajectory_id, modality, camera_name, show_seg_palette = show_seg_palette)

def check_init():
global is_init
Expand Down
42 changes: 38 additions & 4 deletions tartanair/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
'''
# General imports.
import os
from colorama import Fore, Style
import cv2
import numpy as np
import json

# Local imports.
from .tartanair_module import TartanAirModule
from .tartanair_module import TartanAirModule, print_error
from .iterator import TartanAirIterator

_CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -23,8 +23,12 @@ def __init__(self, tartanair_data_root, azure_token = None):
# Modality mapped to a reader.
self.modality_to_vis_func = {'image': self.visimage, 'depth': self.visdepth, 'seg': self.visseg}
self.seg_colors = np.loadtxt(_CURRENT_PATH + '/seg_rgbs.txt', delimiter=',',dtype=np.uint8)
# for visualization purpose
self.data_to_colors = np.zeros((256, 3), dtype=np.uint8)
for color in self.seg_colors:
self.data_to_colors[color[2], :] = color

def visualize(self, env, difficulty = ['easy'], trajectory_id = ['P000'], modality = [], camera_name = []):
def visualize(self, env, difficulty = ['easy'], trajectory_id = ['P000'], modality = [], camera_name = [], show_seg_palette = False):
"""
Visualizes a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion.
Expand Down Expand Up @@ -53,6 +57,16 @@ def visualize(self, env, difficulty = ['easy'], trajectory_id = ['P000'], modali
sample_images.append( vis_img )
sample_image_names.append(vis_img_name)

# Visualize the semantic segmentation palette
if show_seg_palette:
label_file = os.path.join(self.tartanair_data_root, env, 'seg_label_map.json')
if os.path.isfile(label_file):
vispalette = self.vis_seg_palette(label_file)
cv2.imshow("Semantic Segmentation Palette", vispalette)
else:
print_error("Missing seg_label.json file {}".format(label_file))


#############################
# Visualize the images.
#############################
Expand Down Expand Up @@ -130,7 +144,27 @@ def visdepth(self, depth):
def visseg(self, seg):
segvis = np.zeros(seg.shape+(3,), dtype=np.uint8)

segvis = self.seg_colors[ seg, : ]
segvis = self.data_to_colors[ seg, : ]
segvis = segvis.reshape( seg.shape+(3,) )

return segvis

def vis_seg_palette(self, labelfile):

with open(labelfile,'r') as f:
seglabels = json.load(f)
seglabels = seglabels["name_map"] # {name: ind}
segvalues = [(seglabels[lab], lab) for lab in seglabels] # {ind: name}

num_classes = len(segvalues)
img_height = 20
img_width = 150
palette_img = np.zeros((num_classes * img_height, img_width, 3), dtype=np.uint8)

for i, (idx, label) in enumerate(segvalues):
color = self.data_to_colors[idx]
palette_img[i * img_height : (i + 1) * img_height, :] = color[::-1] # Convert RGB to BGR for OpenCV

cv2.putText(palette_img, label, (10, i * img_height + img_height // 2), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1)

return palette_img

0 comments on commit 49967ba

Please sign in to comment.