From 49967badd6b7614e1ca52343f572961d45f4e73b Mon Sep 17 00:00:00 2001 From: amigoshan Date: Wed, 12 Feb 2025 12:02:14 -0500 Subject: [PATCH] add visualization for seg palette --- examples/visualization_example.py | 3 ++- tartanair/tartanair.py | 4 +-- tartanair/visualizer.py | 42 ++++++++++++++++++++++++++++--- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/examples/visualization_example.py b/examples/visualization_example.py index 4e233e4..19812e3 100644 --- a/examples/visualization_example.py +++ b/examples/visualization_example.py @@ -22,4 +22,5 @@ difficulty='easy', trajectory_id = 'P000', modality = ['image', 'depth', 'seg'], - camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left']) \ No newline at end of file + camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left'], + show_seg_palette = True) \ No newline at end of file diff --git a/tartanair/tartanair.py b/tartanair/tartanair.py index 511596a..7021c73 100644 --- a/tartanair/tartanair.py +++ b/tartanair/tartanair.py @@ -324,7 +324,7 @@ def list_envs(): check_init() return lister.list_envs() -def visualize(env, difficulty, trajectory_id, modality, camera_name = None): +def visualize(env, difficulty, trajectory_id, modality, camera_name = None, show_seg_palette = False): """ Interactively visualizes a trajectory from the TartanAir dataset that is saved locally. @@ -342,7 +342,7 @@ def visualize(env, difficulty, trajectory_id, modality, camera_name = None): """ global visualizer check_init() - visualizer.visualize(env, difficulty, trajectory_id, modality, camera_name) + visualizer.visualize(env, difficulty, trajectory_id, modality, camera_name, show_seg_palette = show_seg_palette) def check_init(): global is_init diff --git a/tartanair/visualizer.py b/tartanair/visualizer.py index c0ffee8..eeefa03 100644 --- a/tartanair/visualizer.py +++ b/tartanair/visualizer.py @@ -6,12 +6,12 @@ ''' # General imports. import os -from colorama import Fore, Style import cv2 import numpy as np +import json # Local imports. -from .tartanair_module import TartanAirModule +from .tartanair_module import TartanAirModule, print_error from .iterator import TartanAirIterator _CURRENT_PATH = os.path.dirname(os.path.realpath(__file__)) @@ -23,8 +23,12 @@ def __init__(self, tartanair_data_root, azure_token = None): # Modality mapped to a reader. self.modality_to_vis_func = {'image': self.visimage, 'depth': self.visdepth, 'seg': self.visseg} self.seg_colors = np.loadtxt(_CURRENT_PATH + '/seg_rgbs.txt', delimiter=',',dtype=np.uint8) + # for visualization purpose + self.data_to_colors = np.zeros((256, 3), dtype=np.uint8) + for color in self.seg_colors: + self.data_to_colors[color[2], :] = color - def visualize(self, env, difficulty = ['easy'], trajectory_id = ['P000'], modality = [], camera_name = []): + def visualize(self, env, difficulty = ['easy'], trajectory_id = ['P000'], modality = [], camera_name = [], show_seg_palette = False): """ Visualizes a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion. @@ -53,6 +57,16 @@ def visualize(self, env, difficulty = ['easy'], trajectory_id = ['P000'], modali sample_images.append( vis_img ) sample_image_names.append(vis_img_name) + # Visualize the semantic segmentation palette + if show_seg_palette: + label_file = os.path.join(self.tartanair_data_root, env, 'seg_label_map.json') + if os.path.isfile(label_file): + vispalette = self.vis_seg_palette(label_file) + cv2.imshow("Semantic Segmentation Palette", vispalette) + else: + print_error("Missing seg_label.json file {}".format(label_file)) + + ############################# # Visualize the images. ############################# @@ -130,7 +144,27 @@ def visdepth(self, depth): def visseg(self, seg): segvis = np.zeros(seg.shape+(3,), dtype=np.uint8) - segvis = self.seg_colors[ seg, : ] + segvis = self.data_to_colors[ seg, : ] segvis = segvis.reshape( seg.shape+(3,) ) return segvis + + def vis_seg_palette(self, labelfile): + + with open(labelfile,'r') as f: + seglabels = json.load(f) + seglabels = seglabels["name_map"] # {name: ind} + segvalues = [(seglabels[lab], lab) for lab in seglabels] # {ind: name} + + num_classes = len(segvalues) + img_height = 20 + img_width = 150 + palette_img = np.zeros((num_classes * img_height, img_width, 3), dtype=np.uint8) + + for i, (idx, label) in enumerate(segvalues): + color = self.data_to_colors[idx] + palette_img[i * img_height : (i + 1) * img_height, :] = color[::-1] # Convert RGB to BGR for OpenCV + + cv2.putText(palette_img, label, (10, i * img_height + img_height // 2), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1) + + return palette_img