Skip to content

Commit

Permalink
enables default data step in megatron parallel to operate on a wider …
Browse files Browse the repository at this point in the history
…variety of tensors (#9641)

* enables default data step in megatron parallel to operate on a wider variety of tensors coming out of the dataloader

* handles the case where a batch is empty

* Apply isort and black reformatting

Signed-off-by: jomitchellnv <[email protected]>

* Allows the default data step to operate on more types
than just dictionaries

Signed-off-by: Jonathan Mitchell <[email protected]>

---------

Signed-off-by: jomitchellnv <[email protected]>
Signed-off-by: Jonathan Mitchell <[email protected]>
Co-authored-by: jomitchellnv <[email protected]>
Co-authored-by: Marc Romeyn <[email protected]>
  • Loading branch information
3 people authored and monica-sekoyan committed Oct 11, 2024
1 parent 60e80d7 commit caab0a7
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

import torch
import torch.distributed
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.transformer.transformer_config import TransformerConfig
from pytorch_lightning.utilities import move_data_to_device
from torch import Tensor, nn
from typing_extensions import override

Expand All @@ -43,15 +45,43 @@ def convert_output(self, output: torch.Tensor) -> torch.Tensor: ...


def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT:
batch = next(dataloader_iter)
"""
Moves the data to a device.
In this case we utilize the match function to unpack the dataloader iterator. There may be a wrapper on the dataloader
iter from here: https://github.com/NVIDIA/NeMo/blob/main/nemo/lightning/fabric/strategies.py#L441.
if isinstance(batch, tuple) and len(batch) == 3:
batch = batch[0]
This will not subset the data for your with context parallel so please override this function if you
want to use context parallel.
if isinstance(batch, dict):
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
Examples:
If the dataloader_iter returns: [Tuple[<tensor>, <int>, <int>]] -> move to device
If the dataloader_iter returns: [<tensor>, <tensor>] -> move to device
return batch
Returns:
DataT: The data moved to the device.
"""
if parallel_state.get_context_parallel_world_size() > 1:
raise ValueError(
"Default data step is being used in a context parallel environment."
"Please define your own data step that appropriately slices the data for context parallel."
)

match next(dataloader_iter):
# If its wrapped in a tuple, unpack it.
case (batch, int(_), int(_)):
pass
# Canonical case.
case batch:
pass
# If the dataloader_iter is empty, return a ValueError.
case _:
batch = None

if batch is not None:
return move_data_to_device(batch, torch.cuda.current_device())
else:
raise ValueError("None returned from dataloader.")


def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor:
Expand Down

0 comments on commit caab0a7

Please sign in to comment.