From 3f536828e042abea4142cc9163c742fdabb1ea72 Mon Sep 17 00:00:00 2001 From: jomitchellnv Date: Mon, 8 Jul 2024 18:36:04 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: jomitchellnv --- nemo/lightning/megatron_parallel.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index af7126357a7a1..e63e3fc09325b 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -62,9 +62,10 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: DataT: The data moved to the device. """ if parallel_state.get_context_parallel_world_size() > 1: - raise ValueError("Default data step is being used in a context parallel environment." - "Please define your own data step that appropriately slices the data for context parallel." - ) + raise ValueError( + "Default data step is being used in a context parallel environment." + "Please define your own data step that appropriately slices the data for context parallel." + ) match next(dataloader_iter): # If its wrapped in a tuple, unpack it. @@ -74,7 +75,7 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: case batch: pass - return move_data_to_device(batch, torch.cuda.current_device()) + return move_data_to_device(batch, torch.cuda.current_device()) def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor: