Skip to content

Commit

Permalink
[NeMo-UX] Fix some dataloading bugs (NVIDIA#9807) (NVIDIA#9850)
Browse files Browse the repository at this point in the history
* reconfigure limit batches

* hacky WAR to modify num_consumed_samples depending on dataset type

* fix typo

* improve design

* Apply isort and black reformatting

* minor improvements

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Co-authored-by: Anna Shors <[email protected]>
Co-authored-by: ashors1 <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>
  • Loading branch information
3 people authored and Vivian Chen committed Aug 1, 2024
1 parent e12d357 commit 5518c72
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 16 deletions.
74 changes: 59 additions & 15 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils import data
from torch.utils.data import DataLoader
from nemo.lightning.data import WrappedDataLoader

from nemo.lightning.pytorch.plugins import MegatronDataSampler

Expand Down Expand Up @@ -121,24 +121,26 @@ def setup(self, stage: str = "") -> None:
# ).build()

def train_dataloader(self) -> TRAIN_DATALOADERS:
return self._create_dataloader(self._train_ds)
return self._create_dataloader(self._train_ds, mode='train')

def val_dataloader(self) -> EVAL_DATALOADERS:
return self._create_dataloader(self._validation_ds)
return self._create_dataloader(self._validation_ds, mode='validation')

def test_dataloader(self) -> EVAL_DATALOADERS:
return self._create_dataloader(self._test_ds)
return self._create_dataloader(self._test_ds, mode='test')

def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader:
self.init_global_step = self.trainer.global_step
return DataLoader(
dataset,
dataloader = WrappedDataLoader(
mode=mode,
dataset=dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate),
**kwargs,
)
return dataloader

@property
def gpt_dataset_config(self) -> "GPTDatasetConfig":
Expand Down Expand Up @@ -185,11 +187,53 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
consistency_check=False,
)
current_global_batch_size = num_microbatch_calculator.current_global_batch_size
'''pl_module.log(
"global_batch_size",
current_global_batch_size,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)'''
self.if_first_step = 1
self.data_sampler.if_first_step = 1

def reconfigure_limit_batches(self):
# Override limit_train_batches in terms of num of microbatches
self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, 'train')
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, 'val')

def _reconfigure_limit_batches(self, limit_batches, dataloader, mode):
"""
Reconfigure trainer.limit_val_batches for pretraining
"""
# Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches num of global batches
from megatron.core.num_microbatches_calculator import get_num_microbatches

if isinstance(limit_batches, int):
limit_batches *= get_num_microbatches()
else:
assert isinstance(limit_batches, float)
# Don't reconfigure if limit_batches is 0.0 or if there's no dataloader
if limit_batches == 0.0 or dataloader is None:
return
# len(dataloader) returns len as num of microbatches
dl_len_in_micro_batches = len(dataloader)
if len(dataloader) != float("inf"):
if limit_batches == 1.0:
limit_batches = dl_len_in_micro_batches
else:
limit_micro_batches = int(dl_len_in_micro_batches * limit_batches)
if limit_micro_batches == 0 and limit_batches > 0.0:
min_percentage = 1.0 / len(dataloader)
raise MisconfigurationException(
f"You requested to check {limit_batches} of the val_dataloader but"
f" {limit_batches} * {len(dataloader)} < 1. Please increase the"
f" `limit_val_batches` argument. Try at least"
f" `limit_val_batches={min_percentage}`"
)
# Make sure trainer.limit_val_batches is a multiple of num of microbatches
if limit_micro_batches < get_num_microbatches():
limit_batches = get_num_microbatches()
else:
limit_batches = limit_batches - limit_batches % get_num_microbatches()

if mode == 'train':
self.trainer.limit_train_batches = limit_batches
else:
self.trainer.limit_val_batches = limit_batches

# Override num sanity steps to be a multiple of num of microbatches
self.trainer.num_sanity_val_steps *= get_num_microbatches()
9 changes: 9 additions & 0 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import DataLoader, Dataset


## TODO: remove? unused
def create_dataloader(
dataset: "Dataset", drop_last: bool = True, pad_samples_to_global_batch_size=False, **kwargs
) -> DataLoader:
Expand Down Expand Up @@ -127,6 +128,14 @@ def add_megatron_sampler(
)


class WrappedDataLoader(DataLoader):
"""Wrapper around torch DataLoader which stores the dataloader mode"""

def __init__(self, mode="train", **dataloader_kwargs):
super().__init__(**dataloader_kwargs)
self.mode = mode


# TODO: Replace this with megatron.core.data.data_samplers after we upgrade
class BaseMegatronSampler:
def __init__(
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def setup(self, global_rank: int) -> None:
def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
from nemo.lightning.data import add_megatron_sampler

mode = getattr(dataloader, 'mode', 'train')
return add_megatron_sampler(
dataloader,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=self.rampup_batch_size,
consumed_samples=self.init_consumed_samples,
consumed_samples=self.init_consumed_samples if mode == 'train' else 0,
dataloader_type=self.dataloader_type,
)

Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def setup(self, trainer: pl.Trainer) -> None:
if not self.data_sampler and hasattr(datamodule, "data_sampler"):
self.data_sampler = datamodule.data_sampler
self.data_sampler.setup(self.cluster_environment.global_rank())
datamodule.reconfigure_limit_batches()

if self.data_sampler:
self.data_sampler.connect(trainer)
Expand Down

0 comments on commit 5518c72

Please sign in to comment.