From d1df02b1653449c0c5ccbf061048ef54bba05c36 Mon Sep 17 00:00:00 2001 From: Nik-V9 Date: Thu, 22 Aug 2024 18:25:52 +0000 Subject: [PATCH 1/3] Add support for returning filelist and multi-thread download --- .gitignore | 3 + examples/multi_thread_download_example.py | 120 ++++++++++++++++++++++ tartanair/downloader.py | 37 ++++--- tartanair/tartanair.py | 7 +- 4 files changed, 148 insertions(+), 19 deletions(-) create mode 100644 examples/multi_thread_download_example.py diff --git a/.gitignore b/.gitignore index ee74ee0..f8a9584 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Local +local/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/examples/multi_thread_download_example.py b/examples/multi_thread_download_example.py new file mode 100644 index 0000000..47d782f --- /dev/null +++ b/examples/multi_thread_download_example.py @@ -0,0 +1,120 @@ +import os +import time +import logging +import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed + +import tartanair as ta + + +def download_dataset(env, modality, cam_name): + try: + # Attempt to download the dataset + success, filelist = ta.download(env=env, + difficulty=['easy', 'hard'], + modality=modality, + camera_name=cam_name, + unzip=False) + except Exception as e: + logging.error(f"Failed to download {env} {modality} {cam_name}: {e}") + + +def download_all_in_parallel(trajectories, modalities, num_workers): + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for env in trajectories: + for modality in modalities: + if modality in ['imu', 'lidar', 'flow']: + cam_names = ["lcam_front"] + else: + cam_names = ["lcam_back", "lcam_bottom", "lcam_equirect", "lcam_fish", "lcam_front", + "lcam_left", "lcam_right", "lcam_top", "rcam_back", "rcam_bottom", + "rcam_equirect", "rcam_fish", "rcam_front", "rcam_left", "rcam_right", "rcam_top"] + for cam_name in cam_names: + futures.append(executor.submit(download_dataset, env, modality, cam_name)) + # Wait for a few seconds to avoid overloading the data server + time.sleep(10) + + # Wait for all futures to complete + for future in as_completed(futures): + future.result() # This will re-raise any exceptions caught during the futures' execution + + +def retry_failed_downloads(error_log_path, num_workers): + # Read list of environments, modalities and camera names from the error log + trajectories = [] + modalities = [] + cam_names = [] + with open(error_log_path, 'r') as f: + for line in f: + env, modality, cam_name = line.split(" ")[4:7] + cam_name = cam_name.replace(":", "") + trajectories.append(env) + modalities.append(modality) + cam_names.append(cam_name) + # Download data in parallel + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for data_idx in range(len(trajectories)): + env = trajectories[data_idx] + modality = modalities[data_idx] + cam_name = cam_names[data_idx] + futures.append(executor.submit(download_dataset, env, modality, cam_name)) + # Wait for a few seconds to avoid overloading the data server + time.sleep(10) + + # Wait for all futures to complete + for future in as_completed(futures): + future.result() # This will re-raise any exceptions caught during the futures' execution + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Download TartanAir datasets.") + parser.add_argument("--data_root", type=str, required=True, help="Root directory for TartanAir data.") + parser.add_argument("--retry_failed", action='store_true', help="Retry failed downloads.") + parser.add_argument("--error_log_name", type=str, default="error_log.txt", help="Name of the error log file.") + parser.add_argument("--error_log_path", type=str, default="", help="Path to store the error log file.") + parser.add_argument("--num_workers", type=int, default=24, help="Number of workers for parallel downloads.") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + error_log_path = args.error_log_path if args.error_log_path else '.' + error_log_file = os.path.join(error_log_path, args.error_log_name) + + # Create the log directory if it doesn't exist + if not os.path.exists(error_log_path): + os.makedirs(error_log_path) + + # Setup logging + logging.basicConfig(filename=error_log_file, level=logging.ERROR, format='%(asctime)s:%(levelname)s:%(message)s') + + # Initialize TartanAir Module. + tartanair_data_root = args.data_root + ta.init(tartanair_data_root) + + # Define Trajectories and Modalities to be downloaded + trajectories = [ + "AbandonedCable", "AbandonedFactory", "AbandonedFactory2", "AbandonedSchool", + "AmericanDiner", "AmusementPark", "AncientTowns", "Antiquity3D", "Apocalyptic", + "ArchVizTinyHouseDay", "ArchVizTinyHouseNight", "BrushifyMoon", "CarWelding", + "CastleFortress", "CoalMine", "ConstructionSite", "CountryHouse", "CyberPunkDowntown", + "Cyberpunk", "DesertGasStation", "Downtown", "EndofTheWorld", "FactoryWeather", "Fantasy", + "ForestEnv", "Gascola", "GothicIsland", "GreatMarsh", "HQWesternSaloon", "HongKong", "Hospital", + "House", "IndustrialHangar", "JapaneseAlley", "JapaneseCity", "MiddleEast", "ModUrbanCity", + "ModernCityDowntown", "ModularNeighborhood", "ModularNeighborhoodIntExt", "NordicHarbor", + "Ocean", "Office", "OldBrickHouseDay", "OldBrickHouseNight", "OldIndustrialCity", "OldScandinavia", + "OldTownFall", "OldTownNight", "OldTownSummer", "OldTownWinter", "PolarSciFi", "Prison", "Restaurant", + "RetroOffice", "Rome", "Ruins", "SeasideTown", "SeasonalForestAutumn", "SeasonalForestSpring", + "SeasonalForestSummerNight", "SeasonalForestWinter", "SeasonalForestWinterNight", "Sewerage", + "ShoreCaves", "Slaughter", "SoulCity", "Supermarket", "TerrainBlending", "UrbanConstruction", + "VictorianStreet", "WaterMillDay", "WaterMillNight", "WesternDesertTown" + ] + modalities = ['imu', 'lidar', 'flow', 'image', 'depth', 'seg'] + + download_all_in_parallel(trajectories, modalities, args.num_workers) + + if args.retry_failed: + retry_failed_downloads(error_log_file, args.num_workers) diff --git a/tartanair/downloader.py b/tartanair/downloader.py index 8c29afb..5c12ecc 100644 --- a/tartanair/downloader.py +++ b/tartanair/downloader.py @@ -245,7 +245,7 @@ def unzip_files(self, zipfilelist): os.system(cmd) print_highlight("Unzipping Completed! ") - def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, **kwargs): + def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, download = True, **kwargs): """ Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion. @@ -283,27 +283,32 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c # Check that the environments are valid. if not self.check_env_valid(env): - return False + return False, None # Check that the modalities are valid if not self.check_modality_valid(modality): - return False + return False, None # Check that the difficulty are valid if not self.check_difficulty_valid(difficulty): - return False + return False, None # Check that the camera names are valid if not self.check_camera_valid(camera_name): - return False + return False, None zipfilelist = self.generate_filelist(env, difficulty, modality, camera_name) - # import ipdb;ipdb.set_trace() if not self.doublecheck_filelist(zipfilelist): - return False - - suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root) - if suc: - print_highlight("Download completed! Enjoy using TartanAir!") - - if unzip: - self.unzip_files(targetfilelist) - - return True + return False, None + + if download: + suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root) + if suc: + print_highlight("Download completed! Enjoy using TartanAir!") + + if unzip: + self.unzip_files(targetfilelist) + else: + targetfilelist = [] + for source_file_name in zipfilelist: + target_file_name = source_file_name.replace('/', '_') + targetfilelist.append(target_file_name) + + return True, targetfilelist diff --git a/tartanair/tartanair.py b/tartanair/tartanair.py index 45aa2eb..dd450d2 100644 --- a/tartanair/tartanair.py +++ b/tartanair/tartanair.py @@ -80,7 +80,7 @@ def init(tartanair_root): return True -def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False): +def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False, download = True): """ Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory. @@ -101,7 +101,8 @@ def download(env = [], difficulty = [], trajectory_id = [], modality = [], camer global downloader check_init() - downloader.download(env, difficulty, modality, camera_name, config, unzip) + success, filelist = downloader.download(env, difficulty, modality, camera_name, config, unzip, download) + return success, filelist def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"): """ @@ -418,4 +419,4 @@ def get_random_accessor(): """ global random_accessor check_init() - return random_accessor \ No newline at end of file + return random_accessor From 4d52eac830c30a13672f7bee4fa2973304b6c846 Mon Sep 17 00:00:00 2001 From: amigoshan Date: Thu, 5 Dec 2024 14:45:18 -0500 Subject: [PATCH 2/3] unify the interface, add multi-thread and failure checking to the downloader script --- examples/multi_thread_download_example.py | 132 +++------------------- tartanair/data_cacher | 2 +- tartanair/downloader.py | 117 ++++++++++++------- tartanair/tartanair.py | 28 ++++- 4 files changed, 122 insertions(+), 157 deletions(-) diff --git a/examples/multi_thread_download_example.py b/examples/multi_thread_download_example.py index 47d782f..91b2fd8 100644 --- a/examples/multi_thread_download_example.py +++ b/examples/multi_thread_download_example.py @@ -1,120 +1,24 @@ -import os -import time -import logging -import argparse -from concurrent.futures import ThreadPoolExecutor, as_completed +# General imports. +import sys +# Local imports. +sys.path.append('..') import tartanair as ta +# Create a TartanAir object. +tartanair_data_root = '/my/path/to/root/folder/for/tartanair-v2' -def download_dataset(env, modality, cam_name): - try: - # Attempt to download the dataset - success, filelist = ta.download(env=env, - difficulty=['easy', 'hard'], - modality=modality, - camera_name=cam_name, - unzip=False) - except Exception as e: - logging.error(f"Failed to download {env} {modality} {cam_name}: {e}") +ta.init(tartanair_data_root) +# Download data from following environments. +env = [ "Prison", + "Ruins", + "UrbanConstruction", +] -def download_all_in_parallel(trajectories, modalities, num_workers): - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for env in trajectories: - for modality in modalities: - if modality in ['imu', 'lidar', 'flow']: - cam_names = ["lcam_front"] - else: - cam_names = ["lcam_back", "lcam_bottom", "lcam_equirect", "lcam_fish", "lcam_front", - "lcam_left", "lcam_right", "lcam_top", "rcam_back", "rcam_bottom", - "rcam_equirect", "rcam_fish", "rcam_front", "rcam_left", "rcam_right", "rcam_top"] - for cam_name in cam_names: - futures.append(executor.submit(download_dataset, env, modality, cam_name)) - # Wait for a few seconds to avoid overloading the data server - time.sleep(10) - - # Wait for all futures to complete - for future in as_completed(futures): - future.result() # This will re-raise any exceptions caught during the futures' execution - - -def retry_failed_downloads(error_log_path, num_workers): - # Read list of environments, modalities and camera names from the error log - trajectories = [] - modalities = [] - cam_names = [] - with open(error_log_path, 'r') as f: - for line in f: - env, modality, cam_name = line.split(" ")[4:7] - cam_name = cam_name.replace(":", "") - trajectories.append(env) - modalities.append(modality) - cam_names.append(cam_name) - # Download data in parallel - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for data_idx in range(len(trajectories)): - env = trajectories[data_idx] - modality = modalities[data_idx] - cam_name = cam_names[data_idx] - futures.append(executor.submit(download_dataset, env, modality, cam_name)) - # Wait for a few seconds to avoid overloading the data server - time.sleep(10) - - # Wait for all futures to complete - for future in as_completed(futures): - future.result() # This will re-raise any exceptions caught during the futures' execution - - -def parse_arguments(): - parser = argparse.ArgumentParser(description="Download TartanAir datasets.") - parser.add_argument("--data_root", type=str, required=True, help="Root directory for TartanAir data.") - parser.add_argument("--retry_failed", action='store_true', help="Retry failed downloads.") - parser.add_argument("--error_log_name", type=str, default="error_log.txt", help="Name of the error log file.") - parser.add_argument("--error_log_path", type=str, default="", help="Path to store the error log file.") - parser.add_argument("--num_workers", type=int, default=24, help="Number of workers for parallel downloads.") - - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_arguments() - error_log_path = args.error_log_path if args.error_log_path else '.' - error_log_file = os.path.join(error_log_path, args.error_log_name) - - # Create the log directory if it doesn't exist - if not os.path.exists(error_log_path): - os.makedirs(error_log_path) - - # Setup logging - logging.basicConfig(filename=error_log_file, level=logging.ERROR, format='%(asctime)s:%(levelname)s:%(message)s') - - # Initialize TartanAir Module. - tartanair_data_root = args.data_root - ta.init(tartanair_data_root) - - # Define Trajectories and Modalities to be downloaded - trajectories = [ - "AbandonedCable", "AbandonedFactory", "AbandonedFactory2", "AbandonedSchool", - "AmericanDiner", "AmusementPark", "AncientTowns", "Antiquity3D", "Apocalyptic", - "ArchVizTinyHouseDay", "ArchVizTinyHouseNight", "BrushifyMoon", "CarWelding", - "CastleFortress", "CoalMine", "ConstructionSite", "CountryHouse", "CyberPunkDowntown", - "Cyberpunk", "DesertGasStation", "Downtown", "EndofTheWorld", "FactoryWeather", "Fantasy", - "ForestEnv", "Gascola", "GothicIsland", "GreatMarsh", "HQWesternSaloon", "HongKong", "Hospital", - "House", "IndustrialHangar", "JapaneseAlley", "JapaneseCity", "MiddleEast", "ModUrbanCity", - "ModernCityDowntown", "ModularNeighborhood", "ModularNeighborhoodIntExt", "NordicHarbor", - "Ocean", "Office", "OldBrickHouseDay", "OldBrickHouseNight", "OldIndustrialCity", "OldScandinavia", - "OldTownFall", "OldTownNight", "OldTownSummer", "OldTownWinter", "PolarSciFi", "Prison", "Restaurant", - "RetroOffice", "Rome", "Ruins", "SeasideTown", "SeasonalForestAutumn", "SeasonalForestSpring", - "SeasonalForestSummerNight", "SeasonalForestWinter", "SeasonalForestWinterNight", "Sewerage", - "ShoreCaves", "Slaughter", "SoulCity", "Supermarket", "TerrainBlending", "UrbanConstruction", - "VictorianStreet", "WaterMillDay", "WaterMillNight", "WesternDesertTown" - ] - modalities = ['imu', 'lidar', 'flow', 'image', 'depth', 'seg'] - - download_all_in_parallel(trajectories, modalities, args.num_workers) - - if args.retry_failed: - retry_failed_downloads(error_log_file, args.num_workers) +ta.download_multi_thread(env = env, + difficulty = ['easy', 'hard'], + modality = ['image', 'depth'], + camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'], + unzip = True, + num_workers = 8) diff --git a/tartanair/data_cacher b/tartanair/data_cacher index 69e887d..50d8965 160000 --- a/tartanair/data_cacher +++ b/tartanair/data_cacher @@ -1 +1 @@ -Subproject commit 69e887d87b9f91c30e62a12b6f004bd0419bb1db +Subproject commit 50d89654c5a5bd76d341b6c7e60afe740b435a39 diff --git a/tartanair/downloader.py b/tartanair/downloader.py index fd5d8a3..a7f4676 100644 --- a/tartanair/downloader.py +++ b/tartanair/downloader.py @@ -14,6 +14,8 @@ # Local imports. from .tartanair_module import TartanAirModule, print_error, print_highlight, print_warn from os.path import isdir, isfile, join +from concurrent.futures import ThreadPoolExecutor, as_completed +import time class AirLabDownloader(object): def __init__(self, bucket_name = 'tartanair2') -> None: @@ -23,30 +25,30 @@ def __init__(self, bucket_name = 'tartanair2') -> None: access_key = "4e54CkGDFg2RmPjaQYmW" secret_key = "mKdGwketlYUcXQwcPxuzinSxJazoyMpAip47zYdl" - self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=False) + self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=True) self.bucket_name = bucket_name - def download(self, filelist, destination_path): - target_filelist = [] - for source_file_name in filelist: - target_file_name = join(destination_path, source_file_name.replace('/', '_')) - target_filelist.append(target_file_name) + def download(self, filelist, targetfilelist): + success_source_files, success_target_files = [], [] + for source_file_name, target_file_name in zip(filelist, targetfilelist): print('--') if isfile(target_file_name): print_error('Error: Target file {} already exists..'.format(target_file_name)) - return False, None + return False, success_source_files, success_target_files print(f" Downloading {source_file_name} from {self.bucket_name}...") - self.client.fput_object(self.bucket_name, target_file_name, source_file_name) + self.client.fget_object(self.bucket_name, source_file_name, target_file_name) print(f" Successfully downloaded {source_file_name} to {target_file_name}!") + success_source_files.append(source_file_name) + success_target_files.append(target_file_name) - return True, target_filelist + return True, success_source_files, success_target_files class CloudFlareDownloader(object): def __init__(self, bucket_name = "tartanair-v2") -> None: import boto3 - access_key = "be0116e42ced3fd52c32398b5003ecda" - secret_key = "103fab752dab348fa665dc744be9b8fb6f9cf04f82f9409d79c54a88661a0d40" + access_key = "f1ae9efebbc6a9a7cebbd949ba3a12de" + secret_key = "0a21fe771089d82e048ed0a1dd6067cb29a5666bf4fe95f7be9ba6f72482ec8b" endpoint_url = "https://0a585e9484af268a716f8e6d3be53bbc.r2.cloudflarestorage.com" self.bucket_name = bucket_name @@ -54,7 +56,7 @@ def __init__(self, bucket_name = "tartanair-v2") -> None: aws_secret_access_key=secret_key, endpoint_url=endpoint_url) - def download(self, filelist, destination_path): + def download(self, filelist, targetfilelist): """ Downloads a file from Cloudflare R2 storage using S3 API. @@ -67,26 +69,29 @@ def download(self, filelist, destination_path): - str: A message indicating success or failure. """ - from botocore.exceptions import NoCredentialsError - target_filelist = [] - for source_file_name in filelist: - target_file_name = join(destination_path, source_file_name.replace('/', '_')) - target_filelist.append(target_file_name) + from botocore.exceptions import NoCredentialsError, ClientError + success_source_files, success_target_files = [], [] + for source_file_name, target_file_name in zip(filelist, targetfilelist): print('--') if isfile(target_file_name): print_error('Error: Target file {} already exists..'.format(target_file_name)) - return False, None + return False, success_source_files, success_target_files try: print(f" Downloading {source_file_name} from {self.bucket_name}...") self.s3.download_file(self.bucket_name, source_file_name, target_file_name) print(f" Successfully downloaded {source_file_name} to {target_file_name}!") - except FileNotFoundError: + success_source_files.append(source_file_name) + success_target_files.append(target_file_name) + except ClientError: print_error(f"Error: The file {source_file_name} was not found in the bucket {self.bucket_name}.") - return False, None + return False, success_source_files, success_target_files except NoCredentialsError: print_error("Error: Credentials not available.") - return False, None - return True, target_filelist + return False, success_source_files, success_target_files + except Exception: + print_error("Error: Failed for some reason.") + return False, success_source_files, success_target_files + return True, success_source_files, success_target_files def get_all_s3_objects(self): continuation_token = None @@ -169,7 +174,7 @@ def unzip_files(self, zipfilelist): os.system(cmd) print_highlight("Unzipping Completed! ") - def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, download = True, **kwargs): + def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, **kwargs): """ Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion. @@ -207,32 +212,66 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c # Check that the environments are valid. if not self.check_env_valid(env): - return False, None + return False # Check that the modalities are valid if not self.check_modality_valid(modality): - return False, None + return False # Check that the difficulty are valid if not self.check_difficulty_valid(difficulty): - return False, None + return False # Check that the camera names are valid if not self.check_camera_valid(camera_name): - return False, None + return False zipfilelist = self.generate_filelist(env, difficulty, modality, camera_name) + # import ipdb;ipdb.set_trace() if not self.doublecheck_filelist(zipfilelist): - return False, None + return False - if download: - suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root) - if suc: - print_highlight("Download completed! Enjoy using TartanAir!") + # generate the target file list: + targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist] + all_success_filelist = [] - if unzip: - self.unzip_files(targetfilelist) + suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist) + all_success_filelist.extend(success_target_files) + + # download failed files untill success + trail_count = 0 + while not suc: + zipfilelist = [ff for ff in zipfilelist if ff not in success_source_files] + if len(zipfilelist) == 0: + print_warn("No failed files are found! ") + break + + targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist] + suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist) + all_success_filelist.extend(success_target_files) + trail_count += 1 + if trail_count >= max_failure_trial: + break + + if suc: + print_highlight("Download completed! Enjoy using TartanAir!") else: - targetfilelist = [] - for source_file_name in zipfilelist: - target_file_name = source_file_name.replace('/', '_') - targetfilelist.append(target_file_name) + print_warn("Download with failure! The following files are not downloaded ..") + for ff in zipfilelist: + print_warn(ff) - return True, targetfilelist + if unzip: + self.unzip_files(all_success_filelist) + + return True + + def download_multi_thread(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, num_workers = 8, **kwargs): + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for ee in env: + for dd in difficulty: + futures.append(executor.submit(self.download, env = [ee], difficulty = [dd], modality = modality, camera_name = camera_name, + config = config, unzip = unzip, max_failure_trial = max_failure_trial,)) + # Wait for a few seconds to avoid overloading the data server + time.sleep(2) + + # Wait for all futures to complete + for future in as_completed(futures): + future.result() # This will re-raise any exceptions caught during the futures' execution \ No newline at end of file diff --git a/tartanair/tartanair.py b/tartanair/tartanair.py index e547bea..e461a91 100644 --- a/tartanair/tartanair.py +++ b/tartanair/tartanair.py @@ -86,7 +86,7 @@ def init(tartanair_root): return True -def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False, download = True): +def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False): """ Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory. @@ -107,8 +107,30 @@ def download(env = [], difficulty = [], trajectory_id = [], modality = [], camer global downloader check_init() - success, filelist = downloader.download(env, difficulty, modality, camera_name, config, unzip, download) - return success, filelist + downloader.download(env, difficulty, modality, camera_name, config, unzip) + +def download_multi_thread(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, num_workers = 8): + """ + Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory. + + :param env: The environment to download. Can be a list of environments. + :type env: str or list + :param difficulty: The difficulty of the trajectory. Can be a list of difficulties. Valid difficulties are: easy, hard. + :type difficulty: str or list + :param trajectory_id: The id of the trajectory to download. Can be a list of trajectory ids of form P000, P001, etc. + :type trajectory_id: str or list + :param modality: The modality to download. Can be a list of modalities. Valid modalities are: image, depth, seg, imu{_acc, _gyro, _time, ...}, lidar. Default will include all. + :type modality: str or list + :param camera_name: The camera name to download. Can be a list of camera names. Default will include all. Choices are `lcam_front`, `lcam_right`, `lcam_back`, `lcam_left`, `lcam_top`, `lcam_bottom`, `rcam_front`, `rcam_right`, `rcam_back`, `rcam_left`, `rcam_top`, `rcam_bottom`, `lcam_fish`, `rcam_fish`, `lcam_equirect`, `rcam_equirect`. + Modalities IMU and LIDAR do not need camera names specified. + :type camera_name: str or list + :param config: Optional. Path to a yaml file containing the download configuration. If a config file is provided, the other arguments will be ignored. + :type config: str + """ + + global downloader + check_init() + downloader.download_multi_thread(env = env, difficulty = difficulty, modality = modality, camera_name = camera_name, config = config, unzip = unzip, num_workers = num_workers) def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"): """ From 27bd62e95557fa70fa625f6b715e9a3937376ff1 Mon Sep 17 00:00:00 2001 From: amigoshan Date: Thu, 5 Dec 2024 17:24:51 -0500 Subject: [PATCH 3/3] add example for all data --- examples/multi_thread_download_example.py | 6 ++++++ tartanair/tartanair.py | 9 ++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/multi_thread_download_example.py b/examples/multi_thread_download_example.py index 91b2fd8..4c35f53 100644 --- a/examples/multi_thread_download_example.py +++ b/examples/multi_thread_download_example.py @@ -22,3 +22,9 @@ camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'], unzip = True, num_workers = 8) + +# To download the entire dataset +alldata = ta.get_all_data() # this fill in all available data for env, difficulty, modality and camera_name +ta.download_multi_thread(**alldata, + unzip = True, + num_workers = 8) diff --git a/tartanair/tartanair.py b/tartanair/tartanair.py index e461a91..511596a 100644 --- a/tartanair/tartanair.py +++ b/tartanair/tartanair.py @@ -84,7 +84,14 @@ def init(tartanair_root): is_init = True return True - + +def get_all_data(): + global downloader + return {"env": downloader.env_names, + "difficulty": downloader.difficulty_names, + "modality": downloader.modality_names, + "camera_name": downloader.camera_names, + } def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False): """