From e7e679827f28c6d00c945616aa167c14d4b60c48 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 25 Oct 2024 12:50:28 -0700 Subject: [PATCH] update megatron_init.py inside lightning Signed-off-by: yaoyu-33 --- nemo/lightning/megatron_init.py | 116 +++++++++++++++++++++++++++----- 1 file changed, 99 insertions(+), 17 deletions(-) diff --git a/nemo/lightning/megatron_init.py b/nemo/lightning/megatron_init.py index c060d140cb8c..f7920a55a959 100644 --- a/nemo/lightning/megatron_init.py +++ b/nemo/lightning/megatron_init.py @@ -95,6 +95,8 @@ def initialize_model_parallel_for_nemo( virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, context_parallel_size=1, + encoder_tensor_model_parallel_size=0, + encoder_pipeline_model_parallel_size=0, micro_batch_size=None, global_batch_size=None, rampup_batch_size=None, @@ -120,6 +122,8 @@ def initialize_model_parallel_for_nemo( app_state.pipeline_model_parallel_size = pipeline_model_parallel_size app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size app_state.context_parallel_size = context_parallel_size + app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size + app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size app_state.use_fp8 = use_fp8 app_state.init_mpi_proc_group = init_mpi_proc_group ( @@ -139,6 +143,8 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, context_parallel_size_=context_parallel_size, expert_model_parallel_size_=expert_model_parallel_size, + encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size, + encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size, use_tp_pp_dp_mapping=use_tp_pp_dp_mapping, ) @@ -149,12 +155,14 @@ def initialize_model_parallel_for_nemo( set_expert_model_parallel_world_size(app_state.expert_model_parallel_size) set_expert_model_parallel_rank(app_state.expert_model_parallel_rank) + set_pipeline_model_parallel_world_size( + app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size + ) + set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank) if HAVE_INTERLEAVED: set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size) set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank) - set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size) - set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker) if seed is not None: @@ -247,6 +255,8 @@ def fake_initialize_model_parallel( virtual_pipeline_model_parallel_size_=None, expert_model_parallel_size_=1, context_parallel_size_=1, + encoder_tensor_model_parallel_size_=0, + encoder_pipeline_model_parallel_size_=0, use_tp_pp_dp_mapping=False, ): """ @@ -283,37 +293,109 @@ def fake_initialize_model_parallel( model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size context_parallel_size = min(context_parallel_size_, world_size) - assert ( - world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0 - ), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}' - data_parallel_size = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + if encoder_pipeline_model_parallel_size_ is None: + encoder_pipeline_model_parallel_size = 0 + else: + encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_ + + if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0: + encoder_tensor_model_parallel_size = tensor_model_parallel_size + else: + encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_ + + if encoder_tensor_model_parallel_size > 0: + assert encoder_pipeline_model_parallel_size > 0 + assert ( + encoder_tensor_model_parallel_size <= tensor_model_parallel_size + ), "We do not support encoders with more TP than the decoder." + + encoder_model_size = ( + encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size + ) + decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + total_model_size = encoder_model_size + decoder_model_size + + assert world_size % total_model_size == 0, ( + f'world_size: {world_size} must be divisible by total world_size: ' + f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} ' + f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} ' + f'* (decoder_)context_parallel_size {context_parallel_size} + ' + f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} ' + f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} ' + f'* context_parallel_size {context_parallel_size}' ) + data_parallel_size = world_size // total_model_size - num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size + encoder_world_size = encoder_model_size * data_parallel_size + decoder_world_size = decoder_model_size * data_parallel_size + assert encoder_world_size + decoder_world_size == world_size virtual_pipeline_model_parallel_rank = None if virtual_pipeline_model_parallel_size_ is not None: virtual_pipeline_model_parallel_rank = 0 - rank_generator = RankGenerator( + if encoder_world_size > 0: + encoder_rank_generator = RankGenerator( + tp=encoder_tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=encoder_pipeline_model_parallel_size, + cp=context_parallel_size, + order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=0, + ) + else: + encoder_rank_generator = None + + decoder_rank_generator = RankGenerator( tp=tensor_model_parallel_size, ep=expert_model_parallel_size_, dp=data_parallel_size, pp=pipeline_model_parallel_size, cp=context_parallel_size, order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp', + rank_offset=encoder_world_size, ) + def generator_wrapper(group_type, **kwargs): + from itertools import cycle + + """The `RankGenerator` class produces a hyper-rectangle for a given set of + tensor, pipeline, data, expert, and context parallelism. If we have an encoder, + in addition to the default decoder, we essentially instantiate two `RankGenerator` + classes to construct the parallelism for each module separately, and we then have + to stitch them together for the right groups. For now, this means pp and tp-pp.""" + d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + if encoder_rank_generator is None: + for x in d_ranks: + yield x + return + e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs) + if group_type == 'pp': + # Map 1 encoder tp rank to several decoder tp ranks, because + # these won't be the same size. + for x, y in zip(cycle(e_ranks), d_ranks): + yield x + y + elif group_type == 'tp-pp': + # For this group, we can just return the concatenated + # groups together, because their sizes are the same. + assert len(e_ranks) == len(d_ranks) + for x, y in zip(e_ranks, d_ranks): + yield x + y + else: + for x in e_ranks: + yield x + for x in d_ranks: + yield x + # Build the data-parallel groups. all_data_parallel_group_ranks_with_cp = [] - for ranks in rank_generator.get_ranks('dp'): + for ranks in generator_wrapper('pp'): if rank in ranks: data_parallel_group = list(ranks) logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}') - for ranks_with_cp in rank_generator.get_ranks('dp-cp'): + for ranks_with_cp in generator_wrapper('dp-cp'): all_data_parallel_group_ranks_with_cp.append(ranks_with_cp) if rank in ranks_with_cp: data_parallel_group_with_cp = ranks_with_cp @@ -329,7 +411,7 @@ def fake_initialize_model_parallel( # Build the context-parallel groups. all_context_parallel_group_ranks = [] - for ranks in rank_generator.get_ranks('cp'): + for ranks in generator_wrapper('cp'): all_context_parallel_group_ranks.append(ranks) if rank in ranks: context_parallel_group = ranks @@ -341,7 +423,7 @@ def fake_initialize_model_parallel( # Build the model-parallel groups. all_model_parallel_group_ranks = [] - for ranks in rank_generator.get_ranks('tp-pp'): + for ranks in generator_wrapper('tp-pp'): all_model_parallel_group_ranks.append(ranks) if rank in ranks: logging.info(f'Rank {rank} has model parallel group: {list(ranks)}') @@ -350,7 +432,7 @@ def fake_initialize_model_parallel( # Build the tensor model-parallel groups. all_tensor_model_parallel_group_ranks = [] tensor_model_parallel_group = None - for ranks in rank_generator.get_ranks('tp'): + for ranks in generator_wrapper('tp'): all_tensor_model_parallel_group_ranks.append(ranks) if rank in ranks: tensor_model_parallel_group = ranks @@ -364,7 +446,7 @@ def fake_initialize_model_parallel( # EP rank expert_model_parallel_rank = 0 if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: - for ranks in rank_generator.get_ranks('ep', independent_ep=True): + for ranks in generator_wrapper('ep', independent_ep=True): if rank in ranks: expert_model_parallel_rank = list(ranks).index(rank) @@ -375,7 +457,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_group = None embedding_group = None embedding_rank = None - for ranks in rank_generator.get_ranks('pp'): + for ranks in generator_wrapper('pp'): all_pipeline_model_parallel_group_ranks.append(ranks) if rank in ranks: pipeline_model_parallel_group = ranks