Skip to content

Commit

Permalink
Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
Victarry committed Nov 26, 2024
1 parent ef5cffc commit 33445e5
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def fake_initialize_model_parallel(
pipeline_model_parallel_split_rank_=None,
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
expert_tensor_parallel_size_=None,
context_parallel_size_=1,
encoder_tensor_model_parallel_size_=0,
encoder_pipeline_model_parallel_size_=0,
Expand Down Expand Up @@ -349,18 +350,18 @@ def fake_initialize_model_parallel(

decoder_rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size_,
ep=1,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)
# Build expert rank generator
if expert_tensor_parallel_size is None:
expert_tensor_parallel_size = tensor_model_parallel_size
if expert_tensor_parallel_size_ is None:
expert_tensor_parallel_size_ = tensor_model_parallel_size
expert_tensor_model_pipeline_parallel_size = (
expert_tensor_parallel_size * expert_model_parallel_size * pipeline_model_parallel_size
expert_tensor_parallel_size_ * expert_model_parallel_size_ * pipeline_model_parallel_size
)
expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size
if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0:
Expand All @@ -370,12 +371,12 @@ def fake_initialize_model_parallel(

# TODO: support expert specific ordering
expert_decoder_rank_generator = RankGenerator(
tp=expert_tensor_parallel_size,
ep=expert_model_parallel_size,
tp=expert_tensor_parallel_size_,
ep=expert_model_parallel_size_,
dp=expert_data_parallel_size,
pp=pipeline_model_parallel_size,
cp=1,
order=order,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)

Expand All @@ -391,6 +392,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp."""
from itertools import cycle
if is_expert:
d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs)
else:
Expand Down

0 comments on commit 33445e5

Please sign in to comment.