Skip to content

Commit

Permalink
Parametrize FPS group (NVIDIA#9648) (NVIDIA#9669)
Browse files Browse the repository at this point in the history
* Parametrize FPS group

* Apply isort and black reformatting

* Change deafult to False

* Add logic to new ckptIO

* Turn on parallel save by default

---------

Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: mikolajblaz <[email protected]>
Co-authored-by: mikolajblaz <[email protected]>
Co-authored-by: Dmytro Pykhtar <[email protected]>
Signed-off-by: tonyjie <[email protected]>
  • Loading branch information
3 people authored and tonyjie committed Aug 6, 2024
1 parent 9d0349e commit 6587378
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ model:
# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead
dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
Expand Down
7 changes: 6 additions & 1 deletion nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = True,
parallel_save_within_dp: bool = False,
parallel_load: bool = False,
):
self.save_ckpt_format = save_ckpt_format
Expand All @@ -85,6 +86,7 @@ def __init__(
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_save_within_dp = parallel_save_within_dp
self.parallel_load = parallel_load

self._save_sharded_strategy = None
Expand Down Expand Up @@ -216,8 +218,11 @@ def _determine_dist_ckpt_save_strategy(self):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
parallelization_group = (
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
save_strategy, parallelization_group, self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure=False,
ckpt_parallel_save=True,
ckpt_parallel_save_within_dp=False,
ckpt_parallel_load=False,
ckpt_parallel_save_optim=True,
**kwargs,
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
self.torch_dist_multiproc = ckpt_torch_dist_multiproc
self.assume_constant_structure = ckpt_assume_constant_structure
self.parallel_save = ckpt_parallel_save
self.parallel_save_within_dp = ckpt_parallel_save_within_dp
self.parallel_load = ckpt_parallel_load
self.parallel_save_optim = ckpt_parallel_save_optim

Expand Down Expand Up @@ -566,6 +568,7 @@ def checkpoint_io(self) -> CheckpointIO:
torch_dist_multiproc=self.torch_dist_multiproc,
assume_constant_structure=self.assume_constant_structure,
parallel_save=self.parallel_save,
parallel_save_within_dp=self.parallel_save_within_dp,
parallel_load=self.parallel_load,
)
if async_save:
Expand Down
8 changes: 7 additions & 1 deletion nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = False,
parallel_save_within_dp: bool = False,
parallel_load: bool = False,
):
super().__init__()
Expand All @@ -218,6 +219,7 @@ def __init__(
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_save_within_dp = parallel_save_within_dp
self.parallel_load = parallel_load

self._save_sharded_strategy = None
Expand All @@ -239,6 +241,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False),
parallel_load=model_cfg.get('dist_ckpt_parallel_load', False),
)

Expand Down Expand Up @@ -377,8 +380,11 @@ def _determine_dist_ckpt_save_strategy(self):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
parallelization_group = (
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
save_strategy, parallelization_group, self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
Expand Down

0 comments on commit 6587378

Please sign in to comment.