Skip to content

Commit

Permalink
Allow MegatronPretrainingRandomSampler to do multi-epoch training (#8239
Browse files Browse the repository at this point in the history
)

* Initial commit of sampler updates and bug fixes

Signed-off-by: Daniel Egert <[email protected]>

* Added extra assert and some code cleanups

Signed-off-by: Daniel Egert <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor formatting changes

Signed-off-by: Daniel Egert <[email protected]>

---------

Signed-off-by: Daniel Egert <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
2 people authored and akoumpa committed Jan 30, 2024
1 parent b714028 commit 3d2ce15
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
drop_last: bool = True,
global_batch_size: Optional[int] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
seed: int = 0,
) -> None:
super().__init__(
total_samples=total_samples,
Expand All @@ -146,12 +147,19 @@ def __init__(
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
assert (
pad_samples_to_global_batch_size == False
not pad_samples_to_global_batch_size
), "`MegatronPretrainingRandomSampler` does not support sample padding"
if (not drop_last) and self.micro_batch_times_data_parallel_size > 1:
raise RuntimeError(
"`MegatronPretrainingRandomSampler` does not support drop_last=False when micro_batch_size * data_parallel_size > 1. \
please reduce your MBS and data parallelism to 1 if you want to use drop_last=False, or switch to drop_last=True to avoid this error"
)
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
self.seed = seed

def __len__(self):
num_available_samples: int = self.total_samples
active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0)
num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
Expand All @@ -175,7 +183,7 @@ def __iter__(self):
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.seed + self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,20 @@ def __init__(
self.data_parallel_size: int = data_parallel_size
self.drop_last: bool = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size

self.update_global_batch_size(global_batch_size)

def update_global_batch_size(self, new_global_batch_size: int) -> None:
"""Update the global batch size."""
self._global_batch_size = new_global_batch_size
if self._global_batch_size % (self.micro_batch_size * self.data_parallel_size) != 0:
if self._global_batch_size % self.micro_batch_times_data_parallel_size != 0:
raise RuntimeError(
f"`global_batch_size` ({self._global_batch_size}) is not divisible by "
f"`micro_batch_size ({self.micro_batch_size}) x data_parallel_size "
f"({self.data_parallel_size})`"
)
self._num_micro_batches = self._global_batch_size // (self.micro_batch_size * self.data_parallel_size)
self._num_micro_batches = self._global_batch_size // self.micro_batch_times_data_parallel_size
self._global_batch_size_on_this_data_parallel_rank = self._num_micro_batches * self.micro_batch_size

@property
Expand Down Expand Up @@ -176,10 +177,11 @@ class MegatronPretrainingRandomBatchSampler(BaseMegatronBatchSampler):
# I omit those two arguments.
# commit: https://github.com/NVIDIA/Megatron-LM/commit/7a77abd9b6267dc0020a60b424b4748fc22790bb
#
# NOTE (degert): I have re-written this class somewhat as previous implementation relied on the
# base class constructor which would have thrown in the case of consumed_samples >= total_samples
# which this class was designed to do, as that is how it implicitly calculates the current epoch
# NOTE (degert): I have re-written this class somewhat to give the length correctly when consumed_samples
# are larger than total_samples, which happens with epochs > 1 training when using this Sampler
# I have also added an explicit seed which allows us to remove Dataset-side shuffling in Nemo-Aligner
#
# This class does not currently work with pad_samples_to_global_batch_size=True
def __init__(
self,
total_samples: int,
Expand All @@ -192,32 +194,26 @@ def __init__(
pad_samples_to_global_batch_size: bool = False,
seed: int = 0,
) -> None:

# Sanity checks.
if total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(total_samples))
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
global_batch_size=global_batch_size,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
assert (
not pad_samples_to_global_batch_size
), "`MegatronPretrainingRandomBatchSampler` does not support sample padding"
if (not drop_last) and self.micro_batch_times_data_parallel_size > 1:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
data_parallel_rank, data_parallel_size
)
"`MegatronPretrainingRandomBatchSampler` does not support drop_last=False when micro_batch_size * data_parallel_size > 1. \
please reduce your MBS and data parallelism to 1 if you want to use drop_last=False, or switch to drop_last=True to avoid this error"
)

self.total_samples: int = total_samples
self.consumed_samples: int = consumed_samples
self.micro_batch_size: int = micro_batch_size
self.data_parallel_rank: int = data_parallel_rank
self.data_parallel_size: int = data_parallel_size
self.drop_last: bool = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
self.seed = seed

self.update_global_batch_size(global_batch_size)
self.last_batch_size = self.total_samples % self._global_batch_size
self.seed = seed

def __len__(self) -> int:
"""Length of Random Batch Sampler.
Expand All @@ -226,10 +222,8 @@ def __len__(self) -> int:
When `rampup_batch_size` is enabled, the return value can be not exactly precise.
"""
active_total_samples = self.total_samples - self.last_batch_size
num_available_samples = (
active_total_samples * (1 + (self.consumed_samples // active_total_samples))
) - self.consumed_samples
active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0)
num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
Expand All @@ -239,10 +233,10 @@ def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % (self.micro_batch_size * self.data_parallel_size) == 0
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# data sharding and random sampling
bucket_size = (self.total_samples // (self.micro_batch_size * self.data_parallel_size)) * self.micro_batch_size
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

Expand Down

0 comments on commit 3d2ce15

Please sign in to comment.