Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for returning file list and multi-thread download #34

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Local
local/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
30 changes: 30 additions & 0 deletions examples/multi_thread_download_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# General imports.
import sys

# Local imports.
sys.path.append('..')
import tartanair as ta

# Create a TartanAir object.
tartanair_data_root = '/my/path/to/root/folder/for/tartanair-v2'

ta.init(tartanair_data_root)

# Download data from following environments.
env = [ "Prison",
"Ruins",
"UrbanConstruction",
]

ta.download_multi_thread(env = env,
difficulty = ['easy', 'hard'],
modality = ['image', 'depth'],
camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'],
unzip = True,
num_workers = 8)

# To download the entire dataset
alldata = ta.get_all_data() # this fill in all available data for env, difficulty, modality and camera_name
ta.download_multi_thread(**alldata,
unzip = True,
num_workers = 8)
94 changes: 69 additions & 25 deletions tartanair/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# Local imports.
from .tartanair_module import TartanAirModule, print_error, print_highlight, print_warn
from os.path import isdir, isfile, join
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

class AirLabDownloader(object):
def __init__(self, bucket_name = 'tartanair2') -> None:
Expand All @@ -23,38 +25,38 @@ def __init__(self, bucket_name = 'tartanair2') -> None:
access_key = "4e54CkGDFg2RmPjaQYmW"
secret_key = "mKdGwketlYUcXQwcPxuzinSxJazoyMpAip47zYdl"

self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=False)
self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=True)
self.bucket_name = bucket_name

def download(self, filelist, destination_path):
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
def download(self, filelist, targetfilelist):
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files

print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.client.fput_object(self.bucket_name, target_file_name, source_file_name)
self.client.fget_object(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)

return True, target_filelist
return True, success_source_files, success_target_files

class CloudFlareDownloader(object):
def __init__(self, bucket_name = "tartanair-v2") -> None:
import boto3
access_key = "be0116e42ced3fd52c32398b5003ecda"
secret_key = "103fab752dab348fa665dc744be9b8fb6f9cf04f82f9409d79c54a88661a0d40"
access_key = "f1ae9efebbc6a9a7cebbd949ba3a12de"
secret_key = "0a21fe771089d82e048ed0a1dd6067cb29a5666bf4fe95f7be9ba6f72482ec8b"
endpoint_url = "https://0a585e9484af268a716f8e6d3be53bbc.r2.cloudflarestorage.com"

self.bucket_name = bucket_name
self.s3 = boto3.client('s3', aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
endpoint_url=endpoint_url)

def download(self, filelist, destination_path):
def download(self, filelist, targetfilelist):
"""
Downloads a file from Cloudflare R2 storage using S3 API.

Expand All @@ -67,26 +69,29 @@ def download(self, filelist, destination_path):
- str: A message indicating success or failure.
"""

from botocore.exceptions import NoCredentialsError
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
from botocore.exceptions import NoCredentialsError, ClientError
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files
try:
print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.s3.download_file(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
except FileNotFoundError:
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)
except ClientError:
print_error(f"Error: The file {source_file_name} was not found in the bucket {self.bucket_name}.")
return False, None
return False, success_source_files, success_target_files
except NoCredentialsError:
print_error("Error: Credentials not available.")
return False, None
return True, target_filelist
return False, success_source_files, success_target_files
except Exception:
print_error("Error: Failed for some reason.")
return False, success_source_files, success_target_files
return True, success_source_files, success_target_files

def get_all_s3_objects(self):
continuation_token = None
Expand Down Expand Up @@ -169,7 +174,7 @@ def unzip_files(self, zipfilelist):
os.system(cmd)
print_highlight("Unzipping Completed! ")

def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, **kwargs):
def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, **kwargs):
"""
Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion.

Expand Down Expand Up @@ -223,11 +228,50 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c
if not self.doublecheck_filelist(zipfilelist):
return False

suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root)
# generate the target file list:
targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
all_success_filelist = []

suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)

# download failed files untill success
trail_count = 0
while not suc:
zipfilelist = [ff for ff in zipfilelist if ff not in success_source_files]
if len(zipfilelist) == 0:
print_warn("No failed files are found! ")
break

targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)
trail_count += 1
if trail_count >= max_failure_trial:
break

if suc:
print_highlight("Download completed! Enjoy using TartanAir!")
else:
print_warn("Download with failure! The following files are not downloaded ..")
for ff in zipfilelist:
print_warn(ff)

if unzip:
self.unzip_files(targetfilelist)
self.unzip_files(all_success_filelist)

return True

def download_multi_thread(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, num_workers = 8, **kwargs):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for ee in env:
for dd in difficulty:
futures.append(executor.submit(self.download, env = [ee], difficulty = [dd], modality = modality, camera_name = camera_name,
config = config, unzip = unzip, max_failure_trial = max_failure_trial,))
# Wait for a few seconds to avoid overloading the data server
time.sleep(2)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution
35 changes: 32 additions & 3 deletions tartanair/tartanair.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@ def init(tartanair_root):
is_init = True

return True


def get_all_data():
global downloader
return {"env": downloader.env_names,
"difficulty": downloader.difficulty_names,
"modality": downloader.modality_names,
"camera_name": downloader.camera_names,
}

def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False):
"""
Expand All @@ -109,6 +116,29 @@ def download(env = [], difficulty = [], modality = [], camera_name = [], config
check_init()
downloader.download(env, difficulty, modality, camera_name, config, unzip)

def download_multi_thread(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, num_workers = 8):
"""
Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory.

:param env: The environment to download. Can be a list of environments.
:type env: str or list
:param difficulty: The difficulty of the trajectory. Can be a list of difficulties. Valid difficulties are: easy, hard.
:type difficulty: str or list
:param trajectory_id: The id of the trajectory to download. Can be a list of trajectory ids of form P000, P001, etc.
:type trajectory_id: str or list
:param modality: The modality to download. Can be a list of modalities. Valid modalities are: image, depth, seg, imu{_acc, _gyro, _time, ...}, lidar. Default will include all.
:type modality: str or list
:param camera_name: The camera name to download. Can be a list of camera names. Default will include all. Choices are `lcam_front`, `lcam_right`, `lcam_back`, `lcam_left`, `lcam_top`, `lcam_bottom`, `rcam_front`, `rcam_right`, `rcam_back`, `rcam_left`, `rcam_top`, `rcam_bottom`, `lcam_fish`, `rcam_fish`, `lcam_equirect`, `rcam_equirect`.
Modalities IMU and LIDAR do not need camera names specified.
:type camera_name: str or list
:param config: Optional. Path to a yaml file containing the download configuration. If a config file is provided, the other arguments will be ignored.
:type config: str
"""

global downloader
check_init()
downloader.download_multi_thread(env = env, difficulty = difficulty, modality = modality, camera_name = camera_name, config = config, unzip = unzip, num_workers = num_workers)

def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"):
"""
Synthesizes raw data into new camera-models. A few camera models are provided, although you can also provide your own camera models. The currently available camera models are:
Expand Down Expand Up @@ -399,7 +429,6 @@ def evaluate_traj(est_traj,

:return: A dictionary containing the evaluation metrics, which include ATE, RPE, the ground truth trajectory, and the estimated trajectory after alignment and scaling if those were requested
:rtype: dict

"""
global evaluator
check_init()
Expand All @@ -424,4 +453,4 @@ def evaluate_traj(est_traj,
# """
# global random_accessor
# check_init()
# return random_accessor
# return random_accessor
Loading