Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mcore parallelism initialization in nemo2 #10643

Merged
merged 14 commits into from
Oct 30, 2024
116 changes: 99 additions & 17 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def initialize_model_parallel_for_nemo(
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
context_parallel_size=1,
encoder_tensor_model_parallel_size=0,
encoder_pipeline_model_parallel_size=0,
ericharper marked this conversation as resolved.
Show resolved Hide resolved
micro_batch_size=None,
global_batch_size=None,
rampup_batch_size=None,
Expand All @@ -120,6 +122,8 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.context_parallel_size = context_parallel_size
app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size
app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
(
Expand All @@ -139,6 +143,8 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size,
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
)

Expand All @@ -149,12 +155,14 @@ def initialize_model_parallel_for_nemo(
set_expert_model_parallel_world_size(app_state.expert_model_parallel_size)
set_expert_model_parallel_rank(app_state.expert_model_parallel_rank)

set_pipeline_model_parallel_world_size(
app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size
)
ericharper marked this conversation as resolved.
Show resolved Hide resolved
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)
set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank)
if HAVE_INTERLEAVED:
set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size)
set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank)
set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)

tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker)
if seed is not None:
Expand Down Expand Up @@ -247,6 +255,8 @@ def fake_initialize_model_parallel(
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
context_parallel_size_=1,
encoder_tensor_model_parallel_size_=0,
encoder_pipeline_model_parallel_size_=0,
use_tp_pp_dp_mapping=False,
):
"""
Expand Down Expand Up @@ -283,37 +293,109 @@ def fake_initialize_model_parallel(
model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
context_parallel_size = min(context_parallel_size_, world_size)

assert (
world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0
), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}'
data_parallel_size = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
if encoder_pipeline_model_parallel_size_ is None:
encoder_pipeline_model_parallel_size = 0
else:
encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_

if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0:
encoder_tensor_model_parallel_size = tensor_model_parallel_size
else:
encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_

if encoder_tensor_model_parallel_size > 0:
assert encoder_pipeline_model_parallel_size > 0
assert (
encoder_tensor_model_parallel_size <= tensor_model_parallel_size
), "We do not support encoders with more TP than the decoder."

encoder_model_size = (
encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size
)
decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
total_model_size = encoder_model_size + decoder_model_size

assert world_size % total_model_size == 0, (
f'world_size: {world_size} must be divisible by total world_size: '
f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} '
f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} '
f'* (decoder_)context_parallel_size {context_parallel_size} + '
f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} '
f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} '
f'* context_parallel_size {context_parallel_size}'
)
data_parallel_size = world_size // total_model_size

num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
encoder_world_size = encoder_model_size * data_parallel_size
decoder_world_size = decoder_model_size * data_parallel_size
Fixed Show fixed Hide fixed
assert encoder_world_size + decoder_world_size == world_size

virtual_pipeline_model_parallel_rank = None
if virtual_pipeline_model_parallel_size_ is not None:
virtual_pipeline_model_parallel_rank = 0

rank_generator = RankGenerator(
if encoder_world_size > 0:
encoder_rank_generator = RankGenerator(
tp=encoder_tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=encoder_pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=0,
)
else:
encoder_rank_generator = None

decoder_rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size_,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)

def generator_wrapper(group_type, **kwargs):
from itertools import cycle

"""The `RankGenerator` class produces a hyper-rectangle for a given set of
tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp."""
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if encoder_rank_generator is None:
for x in d_ranks:
yield x
return
e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
if group_type == 'pp':
# Map 1 encoder tp rank to several decoder tp ranks, because
# these won't be the same size.
for x, y in zip(cycle(e_ranks), d_ranks):
yield x + y
elif group_type == 'tp-pp':
# For this group, we can just return the concatenated
# groups together, because their sizes are the same.
assert len(e_ranks) == len(d_ranks)
for x, y in zip(e_ranks, d_ranks):
yield x + y
else:
for x in e_ranks:
yield x
for x in d_ranks:
yield x

# Build the data-parallel groups.
all_data_parallel_group_ranks_with_cp = []
for ranks in rank_generator.get_ranks('dp'):
for ranks in generator_wrapper('dp'):
if rank in ranks:
data_parallel_group = list(ranks)
logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}')

for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
for ranks_with_cp in generator_wrapper('dp-cp'):
all_data_parallel_group_ranks_with_cp.append(ranks_with_cp)
if rank in ranks_with_cp:
data_parallel_group_with_cp = ranks_with_cp
Expand All @@ -329,7 +411,7 @@ def fake_initialize_model_parallel(

# Build the context-parallel groups.
all_context_parallel_group_ranks = []
for ranks in rank_generator.get_ranks('cp'):
for ranks in generator_wrapper('cp'):
all_context_parallel_group_ranks.append(ranks)
if rank in ranks:
context_parallel_group = ranks
Expand All @@ -341,7 +423,7 @@ def fake_initialize_model_parallel(

# Build the model-parallel groups.
all_model_parallel_group_ranks = []
for ranks in rank_generator.get_ranks('tp-pp'):
for ranks in generator_wrapper('tp-pp'):
all_model_parallel_group_ranks.append(ranks)
if rank in ranks:
logging.info(f'Rank {rank} has model parallel group: {list(ranks)}')
Expand All @@ -350,7 +432,7 @@ def fake_initialize_model_parallel(
# Build the tensor model-parallel groups.
all_tensor_model_parallel_group_ranks = []
tensor_model_parallel_group = None
for ranks in rank_generator.get_ranks('tp'):
for ranks in generator_wrapper('tp'):
all_tensor_model_parallel_group_ranks.append(ranks)
if rank in ranks:
tensor_model_parallel_group = ranks
Expand All @@ -364,7 +446,7 @@ def fake_initialize_model_parallel(
# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
for ranks in rank_generator.get_ranks('ep', independent_ep=True):
for ranks in generator_wrapper('ep', independent_ep=True):
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank)

Expand All @@ -375,7 +457,7 @@ def fake_initialize_model_parallel(
pipeline_model_parallel_group = None
embedding_group = None
embedding_rank = None
for ranks in rank_generator.get_ranks('pp'):
for ranks in generator_wrapper('pp'):
all_pipeline_model_parallel_group_ranks.append(ranks)
if rank in ranks:
pipeline_model_parallel_group = ranks
Expand Down
4 changes: 4 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def init_parallel_ranks(
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
context_parallel_size=parallel_config.context_parallel_size,
encoder_tensor_model_parallel_size=getattr(parallel_config, "encoder_tensor_model_parallel_size", 0),
encoder_pipeline_model_parallel_size=getattr(parallel_config, "encoder_pipeline_model_parallel_size", 0),
seed=seed,
pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None),
use_fp8=fp8,
Expand Down Expand Up @@ -113,6 +115,8 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
encoder_pipeline_model_parallel_size=app_state.encoder_pipeline_model_parallel_size,
encoder_tensor_model_parallel_size=app_state.encoder_tensor_model_parallel_size,
context_parallel_size=app_state.context_parallel_size,
expert_model_parallel_size=app_state.expert_model_parallel_size,
)
Expand Down
Loading
Loading