From 2409503d6e82dc8c0f2e3378f31c4320a108f3bf Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 10 Jul 2024 17:38:07 +0200 Subject: [PATCH] =?UTF-8?q?Revert=20"enables=20default=20data=20step=20in?= =?UTF-8?q?=20megatron=20parallel=20to=20operate=20on=20a=20wider=20?= =?UTF-8?q?=E2=80=A6"=20(#9666)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tugrul Konuk --- nemo/lightning/megatron_parallel.py | 42 +++++------------------------ 1 file changed, 6 insertions(+), 36 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 73913ada0cff..2f2308717004 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -25,11 +25,9 @@ import torch import torch.distributed -from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel as McoreDDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.transformer.transformer_config import TransformerConfig -from pytorch_lightning.utilities import move_data_to_device from torch import Tensor, nn from typing_extensions import override @@ -45,43 +43,15 @@ def convert_output(self, output: torch.Tensor) -> torch.Tensor: ... def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: - """ - Moves the data to a device. - - In this case we utilize the match function to unpack the dataloader iterator. There may be a wrapper on the dataloader - iter from here: https://github.com/NVIDIA/NeMo/blob/main/nemo/lightning/fabric/strategies.py#L441. + batch = next(dataloader_iter) - This will not subset the data for your with context parallel so please override this function if you - want to use context parallel. + if isinstance(batch, tuple) and len(batch) == 3: + batch = batch[0] - Examples: - If the dataloader_iter returns: [Tuple[, , ]] -> move to device - If the dataloader_iter returns: [, ] -> move to device + if isinstance(batch, dict): + batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} - Returns: - DataT: The data moved to the device. - """ - if parallel_state.get_context_parallel_world_size() > 1: - raise ValueError( - "Default data step is being used in a context parallel environment." - "Please define your own data step that appropriately slices the data for context parallel." - ) - - match next(dataloader_iter): - # If its wrapped in a tuple, unpack it. - case (batch, int(_), int(_)): - pass - # Canonical case. - case batch: - pass - # If the dataloader_iter is empty, return a ValueError. - case _: - batch = None - - if batch is not None: - return move_data_to_device(batch, torch.cuda.current_device()) - else: - raise ValueError("None returned from dataloader.") + return batch def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor: