Skip to content

Commit

Permalink
Keep max_seqlen and cu_seqlens_argmin for later micro-batches when PP…
Browse files Browse the repository at this point in the history
…>1 (#8334) (#8346)

Signed-off-by: Sangkug Lym <[email protected]>
Co-authored-by: Sangkug Lym <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
3 people authored Feb 16, 2024
1 parent 5a86625 commit b5ce971
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -937,8 +937,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

# Transfer needed data to GPU
required_keys = set()
max_seqlen = batch.pop('max_seqlen').squeeze() if 'max_seqlen' in batch else None
cu_seqlens_argmin = batch.pop('cu_seqlens_argmin') if 'cu_seqlens_argmin' in batch else None
max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None
cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
required_keys.update(batch.keys())
else:
Expand Down

0 comments on commit b5ce971

Please sign in to comment.