Skip to content

Commit

Permalink
Add support for returning filelist and multi-thread download
Browse files Browse the repository at this point in the history
  • Loading branch information
Nik-V9 committed Aug 22, 2024
1 parent 18d8f72 commit d1df02b
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 19 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Local
local/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
120 changes: 120 additions & 0 deletions examples/multi_thread_download_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
import time
import logging
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed

import tartanair as ta


def download_dataset(env, modality, cam_name):
try:
# Attempt to download the dataset
success, filelist = ta.download(env=env,
difficulty=['easy', 'hard'],
modality=modality,
camera_name=cam_name,
unzip=False)
except Exception as e:
logging.error(f"Failed to download {env} {modality} {cam_name}: {e}")


def download_all_in_parallel(trajectories, modalities, num_workers):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for env in trajectories:
for modality in modalities:
if modality in ['imu', 'lidar', 'flow']:
cam_names = ["lcam_front"]
else:
cam_names = ["lcam_back", "lcam_bottom", "lcam_equirect", "lcam_fish", "lcam_front",
"lcam_left", "lcam_right", "lcam_top", "rcam_back", "rcam_bottom",
"rcam_equirect", "rcam_fish", "rcam_front", "rcam_left", "rcam_right", "rcam_top"]
for cam_name in cam_names:
futures.append(executor.submit(download_dataset, env, modality, cam_name))
# Wait for a few seconds to avoid overloading the data server
time.sleep(10)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution


def retry_failed_downloads(error_log_path, num_workers):
# Read list of environments, modalities and camera names from the error log
trajectories = []
modalities = []
cam_names = []
with open(error_log_path, 'r') as f:
for line in f:
env, modality, cam_name = line.split(" ")[4:7]
cam_name = cam_name.replace(":", "")
trajectories.append(env)
modalities.append(modality)
cam_names.append(cam_name)
# Download data in parallel
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for data_idx in range(len(trajectories)):
env = trajectories[data_idx]
modality = modalities[data_idx]
cam_name = cam_names[data_idx]
futures.append(executor.submit(download_dataset, env, modality, cam_name))
# Wait for a few seconds to avoid overloading the data server
time.sleep(10)

# Wait for all futures to complete
for future in as_completed(futures):
future.result() # This will re-raise any exceptions caught during the futures' execution


def parse_arguments():
parser = argparse.ArgumentParser(description="Download TartanAir datasets.")
parser.add_argument("--data_root", type=str, required=True, help="Root directory for TartanAir data.")
parser.add_argument("--retry_failed", action='store_true', help="Retry failed downloads.")
parser.add_argument("--error_log_name", type=str, default="error_log.txt", help="Name of the error log file.")
parser.add_argument("--error_log_path", type=str, default="", help="Path to store the error log file.")
parser.add_argument("--num_workers", type=int, default=24, help="Number of workers for parallel downloads.")

return parser.parse_args()


if __name__ == "__main__":
args = parse_arguments()
error_log_path = args.error_log_path if args.error_log_path else '.'
error_log_file = os.path.join(error_log_path, args.error_log_name)

# Create the log directory if it doesn't exist
if not os.path.exists(error_log_path):
os.makedirs(error_log_path)

# Setup logging
logging.basicConfig(filename=error_log_file, level=logging.ERROR, format='%(asctime)s:%(levelname)s:%(message)s')

# Initialize TartanAir Module.
tartanair_data_root = args.data_root
ta.init(tartanair_data_root)

