Skip to content

Commit

Permalink
Merge pull request #34 from castacks/multi_thread_download
Browse files Browse the repository at this point in the history
Add support for returning file list and multi-thread download
  • Loading branch information
Amigoshan authored Dec 6, 2024
2 parents 289678d + 27bd62e commit d8938e6
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 24 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Local
local/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
30 changes: 30 additions & 0 deletions examples/multi_thread_download_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# General imports.
import sys

# Local imports.
sys.path.append('..')
import tartanair as ta

# Create a TartanAir object.
tartanair_data_root = '/my/path/to/root/folder/for/tartanair-v2'

ta.init(tartanair_data_root)

# Download data from following environments.
env = [ "Prison",
"Ruins",
"UrbanConstruction",
]

ta.download_multi_thread(env = env,
difficulty = ['easy', 'hard'],
modality = ['image', 'depth'],
camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'],
unzip = True,
num_workers = 8)

# To download the entire dataset
alldata = ta.get_all_data() # this fill in all available data for env, difficulty, modality and camera_name
ta.download_multi_thread(**alldata,
unzip = True,
num_workers = 8)
86 changes: 65 additions & 21 deletions tartanair/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# Local imports.
from .tartanair_module import TartanAirModule, print_error, print_highlight, print_warn
from os.path import isdir, isfile, join
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

class AirLabDownloader(object):
def __init__(self, bucket_name = 'tartanair2') -> None:
Expand All @@ -26,21 +28,21 @@ def __init__(self, bucket_name = 'tartanair2') -> None:
self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=True)
self.bucket_name = bucket_name

def download(self, filelist, destination_path):
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
def download(self, filelist, targetfilelist):
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files

print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.client.fget_object(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)

return True, target_filelist
return True, success_source_files, success_target_files

class CloudFlareDownloader(object):
def __init__(self, bucket_name = "tartanair-v2") -> None:
Expand All @@ -54,7 +56,7 @@ def __init__(self, bucket_name = "tartanair-v2") -> None:
aws_secret_access_key=secret_key,
endpoint_url=endpoint_url)

def download(self, filelist, destination_path):
def download(self, filelist, targetfilelist):
"""
Downloads a file from Cloudflare R2 storage using S3 API.
Expand All @@ -67,26 +69,29 @@ def download(self, filelist, destination_path):
- str: A message indicating success or failure.
"""

from botocore.exceptions import NoCredentialsError
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
from botocore.exceptions import NoCredentialsError, ClientError
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files
try:
print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.s3.download_file(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
except FileNotFoundError:
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)
except ClientError:
print_error(f"Error: The file {source_file_name} was not found in the bucket {self.bucket_name}.")
return False, None
return False, success_source_files, success_target_files
except NoCredentialsError:
print_error("Error: Credentials not available.")
return False, None
return True, target_filelist
return False, success_source_files, success_target_files
except Exception:
print_error("Error: Failed for some reason.")
return False, success_source_files, success_target_files
return True, success_source_files, success_target_files

def get_all_s3_objects(self):
continuation_token = None
Expand Down Expand Up @@ -169,7 +174,7 @@ def unzip_files(self, zipfilelist):
os.system(cmd)
print_highlight("Unzipping Completed! ")

def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, **kwargs):
def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, **kwargs):
"""
Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion.
Expand Down Expand Up @@ -223,11 +228,50 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c
if not self.doublecheck_filelist(zipfilelist):
return False

suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root)
# generate the target file list:
targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
all_success_filelist = []

suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)

# download failed files untill success
trail_count = 0
while not suc:
zipfilelist = [ff for ff in zipfilelist if ff not in success_source_files]
if len(zipfilelist) == 0:
print_warn("No failed files are found! ")
break

targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)
trail_count += 1
if trail_count >= max_failure_trial:
break

if suc:
print_highlight("Download completed! Enjoy using TartanAir!")
else:
print_warn("Download with failure! The following files are not downloaded ..")
for ff in zipfilelist:
print_warn(ff)

if unzip:
self.unzip_files(targetfilelist)
self.unzip_files(all_success_filelist)

return True

def download_multi_thread(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, num_workers = 8, **kwargs):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for ee in env:
for dd in difficulty:
futures.append(executor.submit(self.download, env = [ee], difficulty = [dd], modality = modality, camera_name = camera_name,
config = config, unzip = unzip, max_failure_trial = max_failure_trial,))
# Wait for a few seconds to avoid overloading the data server
time.sleep(2)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution
35 changes: 32 additions & 3 deletions tartanair/tartanair.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@ def init(tartanair_root):
is_init = True

return True


def get_all_data():
global downloader
return {"env": downloader.env_names,
"difficulty": downloader.difficulty_names,
"modality": downloader.modality_names,
"camera_name": downloader.camera_names,
}

def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False):
"""
Expand All @@ -109,6 +116,29 @@ def download(env = [], difficulty = [], modality = [], camera_name = [], config
check_init()
downloader.download(env, difficulty, modality, camera_name, config, unzip)

def download_multi_thread(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, num_workers = 8):
"""
Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory.
:param env: The environment to download. Can be a list of environments.
:type env: str or list
:param difficulty: The difficulty of the trajectory. Can be a list of difficulties. Valid difficulties are: easy, hard.
:type difficulty: str or list
:param trajectory_id: The id of the trajectory to download. Can be a list of trajectory ids of form P000, P001, etc.
:type trajectory_id: str or list
:param modality: The modality to download. Can be a list of modalities. Valid modalities are: image, depth, seg, imu{_acc, _gyro, _time, ...}, lidar. Default will include all.
:type modality: str or list
:param camera_name: The camera name to download. Can be a list of camera names. Default will include all. Choices are `lcam_front`, `lcam_right`, `lcam_back`, `lcam_left`, `lcam_top`, `lcam_bottom`, `rcam_front`, `rcam_right`, `rcam_back`, `rcam_left`, `rcam_top`, `rcam_bottom`, `lcam_fish`, `rcam_fish`, `lcam_equirect`, `rcam_equirect`.
Modalities IMU and LIDAR do not need camera names specified.
:type camera_name: str or list
:param config: Optional. Path to a yaml file containing the download configuration. If a config file is provided, the other arguments will be ignored.
:type config: str
"""

global downloader
check_init()
downloader.download_multi_thread(env = env, difficulty = difficulty, modality = modality, camera_name = camera_name, config = config, unzip = unzip, num_workers = num_workers)

def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"):
"""
Synthesizes raw data into new camera-models. A few camera models are provided, although you can also provide your own camera models. The currently available camera models are:
Expand Down Expand Up @@ -399,7 +429,6 @@ def evaluate_traj(est_traj,
:return: A dictionary containing the evaluation metrics, which include ATE, RPE, the ground truth trajectory, and the estimated trajectory after alignment and scaling if those were requested
:rtype: dict
"""
global evaluator
check_init()
Expand All @@ -424,4 +453,4 @@ def evaluate_traj(est_traj,
# """
# global random_accessor
# check_init()
# return random_accessor
# return random_accessor

0 comments on commit d8938e6

Please sign in to comment.