From 5c3723af60640703bdb07cdfb73493397b9f9fc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Tue, 9 Jul 2024 13:27:02 +0200 Subject: [PATCH 1/5] Parametrize FPS group MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- .../nlp/language_modeling/conf/megatron_gpt_config.yaml | 1 + nemo/utils/callbacks/dist_ckpt_io.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index ac1f4a37b232..8b5796759ba3 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -177,6 +177,7 @@ model: dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_save_within_dp: True # if true (recommended), save will be parallelized within a DP group (whole world otherwise). Setting to False might add a small save overhead. dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves. diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 144c07addaa8..4303e5536386 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -206,6 +206,7 @@ def __init__( torch_dist_multiproc: Optional[int] = None, assume_constant_structure: bool = False, parallel_save: bool = False, + parallel_save_within_dp: bool = True, parallel_load: bool = False, ): super().__init__() @@ -218,6 +219,7 @@ def __init__( self.torch_dist_multiproc = torch_dist_multiproc self.assume_constant_structure = assume_constant_structure self.parallel_save = parallel_save + self.parallel_save_within_dp = parallel_save_within_dp self.parallel_load = parallel_load self._save_sharded_strategy = None @@ -239,6 +241,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): async_save=async_save, torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), parallel_save=model_cfg.get('dist_ckpt_parallel_save', False), + parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', True), parallel_load=model_cfg.get('dist_ckpt_parallel_load', False), ) @@ -377,8 +380,9 @@ def _determine_dist_ckpt_save_strategy(self): save_strategy.use_cached_ckpt_structure = self.assume_constant_structure if self.parallel_save: + parallelization_group = get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure + save_strategy, parallelization_group, self.assume_constant_structure ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.') From 3a33045805f5613da2d03f46acaed8b4c8f2f810 Mon Sep 17 00:00:00 2001 From: mikolajblaz Date: Tue, 9 Jul 2024 11:28:45 +0000 Subject: [PATCH 2/5] Apply isort and black reformatting Signed-off-by: mikolajblaz --- nemo/utils/callbacks/dist_ckpt_io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 4303e5536386..774435bbfce9 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -380,7 +380,9 @@ def _determine_dist_ckpt_save_strategy(self): save_strategy.use_cached_ckpt_structure = self.assume_constant_structure if self.parallel_save: - parallelization_group = get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None + parallelization_group = ( + get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None + ) save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, parallelization_group, self.assume_constant_structure ) From 8fcdf656dddece294df4f72ce887bffb2396cfb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Tue, 9 Jul 2024 13:49:01 +0200 Subject: [PATCH 3/5] Change deafult to False MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- examples/nlp/language_modeling/conf/megatron_gpt_config.yaml | 2 +- nemo/utils/callbacks/dist_ckpt_io.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 8b5796759ba3..1cc3f6e527d4 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -177,7 +177,7 @@ model: dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint - dist_ckpt_parallel_save_within_dp: True # if true (recommended), save will be parallelized within a DP group (whole world otherwise). Setting to False might add a small save overhead. + dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves. diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 774435bbfce9..ad2ad1eebec0 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -206,7 +206,7 @@ def __init__( torch_dist_multiproc: Optional[int] = None, assume_constant_structure: bool = False, parallel_save: bool = False, - parallel_save_within_dp: bool = True, + parallel_save_within_dp: bool = False, parallel_load: bool = False, ): super().__init__() @@ -241,7 +241,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): async_save=async_save, torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), parallel_save=model_cfg.get('dist_ckpt_parallel_save', False), - parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', True), + parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False), parallel_load=model_cfg.get('dist_ckpt_parallel_load', False), ) From 77de9bf1423e56e93850142ec2be666469bf2d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Tue, 9 Jul 2024 13:52:20 +0200 Subject: [PATCH 4/5] Add logic to new ckptIO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- nemo/lightning/io/pl.py | 7 ++++++- nemo/lightning/pytorch/strategies.py | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 2cadc56e59b4..02b998378ea3 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -77,6 +77,7 @@ def __init__( torch_dist_multiproc: Optional[int] = None, assume_constant_structure: bool = False, parallel_save: bool = True, + parallel_save_within_dp: bool = False, parallel_load: bool = False, ): self.save_ckpt_format = save_ckpt_format @@ -85,6 +86,7 @@ def __init__( self.torch_dist_multiproc = torch_dist_multiproc self.assume_constant_structure = assume_constant_structure self.parallel_save = parallel_save + self.parallel_save_within_dp = parallel_save_within_dp self.parallel_load = parallel_load self._save_sharded_strategy = None @@ -216,8 +218,11 @@ def _determine_dist_ckpt_save_strategy(self): save_strategy.use_cached_ckpt_structure = self.assume_constant_structure if self.parallel_save: + parallelization_group = ( + get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None + ) save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure + save_strategy, parallelization_group, self.assume_constant_structure ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.') diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index d0e502839f2f..093bfeee30b7 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -108,6 +108,7 @@ def __init__( ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere? ckpt_assume_constant_structure=False, ckpt_parallel_save=True, + ckpt_parallel_save_within_dp=False, ckpt_parallel_load=False, ckpt_parallel_save_optim=True, setup_optimizers: bool = True, @@ -143,6 +144,7 @@ def __init__( self.torch_dist_multiproc = ckpt_torch_dist_multiproc self.assume_constant_structure = ckpt_assume_constant_structure self.parallel_save = ckpt_parallel_save + self.parallel_save_within_dp = ckpt_parallel_save_within_dp self.parallel_load = ckpt_parallel_load self.parallel_save_optim = ckpt_parallel_save_optim @@ -578,6 +580,7 @@ def checkpoint_io(self) -> CheckpointIO: torch_dist_multiproc=self.torch_dist_multiproc, assume_constant_structure=self.assume_constant_structure, parallel_save=self.parallel_save, + parallel_save_within_dp=self.parallel_save_within_dp, parallel_load=self.parallel_load, ) if async_save: From 07ac79762249ea5a40a2915b9192a1894a8dc66c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Tue, 9 Jul 2024 14:11:24 +0200 Subject: [PATCH 5/5] Turn on parallel save by default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Mikołaj Błaż --- examples/nlp/language_modeling/conf/megatron_gpt_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 1cc3f6e527d4..1599f38cbfa8 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -176,7 +176,7 @@ model: # Distributed checkpoint setup dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU - dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format