Skip to content

Commit

Permalink
Broadcast loss only when using pipeline parallelism and within the pi…
Browse files Browse the repository at this point in the history
…peline parallel domain (NVIDIA#7576) (NVIDIA#7586)

* Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain
* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sangkug Lym <[email protected]>
Co-authored-by: Sangkug Lym <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
3 people authored and ssh-meister committed Oct 10, 2023
1 parent e52c99b commit cbb499c
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import itertools
import os
import queue
import warnings
from dataclasses import fields
Expand Down Expand Up @@ -273,6 +274,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False)
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.loss_broadcast_src_rank = None

self.inference_params = None

Expand Down Expand Up @@ -627,17 +630,29 @@ def training_step(self, dataloader_iter, batch_idx):
self.allreduce_first_last_embeddings()

## logging
# we can only log on one rank if it is rank zero so we broadcast from last rank
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
torch.distributed.broadcast(loss_mean, get_last_rank())
if self.log_train_loss:
# When using pipeline parallelism, loss is calculated only in the last pipeline stage and
# it should be casted to other pipeline stages for logging.
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
torch.distributed.broadcast(
loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
)
self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)

# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
if self.torch_dtype == torch.float16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"):
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)
# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"):
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)

self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, rank_zero_only=True, batch_size=1)
self.log(
Expand Down Expand Up @@ -962,8 +977,19 @@ def on_validation_epoch_end(self):
else:
averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda()

# we can only log on one rank if it is rank zero so we broadcast from last rank
torch.distributed.broadcast(averaged_loss, get_last_rank())
# When using pipeline parallelism, loss is calculated only in the last pipeline stage and
# it should be casted to other pipeline stages for logging.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
torch.distributed.broadcast(
averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
)

self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1)
self.validation_step_outputs.clear() # free memory
Expand Down

0 comments on commit cbb499c

Please sign in to comment.