diff --git a/.gitignore b/.gitignore index 5db2ffc..58aac05 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Local +local/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/examples/multi_thread_download_example.py b/examples/multi_thread_download_example.py new file mode 100644 index 0000000..4c35f53 --- /dev/null +++ b/examples/multi_thread_download_example.py @@ -0,0 +1,30 @@ +# General imports. +import sys + +# Local imports. +sys.path.append('..') +import tartanair as ta + +# Create a TartanAir object. +tartanair_data_root = '/my/path/to/root/folder/for/tartanair-v2' + +ta.init(tartanair_data_root) + +# Download data from following environments. +env = [ "Prison", + "Ruins", + "UrbanConstruction", +] + +ta.download_multi_thread(env = env, + difficulty = ['easy', 'hard'], + modality = ['image', 'depth'], + camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'], + unzip = True, + num_workers = 8) + +# To download the entire dataset +alldata = ta.get_all_data() # this fill in all available data for env, difficulty, modality and camera_name +ta.download_multi_thread(**alldata, + unzip = True, + num_workers = 8) diff --git a/tartanair/downloader.py b/tartanair/downloader.py index c4df357..a7f4676 100644 --- a/tartanair/downloader.py +++ b/tartanair/downloader.py @@ -14,6 +14,8 @@ # Local imports. from .tartanair_module import TartanAirModule, print_error, print_highlight, print_warn from os.path import isdir, isfile, join +from concurrent.futures import ThreadPoolExecutor, as_completed +import time class AirLabDownloader(object): def __init__(self, bucket_name = 'tartanair2') -> None: @@ -26,21 +28,21 @@ def __init__(self, bucket_name = 'tartanair2') -> None: self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=True) self.bucket_name = bucket_name - def download(self, filelist, destination_path): - target_filelist = [] - for source_file_name in filelist: - target_file_name = join(destination_path, source_file_name.replace('/', '_')) - target_filelist.append(target_file_name) + def download(self, filelist, targetfilelist): + success_source_files, success_target_files = [], [] + for source_file_name, target_file_name in zip(filelist, targetfilelist): print('--') if isfile(target_file_name): print_error('Error: Target file {} already exists..'.format(target_file_name)) - return False, None + return False, success_source_files, success_target_files print(f" Downloading {source_file_name} from {self.bucket_name}...") self.client.fget_object(self.bucket_name, source_file_name, target_file_name) print(f" Successfully downloaded {source_file_name} to {target_file_name}!") + success_source_files.append(source_file_name) + success_target_files.append(target_file_name) - return True, target_filelist + return True, success_source_files, success_target_files class CloudFlareDownloader(object): def __init__(self, bucket_name = "tartanair-v2") -> None: @@ -54,7 +56,7 @@ def __init__(self, bucket_name = "tartanair-v2") -> None: aws_secret_access_key=secret_key, endpoint_url=endpoint_url) - def download(self, filelist, destination_path): + def download(self, filelist, targetfilelist): """ Downloads a file from Cloudflare R2 storage using S3 API. @@ -67,26 +69,29 @@ def download(self, filelist, destination_path): - str: A message indicating success or failure. """ - from botocore.exceptions import NoCredentialsError - target_filelist = [] - for source_file_name in filelist: - target_file_name = join(destination_path, source_file_name.replace('/', '_')) - target_filelist.append(target_file_name) + from botocore.exceptions import NoCredentialsError, ClientError + success_source_files, success_target_files = [], [] + for source_file_name, target_file_name in zip(filelist, targetfilelist): print('--') if isfile(target_file_name): print_error('Error: Target file {} already exists..'.format(target_file_name)) - return False, None + return False, success_source_files, success_target_files try: print(f" Downloading {source_file_name} from {self.bucket_name}...") self.s3.download_file(self.bucket_name, source_file_name, target_file_name) print(f" Successfully downloaded {source_file_name} to {target_file_name}!") - except FileNotFoundError: + success_source_files.append(source_file_name) + success_target_files.append(target_file_name) + except ClientError: print_error(f"Error: The file {source_file_name} was not found in the bucket {self.bucket_name}.") - return False, None + return False, success_source_files, success_target_files except NoCredentialsError: print_error("Error: Credentials not available.") - return False, None - return True, target_filelist + return False, success_source_files, success_target_files + except Exception: + print_error("Error: Failed for some reason.") + return False, success_source_files, success_target_files + return True, success_source_files, success_target_files def get_all_s3_objects(self): continuation_token = None @@ -169,7 +174,7 @@ def unzip_files(self, zipfilelist): os.system(cmd) print_highlight("Unzipping Completed! ") - def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, **kwargs): + def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, **kwargs): """ Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion. @@ -223,11 +228,50 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c if not self.doublecheck_filelist(zipfilelist): return False - suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root) + # generate the target file list: + targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist] + all_success_filelist = [] + + suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist) + all_success_filelist.extend(success_target_files) + + # download failed files untill success + trail_count = 0 + while not suc: + zipfilelist = [ff for ff in zipfilelist if ff not in success_source_files] + if len(zipfilelist) == 0: + print_warn("No failed files are found! ") + break + + targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist] + suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist) + all_success_filelist.extend(success_target_files) + trail_count += 1 + if trail_count >= max_failure_trial: + break + if suc: print_highlight("Download completed! Enjoy using TartanAir!") + else: + print_warn("Download with failure! The following files are not downloaded ..") + for ff in zipfilelist: + print_warn(ff) if unzip: - self.unzip_files(targetfilelist) + self.unzip_files(all_success_filelist) return True + + def download_multi_thread(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, num_workers = 8, **kwargs): + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for ee in env: + for dd in difficulty: + futures.append(executor.submit(self.download, env = [ee], difficulty = [dd], modality = modality, camera_name = camera_name, + config = config, unzip = unzip, max_failure_trial = max_failure_trial,)) + # Wait for a few seconds to avoid overloading the data server + time.sleep(2) + + # Wait for all futures to complete + for future in as_completed(futures): + future.result() # This will re-raise any exceptions caught during the futures' execution \ No newline at end of file diff --git a/tartanair/tartanair.py b/tartanair/tartanair.py index cfed869..511596a 100644 --- a/tartanair/tartanair.py +++ b/tartanair/tartanair.py @@ -84,7 +84,14 @@ def init(tartanair_root): is_init = True return True - + +def get_all_data(): + global downloader + return {"env": downloader.env_names, + "difficulty": downloader.difficulty_names, + "modality": downloader.modality_names, + "camera_name": downloader.camera_names, + } def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False): """ @@ -109,6 +116,29 @@ def download(env = [], difficulty = [], modality = [], camera_name = [], config check_init() downloader.download(env, difficulty, modality, camera_name, config, unzip) +def download_multi_thread(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, num_workers = 8): + """ + Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory. + + :param env: The environment to download. Can be a list of environments. + :type env: str or list + :param difficulty: The difficulty of the trajectory. Can be a list of difficulties. Valid difficulties are: easy, hard. + :type difficulty: str or list + :param trajectory_id: The id of the trajectory to download. Can be a list of trajectory ids of form P000, P001, etc. + :type trajectory_id: str or list + :param modality: The modality to download. Can be a list of modalities. Valid modalities are: image, depth, seg, imu{_acc, _gyro, _time, ...}, lidar. Default will include all. + :type modality: str or list + :param camera_name: The camera name to download. Can be a list of camera names. Default will include all. Choices are `lcam_front`, `lcam_right`, `lcam_back`, `lcam_left`, `lcam_top`, `lcam_bottom`, `rcam_front`, `rcam_right`, `rcam_back`, `rcam_left`, `rcam_top`, `rcam_bottom`, `lcam_fish`, `rcam_fish`, `lcam_equirect`, `rcam_equirect`. + Modalities IMU and LIDAR do not need camera names specified. + :type camera_name: str or list + :param config: Optional. Path to a yaml file containing the download configuration. If a config file is provided, the other arguments will be ignored. + :type config: str + """ + + global downloader + check_init() + downloader.download_multi_thread(env = env, difficulty = difficulty, modality = modality, camera_name = camera_name, config = config, unzip = unzip, num_workers = num_workers) + def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"): """ Synthesizes raw data into new camera-models. A few camera models are provided, although you can also provide your own camera models. The currently available camera models are: @@ -399,7 +429,6 @@ def evaluate_traj(est_traj, :return: A dictionary containing the evaluation metrics, which include ATE, RPE, the ground truth trajectory, and the estimated trajectory after alignment and scaling if those were requested :rtype: dict - """ global evaluator check_init() @@ -424,4 +453,4 @@ def evaluate_traj(est_traj, # """ # global random_accessor # check_init() -# return random_accessor \ No newline at end of file +# return random_accessor