Skip to content

Commit

Permalink
unify the interface, add multi-thread and failure checking to the dow…
Browse files Browse the repository at this point in the history
…nloader script
  • Loading branch information
theairlab committed Dec 5, 2024
1 parent 57615f1 commit 4d52eac
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 157 deletions.
132 changes: 18 additions & 114 deletions examples/multi_thread_download_example.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,24 @@
import os
import time
import logging
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
# General imports.
import sys

# Local imports.
sys.path.append('..')
import tartanair as ta

# Create a TartanAir object.
tartanair_data_root = '/my/path/to/root/folder/for/tartanair-v2'

def download_dataset(env, modality, cam_name):
try:
# Attempt to download the dataset
success, filelist = ta.download(env=env,
difficulty=['easy', 'hard'],
modality=modality,
camera_name=cam_name,
unzip=False)
except Exception as e:
logging.error(f"Failed to download {env} {modality} {cam_name}: {e}")
ta.init(tartanair_data_root)

# Download data from following environments.
env = [ "Prison",
"Ruins",
"UrbanConstruction",
]

def download_all_in_parallel(trajectories, modalities, num_workers):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for env in trajectories:
for modality in modalities:
if modality in ['imu', 'lidar', 'flow']:
cam_names = ["lcam_front"]
else:
cam_names = ["lcam_back", "lcam_bottom", "lcam_equirect", "lcam_fish", "lcam_front",
"lcam_left", "lcam_right", "lcam_top", "rcam_back", "rcam_bottom",
"rcam_equirect", "rcam_fish", "rcam_front", "rcam_left", "rcam_right", "rcam_top"]
for cam_name in cam_names:
futures.append(executor.submit(download_dataset, env, modality, cam_name))
# Wait for a few seconds to avoid overloading the data server
time.sleep(10)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution


def retry_failed_downloads(error_log_path, num_workers):
# Read list of environments, modalities and camera names from the error log
trajectories = []
modalities = []
cam_names = []
with open(error_log_path, 'r') as f:
for line in f:
env, modality, cam_name = line.split(" ")[4:7]
cam_name = cam_name.replace(":", "")
trajectories.append(env)
modalities.append(modality)
cam_names.append(cam_name)
# Download data in parallel
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for data_idx in range(len(trajectories)):
env = trajectories[data_idx]
modality = modalities[data_idx]
cam_name = cam_names[data_idx]
futures.append(executor.submit(download_dataset, env, modality, cam_name))
# Wait for a few seconds to avoid overloading the data server
time.sleep(10)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution


def parse_arguments():
parser = argparse.ArgumentParser(description="Download TartanAir datasets.")
parser.add_argument("--data_root", type=str, required=True, help="Root directory for TartanAir data.")
parser.add_argument("--retry_failed", action='store_true', help="Retry failed downloads.")
parser.add_argument("--error_log_name", type=str, default="error_log.txt", help="Name of the error log file.")
parser.add_argument("--error_log_path", type=str, default="", help="Path to store the error log file.")
parser.add_argument("--num_workers", type=int, default=24, help="Number of workers for parallel downloads.")

return parser.parse_args()


if __name__ == "__main__":
args = parse_arguments()
error_log_path = args.error_log_path if args.error_log_path else '.'
error_log_file = os.path.join(error_log_path, args.error_log_name)

# Create the log directory if it doesn't exist
if not os.path.exists(error_log_path):
os.makedirs(error_log_path)

# Setup logging
logging.basicConfig(filename=error_log_file, level=logging.ERROR, format='%(asctime)s:%(levelname)s:%(message)s')

# Initialize TartanAir Module.
tartanair_data_root = args.data_root
ta.init(tartanair_data_root)

