Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for returning file list and multi-thread download #34

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 18 additions & 114 deletions examples/multi_thread_download_example.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,24 @@
import os
import time
import logging
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
# General imports.
import sys

# Local imports.
sys.path.append('..')
import tartanair as ta

# Create a TartanAir object.
tartanair_data_root = '/my/path/to/root/folder/for/tartanair-v2'

def download_dataset(env, modality, cam_name):
try:
# Attempt to download the dataset
success, filelist = ta.download(env=env,
difficulty=['easy', 'hard'],
modality=modality,
camera_name=cam_name,
unzip=False)
except Exception as e:
logging.error(f"Failed to download {env} {modality} {cam_name}: {e}")
ta.init(tartanair_data_root)

# Download data from following environments.
env = [ "Prison",
"Ruins",
"UrbanConstruction",
]

def download_all_in_parallel(trajectories, modalities, num_workers):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for env in trajectories:
for modality in modalities:
if modality in ['imu', 'lidar', 'flow']:
cam_names = ["lcam_front"]
else:
cam_names = ["lcam_back", "lcam_bottom", "lcam_equirect", "lcam_fish", "lcam_front",
"lcam_left", "lcam_right", "lcam_top", "rcam_back", "rcam_bottom",
"rcam_equirect", "rcam_fish", "rcam_front", "rcam_left", "rcam_right", "rcam_top"]
for cam_name in cam_names:
futures.append(executor.submit(download_dataset, env, modality, cam_name))
# Wait for a few seconds to avoid overloading the data server
time.sleep(10)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution


def retry_failed_downloads(error_log_path, num_workers):
# Read list of environments, modalities and camera names from the error log
trajectories = []
modalities = []
cam_names = []
with open(error_log_path, 'r') as f:
for line in f:
env, modality, cam_name = line.split(" ")[4:7]
cam_name = cam_name.replace(":", "")
trajectories.append(env)
modalities.append(modality)
cam_names.append(cam_name)
# Download data in parallel
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for data_idx in range(len(trajectories)):
env = trajectories[data_idx]
modality = modalities[data_idx]
cam_name = cam_names[data_idx]
futures.append(executor.submit(download_dataset, env, modality, cam_name))
# Wait for a few seconds to avoid overloading the data server
time.sleep(10)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution


def parse_arguments():
parser = argparse.ArgumentParser(description="Download TartanAir datasets.")
parser.add_argument("--data_root", type=str, required=True, help="Root directory for TartanAir data.")
parser.add_argument("--retry_failed", action='store_true', help="Retry failed downloads.")
parser.add_argument("--error_log_name", type=str, default="error_log.txt", help="Name of the error log file.")
parser.add_argument("--error_log_path", type=str, default="", help="Path to store the error log file.")
parser.add_argument("--num_workers", type=int, default=24, help="Number of workers for parallel downloads.")

return parser.parse_args()


if __name__ == "__main__":
args = parse_arguments()
error_log_path = args.error_log_path if args.error_log_path else '.'
error_log_file = os.path.join(error_log_path, args.error_log_name)

# Create the log directory if it doesn't exist
if not os.path.exists(error_log_path):
os.makedirs(error_log_path)

# Setup logging
logging.basicConfig(filename=error_log_file, level=logging.ERROR, format='%(asctime)s:%(levelname)s:%(message)s')

# Initialize TartanAir Module.
tartanair_data_root = args.data_root
ta.init(tartanair_data_root)

# Define Trajectories and Modalities to be downloaded
trajectories = [
Nik-V9 marked this conversation as resolved.
Show resolved Hide resolved
"AbandonedCable", "AbandonedFactory", "AbandonedFactory2", "AbandonedSchool",
"AmericanDiner", "AmusementPark", "AncientTowns", "Antiquity3D", "Apocalyptic",
"ArchVizTinyHouseDay", "ArchVizTinyHouseNight", "BrushifyMoon", "CarWelding",
"CastleFortress", "CoalMine", "ConstructionSite", "CountryHouse", "CyberPunkDowntown",
"Cyberpunk", "DesertGasStation", "Downtown", "EndofTheWorld", "FactoryWeather", "Fantasy",
"ForestEnv", "Gascola", "GothicIsland", "GreatMarsh", "HQWesternSaloon", "HongKong", "Hospital",
"House", "IndustrialHangar", "JapaneseAlley", "JapaneseCity", "MiddleEast", "ModUrbanCity",
"ModernCityDowntown", "ModularNeighborhood", "ModularNeighborhoodIntExt", "NordicHarbor",
"Ocean", "Office", "OldBrickHouseDay", "OldBrickHouseNight", "OldIndustrialCity", "OldScandinavia",
"OldTownFall", "OldTownNight", "OldTownSummer", "OldTownWinter", "PolarSciFi", "Prison", "Restaurant",
"RetroOffice", "Rome", "Ruins", "SeasideTown", "SeasonalForestAutumn", "SeasonalForestSpring",
"SeasonalForestSummerNight", "SeasonalForestWinter", "SeasonalForestWinterNight", "Sewerage",
"ShoreCaves", "Slaughter", "SoulCity", "Supermarket", "TerrainBlending", "UrbanConstruction",
"VictorianStreet", "WaterMillDay", "WaterMillNight", "WesternDesertTown"
]
modalities = ['imu', 'lidar', 'flow', 'image', 'depth', 'seg']

