Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametrize FPS group #9648

Merged
merged 6 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ model:
# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead
dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
Expand Down
7 changes: 6 additions & 1 deletion nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = True,
parallel_save_within_dp: bool = False,
parallel_load: bool = False,
):
self.save_ckpt_format = save_ckpt_format
Expand All @@ -85,6 +86,7 @@ def __init__(
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_save_within_dp = parallel_save_within_dp
self.parallel_load = parallel_load

self._save_sharded_strategy = None
Expand Down Expand Up @@ -216,8 +218,11 @@ def _determine_dist_ckpt_save_strategy(self):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
parallelization_group = (
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
save_strategy, parallelization_group, self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure=False,
ckpt_parallel_save=True,
ckpt_parallel_save_within_dp=False,
ckpt_parallel_load=False,
ckpt_parallel_save_optim=True,
setup_optimizers: bool = True,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
self.torch_dist_multiproc = ckpt_torch_dist_multiproc
self.assume_constant_structure = ckpt_assume_constant_structure
self.parallel_save = ckpt_parallel_save
self.parallel_save_within_dp = ckpt_parallel_save_within_dp
self.parallel_load = ckpt_parallel_load
self.parallel_save_optim = ckpt_parallel_save_optim

Expand Down Expand Up @@ -578,6 +580,7 @@ def checkpoint_io(self) -> CheckpointIO:
torch_dist_multiproc=self.torch_dist_multiproc,
assume_constant_structure=self.assume_constant_structure,
parallel_save=self.parallel_save,
parallel_save_within_dp=self.parallel_save_within_dp,
parallel_load=self.parallel_load,
)
if async_save:
Expand Down
8 changes: 7 additions & 1 deletion nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = False,
parallel_save_within_dp: bool = False,
parallel_load: bool = False,
):
super().__init__()
Expand All @@ -218,6 +219,7 @@ def __init__(
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_save_within_dp = parallel_save_within_dp
self.parallel_load = parallel_load

self._save_sharded_strategy = None
Expand All @@ -239,6 +241,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False),
parallel_load=model_cfg.get('dist_ckpt_parallel_load', False),
)

Expand Down Expand Up @@ -377,8 +380,11 @@ def _determine_dist_ckpt_save_strategy(self):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
parallelization_group = (
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
save_strategy, parallelization_group, self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
Expand Down
Loading