Skip to content

Commit

Permalink
Update mcore parallelism initialization in nemo2 (#10643)
Browse files Browse the repository at this point in the history
* update mcore parallelism initialization

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* Update megatron_init.py

Signed-off-by: Yu Yao <[email protected]>

* add encoder parallel default config

Signed-off-by: yaoyu-33 <[email protected]>

* Fix _strategy_lib.py

Signed-off-by: Yu Yao <[email protected]>

* update megatron_init.py inside lightning

Signed-off-by: yaoyu-33 <[email protected]>

* fix test

Signed-off-by: yaoyu-33 <[email protected]>

* try fix test

Signed-off-by: yaoyu-33 <[email protected]>

* try fix test

Signed-off-by: yaoyu-33 <[email protected]>

* Fix megatron megatron_init.py dp

Signed-off-by: Yu Yao <[email protected]>

* Update lightning megatron_init.py dp

Signed-off-by: Yu Yao <[email protected]>

---------

Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: Yu Yao <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent a8fd3d6 commit 85e14ca
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 34 deletions.
116 changes: 99 additions & 17 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def initialize_model_parallel_for_nemo(
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
context_parallel_size=1,
encoder_tensor_model_parallel_size=0,
encoder_pipeline_model_parallel_size=0,
micro_batch_size=None,
global_batch_size=None,
rampup_batch_size=None,
Expand All @@ -120,6 +122,8 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.context_parallel_size = context_parallel_size
app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size
app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
(
Expand All @@ -139,6 +143,8 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size,
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
)

Expand All @@ -149,12 +155,14 @@ def initialize_model_parallel_for_nemo(
set_expert_model_parallel_world_size(app_state.expert_model_parallel_size)
set_expert_model_parallel_rank(app_state.expert_model_parallel_rank)

set_pipeline_model_parallel_world_size(
app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size
)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)
set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank)
if HAVE_INTERLEAVED:
set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size)
set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank)
set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)

tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker)
if seed is not None:
Expand Down Expand Up @@ -247,6 +255,8 @@ def fake_initialize_model_parallel(
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
context_parallel_size_=1,
encoder_tensor_model_parallel_size_=0,
encoder_pipeline_model_parallel_size_=0,
use_tp_pp_dp_mapping=False,
):
"""
Expand Down Expand Up @@ -283,37 +293,109 @@ def fake_initialize_model_parallel(
model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
context_parallel_size = min(context_parallel_size_, world_size)

assert (
world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0
), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}'
data_parallel_size = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
if encoder_pipeline_model_parallel_size_ is None:
encoder_pipeline_model_parallel_size = 0
else:
encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_

if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0:
encoder_tensor_model_parallel_size = tensor_model_parallel_size
else:
encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_

if encoder_tensor_model_parallel_size > 0:
assert encoder_pipeline_model_parallel_size > 0
assert (
encoder_tensor_model_parallel_size <= tensor_model_parallel_size
), "We do not support encoders with more TP than the decoder."

encoder_model_size = (
encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size
)
decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
total_model_size = encoder_model_size + decoder_model_size

assert world_size % total_model_size == 0, (
f'world_size: {world_size} must be divisible by total world_size: '
f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} '
f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} '
f'* (decoder_)context_parallel_size {context_parallel_size} + '
f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} '
f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} '
f'* context_parallel_size {context_parallel_size}'
)
data_parallel_size = world_size // total_model_size

num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
encoder_world_size = encoder_model_size * data_parallel_size
decoder_world_size = decoder_model_size * data_parallel_size
assert encoder_world_size + decoder_world_size == world_size

virtual_pipeline_model_parallel_rank = None
if virtual_pipeline_model_parallel_size_ is not None:
virtual_pipeline_model_parallel_rank = 0

rank_generator = RankGenerator(
if encoder_world_size > 0:
encoder_rank_generator = RankGenerator(
tp=encoder_tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=encoder_pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=0,
)
else:
encoder_rank_generator = None

decoder_rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size_,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)

def generator_wrapper(group_type, **kwargs):
from itertools import cycle

"""The `RankGenerator` class produces a hyper-rectangle for a given set of
tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp."""
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if encoder_rank_generator is None:
for x in d_ranks:
yield x
return
e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
if group_type == 'pp':
# Map 1 encoder tp rank to several decoder tp ranks, because
# these won't be the same size.
for x, y in zip(cycle(e_ranks), d_ranks):
yield x + y
elif group_type == 'tp-pp':
# For this group, we can just return the concatenated
# groups together, because their sizes are the same.
assert len(e_ranks) == len(d_ranks)
for x, y in zip(e_ranks, d_ranks):
yield x + y
else:
for x in e_ranks:
yield x
for x in d_ranks:
yield x

# Build the data-parallel groups.
all_data_parallel_group_ranks_with_cp = []
for ranks in rank_generator.get_ranks('dp'):
for ranks in generator_wrapper('dp'):
if rank in ranks:
data_parallel_group = list(ranks)
logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}')

for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
for ranks_with_cp in generator_wrapper('dp-cp'):
all_data_parallel_group_ranks_with_cp.append(ranks_with_cp)
if rank in ranks_with_cp:
data_parallel_group_with_cp = ranks_with_cp
Expand All @@ -329,7 +411,7 @@ def fake_initialize_model_parallel(

# Build the context-parallel groups.
all_context_parallel_group_ranks = []
for ranks in rank_generator.get_ranks('cp'):
for ranks in generator_wrapper('cp'):
all_context_parallel_group_ranks.append(ranks)
if rank in ranks:
context_parallel_group = ranks
Expand All @@ -341,7 +423,7 @@ def fake_initialize_model_parallel(

# Build the model-parallel groups.
all_model_parallel_group_ranks = []
for ranks in rank_generator.get_ranks('tp-pp'):
for ranks in generator_wrapper('tp-pp'):
all_model_parallel_group_ranks.append(ranks)
if rank in ranks:
logging.info(f'Rank {rank} has model parallel group: {list(ranks)}')
Expand All @@ -350,7 +432,7 @@ def fake_initialize_model_parallel(
# Build the tensor model-parallel groups.
all_tensor_model_parallel_group_ranks = []
tensor_model_parallel_group = None
for ranks in rank_generator.get_ranks('tp'):
for ranks in generator_wrapper('tp'):
all_tensor_model_parallel_group_ranks.append(ranks)
if rank in ranks:
tensor_model_parallel_group = ranks
Expand All @@ -364,7 +446,7 @@ def fake_initialize_model_parallel(
# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
for ranks in rank_generator.get_ranks('ep', independent_ep=True):
for ranks in generator_wrapper('ep', independent_ep=True):
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank)

Expand All @@ -375,7 +457,7 @@ def fake_initialize_model_parallel(
pipeline_model_parallel_group = None
embedding_group = None
embedding_rank = None
for ranks in rank_generator.get_ranks('pp'):
for ranks in generator_wrapper('pp'):
all_pipeline_model_parallel_group_ranks.append(ranks)
if rank in ranks:
pipeline_model_parallel_group = ranks
Expand Down
4 changes: 4 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def init_parallel_ranks(
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
context_parallel_size=parallel_config.context_parallel_size,
encoder_tensor_model_parallel_size=getattr(parallel_config, "encoder_tensor_model_parallel_size", 0),
encoder_pipeline_model_parallel_size=getattr(parallel_config, "encoder_pipeline_model_parallel_size", 0),
seed=seed,
pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None),
use_fp8=fp8,
Expand Down Expand Up @@ -113,6 +115,8 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
encoder_pipeline_model_parallel_size=app_state.encoder_pipeline_model_parallel_size,
encoder_tensor_model_parallel_size=app_state.encoder_tensor_model_parallel_size,
context_parallel_size=app_state.context_parallel_size,
expert_model_parallel_size=app_state.expert_model_parallel_size,
)
Expand Down
Loading

0 comments on commit 85e14ca

Please sign in to comment.