download_all_in_parallel(trajectories, modalities, args.num_workers)

if args.retry_failed:
retry_failed_downloads(error_log_file, args.num_workers)
ta.download_multi_thread(env = env,
difficulty = ['easy', 'hard'],
modality = ['image', 'depth'],
camera_name = ['lcam_front', 'lcam_right', 'lcam_back', 'lcam_left', 'lcam_top', 'lcam_bottom'],
unzip = True,
num_workers = 8)
117 changes: 78 additions & 39 deletions tartanair/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# Local imports.
from .tartanair_module import TartanAirModule, print_error, print_highlight, print_warn
from os.path import isdir, isfile, join
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

class AirLabDownloader(object):
def __init__(self, bucket_name = 'tartanair2') -> None:
Expand All @@ -23,38 +25,38 @@ def __init__(self, bucket_name = 'tartanair2') -> None:
access_key = "4e54CkGDFg2RmPjaQYmW"
secret_key = "mKdGwketlYUcXQwcPxuzinSxJazoyMpAip47zYdl"

self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=False)
self.client = Minio(endpoint_url, access_key=access_key, secret_key=secret_key, secure=True)
self.bucket_name = bucket_name

def download(self, filelist, destination_path):
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
def download(self, filelist, targetfilelist):
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files

print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.client.fput_object(self.bucket_name, target_file_name, source_file_name)
self.client.fget_object(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)

return True, target_filelist
return True, success_source_files, success_target_files

class CloudFlareDownloader(object):
def __init__(self, bucket_name = "tartanair-v2") -> None:
import boto3
access_key = "be0116e42ced3fd52c32398b5003ecda"
secret_key = "103fab752dab348fa665dc744be9b8fb6f9cf04f82f9409d79c54a88661a0d40"
access_key = "f1ae9efebbc6a9a7cebbd949ba3a12de"
secret_key = "0a21fe771089d82e048ed0a1dd6067cb29a5666bf4fe95f7be9ba6f72482ec8b"
endpoint_url = "https://0a585e9484af268a716f8e6d3be53bbc.r2.cloudflarestorage.com"

self.bucket_name = bucket_name
self.s3 = boto3.client('s3', aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
endpoint_url=endpoint_url)

def download(self, filelist, destination_path):
def download(self, filelist, targetfilelist):
"""
Downloads a file from Cloudflare R2 storage using S3 API.

Expand All @@ -67,26 +69,29 @@ def download(self, filelist, destination_path):
- str: A message indicating success or failure.
"""

from botocore.exceptions import NoCredentialsError
target_filelist = []
for source_file_name in filelist:
target_file_name = join(destination_path, source_file_name.replace('/', '_'))
target_filelist.append(target_file_name)
from botocore.exceptions import NoCredentialsError, ClientError
success_source_files, success_target_files = [], []
for source_file_name, target_file_name in zip(filelist, targetfilelist):
print('--')
if isfile(target_file_name):
print_error('Error: Target file {} already exists..'.format(target_file_name))
return False, None
return False, success_source_files, success_target_files
try:
print(f" Downloading {source_file_name} from {self.bucket_name}...")
self.s3.download_file(self.bucket_name, source_file_name, target_file_name)
print(f" Successfully downloaded {source_file_name} to {target_file_name}!")
except FileNotFoundError:
success_source_files.append(source_file_name)
success_target_files.append(target_file_name)
except ClientError:
print_error(f"Error: The file {source_file_name} was not found in the bucket {self.bucket_name}.")
return False, None
return False, success_source_files, success_target_files
except NoCredentialsError:
print_error("Error: Credentials not available.")
return False, None
return True, target_filelist
return False, success_source_files, success_target_files
except Exception:
print_error("Error: Failed for some reason.")
return False, success_source_files, success_target_files
return True, success_source_files, success_target_files

