Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "enables default data step in megatron parallel to operate on a wider variety of tensors" #9666

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 6 additions & 36 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@

import torch
import torch.distributed
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.transformer.transformer_config import TransformerConfig
from pytorch_lightning.utilities import move_data_to_device
from torch import Tensor, nn
from typing_extensions import override

Expand All @@ -45,43 +43,15 @@ def convert_output(self, output: torch.Tensor) -> torch.Tensor: ...


def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT:
"""
Moves the data to a device.

In this case we utilize the match function to unpack the dataloader iterator. There may be a wrapper on the dataloader
iter from here: https://github.com/NVIDIA/NeMo/blob/main/nemo/lightning/fabric/strategies.py#L441.
batch = next(dataloader_iter)

This will not subset the data for your with context parallel so please override this function if you
want to use context parallel.
if isinstance(batch, tuple) and len(batch) == 3:
batch = batch[0]

Examples:
If the dataloader_iter returns: [Tuple[<tensor>, <int>, <int>]] -> move to device
If the dataloader_iter returns: [<tensor>, <tensor>] -> move to device
if isinstance(batch, dict):
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}

Returns:
DataT: The data moved to the device.
"""
if parallel_state.get_context_parallel_world_size() > 1:
raise ValueError(
"Default data step is being used in a context parallel environment."
"Please define your own data step that appropriately slices the data for context parallel."
)

match next(dataloader_iter):
# If its wrapped in a tuple, unpack it.
case (batch, int(_), int(_)):
pass
# Canonical case.
case batch:
pass
# If the dataloader_iter is empty, return a ValueError.
case _:
batch = None

if batch is not None:
return move_data_to_device(batch, torch.cuda.current_device())
else:
raise ValueError("None returned from dataloader.")
return batch


def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor:
Expand Down
Loading