Skip to content

Commit

Permalink
Broadcast loss only when using pipeline parallelism and within the pi…
Browse files Browse the repository at this point in the history
…peline parallel domain (#7576) (#7586)

* Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain
* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sangkug Lym <[email protected]>
Co-authored-by: Sangkug Lym <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 3, 2023
1 parent 0a4fba3 commit 6498110
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import itertools
import os
import queue
import warnings
from dataclasses import fields
Expand Down Expand Up @@ -273,6 +274,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False)
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.loss_broadcast_src_rank = None

self.inference_params = None

Expand Down Expand Up @@ -627,17 +630,29 @@ def training_step(self, dataloader_iter, batch_idx):
self.allreduce_first_last_embeddings()

## logging
# we can only log on one rank if it is rank zero so we broadcast from last rank
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
torch.distributed.broadcast(loss_mean, get_last_rank())
if self.log_train_loss:
# When using pipeline parallelism, loss is calculated only in the last pipeline stage and
# it should be casted to other pipeline stages for logging.
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
torch.distributed.broadcast(
loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
)
self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)

# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
if self.torch_dtype == torch.float16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"):
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)
# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"):
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)

self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, rank_zero_only=True, batch_size=1)
self.log(
Expand Down Expand Up @@ -962,8 +977,19 @@ def on_validation_epoch_end(self):
else:
averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda()

# we can only log on one rank if it is rank zero so we broadcast from last rank
torch.distributed.broadcast(averaged_loss, get_last_rank())
# When using pipeline parallelism, loss is calculated only in the last pipeline stage and
# it should be casted to other pipeline stages for logging.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
torch.distributed.broadcast(
averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
)

self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1)
self.validation_step_outputs.clear() # free memory
Expand Down

0 comments on commit 6498110

Please sign in to comment.