diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index ac1f4a37b232..1599f38cbfa8 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -176,7 +176,8 @@ model: # Distributed checkpoint setup dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU - dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves. diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 2cadc56e59b4..02b998378ea3 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -77,6 +77,7 @@ def __init__( torch_dist_multiproc: Optional[int] = None, assume_constant_structure: bool = False, parallel_save: bool = True, + parallel_save_within_dp: bool = False, parallel_load: bool = False, ): self.save_ckpt_format = save_ckpt_format @@ -85,6 +86,7 @@ def __init__( self.torch_dist_multiproc = torch_dist_multiproc self.assume_constant_structure = assume_constant_structure self.parallel_save = parallel_save + self.parallel_save_within_dp = parallel_save_within_dp self.parallel_load = parallel_load self._save_sharded_strategy = None @@ -216,8 +218,11 @@ def _determine_dist_ckpt_save_strategy(self): save_strategy.use_cached_ckpt_structure = self.assume_constant_structure if self.parallel_save: + parallelization_group = ( + get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None + ) save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure + save_strategy, parallelization_group, self.assume_constant_structure ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.') diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 6a84319b4fa2..d75239f7e668 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -108,6 +108,7 @@ def __init__( ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere? ckpt_assume_constant_structure=False, ckpt_parallel_save=True, + ckpt_parallel_save_within_dp=False, ckpt_parallel_load=False, ckpt_parallel_save_optim=True, **kwargs, @@ -139,6 +140,7 @@ def __init__( self.torch_dist_multiproc = ckpt_torch_dist_multiproc self.assume_constant_structure = ckpt_assume_constant_structure self.parallel_save = ckpt_parallel_save + self.parallel_save_within_dp = ckpt_parallel_save_within_dp self.parallel_load = ckpt_parallel_load self.parallel_save_optim = ckpt_parallel_save_optim @@ -566,6 +568,7 @@ def checkpoint_io(self) -> CheckpointIO: torch_dist_multiproc=self.torch_dist_multiproc, assume_constant_structure=self.assume_constant_structure, parallel_save=self.parallel_save, + parallel_save_within_dp=self.parallel_save_within_dp, parallel_load=self.parallel_load, ) if async_save: diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 144c07addaa8..ad2ad1eebec0 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -206,6 +206,7 @@ def __init__( torch_dist_multiproc: Optional[int] = None, assume_constant_structure: bool = False, parallel_save: bool = False, + parallel_save_within_dp: bool = False, parallel_load: bool = False, ): super().__init__() @@ -218,6 +219,7 @@ def __init__( self.torch_dist_multiproc = torch_dist_multiproc self.assume_constant_structure = assume_constant_structure self.parallel_save = parallel_save + self.parallel_save_within_dp = parallel_save_within_dp self.parallel_load = parallel_load self._save_sharded_strategy = None @@ -239,6 +241,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): async_save=async_save, torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), parallel_save=model_cfg.get('dist_ckpt_parallel_save', False), + parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False), parallel_load=model_cfg.get('dist_ckpt_parallel_load', False), ) @@ -377,8 +380,11 @@ def _determine_dist_ckpt_save_strategy(self): save_strategy.use_cached_ckpt_structure = self.assume_constant_structure if self.parallel_save: + parallelization_group = ( + get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None + ) save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure + save_strategy, parallelization_group, self.assume_constant_structure ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.')