Skip to content

Commit

Permalink
Wrap batch_sampler with _IndexBatchSamplerWrapper (NVIDIA#10934)
Browse files Browse the repository at this point in the history
* wrap batch_sampler

Signed-off-by: Farhad Ramezanghorbani <[email protected]>

* Apply isort and black reformatting

Signed-off-by: farhadrgh <[email protected]>

* pass dataloader mode

* Apply isort and black reformatting

Signed-off-by: farhadrgh <[email protected]>

* pass dataloader mode

Signed-off-by: Farhad Ramezanghorbani <[email protected]>

* pass dataloader mode

Signed-off-by: Farhad Ramezanghorbani <[email protected]>

* resolve conflict

Signed-off-by: Farhad Ramezanghorbani <[email protected]>

* change import

Signed-off-by: Farhad Ramezanghorbani <[email protected]>

---------

Signed-off-by: Farhad Ramezanghorbani <[email protected]>
Signed-off-by: farhadrgh <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 5, 2024
1 parent 888b50c commit 697abdf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 6 additions & 0 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List, Literal, Optional

import torch
from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper
from torch.utils.data import DataLoader, Dataset


Expand Down Expand Up @@ -139,6 +140,7 @@ def add_megatron_sampler(
dataloader_type: Literal["single", "cyclic", "batch"] = "single",
drop_last: bool = True,
pad_samples_to_global_batch_size: bool = False,
dataloader_mode: Literal["train", "validation", "test", "predict"] = "train",
rank: int = 0,
world_size: int = 1,
# data_sharding: bool = False
Expand Down Expand Up @@ -170,6 +172,7 @@ def add_megatron_sampler(
pad_samples_to_global_batch_size (bool, optional): Whether to pad the last incomplete
batch to the `global_batch_size` (defaults to False, only applies when
`drop_last` is False).
dataloader_mode (Literal["train", "validation", "test", "predict"]): The mode of dataloader.
Returns:
DataLoader: A new DataLoader instance with the configured Megatron sampler.
Expand Down Expand Up @@ -214,6 +217,9 @@ def add_megatron_sampler(
else:
raise Exception(f'{dataloader_type} dataloader type is not supported.')

if dataloader_mode in ["test", "predict"]:
batch_sampler = _IndexBatchSamplerWrapper(batch_sampler) # BatchSampler wrapper to capture its indices

return DataLoader(
dataloader.dataset,
batch_sampler=batch_sampler,
Expand Down
5 changes: 2 additions & 3 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
init_consumed_samples: int = 0,
init_global_step: int = 0,
output_log: bool = True,
drop_last: bool = True,
):
self.seq_len = seq_len
self.output_log = output_log
Expand All @@ -57,7 +56,6 @@ def __init__(
self.if_first_step = 0
self.prev_global_batch_size = None
self.init_global_step = init_global_step
self.drop_last = drop_last

def setup(self, global_rank: int) -> None:
from nemo.lightning.data import setup_microbatch_calculator
Expand All @@ -80,7 +78,8 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0
rampup_batch_size=self.rampup_batch_size,
consumed_samples=self.init_consumed_samples if mode == 'train' else 0,
dataloader_type=self.dataloader_type,
drop_last=self.drop_last,
drop_last=mode not in ["test", "predict"], # don't drop the incomplete batch in test and predict methods
dataloader_mode=mode, # dataloader wrapped with nemo.lightning.data.WrappedDataLoader has mode attribute
rank=data_parallel_rank,
world_size=data_parallel_size,
)
Expand Down

0 comments on commit 697abdf

Please sign in to comment.