Skip to content

Commit

Permalink
Avoid redundant copying of sharding input dir if local
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Dec 20, 2023
1 parent 6ca40de commit b03f9c0
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def _get_sharded_checkpoint_dirs(
return sharded_checkpoint_directories


def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str):
def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str) -> bool:
max_train_config_size = 1 * 1024 * 1024 # 1MB

if not StorageAdapter.get_storage_type_for_path(local_checkpoint_dir) == StorageType.LOCAL_FS:
Expand All @@ -811,7 +811,7 @@ def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str):
checkpoint_storage = _get_storage_adapter_for_path(local_checkpoint_dir)
if CONFIG_YAML in checkpoint_storage.list_entries(local_checkpoint_dir, max_file_size=max_train_config_size):
# Config already exists in the checkpoint
return
return False

log.info("%s not found in %s, attempting to get it from %s", CONFIG_YAML, local_checkpoint_dir, run_dir)

Expand All @@ -820,22 +820,26 @@ def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str):
if run_storage.is_file(run_config_yaml_path):
local_config_yaml_path = cached_path(run_config_yaml_path)
shutil.copy(local_config_yaml_path, local_checkpoint_dir)
return
return True

log.warning("Cannot find training config to add to checkpoint %s", local_checkpoint_dir)
return False


def _unshard_checkpoint(
sharded_checkpoint_dir: str, dest_dir: str, run_dir: str, unsharding_config: UnshardCheckpointsConfig
):
local_storage = LocalFileSystemAdapter()

# Download checkpoint to a temp dir
sharding_input_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)
src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir)
src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir)
# Download checkpoint to a temp dir if it is in cloud storage
if StorageAdapter.get_storage_type_for_path(sharded_checkpoint_dir) != StorageType.LOCAL_FS:
sharding_input_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)
src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir)
src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir)
else:
sharding_input_dir = sharded_checkpoint_dir

_add_training_config_to_checkpoint(sharding_input_dir, run_dir)
training_config_added = _add_training_config_to_checkpoint(sharding_input_dir, run_dir)

# Set unsharder output to a temp dir
sharding_output_dir: str
Expand Down Expand Up @@ -863,7 +867,9 @@ def _unshard_checkpoint(
e,
)

local_storage.delete_path(sharding_input_dir)
if training_config_added:
local_storage.delete_path(str(Path(sharding_input_dir) / CONFIG_YAML))

local_storage.delete_path(sharding_output_dir)
return

Expand Down Expand Up @@ -898,9 +904,6 @@ def _unshard_checkpoint(
dest_storage = _get_storage_adapter_for_path(dest_dir)
dest_storage.upload(sharding_output_dir, dest_dir)

local_storage.delete_path(sharding_input_dir)
local_storage.delete_path(sharding_output_dir)


def _unshard_checkpoints(
run_storage: StorageAdapter,
Expand Down

0 comments on commit b03f9c0

Please sign in to comment.