diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index bd97fc292..22a412d93 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -802,7 +802,7 @@ def _get_sharded_checkpoint_dirs( return sharded_checkpoint_directories -def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str): +def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str) -> bool: max_train_config_size = 1 * 1024 * 1024 # 1MB if not StorageAdapter.get_storage_type_for_path(local_checkpoint_dir) == StorageType.LOCAL_FS: @@ -811,7 +811,7 @@ def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str): checkpoint_storage = _get_storage_adapter_for_path(local_checkpoint_dir) if CONFIG_YAML in checkpoint_storage.list_entries(local_checkpoint_dir, max_file_size=max_train_config_size): # Config already exists in the checkpoint - return + return False log.info("%s not found in %s, attempting to get it from %s", CONFIG_YAML, local_checkpoint_dir, run_dir) @@ -820,9 +820,10 @@ def _add_training_config_to_checkpoint(local_checkpoint_dir: str, run_dir: str): if run_storage.is_file(run_config_yaml_path): local_config_yaml_path = cached_path(run_config_yaml_path) shutil.copy(local_config_yaml_path, local_checkpoint_dir) - return + return True log.warning("Cannot find training config to add to checkpoint %s", local_checkpoint_dir) + return False def _unshard_checkpoint( @@ -830,12 +831,15 @@ def _unshard_checkpoint( ): local_storage = LocalFileSystemAdapter() - # Download checkpoint to a temp dir - sharding_input_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir) - src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir) - src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir) + # Download checkpoint to a temp dir if it is in cloud storage + if StorageAdapter.get_storage_type_for_path(sharded_checkpoint_dir) != StorageType.LOCAL_FS: + sharding_input_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir) + src_storage = _get_storage_adapter_for_path(sharded_checkpoint_dir) + src_storage.download_folder(sharded_checkpoint_dir, sharding_input_dir) + else: + sharding_input_dir = sharded_checkpoint_dir - _add_training_config_to_checkpoint(sharding_input_dir, run_dir) + training_config_added = _add_training_config_to_checkpoint(sharding_input_dir, run_dir) # Set unsharder output to a temp dir sharding_output_dir: str @@ -863,7 +867,9 @@ def _unshard_checkpoint( e, ) - local_storage.delete_path(sharding_input_dir) + if training_config_added: + local_storage.delete_path(str(Path(sharding_input_dir) / CONFIG_YAML)) + local_storage.delete_path(sharding_output_dir) return @@ -898,9 +904,6 @@ def _unshard_checkpoint( dest_storage = _get_storage_adapter_for_path(dest_dir) dest_storage.upload(sharding_output_dir, dest_dir) - local_storage.delete_path(sharding_input_dir) - local_storage.delete_path(sharding_output_dir) - def _unshard_checkpoints( run_storage: StorageAdapter,