Skip to content

Commit

Permalink
make torch_dist ckpt strategy as default (NVIDIA#9852) (NVIDIA#10291)
Browse files Browse the repository at this point in the history
copy of NVIDIA#9852

Signed-off-by: dimapihtar <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
Co-authored-by: dimapihtar <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 5, 2024
1 parent 13b46f3 commit 49693bc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ model:
fsdp_sharded_checkpoint: False # Store and load FSDP shared checkpoint.

# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_format: 'torch_dist' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead
Expand Down
7 changes: 7 additions & 0 deletions nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ def _determine_dist_ckpt_save_strategy(self):
are passed in config or in case of a fully parallel save in which case
a parallelization wrapper is applied.
"""
if self.save_ckpt_format == 'zarr':
logging.warning(
f'`zarr` distributed checkpoint backend is deprecated.'
f' Distributed optimizer checkpoint saving might be extremely slow.'
f' Please switch to PyTorch Distributed format (model.dist_ckpt_format=torch_dist).'
)

if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')

Expand Down
9 changes: 8 additions & 1 deletion nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
it should be provided separately. Defaults to False.
"""
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'torch_dist'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
async_save=async_save,
Expand Down Expand Up @@ -390,6 +390,13 @@ def _determine_dist_ckpt_save_strategy(self):
are passed in config or in case of a fully parallel save in which case
a parallelization wrapper is applied.
"""
if self.save_ckpt_format == 'zarr':
logging.warning(
f'`zarr` distributed checkpoint backend is deprecated.'
f' Distributed optimizer checkpoint saving might be extremely slow.'
f' Please switch to PyTorch Distributed format (model.dist_ckpt_format=torch_dist).'
)

if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')

Expand Down

0 comments on commit 49693bc

Please sign in to comment.