Skip to content

Commit

Permalink
Broadcast loss only when using pipeline parallelism and within the pi…
Browse files Browse the repository at this point in the history
…peline parallel domain

Signed-off-by: Sangkug Lym <[email protected]>
  • Loading branch information
erhoo82 committed Sep 29, 2023
1 parent 165c275 commit 74fa963
Showing 1 changed file with 41 additions and 11 deletions.
52 changes: 41 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import itertools
import queue
import warnings
Expand Down Expand Up @@ -282,6 +283,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False)
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.loss_broadcast_src_rank = None

def get_gpt_module_list(self):
if isinstance(self.model, list):
Expand Down Expand Up @@ -595,17 +598,31 @@ def training_step(self, dataloader_iter, batch_idx):
self.allreduce_first_last_embeddings()

## logging
# we can only log on one rank if it is rank zero so we broadcast from last rank
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
torch.distributed.broadcast(loss_mean, get_last_rank())
if self.log_train_loss:
# When using pipeline parallelism, loss is calculated only in the last pipeline stage and
# it should be casted to other pipeline stages for logging.
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
local_rank = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + local_rank
torch.distributed.broadcast(
loss_mean,
self.loss_broadcast_src_rank,
group=parallel_state.get_pipeline_model_parallel_group(),
)
self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)

# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"):
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)
# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"):
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)

self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, rank_zero_only=True, batch_size=1)
self.log(
Expand Down Expand Up @@ -897,8 +914,21 @@ def validation_epoch_end(self, outputs):
else:
averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda()

# we can only log on one rank if it is rank zero so we broadcast from last rank
torch.distributed.broadcast(averaged_loss, get_last_rank())
# When using pipeline parallelism, loss is calculated only in the last pipeline stage and
# it should be casted to other pipeline stages for logging.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
local_rank = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + local_rank
torch.distributed.broadcast(
averaged_loss,
self.loss_broadcast_src_rank,
group=parallel_state.get_pipeline_model_parallel_group(),
)

self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1)

Expand Down

0 comments on commit 74fa963

Please sign in to comment.