# Define Trajectories and Modalities to be downloaded
trajectories = [
"AbandonedCable", "AbandonedFactory", "AbandonedFactory2", "AbandonedSchool",
"AmericanDiner", "AmusementPark", "AncientTowns", "Antiquity3D", "Apocalyptic",
"ArchVizTinyHouseDay", "ArchVizTinyHouseNight", "BrushifyMoon", "CarWelding",
"CastleFortress", "CoalMine", "ConstructionSite", "CountryHouse", "CyberPunkDowntown",
"Cyberpunk", "DesertGasStation", "Downtown", "EndofTheWorld", "FactoryWeather", "Fantasy",
"ForestEnv", "Gascola", "GothicIsland", "GreatMarsh", "HQWesternSaloon", "HongKong", "Hospital",
"House", "IndustrialHangar", "JapaneseAlley", "JapaneseCity", "MiddleEast", "ModUrbanCity",
"ModernCityDowntown", "ModularNeighborhood", "ModularNeighborhoodIntExt", "NordicHarbor",
"Ocean", "Office", "OldBrickHouseDay", "OldBrickHouseNight", "OldIndustrialCity", "OldScandinavia",
"OldTownFall", "OldTownNight", "OldTownSummer", "OldTownWinter", "PolarSciFi", "Prison", "Restaurant",
"RetroOffice", "Rome", "Ruins", "SeasideTown", "SeasonalForestAutumn", "SeasonalForestSpring",
"SeasonalForestSummerNight", "SeasonalForestWinter", "SeasonalForestWinterNight", "Sewerage",
"ShoreCaves", "Slaughter", "SoulCity", "Supermarket", "TerrainBlending", "UrbanConstruction",
"VictorianStreet", "WaterMillDay", "WaterMillNight", "WesternDesertTown"
]
modalities = ['imu', 'lidar', 'flow', 'image', 'depth', 'seg']

download_all_in_parallel(trajectories, modalities, args.num_workers)

if args.retry_failed:
retry_failed_downloads(error_log_file, args.num_workers)
37 changes: 21 additions & 16 deletions tartanair/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def unzip_files(self, zipfilelist):
os.system(cmd)
print_highlight("Unzipping Completed! ")

def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, **kwargs):
def download(self, env = [], difficulty = [], modality = [], camera_name = [], config = None, unzip = False, download = True, **kwargs):
"""
Downloads a trajectory from the TartanAir dataset. A trajectory includes a set of images and a corresponding trajectory text file describing the motion.
Expand Down Expand Up @@ -283,27 +283,32 @@ def download(self, env = [], difficulty = [], modality = [], camera_name = [], c

# Check that the environments are valid.
if not self.check_env_valid(env):
return False
return False, None
# Check that the modalities are valid
if not self.check_modality_valid(modality):
return False
return False, None
# Check that the difficulty are valid
if not self.check_difficulty_valid(difficulty):
return False
return False, None
# Check that the camera names are valid
if not self.check_camera_valid(camera_name):
return False
return False, None

zipfilelist = self.generate_filelist(env, difficulty, modality, camera_name)
# import ipdb;ipdb.set_trace()
if not self.doublecheck_filelist(zipfilelist):
return False

suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root)
if suc:
print_highlight("Download completed! Enjoy using TartanAir!")

if unzip:
self.unzip_files(targetfilelist)

return True
return False, None

if download:
suc, targetfilelist = self.downloader.download(zipfilelist, self.tartanair_data_root)
if suc:
print_highlight("Download completed! Enjoy using TartanAir!")

if unzip:
self.unzip_files(targetfilelist)
else:
targetfilelist = []
for source_file_name in zipfilelist:
target_file_name = source_file_name.replace('/', '_')
targetfilelist.append(target_file_name)

return True, targetfilelist
7 changes: 4 additions & 3 deletions tartanair/tartanair.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def init(tartanair_root):
return True


def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False):
def download(env = [], difficulty = [], trajectory_id = [], modality = [], camera_name = [], config = None, unzip = False, download = True):
"""
Download data from the TartanAir dataset. This method will download the data from the Azure server and store it in the `tartanair_root` directory.
Expand All @@ -101,7 +101,8 @@ def download(env = [], difficulty = [], trajectory_id = [], modality = [], camer

global downloader
check_init()
downloader.download(env, difficulty, modality, camera_name, config, unzip)
success, filelist = downloader.download(env, difficulty, modality, camera_name, config, unzip, download)
return success, filelist

def customize(env, difficulty, trajectory_id, modality, new_camera_models_params = [{}], num_workers = 1, device = "cpu"):
"""
Expand Down Expand Up @@ -418,4 +419,4 @@ def get_random_accessor():
"""
global random_accessor
check_init()
return random_accessor
return random_accessor

0 comments on commit d1df02b

Please sign in to comment.