# Define Trajectories and Modalities to be downloaded
trajectories = [
"AbandonedCable", "AbandonedFactory", "AbandonedFactory2", "AbandonedSchool",
"AmericanDiner", "AmusementPark", "AncientTowns", "Antiquity3D", "Apocalyptic",
"ArchVizTinyHouseDay", "ArchVizTinyHouseNight", "BrushifyMoon", "CarWelding",
"CastleFortress", "CoalMine", "ConstructionSite", "CountryHouse", "CyberPunkDowntown",
"Cyberpunk", "DesertGasStation", "Downtown", "EndofTheWorld", "FactoryWeather", "Fantasy",
"ForestEnv", "Gascola", "GothicIsland", "GreatMarsh", "HQWesternSaloon", "HongKong", "Hospital",
"House", "IndustrialHangar", "JapaneseAlley", "JapaneseCity", "MiddleEast", "ModUrbanCity",
"ModernCityDowntown", "ModularNeighborhood", "ModularNeighborhoodIntExt", "NordicHarbor",
"Ocean", "Office", "OldBrickHouseDay", "OldBrickHouseNight", "OldIndustrialCity", "OldScandinavia",
"OldTownFall", "OldTownNight", "OldTownSummer", "OldTownWinter", "PolarSciFi", "Prison", "Restaurant",
"RetroOffice", "Rome", "Ruins", "SeasideTown", "SeasonalForestAutumn", "SeasonalForestSpring",
"SeasonalForestSummerNight", "SeasonalForestWinter", "SeasonalForestWinterNight", "Sewerage",
"ShoreCaves", "Slaughter", "SoulCity", "Supermarket", "TerrainBlending", "UrbanConstruction",
"VictorianStreet", "WaterMillDay", "WaterMillNight", "WesternDesertTown"
]
modalities = ['imu', 'lidar', 'flow', 'image', 'depth', 'seg']

download_all_in_parallel(trajectories, modalities, args.num_workers)

if args.retry_failed:
retry_failed_downloads(error_log_file, args.num_workers)
ta.download_multi_thread(env = env,
difficulty = ['easy', 'hard'],
modality = ['image', 'depth'],
camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'],
unzip = True,
num_workers = 8)
117 changes: 78 additions & 39 deletions tartanair/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# Local imports.
from .tartanair_module import TartanAirModule, print_error, print_highlight, print_warn
from os.path import isdir, isfile, join
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

class AirLabDownloader(object):
def __init__(self, bucket_name = 'tartanair2') -> None:
Expand All @@ -23,38 +25,38 @@ def __init__(self, bucket_name = 'tartanair2') -> None:
access_key = "4e54CkGDFg2RmPjaQYmW"
secret_key = "mKdGwketlYUcXQwcPxuzinSxJazoyMpAip47zYdl"

self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=False)
self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=True)
self.bucket_name = bucket_name

def download(self, filelist, destination_path):
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
def download(self, filelist, targetfilelist):
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files

print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.client.fput_object(self.bucket_name, target_file_name, source_file_name)
self.client.fget_object(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)

return True, target_filelist
return True, success_source_files, success_target_files

class CloudFlareDownloader(object):
def __init__(self, bucket_name = "tartanair-v2") -> None:
import boto3
access_key = "be0116e42ced3fd52c32398b5003ecda"
secret_key = "103fab752dab348fa665dc744be9b8fb6f9cf04f82f9409d79c54a88661a0d40"
access_key = "f1ae9efebbc6a9a7cebbd949ba3a12de"
secret_key = "0a21fe771089d82e048ed0a1dd6067cb29a5666bf4fe95f7be9ba6f72482ec8b"
endpoint_url = "https://0a585e9484af268a716f8e6d3be53bbc.r2.cloudflarestorage.com"

self.bucket_name = bucket_name
self.s3 = boto3.client('s3', aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
endpoint_url=endpoint_url)

def download(self, filelist, destination_path):
def download(self, filelist, targetfilelist):
"""
Downloads a file from Cloudflare R2 storage using S3 API.
Expand All @@ -67,26 +69,29 @@ def download(self, filelist, destination_path):
- str: A message indicating success or failure.
"""

from botocore.exceptions import NoCredentialsError
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
from botocore.exceptions import NoCredentialsError, ClientError
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files
try:
print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.s3.download_file(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
except FileNotFoundError:
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)
except ClientError:
print_error(f"Error: The file {source_file_name} was not found in the bucket {self.bucket_name}.")
return False, None
return False, success_source_files, success_target_files
except NoCredentialsError:
print_error("Error: Credentials not available.")
return False, None
return True, target_filelist
return False, success_source_files, success_target_files
except Exception:
print_error("Error: Failed for some reason.")
return False, success_source_files, success_target_files
return True, success_source_files, success_target_files

