From 2eb87fea466b68d25e066ac37b39a29bcc429112 Mon Sep 17 00:00:00 2001 From: jomitchellnv <148147880+jomitchellnv@users.noreply.github.com> Date: Mon, 15 Jul 2024 09:12:49 -0700 Subject: [PATCH] enables default data step in megatron parallel to operate on a wider variety of tensors - second try (#9671) * enables default data step in megatron parallel to operate on a wider variety of tensors coming out of the dataloader Signed-off-by: Jonathan Mitchell * handles the case where a batch is empty Signed-off-by: Jonathan Mitchell * Apply isort and black reformatting Signed-off-by: jomitchellnv Signed-off-by: Jonathan Mitchell * Allows the default data step to operate on more types than just dictionaries Signed-off-by: Jonathan Mitchell * Apply isort and black reformatting Signed-off-by: jomitchellnv --------- Signed-off-by: Jonathan Mitchell Signed-off-by: jomitchellnv Co-authored-by: jomitchellnv Co-authored-by: John St. John --- nemo/lightning/megatron_parallel.py | 30 +++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 2f23087170040..43f058f700f47 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -25,9 +25,11 @@ import torch import torch.distributed +from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel as McoreDDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.transformer.transformer_config import TransformerConfig +from pytorch_lightning.utilities import move_data_to_device from torch import Tensor, nn from typing_extensions import override @@ -43,15 +45,35 @@ def convert_output(self, output: torch.Tensor) -> torch.Tensor: ... def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: + """ + Moves the data to a device. + + In this case we unpack the dataloader iterator. There may be a wrapper on the dataloader + iter from here: https://github.com/NVIDIA/NeMo/blob/main/nemo/lightning/fabric/strategies.py#L441. + + This will not subset the data for your with context parallel so please override this function if you + want to use context parallel. + + Examples: + If the dataloader_iter returns: [Tuple[, , ]] -> move to device + If the dataloader_iter returns: [, ] -> move to device + + Returns: + DataT: The data moved to the device. + """ + if parallel_state.get_context_parallel_world_size() > 1: + raise ValueError( + "Default data step is being used in a context parallel environment." + "Please define your own data step that appropriately slices the data for context parallel." + ) + batch = next(dataloader_iter) + # If its wrapped in a tuple, unpack it. if isinstance(batch, tuple) and len(batch) == 3: batch = batch[0] - if isinstance(batch, dict): - batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} - - return batch + return move_data_to_device(batch, torch.cuda.current_device()) def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor: