From 360bc133b52b237882f2c0e1dcb550ae57ee6dea Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 3 Oct 2023 13:11:09 -0700 Subject: [PATCH] Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain (#7576) (#7586) * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../language_modeling/megatron_gpt_model.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 3d7a5a127399..46d3f37f5d9b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -13,6 +13,7 @@ # limitations under the License. import itertools +import os import queue import warnings from dataclasses import fields @@ -273,6 +274,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) + self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) + self.loss_broadcast_src_rank = None self.inference_params = None @@ -627,17 +630,29 @@ def training_step(self, dataloader_iter, batch_idx): self.allreduce_first_last_embeddings() ## logging - # we can only log on one rank if it is rank zero so we broadcast from last rank - # we can avoid this broadcast by updating the PTL log function to accept specific ranks - torch.distributed.broadcast(loss_mean, get_last_rank()) + if self.log_train_loss: + # When using pipeline parallelism, loss is calculated only in the last pipeline stage and + # it should be casted to other pipeline stages for logging. + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if self.loss_broadcast_src_rank is None: + dp_size = parallel_state.get_data_parallel_world_size() + tp_size = parallel_state.get_tensor_model_parallel_world_size() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size) + last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1) + self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group + torch.distributed.broadcast( + loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), + ) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) - # (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training - if self.torch_dtype == torch.float16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): - loss_scale = self.trainer.precision_plugin.scaler._scale - if loss_scale is not None: - self.log('loss_scale', loss_scale, batch_size=1) + # (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training + if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True, batch_size=1) self.log( @@ -962,8 +977,19 @@ def on_validation_epoch_end(self): else: averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda() - # we can only log on one rank if it is rank zero so we broadcast from last rank - torch.distributed.broadcast(averaged_loss, get_last_rank()) + # When using pipeline parallelism, loss is calculated only in the last pipeline stage and + # it should be casted to other pipeline stages for logging. + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if self.loss_broadcast_src_rank is None: + dp_size = parallel_state.get_data_parallel_world_size() + tp_size = parallel_state.get_tensor_model_parallel_world_size() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size) + last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1) + self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group + torch.distributed.broadcast( + averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), + ) self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) self.validation_step_outputs.clear() # free memory