def get_all_s3_objects(self):
continuation_token = None
Expand Down Expand Up @@ -169,7 +174,7 @@ def unzip_files(self, zipfilelist):
os.system(cmd)
print_highlight("Unzipping Completed! ")

def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, download = True, **kwargs):
def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, **kwargs):
"""
Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion.
Expand Down Expand Up @@ -207,32 +212,66 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c

# Check that the environments are valid.
if not self.check_env_valid(env):
return False, None
return False
# Check that the modalities are valid
if not self.check_modality_valid(modality):
return False, None
return False
# Check that the difficulty are valid
if not self.check_difficulty_valid(difficulty):
return False, None
return False
# Check that the camera names are valid
if not self.check_camera_valid(camera_name):
return False, None
return False

zipfilelist = self.generate_filelist(env, difficulty, modality, camera_name)
# import ipdb;ipdb.set_trace()
if not self.doublecheck_filelist(zipfilelist):
return False, None
return False

if download:
suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root)
if suc:
print_highlight("Download completed! Enjoy using TartanAir!")
# generate the target file list:
targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
all_success_filelist = []

if unzip:
self.unzip_files(targetfilelist)
suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)

# download failed files untill success
trail_count = 0
while not suc:
zipfilelist = [ff for ff in zipfilelist if ff not in success_source_files]
if len(zipfilelist) == 0:
print_warn("No failed files are found! ")
break

targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)
trail_count += 1
if trail_count >= max_failure_trial:
break

if suc:
print_highlight("Download completed! Enjoy using TartanAir!")
else:
targetfilelist = []
for source_file_name in zipfilelist:
target_file_name = source_file_name.replace('/', '_')
targetfilelist.append(target_file_name)
print_warn("Download with failure! The following files are not downloaded ..")
for ff in zipfilelist:
print_warn(ff)

return True, targetfilelist
if unzip:
self.unzip_files(all_success_filelist)

return True

def download_multi_thread(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, num_workers = 8, **kwargs):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for ee in env:
for dd in difficulty:
futures.append(executor.submit(self.download, env = [ee], difficulty = [dd], modality = modality, camera_name = camera_name,
config = config, unzip = unzip, max_failure_trial = max_failure_trial,))
# Wait for a few seconds to avoid overloading the data server
time.sleep(2)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution
28 changes: 25 additions & 3 deletions tartanair/tartanair.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def init(tartanair_root):
return True


def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False, download = True):
def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False):
"""
Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory.
Expand All @@ -107,8 +107,30 @@ def download(env = [], difficulty = [], trajectory_id = [], modality = [], camer

global downloader
check_init()
success, filelist = downloader.download(env, difficulty, modality, camera_name, config, unzip, download)
return success, filelist
downloader.download(env, difficulty, modality, camera_name, config, unzip)

def download_multi_thread(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, num_workers = 8):
"""
Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory.
:param env: The environment to download. Can be a list of environments.
:type env: str or list
:param difficulty: The difficulty of the trajectory. Can be a list of difficulties. Valid difficulties are: easy, hard.
:type difficulty: str or list
:param trajectory_id: The id of the trajectory to download. Can be a list of trajectory ids of form P000, P001, etc.
:type trajectory_id: str or list
:param modality: The modality to download. Can be a list of modalities. Valid modalities are: image, depth, seg, imu{_acc, _gyro, _time, ...}, lidar. Default will include all.
:type modality: str or list
:param camera_name: The camera name to download. Can be a list of camera names. Default will include all. Choices are `lcam_front`, `lcam_right`, `lcam_back`, `lcam_left`, `lcam_top`, `lcam_bottom`, `rcam_front`, `rcam_right`, `rcam_back`, `rcam_left`, `rcam_top`, `rcam_bottom`, `lcam_fish`, `rcam_fish`, `lcam_equirect`, `rcam_equirect`.
Modalities IMU and LIDAR do not need camera names specified.
:type camera_name: str or list
:param config: Optional. Path to a yaml file containing the download configuration. If a config file is provided, the other arguments will be ignored.
:type config: str
"""

global downloader
check_init()
downloader.download_multi_thread(env = env, difficulty = difficulty, modality = modality, camera_name = camera_name, config = config, unzip = unzip, num_workers = num_workers)

def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"):
"""
Expand Down

0 comments on commit 4d52eac

Please sign in to comment.