def get_all_s3_objects(self):
continuation_token = None
Expand Down Expand Up @@ -169,7 +174,7 @@ def unzip_files(self, zipfilelist):
os.system(cmd)
print_highlight("Unzipping Completed! ")

def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, download = True, **kwargs):
def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, **kwargs):
"""
Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion.

Expand Down Expand Up @@ -207,32 +212,66 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c

# Check that the environments are valid.
if not self.check_env_valid(env):
return False, None
return False
# Check that the modalities are valid
if not self.check_modality_valid(modality):
return False, None
return False
# Check that the difficulty are valid
if not self.check_difficulty_valid(difficulty):
return False, None
return False
# Check that the camera names are valid
if not self.check_camera_valid(camera_name):
return False, None
return False

zipfilelist = self.generate_filelist(env, difficulty, modality, camera_name)
# import ipdb;ipdb.set_trace()
if not self.doublecheck_filelist(zipfilelist):
return False, None
return False

if download:
suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root)
if suc:
print_highlight("Download completed! Enjoy using TartanAir!")
# generate the target file list:
targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
all_success_filelist = []

if unzip:
self.unzip_files(targetfilelist)
suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)

# download failed files untill success
trail_count = 0
while not suc:
zipfilelist = [ff for ff in zipfilelist if ff not in success_source_files]
if len(zipfilelist) == 0:
print_warn("No failed files are found! ")
break

targetfilelist = [join(self.tartanair_data_root, zipfile.replace('/', '_')) for zipfile in zipfilelist]
suc, success_source_files, success_target_files = self.downloader.download(zipfilelist, targetfilelist)
all_success_filelist.extend(success_target_files)
trail_count += 1
if trail_count >= max_failure_trial:
break

if suc:
print_highlight("Download completed! Enjoy using TartanAir!")
else:
targetfilelist = []
for source_file_name in zipfilelist:
target_file_name = source_file_name.replace('/', '_')
targetfilelist.append(target_file_name)
print_warn("Download with failure! The following files are not downloaded ..")
for ff in zipfilelist:
print_warn(ff)

return True, targetfilelist
if unzip:
self.unzip_files(all_success_filelist)

return True

def download_multi_thread(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, max_failure_trial = 3, num_workers = 8, **kwargs):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for ee in env:
for dd in difficulty:
futures.append(executor.submit(self.download, env = [ee], difficulty = [dd], modality = modality, camera_name = camera_name,
config = config, unzip = unzip, max_failure_trial = max_failure_trial,))
# Wait for a few seconds to avoid overloading the data server
time.sleep(2)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution
28 changes: 25 additions & 3 deletions tartanair/tartanair.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def init(tartanair_root):
return True


def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False, download = True):
def download(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False):
"""
Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory.

Expand All @@ -107,8 +107,30 @@ def download(env = [], difficulty = [], trajectory_id = [], modality = [], camer

global downloader
check_init()
success, filelist = downloader.download(env, difficulty, modality, camera_name, config, unzip, download)
return success, filelist
downloader.download(env, difficulty, modality, camera_name, config, unzip)

def download_multi_thread(env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, num_workers = 8):
"""
Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory.

:param env: The environment to download. Can be a list of environments.
:type env: str or list
:param difficulty: The difficulty of the trajectory. Can be a list of difficulties. Valid difficulties are: easy, hard.
:type difficulty: str or list
:param trajectory_id: The id of the trajectory to download. Can be a list of trajectory ids of form P000, P001, etc.
:type trajectory_id: str or list
:param modality: The modality to download. Can be a list of modalities. Valid modalities are: image, depth, seg, imu{_acc, _gyro, _time, ...}, lidar. Default will include all.
:type modality: str or list
:param camera_name: The camera name to download. Can be a list of camera names. Default will include all. Choices are `lcam_front`, `lcam_right`, `lcam_back`, `lcam_left`, `lcam_top`, `lcam_bottom`, `rcam_front`, `rcam_right`, `rcam_back`, `rcam_left`, `rcam_top`, `rcam_bottom`, `lcam_fish`, `rcam_fish`, `lcam_equirect`, `rcam_equirect`.
Modalities IMU and LIDAR do not need camera names specified.
:type camera_name: str or list
:param config: Optional. Path to a yaml file containing the download configuration. If a config file is provided, the other arguments will be ignored.
:type config: str
"""

global downloader
check_init()
downloader.download_multi_thread(env = env, difficulty = difficulty, modality = modality, camera_name = camera_name, config = config, unzip = unzip, num_workers = num_workers)

def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"):
"""
Expand Down
Loading