From 85e14ca6c4d357a85d611ce5322f26ae206c3a46 Mon Sep 17 00:00:00 2001 From: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Date: Tue, 29 Oct 2024 20:30:35 -0700 Subject: [PATCH] Update mcore parallelism initialization in nemo2 (#10643) * update mcore parallelism initialization Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Update megatron_init.py Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * add encoder parallel default config Signed-off-by: yaoyu-33 * Fix _strategy_lib.py Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * update megatron_init.py inside lightning Signed-off-by: yaoyu-33 * fix test Signed-off-by: yaoyu-33 * try fix test Signed-off-by: yaoyu-33 * try fix test Signed-off-by: yaoyu-33 * Fix megatron megatron_init.py dp Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * Update lightning megatron_init.py dp Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> --------- Signed-off-by: yaoyu-33 Signed-off-by: yaoyu-33 Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: yaoyu-33 Co-authored-by: Pablo Garay --- .../modules/common/megatron/megatron_init.py | 116 +++++++++++++++--- nemo/lightning/_strategy_lib.py | 4 + nemo/lightning/megatron_init.py | 116 +++++++++++++++--- .../pytorch/strategies/megatron_strategy.py | 8 ++ nemo/utils/app_state.py | 66 ++++++++++ tests/lightning/test_strategy_lib.py | 6 + 6 files changed, 282 insertions(+), 34 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index c060d140cb8c..10b939d4aecb 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -95,6 +95,8 @@ def initialize_model_parallel_for_nemo( virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, context_parallel_size=1, + encoder_tensor_model_parallel_size=0, + encoder_pipeline_model_parallel_size=0, micro_batch_size=None, global_batch_size=None, rampup_batch_size=None, @@ -120,6 +122,8 @@ def initialize_model_parallel_for_nemo( app_state.pipeline_model_parallel_size = pipeline_model_parallel_size app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size app_state.context_parallel_size = context_parallel_size + app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size + app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size app_state.use_fp8 = use_fp8 app_state.init_mpi_proc_group = init_mpi_proc_group ( @@ -139,6 +143,8 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, context_parallel_size_=context_parallel_size, expert_model_parallel_size_=expert_model_parallel_size, + encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size, + encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size, use_tp_pp_dp_mapping=use_tp_pp_dp_mapping, ) @@ -149,12 +155,14 @@ def initialize_model_parallel_for_nemo( set_expert_model_parallel_world_size(app_state.expert_model_parallel_size) set_expert_model_parallel_rank(app_state.expert_model_parallel_rank) + set_pipeline_model_parallel_world_size( + app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size + ) + set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank) if HAVE_INTERLEAVED: set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size) set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank) - set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size) - set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker) if seed is not None: @@ -247,6 +255,8 @@ def fake_initialize_model_parallel( virtual_pipeline_model_parallel_size_=None, expert_model_parallel_size_=1, context_parallel_size_=1, + encoder_tensor_model_parallel_size_=0, + encoder_pipeline_model_parallel_size_=0, use_tp_pp_dp_mapping=False, ): """ @@ -283,37 +293,109 @@ def fake_initialize_model_parallel( model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size context_parallel_size = min(context_parallel_size_, world_size) - assert ( - world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0 - ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}' - data_parallel_size = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + if encoder_pipeline_model_parallel_size_ is None: + encoder_pipeline_model_parallel_size = 0 + else: + encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_ + + if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0: + encoder_tensor_model_parallel_size = tensor_model_parallel_size + else: + encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_ + + if encoder_tensor_model_parallel_size > 0: + assert encoder_pipeline_model_parallel_size > 0 + assert ( + encoder_tensor_model_parallel_size <= tensor_model_parallel_size + ), "We do not support encoders with more TP than the decoder." + + encoder_model_size = ( + encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size + ) + decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + total_model_size = encoder_model_size + decoder_model_size + + assert world_size % total_model_size == 0, ( + f'world_size: {world_size} must be divisible by total world_size: ' + f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} ' + f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} ' + f'* (decoder_)context_parallel_size {context_parallel_size} + ' + f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} ' + f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} ' + f'* context_parallel_size {context_parallel_size}' ) + data_parallel_size = world_size // total_model_size - num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + encoder_world_size = encoder_model_size * data_parallel_size + decoder_world_size = decoder_model_size * data_parallel_size + assert encoder_world_size + decoder_world_size == world_size virtual_pipeline_model_parallel_rank = None if virtual_pipeline_model_parallel_size_ is not None: virtual_pipeline_model_parallel_rank = 0 - rank_generator = RankGenerator( + if encoder_world_size > 0: + encoder_rank_generator = RankGenerator( + tp=encoder_tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=encoder_pipeline_model_parallel_size, + cp=context_parallel_size, + order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=0, + ) + else: + encoder_rank_generator = None + + decoder_rank_generator = RankGenerator( tp=tensor_model_parallel_size, ep=expert_model_parallel_size_, dp=data_parallel_size, pp=pipeline_model_parallel_size, cp=context_parallel_size, order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=encoder_world_size, ) + def generator_wrapper(group_type, **kwargs): + from itertools import cycle + + """The `RankGenerator` class produces a hyper-rectangle for a given set of + tensor, pipeline, data, expert, and context parallelism. If we have an encoder, + in addition to the default decoder, we essentially instantiate two `RankGenerator` + classes to construct the parallelism for each module separately, and we then have + to stitch them together for the right groups. For now, this means pp and tp-pp.""" + d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + if encoder_rank_generator is None: + for x in d_ranks: + yield x + return + e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs) + if group_type == 'pp': + # Map 1 encoder tp rank to several decoder tp ranks, because + # these won't be the same size. + for x, y in zip(cycle(e_ranks), d_ranks): + yield x + y + elif group_type == 'tp-pp': + # For this group, we can just return the concatenated + # groups together, because their sizes are the same. + assert len(e_ranks) == len(d_ranks) + for x, y in zip(e_ranks, d_ranks): + yield x + y + else: + for x in e_ranks: + yield x + for x in d_ranks: + yield x + # Build the data-parallel groups. all_data_parallel_group_ranks_with_cp = [] - for ranks in rank_generator.get_ranks('dp'): + for ranks in generator_wrapper('dp'): if rank in ranks: data_parallel_group = list(ranks) logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') - for ranks_with_cp in rank_generator.get_ranks('dp-cp'): + for ranks_with_cp in generator_wrapper('dp-cp'): all_data_parallel_group_ranks_with_cp.append(ranks_with_cp) if rank in ranks_with_cp: data_parallel_group_with_cp = ranks_with_cp @@ -329,7 +411,7 @@ def fake_initialize_model_parallel( # Build the context-parallel groups. all_context_parallel_group_ranks = [] - for ranks in rank_generator.get_ranks('cp'): + for ranks in generator_wrapper('cp'): all_context_parallel_group_ranks.append(ranks) if rank in ranks: context_parallel_group = ranks @@ -341,7 +423,7 @@ def fake_initialize_model_parallel( # Build the model-parallel groups. all_model_parallel_group_ranks = [] - for ranks in rank_generator.get_ranks('tp-pp'): + for ranks in generator_wrapper('tp-pp'): all_model_parallel_group_ranks.append(ranks) if rank in ranks: logging.info(f'Rank {rank} has model parallel group: {list(ranks)}') @@ -350,7 +432,7 @@ def fake_initialize_model_parallel( # Build the tensor model-parallel groups. all_tensor_model_parallel_group_ranks = [] tensor_model_parallel_group = None - for ranks in rank_generator.get_ranks('tp'): + for ranks in generator_wrapper('tp'): all_tensor_model_parallel_group_ranks.append(ranks) if rank in ranks: tensor_model_parallel_group = ranks @@ -364,7 +446,7 @@ def fake_initialize_model_parallel( # EP rank expert_model_parallel_rank = 0 if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: - for ranks in rank_generator.get_ranks('ep', independent_ep=True): + for ranks in generator_wrapper('ep', independent_ep=True): if rank in ranks: expert_model_parallel_rank = list(ranks).index(rank) @@ -375,7 +457,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_group = None embedding_group = None embedding_rank = None - for ranks in rank_generator.get_ranks('pp'): + for ranks in generator_wrapper('pp'): all_pipeline_model_parallel_group_ranks.append(ranks) if rank in ranks: pipeline_model_parallel_group = ranks diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index ba4847219ed3..40a79c94c59f 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -84,6 +84,8 @@ def init_parallel_ranks( pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size, context_parallel_size=parallel_config.context_parallel_size, + encoder_tensor_model_parallel_size=getattr(parallel_config, "encoder_tensor_model_parallel_size", 0), + encoder_pipeline_model_parallel_size=getattr(parallel_config, "encoder_pipeline_model_parallel_size", 0), seed=seed, pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None), use_fp8=fp8, @@ -113,6 +115,8 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None: pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + encoder_pipeline_model_parallel_size=app_state.encoder_pipeline_model_parallel_size, + encoder_tensor_model_parallel_size=app_state.encoder_tensor_model_parallel_size, context_parallel_size=app_state.context_parallel_size, expert_model_parallel_size=app_state.expert_model_parallel_size, ) diff --git a/nemo/lightning/megatron_init.py b/nemo/lightning/megatron_init.py index c060d140cb8c..10b939d4aecb 100644 --- a/nemo/lightning/megatron_init.py +++ b/nemo/lightning/megatron_init.py @@ -95,6 +95,8 @@ def initialize_model_parallel_for_nemo( virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, context_parallel_size=1, + encoder_tensor_model_parallel_size=0, + encoder_pipeline_model_parallel_size=0, micro_batch_size=None, global_batch_size=None, rampup_batch_size=None, @@ -120,6 +122,8 @@ def initialize_model_parallel_for_nemo( app_state.pipeline_model_parallel_size = pipeline_model_parallel_size app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size app_state.context_parallel_size = context_parallel_size + app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size + app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size app_state.use_fp8 = use_fp8 app_state.init_mpi_proc_group = init_mpi_proc_group ( @@ -139,6 +143,8 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, context_parallel_size_=context_parallel_size, expert_model_parallel_size_=expert_model_parallel_size, + encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size, + encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size, use_tp_pp_dp_mapping=use_tp_pp_dp_mapping, ) @@ -149,12 +155,14 @@ def initialize_model_parallel_for_nemo( set_expert_model_parallel_world_size(app_state.expert_model_parallel_size) set_expert_model_parallel_rank(app_state.expert_model_parallel_rank) + set_pipeline_model_parallel_world_size( + app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size + ) + set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank) if HAVE_INTERLEAVED: set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size) set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank) - set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size) - set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker) if seed is not None: @@ -247,6 +255,8 @@ def fake_initialize_model_parallel( virtual_pipeline_model_parallel_size_=None, expert_model_parallel_size_=1, context_parallel_size_=1, + encoder_tensor_model_parallel_size_=0, + encoder_pipeline_model_parallel_size_=0, use_tp_pp_dp_mapping=False, ): """ @@ -283,37 +293,109 @@ def fake_initialize_model_parallel( model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size context_parallel_size = min(context_parallel_size_, world_size) - assert ( - world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0 - ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}' - data_parallel_size = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + if encoder_pipeline_model_parallel_size_ is None: + encoder_pipeline_model_parallel_size = 0 + else: + encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_ + + if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0: + encoder_tensor_model_parallel_size = tensor_model_parallel_size + else: + encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_ + + if encoder_tensor_model_parallel_size > 0: + assert encoder_pipeline_model_parallel_size > 0 + assert ( + encoder_tensor_model_parallel_size <= tensor_model_parallel_size + ), "We do not support encoders with more TP than the decoder." + + encoder_model_size = ( + encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size + ) + decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + total_model_size = encoder_model_size + decoder_model_size + + assert world_size % total_model_size == 0, ( + f'world_size: {world_size} must be divisible by total world_size: ' + f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} ' + f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} ' + f'* (decoder_)context_parallel_size {context_parallel_size} + ' + f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} ' + f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} ' + f'* context_parallel_size {context_parallel_size}' ) + data_parallel_size = world_size // total_model_size - num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + encoder_world_size = encoder_model_size * data_parallel_size + decoder_world_size = decoder_model_size * data_parallel_size + assert encoder_world_size + decoder_world_size == world_size virtual_pipeline_model_parallel_rank = None if virtual_pipeline_model_parallel_size_ is not None: virtual_pipeline_model_parallel_rank = 0 - rank_generator = RankGenerator( + if encoder_world_size > 0: + encoder_rank_generator = RankGenerator( + tp=encoder_tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=encoder_pipeline_model_parallel_size, + cp=context_parallel_size, + order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=0, + ) + else: + encoder_rank_generator = None + + decoder_rank_generator = RankGenerator( tp=tensor_model_parallel_size, ep=expert_model_parallel_size_, dp=data_parallel_size, pp=pipeline_model_parallel_size, cp=context_parallel_size, order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=encoder_world_size, ) + def generator_wrapper(group_type, **kwargs): + from itertools import cycle + + """The `RankGenerator` class produces a hyper-rectangle for a given set of + tensor, pipeline, data, expert, and context parallelism. If we have an encoder, + in addition to the default decoder, we essentially instantiate two `RankGenerator` + classes to construct the parallelism for each module separately, and we then have + to stitch them together for the right groups. For now, this means pp and tp-pp.""" + d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + if encoder_rank_generator is None: + for x in d_ranks: + yield x + return + e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs) + if group_type == 'pp': + # Map 1 encoder tp rank to several decoder tp ranks, because + # these won't be the same size. + for x, y in zip(cycle(e_ranks), d_ranks): + yield x + y + elif group_type == 'tp-pp': + # For this group, we can just return the concatenated + # groups together, because their sizes are the same. + assert len(e_ranks) == len(d_ranks) + for x, y in zip(e_ranks, d_ranks): + yield x + y + else: + for x in e_ranks: + yield x + for x in d_ranks: + yield x + # Build the data-parallel groups. all_data_parallel_group_ranks_with_cp = [] - for ranks in rank_generator.get_ranks('dp'): + for ranks in generator_wrapper('dp'): if rank in ranks: data_parallel_group = list(ranks) logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') - for ranks_with_cp in rank_generator.get_ranks('dp-cp'): + for ranks_with_cp in generator_wrapper('dp-cp'): all_data_parallel_group_ranks_with_cp.append(ranks_with_cp) if rank in ranks_with_cp: data_parallel_group_with_cp = ranks_with_cp @@ -329,7 +411,7 @@ def fake_initialize_model_parallel( # Build the context-parallel groups. all_context_parallel_group_ranks = [] - for ranks in rank_generator.get_ranks('cp'): + for ranks in generator_wrapper('cp'): all_context_parallel_group_ranks.append(ranks) if rank in ranks: context_parallel_group = ranks @@ -341,7 +423,7 @@ def fake_initialize_model_parallel( # Build the model-parallel groups. all_model_parallel_group_ranks = [] - for ranks in rank_generator.get_ranks('tp-pp'): + for ranks in generator_wrapper('tp-pp'): all_model_parallel_group_ranks.append(ranks) if rank in ranks: logging.info(f'Rank {rank} has model parallel group: {list(ranks)}') @@ -350,7 +432,7 @@ def fake_initialize_model_parallel( # Build the tensor model-parallel groups. all_tensor_model_parallel_group_ranks = [] tensor_model_parallel_group = None - for ranks in rank_generator.get_ranks('tp'): + for ranks in generator_wrapper('tp'): all_tensor_model_parallel_group_ranks.append(ranks) if rank in ranks: tensor_model_parallel_group = ranks @@ -364,7 +446,7 @@ def fake_initialize_model_parallel( # EP rank expert_model_parallel_rank = 0 if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: - for ranks in rank_generator.get_ranks('ep', independent_ep=True): + for ranks in generator_wrapper('ep', independent_ep=True): if rank in ranks: expert_model_parallel_rank = list(ranks).index(rank) @@ -375,7 +457,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_group = None embedding_group = None embedding_rank = None - for ranks in rank_generator.get_ranks('pp'): + for ranks in generator_wrapper('pp'): all_pipeline_model_parallel_group_ranks.append(ranks) if rank in ranks: pipeline_model_parallel_group = ranks diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c61c3371cc3c..c22df7cc9dfe 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -97,6 +97,8 @@ class ParallelismConfig: expert_model_parallel_size: int moe_extended_tp: bool pipeline_dtype: torch.dtype + encoder_tensor_model_parallel_size: int = 0 + encoder_pipeline_model_parallel_size: int = 0 class MegatronStrategy(DDPStrategy, io.IOMixin): @@ -177,6 +179,8 @@ def __init__( sequence_parallel: bool = False, expert_model_parallel_size: int = 1, moe_extended_tp: bool = False, + encoder_tensor_model_parallel_size: Optional[int] = 0, + encoder_pipeline_model_parallel_size: Optional[int] = 0, data_sampler: Optional["DataSampler"] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment=None, # TODO: Add type-hint @@ -220,6 +224,8 @@ def __init__( self.moe_extended_tp = moe_extended_tp self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size self.sequence_parallel = sequence_parallel + self.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size + self.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size self.lazy_init = lazy_init self.ckpt_load_optimizer = ckpt_load_optimizer self.ckpt_save_optimizer = ckpt_save_optimizer @@ -821,6 +827,8 @@ def parallelism(self) -> ParallelismConfig: sequence_parallel=self.sequence_parallel, expert_model_parallel_size=self.expert_model_parallel_size, moe_extended_tp=self.moe_extended_tp, + encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size, + encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size, pipeline_dtype=self.pipeline_dtype, ) diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index 7a60c3969df3..37193cfdd8c5 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -50,6 +50,8 @@ def __init__(self): self._expert_model_parallel_size = None self._pipeline_model_parallel_size = None self._virtual_pipeline_model_parallel_size = None + self._encoder_tensor_model_parallel_size = None + self._encoder_pipeline_model_parallel_size = None self._pipeline_model_parallel_group = None self._pipeline_model_parallel_split_rank = None self._is_megatron_initialized = False @@ -200,6 +202,38 @@ def pipeline_model_parallel_size(self, size): """ self._pipeline_model_parallel_size = size + @property + def encoder_tensor_model_parallel_size(self): + """Property returns the number of GPUs in each model parallel group. + Returns: + Number of GPUs in each model parallel group. + """ + return self._encoder_tensor_model_parallel_size + + @encoder_tensor_model_parallel_size.setter + def encoder_tensor_model_parallel_size(self, size): + """Property sets the number of GPUs in each model parallel group. + Args: + size (int): Number of GPUs in each model parallel group. + """ + self._encoder_tensor_model_parallel_size = size + + @property + def encoder_pipeline_model_parallel_size(self): + """Property returns the number of GPUs in each model parallel group. + Returns: + Number of GPUs in each model parallel group. + """ + return self._encoder_pipeline_model_parallel_size + + @encoder_pipeline_model_parallel_size.setter + def encoder_pipeline_model_parallel_size(self, size): + """Property sets the number of GPUs in each model parallel group. + Args: + size (int): Number of GPUs in each model parallel group. + """ + self._encoder_pipeline_model_parallel_size = size + @property def use_tp_pp_dp_mapping(self): return self._use_tp_pp_dp_mapping @@ -336,6 +370,38 @@ def virtual_pipeline_model_parallel_rank(self, rank): """ self._virtual_pipeline_model_parallel_rank = rank + @property + def encoder_tensor_model_parallel_rank(self): + """Property returns the encoder tensor model parallel rank. + Returns: + Tensor model parallel rank. + """ + return self._encoder_tensor_model_parallel_rank + + @encoder_tensor_model_parallel_rank.setter + def encoder_tensor_model_parallel_rank(self, rank): + """Property sets the encoder tensor model parallel rank. + Args: + rank (int): Tensor model parallel rank. + """ + self._encoder_tensor_model_parallel_rank = rank + + @property + def encoder_pipeline_model_parallel_rank(self): + """Property returns the encoder pipeline model parallel rank. + Returns: + Tensor model parallel rank. + """ + return self._encoder_pipeline_model_parallel_rank + + @encoder_pipeline_model_parallel_rank.setter + def encoder_pipeline_model_parallel_rank(self, rank): + """Property sets the encoder pipeline model parallel rank. + Args: + rank (int): Tensor model parallel rank. + """ + self._encoder_pipeline_model_parallel_rank = rank + @property def pipeline_model_parallel_split_rank(self): """Property returns the rank at which Encoder and Decoder are split into different pipelines for Megatrron Encoder-Decoder models. diff --git a/tests/lightning/test_strategy_lib.py b/tests/lightning/test_strategy_lib.py index 4410d0b1b910..241debd16316 100644 --- a/tests/lightning/test_strategy_lib.py +++ b/tests/lightning/test_strategy_lib.py @@ -78,6 +78,8 @@ def test_init_parallel_ranks() -> None: mock_parallel_config.virtual_pipeline_model_parallel_size = 4 mock_parallel_config.context_parallel_size = 2 mock_parallel_config.expert_model_parallel_size = 2 + mock_parallel_config.encoder_tensor_model_parallel_size = 0 + mock_parallel_config.encoder_pipeline_model_parallel_size = 0 mock_parallel_config.tp_comm_overlap = False mock_parallel_config.pipeline_model_parallel_split_rank = None @@ -99,6 +101,8 @@ def test_init_parallel_ranks() -> None: "context_parallel_size": 2, "expert_model_parallel_size": 2, "pipeline_model_parallel_split_rank": None, + "encoder_pipeline_model_parallel_size": 0, + "encoder_tensor_model_parallel_size": 0, "use_fp8": False, "init_mpi_proc_group": False, } @@ -135,6 +139,8 @@ def test_init_model_parallel(mock_mpu, *args): pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, + encoder_pipeline_model_parallel_size=None, + encoder_tensor_model_parallel_size=None, context_parallel_size=2, expert_model_parallel_size=2, )