From b5ce971cce57babb4e29181fda1139ad1667a0d1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:28:32 -0700 Subject: [PATCH] Keep max_seqlen and cu_seqlens_argmin for later micro-batches when PP>1 (#8334) (#8346) Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: Eric Harper --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 0a9c65be42ab..9c3657d4c4ef 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -937,8 +937,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # Transfer needed data to GPU required_keys = set() - max_seqlen = batch.pop('max_seqlen').squeeze() if 'max_seqlen' in batch else None - cu_seqlens_argmin = batch.pop('cu_seqlens_argmin') if 'cu_seqlens_argmin' in batch else None + max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None + cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None if parallel_state.get_pipeline_model_parallel_world_size() == 1: required_keys.update(batch.keys()) else: