diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index af7126357a7a1..e63e3fc09325b 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -62,9 +62,10 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: DataT: The data moved to the device. """ if parallel_state.get_context_parallel_world_size() > 1: - raise ValueError("Default data step is being used in a context parallel environment." - "Please define your own data step that appropriately slices the data for context parallel." - ) + raise ValueError( + "Default data step is being used in a context parallel environment." + "Please define your own data step that appropriately slices the data for context parallel." + ) match next(dataloader_iter): # If its wrapped in a tuple, unpack it. @@ -74,7 +75,7 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: case batch: pass - return move_data_to_device(batch, torch.cuda.current_device()) + return move_data_to_device(batch, torch.cuda.current_device()) def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor: