Skip to content

Commit

Permalink
update megatron_init.py inside lightning
Browse files Browse the repository at this point in the history
Signed-off-by: yaoyu-33 <[email protected]>
  • Loading branch information
yaoyu-33 committed Oct 25, 2024
1 parent 8a13481 commit e7e6798
Showing 1 changed file with 99 additions and 17 deletions.
116 changes: 99 additions & 17 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def initialize_model_parallel_for_nemo(
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
context_parallel_size=1,
encoder_tensor_model_parallel_size=0,
encoder_pipeline_model_parallel_size=0,
micro_batch_size=None,
global_batch_size=None,
rampup_batch_size=None,
Expand All @@ -120,6 +122,8 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.context_parallel_size = context_parallel_size
app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size
app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
(
Expand All @@ -139,6 +143,8 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size,
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
)

Expand All @@ -149,12 +155,14 @@ def initialize_model_parallel_for_nemo(
set_expert_model_parallel_world_size(app_state.expert_model_parallel_size)
set_expert_model_parallel_rank(app_state.expert_model_parallel_rank)

set_pipeline_model_parallel_world_size(
app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size
)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)
set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank)
if HAVE_INTERLEAVED:
set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size)
set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank)
set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)

tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker)
if seed is not None:
Expand Down Expand Up @@ -247,6 +255,8 @@ def fake_initialize_model_parallel(
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
context_parallel_size_=1,
encoder_tensor_model_parallel_size_=0,
encoder_pipeline_model_parallel_size_=0,
use_tp_pp_dp_mapping=False,
):
"""
Expand Down Expand Up @@ -283,37 +293,109 @@ def fake_initialize_model_parallel(
model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
context_parallel_size = min(context_parallel_size_, world_size)

assert (
world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0
), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}'
data_parallel_size = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
if encoder_pipeline_model_parallel_size_ is None:
encoder_pipeline_model_parallel_size = 0
else:
encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_

if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0:
encoder_tensor_model_parallel_size = tensor_model_parallel_size
else:
encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_

if encoder_tensor_model_parallel_size > 0:
assert encoder_pipeline_model_parallel_size > 0
assert (
encoder_tensor_model_parallel_size <= tensor_model_parallel_size
), "We do not support encoders with more TP than the decoder."

encoder_model_size = (
encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size
)
decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
total_model_size = encoder_model_size + decoder_model_size

assert world_size % total_model_size == 0, (
f'world_size: {world_size} must be divisible by total world_size: '
f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} '
f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} '
f'* (decoder_)context_parallel_size {context_parallel_size} + '
f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} '
f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} '
f'* context_parallel_size {context_parallel_size}'
)
data_parallel_size = world_size // total_model_size

num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
encoder_world_size = encoder_model_size * data_parallel_size
decoder_world_size = decoder_model_size * data_parallel_size
assert encoder_world_size + decoder_world_size == world_size

virtual_pipeline_model_parallel_rank = None
if virtual_pipeline_model_parallel_size_ is not None:
virtual_pipeline_model_parallel_rank = 0

rank_generator = RankGenerator(
if encoder_world_size > 0:
encoder_rank_generator = RankGenerator(
tp=encoder_tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=encoder_pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=0,
)
else:
encoder_rank_generator = None

decoder_rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size_,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)

def generator_wrapper(group_type, **kwargs):
from itertools import cycle

"""The `RankGenerator` class produces a hyper-rectangle for a given set of
tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp."""
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if encoder_rank_generator is None:
for x in d_ranks:
yield x
return
e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
if group_type == 'pp':
# Map 1 encoder tp rank to several decoder tp ranks, because
# these won't be the same size.
for x, y in zip(cycle(e_ranks), d_ranks):
yield x + y
elif group_type == 'tp-pp':
# For this group, we can just return the concatenated
# groups together, because their sizes are the same.
assert len(e_ranks) == len(d_ranks)
for x, y in zip(e_ranks, d_ranks):
yield x + y
else:
for x in e_ranks:
yield x
for x in d_ranks:
yield x

# Build the data-parallel groups.
all_data_parallel_group_ranks_with_cp = []
for ranks in rank_generator.get_ranks('dp'):
for ranks in generator_wrapper('pp'):
if rank in ranks:
data_parallel_group = list(ranks)
logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}')

for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
for ranks_with_cp in generator_wrapper('dp-cp'):
all_data_parallel_group_ranks_with_cp.append(ranks_with_cp)
if rank in ranks_with_cp:
data_parallel_group_with_cp = ranks_with_cp
Expand All @@ -329,7 +411,7 @@ def fake_initialize_model_parallel(

# Build the context-parallel groups.
all_context_parallel_group_ranks = []
for ranks in rank_generator.get_ranks('cp'):
for ranks in generator_wrapper('cp'):
all_context_parallel_group_ranks.append(ranks)
if rank in ranks:
context_parallel_group = ranks
Expand All @@ -341,7 +423,7 @@ def fake_initialize_model_parallel(

# Build the model-parallel groups.
all_model_parallel_group_ranks = []
for ranks in rank_generator.get_ranks('tp-pp'):
for ranks in generator_wrapper('tp-pp'):
all_model_parallel_group_ranks.append(ranks)
if rank in ranks:
logging.info(f'Rank {rank} has model parallel group: {list(ranks)}')
Expand All @@ -350,7 +432,7 @@ def fake_initialize_model_parallel(
# Build the tensor model-parallel groups.
all_tensor_model_parallel_group_ranks = []
tensor_model_parallel_group = None
for ranks in rank_generator.get_ranks('tp'):
for ranks in generator_wrapper('tp'):
all_tensor_model_parallel_group_ranks.append(ranks)
if rank in ranks:
tensor_model_parallel_group = ranks
Expand All @@ -364,7 +446,7 @@ def fake_initialize_model_parallel(
# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
for ranks in rank_generator.get_ranks('ep', independent_ep=True):
for ranks in generator_wrapper('ep', independent_ep=True):
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank)

Expand All @@ -375,7 +457,7 @@ def fake_initialize_model_parallel(
pipeline_model_parallel_group = None
embedding_group = None
embedding_rank = None
for ranks in rank_generator.get_ranks('pp'):
for ranks in generator_wrapper('pp'):
all_pipeline_model_parallel_group_ranks.append(ranks)
if rank in ranks:
pipeline_model_parallel_group = ranks
Expand Down

0 comments on commit e7e6798

Please sign in to comment.