diff --git a/examples/multi_thread_download_example.py b/examples/multi_thread_download_example.py index 47d782f..91b2fd8 100644 --- a/examples/multi_thread_download_example.py +++ b/examples/multi_thread_download_example.py @@ -1,120 +1,24 @@ -import os -import time -import logging -import argparse -from concurrent.futures import ThreadPoolExecutor, as_completed +# General imports. +import sys +# Local imports. +sys.path.append('..') import tartanair as ta +# Create a TartanAir object. +tartanair_data_root = '/my/path/to/root/folder/for/tartanair-v2' -def download_dataset(env, modality, cam_name): - try: - # Attempt to download the dataset - success, filelist = ta.download(env=env, - difficulty=['easy', 'hard'], - modality=modality, - camera_name=cam_name, - unzip=False) - except Exception as e: - logging.error(f"Failed to download {env} {modality} {cam_name}: {e}") +ta.init(tartanair_data_root) +# Download data from following environments. +env = [ "Prison", + "Ruins", + "UrbanConstruction", +] -def download_all_in_parallel(trajectories, modalities, num_workers): - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for env in trajectories: - for modality in modalities: - if modality in ['imu', 'lidar', 'flow']: - cam_names = ["lcam_front"] - else: - cam_names = ["lcam_back", "lcam_bottom", "lcam_equirect", "lcam_fish", "lcam_front", - "lcam_left", "lcam_right", "lcam_top", "rcam_back", "rcam_bottom", - "rcam_equirect", "rcam_fish", "rcam_front", "rcam_left", "rcam_right", "rcam_top"] - for cam_name in cam_names: - futures.append(executor.submit(download_dataset, env, modality, cam_name)) - # Wait for a few seconds to avoid overloading the data server - time.sleep(10) - - # Wait for all futures to complete - for future in as_completed(futures): - future.result() # This will re-raise any exceptions caught during the futures' execution - - -def retry_failed_downloads(error_log_path, num_workers): - # Read list of environments, modalities and camera names from the error log - trajectories = [] - modalities = [] - cam_names = [] - with open(error_log_path, 'r') as f: - for line in f: - env, modality, cam_name = line.split(" ")[4:7] - cam_name = cam_name.replace(":", "") - trajectories.append(env) - modalities.append(modality) - cam_names.append(cam_name) - # Download data in parallel - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [] - for data_idx in range(len(trajectories)): - env = trajectories[data_idx] - modality = modalities[data_idx] - cam_name = cam_names[data_idx] - futures.append(executor.submit(download_dataset, env, modality, cam_name)) - # Wait for a few seconds to avoid overloading the data server - time.sleep(10) - - # Wait for all futures to complete - for future in as_completed(futures): - future.result() # This will re-raise any exceptions caught during the futures' execution - - -def parse_arguments(): - parser = argparse.ArgumentParser(description="Download TartanAir datasets.") - parser.add_argument("--data_root", type=str, required=True, help="Root directory for TartanAir data.") - parser.add_argument("--retry_failed", action='store_true', help="Retry failed downloads.") - parser.add_argument("--error_log_name", type=str, default="error_log.txt", help="Name of the error log file.") - parser.add_argument("--error_log_path", type=str, default="", help="Path to store the error log file.") - parser.add_argument("--num_workers", type=int, default=24, help="Number of workers for parallel downloads.") - - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_arguments() - error_log_path = args.error_log_path if args.error_log_path else '.' - error_log_file = os.path.join(error_log_path, args.error_log_name) - - # Create the log directory if it doesn't exist - if not os.path.exists(error_log_path): - os.makedirs(error_log_path) - - # Setup logging - logging.basicConfig(filename=error_log_file, level=logging.ERROR, format='%(asctime)s:%(levelname)s:%(message)s') - - # Initialize TartanAir Module. - tartanair_data_root = args.data_root - ta.init(tartanair_data_root) - - # Define Trajectories and Modalities to be downloaded - trajectories = [ - "AbandonedCable", "AbandonedFactory", "AbandonedFactory2", "AbandonedSchool", - "AmericanDiner", "AmusementPark", "AncientTowns", "Antiquity3D", "Apocalyptic", - "ArchVizTinyHouseDay", "ArchVizTinyHouseNight", "BrushifyMoon", "CarWelding", - "CastleFortress", "CoalMine", "ConstructionSite", "CountryHouse", "CyberPunkDowntown", - "Cyberpunk", "DesertGasStation", "Downtown", "EndofTheWorld", "FactoryWeather", "Fantasy", - "ForestEnv", "Gascola", "GothicIsland", "GreatMarsh", "HQWesternSaloon", "HongKong", "Hospital", - "House", "IndustrialHangar", "JapaneseAlley", "JapaneseCity", "MiddleEast", "ModUrbanCity", - "ModernCityDowntown", "ModularNeighborhood", "ModularNeighborhoodIntExt", "NordicHarbor", - "Ocean", "Office", "OldBrickHouseDay", "OldBrickHouseNight", "OldIndustrialCity", "OldScandinavia", - "OldTownFall", "OldTownNight", "OldTownSummer", "OldTownWinter", "PolarSciFi", "Prison", "Restaurant", - "RetroOffice", "Rome", "Ruins", "SeasideTown", "SeasonalForestAutumn", "SeasonalForestSpring", - "SeasonalForestSummerNight", "SeasonalForestWinter", "SeasonalForestWinterNight", "Sewerage", - "ShoreCaves", "Slaughter", "SoulCity", "Supermarket", "TerrainBlending", "UrbanConstruction", - "VictorianStreet", "WaterMillDay", "WaterMillNight", "WesternDesertTown" - ] - modalities = ['imu', 'lidar', 'flow', 'image', 'depth', 'seg'] - - download_all_in_parallel(trajectories, modalities, args.num_workers) - - if args.retry_failed: - retry_failed_downloads(error_log_file, args.num_workers) +ta.download_multi_thread(env = env, + difficulty = ['easy', 'hard'], + modality = ['image', 'depth'], + camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'], + unzip = True, + num_workers = 8) diff --git a/tartanair/data_cacher b/tartanair/data_cacher index 69e887d..50d8965 160000 --- a/tartanair/data_cacher +++ b/tartanair/data_cacher @@ -1 +1 @@ -Subproject commit 69e887d87b9f91c30e62a12b6f004bd0419bb1db +Subproject commit 50d89654c5a5bd76d341b6c7e60afe740b435a39 diff --git a/tartanair/downloader.py b/tartanair/downloader.py index fd5d8a3..a7f4676 100644 --- a/tartanair/downloader.py +++ b/tartanair/downloader.py @@ -14,6 +14,8 @@ # Local imports. from .tartanair_module import TartanAirModule, print_error, print_highlight, print_warn from os.path import isdir, isfile, join +from concurrent.futures import ThreadPoolExecutor, as_completed +import time class AirLabDownloader(object): def __init__(self, bucket_name = 'tartanair2') -> None: @@ -23,30 +25,30 @@ def __init__(self, bucket_name = 'tartanair2') -> None: access_key = "4e54CkGDFg2RmPjaQYmW" secret_key = "mKdGwketlYUcXQwcPxuzinSxJazoyMpAip47zYdl" - self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=False) + self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=True) self.bucket_name = bucket_name - def download(self, filelist, destination_path): - target_filelist = [] - for source_file_name in filelist: - target_file_name = join(destination_path, source_file_name.replace('/', '_')) - target_filelist.append(target_file_name) + def download(self, filelist, targetfilelist): + success_source_files, success_target_files = [], [] + for source_file_name, target_file_name in zip(filelist, targetfilelist): print('--') if isfile(target_file_name): print_error('Error: Target file {} already exists..'.format(target_file_name)) - return False, None + return False, success_source_files, success_target_files print(f" Downloading {source_file_name} from {self.bucket_name}...") - self.client.fput_object(self.bucket_name, target_file_name, source_file_name) + self.client.fget_object(self.bucket_name, source_file_name, target_file_name) print(f" Successfully downloaded {source_file_name} to {target_file_name}!") + success_source_files.append(source_file_name) + success_target_files.append(target_file_name) - return True, target_filelist + return True, success_source_files, success_target_files class CloudFlareDownloader(object): def __init__(self, bucket_name = "tartanair-v2") -> None: import boto3 - access_key = "be0116e42ced3fd52c32398b5003ecda" - secret_key = "103fab752dab348fa665dc744be9b8fb6f9cf04f82f9409d79c54a88661a0d40" + access_key = "f1ae9efebbc6a9a7cebbd949ba3a12de" + secret_key = "0a21fe771089d82e048ed0a1dd6067cb29a5666bf4fe95f7be9ba6f72482ec8b" endpoint_url = "https://0a585e9484af268a716f8e6d3be53bbc.r2.cloudflarestorage.com" self.bucket_name = bucket_name @@ -54,7 +56,7 @@ def __init__(self, bucket_name = "tartanair-v2") -> None: aws_secret_access_key=secret_key, endpoint_url=endpoint_url) - def download(self, filelist, destination_path): + def download(self, filelist, targetfilelist): """ Downloads a file from Cloudflare R2 storage using S3 API. @@ -67,26 +69,29 @@ def download(self, filelist, destination_path): - str: A message indicating success or failure. """ - from botocore.exceptions import NoCredentialsError - target_filelist = [] - for source_file_name in filelist: - target_file_name = join(destination_path, source_file_name.replace('/', '_')) - target_filelist.append(target_file_name) + from botocore.exceptions import NoCredentialsError, ClientError + success_source_files, success_target_files = [], [] + for source_file_name, target_file_name in zip(filelist, targetfilelist): print('--') if isfile(target_file_name): print_error('Error: Target file {} already exists..'.format(target_file_name)) - return False, None + return False, success_source_files, success_target_files try: print(f" Downloading {source_file_name} from {self.bucket_name}...") self.s3.download_file(self.bucket_name, source_file_name, target_file_name) print(f" Successfully downloaded {source_file_name} to {target_file_name}!") - except FileNotFoundError: + success_source_files.append(source_file_name) + success_target_files.append(target_file_name) + except ClientError: print_error(f"Error: The file {source_file_name} was not found in the bucket {self.bucket_name}.") - return False, None + return False, success_source_files, success_target_files except NoCredentialsError: print_error("Error: Credentials not available.") - return False, None - return True, target_filelist + return False, success_source_files, success_target_files + except Exception: + print_error("Error: Failed for some reason.") + return False, success_source_files, success_target_files + return True, success_source_files, success_target_files def get_all_s3_objects(self): continuation_token = None @@ -169,7 +174,7 @@ def unzip_files(self, zipfilelist): os.system(cmd) print_highlight("Unzipping Completed! ") - def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, download = True, **kwargs): + def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, **kwargs): """ Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion. @@ -207,32 +212,66 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c # Check that the environments are valid. if not self.check_env_valid(env): - return False, None + return False # Check that the modalities are valid if not self.check_modality_valid(modality): - return False, None + return False # Check that the difficulty are valid if not self.check_difficulty_valid(difficulty): - return False, None + return False # Check that the camera names are valid if not self.check_camera_valid(camera_name): - return False, None + return False zipfilelist = self.generate_filelist(env, difficulty, modality, camera_name) + # import ipdb;ipdb.set_trace() if not self.doublecheck_filelist(zipfilelist): - return False, None + return False - if download: - suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root) - if suc: - print_highlight("Download completed! Enjoy using TartanAir!") + # generate the target file list: + targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist] + all_success_filelist = [] - if unzip: - self.unzip_files(targetfilelist) + suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist) + all_success_filelist.extend(success_target_files) + + # download failed files untill success + trail_count = 0 + while not suc: + zipfilelist = [ff for ff in zipfilelist if ff not in success_source_files] + if len(zipfilelist) == 0: + print_warn("No failed files are found! ") + break + + targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist] + suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist) + all_success_filelist.extend(success_target_files) + trail_count += 1 + if trail_count >= max_failure_trial: + break + + if suc: + print_highlight("Download completed! Enjoy using TartanAir!") else: - targetfilelist = [] - for source_file_name in zipfilelist: - target_file_name = source_file_name.replace('/', '_') - targetfilelist.append(target_file_name) + print_warn("Download with failure! The following files are not downloaded ..") + for ff in zipfilelist: + print_warn(ff) - return True, targetfilelist + if unzip: + self.unzip_files(all_success_filelist) + + return True + + def download_multi_thread(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, num_workers = 8, **kwargs): + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for ee in env: + for dd in difficulty: + futures.append(executor.submit(self.download, env = [ee], difficulty = [dd], modality = modality, camera_name = camera_name, + config = config, unzip = unzip, max_failure_trial = max_failure_trial,)) + # Wait for a few seconds to avoid overloading the data server + time.sleep(2) + + # Wait for all futures to complete + for future in as_completed(futures): + future.result() # This will re-raise any exceptions caught during the futures' execution \ No newline at end of file diff --git a/tartanair/tartanair.py b/tartanair/tartanair.py index e547bea..e461a91 100644 --- a/tartanair/tartanair.py +++ b/tartanair/tartanair.py @@ -86,7 +86,7 @@ def init(tartanair_root): return True -def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False, download = True): +def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False): """ Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory. @@ -107,8 +107,30 @@ def download(env = [], difficulty = [], trajectory_id = [], modality = [], camer global downloader check_init() - success, filelist = downloader.download(env, difficulty, modality, camera_name, config, unzip, download) - return success, filelist + downloader.download(env, difficulty, modality, camera_name, config, unzip) + +def download_multi_thread(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, num_workers = 8): + """ + Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory. + + :param env: The environment to download. Can be a list of environments. + :type env: str or list + :param difficulty: The difficulty of the trajectory. Can be a list of difficulties. Valid difficulties are: easy, hard. + :type difficulty: str or list + :param trajectory_id: The id of the trajectory to download. Can be a list of trajectory ids of form P000, P001, etc. + :type trajectory_id: str or list + :param modality: The modality to download. Can be a list of modalities. Valid modalities are: image, depth, seg, imu{_acc, _gyro, _time, ...}, lidar. Default will include all. + :type modality: str or list + :param camera_name: The camera name to download. Can be a list of camera names. Default will include all. Choices are `lcam_front`, `lcam_right`, `lcam_back`, `lcam_left`, `lcam_top`, `lcam_bottom`, `rcam_front`, `rcam_right`, `rcam_back`, `rcam_left`, `rcam_top`, `rcam_bottom`, `lcam_fish`, `rcam_fish`, `lcam_equirect`, `rcam_equirect`. + Modalities IMU and LIDAR do not need camera names specified. + :type camera_name: str or list + :param config: Optional. Path to a yaml file containing the download configuration. If a config file is provided, the other arguments will be ignored. + :type config: str + """ + + global downloader + check_init() + downloader.download_multi_thread(env = env, difficulty = difficulty, modality = modality, camera_name = camera_name, config = config, unzip = unzip, num_workers = num_workers) def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"): """