From 4dfbe36420c09d6c3c2d7df36ed6bd847c6c6138 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 8 Jul 2024 11:34:18 -0700 Subject: [PATCH 1/4] enables default data step in megatron parallel to operate on a wider variety of tensors coming out of the dataloader --- nemo/lightning/megatron_parallel.py | 35 ++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 2f2308717004..af7126357a7a 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -25,9 +25,11 @@ import torch import torch.distributed +from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel as McoreDDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.transformer.transformer_config import TransformerConfig +from pytorch_lightning.utilities import move_data_to_device from torch import Tensor, nn from typing_extensions import override @@ -43,15 +45,36 @@ def convert_output(self, output: torch.Tensor) -> torch.Tensor: ... def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: - batch = next(dataloader_iter) + """ + Moves the data to a device. + + In this case we utilize the match function to unpack the dataloader iterator. There may be a wrapper on the dataloader + iter from here: https://github.com/NVIDIA/NeMo/blob/main/nemo/lightning/fabric/strategies.py#L441. + + This will not subset the data for your with context parallel so please override this function if you + want to use context parallel. - if isinstance(batch, tuple) and len(batch) == 3: - batch = batch[0] + Examples: + If the dataloader_iter returns: [Tuple[, , ]] -> move to device + If the dataloader_iter returns: [, ] -> move to device - if isinstance(batch, dict): - batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} + Returns: + DataT: The data moved to the device. + """ + if parallel_state.get_context_parallel_world_size() > 1: + raise ValueError("Default data step is being used in a context parallel environment." + "Please define your own data step that appropriately slices the data for context parallel." + ) + + match next(dataloader_iter): + # If its wrapped in a tuple, unpack it. + case (batch, int(_), int(_)): + pass + # Canonical case. + case batch: + pass - return batch + return move_data_to_device(batch, torch.cuda.current_device()) def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor: From 70cd4c4ccb01122d7d9a600cb5a61b54e1df59c4 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 8 Jul 2024 11:57:07 -0700 Subject: [PATCH 2/4] handles the case where a batch is empty --- nemo/lightning/megatron_parallel.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index af7126357a7a..a9d0e496afba 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -73,8 +73,14 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: # Canonical case. case batch: pass - - return move_data_to_device(batch, torch.cuda.current_device()) + # If the dataloader_iter is empty, return None. + case _: + batch = None + + if batch is not None: + return move_data_to_device(batch, torch.cuda.current_device()) + else: + raise ValueError("No valid batch found from dataloader_iter.") def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor: From 1778c4f90f0a24297f54390533306fe4851a66e0 Mon Sep 17 00:00:00 2001 From: jomitchellnv Date: Mon, 8 Jul 2024 18:58:38 +0000 Subject: [PATCH 3/4] Apply isort and black reformatting Signed-off-by: jomitchellnv --- nemo/lightning/megatron_parallel.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index a9d0e496afba..10bf3b4fe30f 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -62,9 +62,10 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: DataT: The data moved to the device. """ if parallel_state.get_context_parallel_world_size() > 1: - raise ValueError("Default data step is being used in a context parallel environment." - "Please define your own data step that appropriately slices the data for context parallel." - ) + raise ValueError( + "Default data step is being used in a context parallel environment." + "Please define your own data step that appropriately slices the data for context parallel." + ) match next(dataloader_iter): # If its wrapped in a tuple, unpack it. @@ -76,9 +77,9 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: # If the dataloader_iter is empty, return None. case _: batch = None - + if batch is not None: - return move_data_to_device(batch, torch.cuda.current_device()) + return move_data_to_device(batch, torch.cuda.current_device()) else: raise ValueError("No valid batch found from dataloader_iter.") From 5971ff5137342ac7071d0b0ceff16c445401f892 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 9 Jul 2024 12:03:10 -0700 Subject: [PATCH 4/4] Allows the default data step to operate on more types than just dictionaries Signed-off-by: Jonathan Mitchell --- nemo/lightning/megatron_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 10bf3b4fe30f..73913ada0cff 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -74,14 +74,14 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: # Canonical case. case batch: pass - # If the dataloader_iter is empty, return None. + # If the dataloader_iter is empty, return a ValueError. case _: batch = None if batch is not None: return move_data_to_device(batch, torch.cuda.current_device()) else: - raise ValueError("No valid batch found from dataloader_iter.") + raise ValueError("None returned from dataloader.") def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor: