Skip to content

Commit

Permalink
Add config option for FP32 embedding grads (#8953)
Browse files Browse the repository at this point in the history
* Add config option for FP32 embedding grads (#8946)

Signed-off-by: Tim Moon <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ericharper <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: ericharper <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: ericharper <[email protected]>
Signed-off-by: Marc Romeyn <[email protected]>
  • Loading branch information
4 people authored and marcromeyn committed Jun 7, 2024
1 parent 4ef4b75 commit f56141b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def configure_optimizers(self):
if self.with_distributed_adam and not self.use_mcore_dist_optim:

# Special handling for embedding grads
with_fp32_embedding_grads = self.cfg.get('with_fp32_embedding_grads', True)
modules = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
module = modules[0] # first virtual rank has the embeddings
Expand All @@ -558,7 +559,7 @@ def configure_optimizers(self):
word_embeddings = (
module.shared_embedding_or_output_weight() if self.mcore_gpt else module.word_embeddings_weight()
)
word_embeddings._with_fp32_optimizer = True
word_embeddings._with_fp32_optimizer = with_fp32_embedding_grads
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.get(
'share_embeddings_and_output_weights', True
):
Expand All @@ -573,7 +574,7 @@ def configure_optimizers(self):
else:
position_embeddings = module.position_embeddings_weight()
if position_embeddings is not None:
position_embeddings._with_fp32_optimizer = True
position_embeddings._with_fp32_optimizer = with_fp32_embedding_grads

# Handle case where embeddings are used in output layer
if parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.cfg.get(
Expand All @@ -583,7 +584,7 @@ def configure_optimizers(self):
word_embeddings = (
module.shared_embedding_or_output_weight() if self.mcore_gpt else module.word_embeddings_weight()
)
word_embeddings._with_fp32_optimizer = True
word_embeddings._with_fp32_optimizer = with_fp32_embedding_grads
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
word_embeddings._disable_greedy_grad_copy = not self.megatron_amp_O2
word_embeddings._disable_overlap_grad_sync = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ def configure_optimizers(self):
param._disable_overlap_grad_sync = True

# Make sure embedding grads are reduced in FP32
with_fp32_embedding_grads = self.cfg.get('with_fp32_embedding_grads', True)
for name, param in self.named_parameters():
if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name:
param._with_fp32_optimizer = True
param._with_fp32_optimizer = with_fp32_embedding_grads

return super().configure_optimizers()

Expand Down Expand Up @@ -346,8 +347,8 @@ def _execute_fwd_bwd_function(self, data_iterator, forward_only, tensor_shape, d

def fwd_bwd_step(self, dataloader_iter, forward_only):
"""
Dataloader produces a global batch which is turned into a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
Dataloader produces a global batch which is turned into a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
"""
# Get seq length of batch
tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size]
Expand All @@ -361,12 +362,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):

def training_step(self, dataloader_iter):
"""
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
"""
# we zero grads here because we also call backward in the megatron fwd/bwd functions
self._optimizer.zero_grad()
Expand Down Expand Up @@ -408,7 +409,11 @@ def training_step(self, dataloader_iter):
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, rank_zero_only=True, batch_size=1)
self.log(
'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1,
'global_step',
self.trainer.global_step,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
Expand All @@ -432,21 +437,21 @@ def max_encoder_seq_length(self) -> int:
return self.cfg.seq_length

def backward(self, *args, **kwargs):
""" LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core.
No need to call it here.
"""LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core.
No need to call it here.
"""
return

def optimizer_zero_grad(self, *args, **kwargs):
""" LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""
return

def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks.
Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/model/distributed.py#L188
Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/model/distributed.py#L188
"""
# Bucketize and all-reduce
buckets = {}
Expand Down Expand Up @@ -768,10 +773,16 @@ def _test_validation_epoch_end(self, step_outputs, prefix):
def on_validation_epoch_end(self):
# FIXME: do we need this? 'global_step' is logged in training_step
self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1)
return self._test_validation_epoch_end(step_outputs=self.validation_step_outputs, prefix="val",)
return self._test_validation_epoch_end(
step_outputs=self.validation_step_outputs,
prefix="val",
)

def on_test_epoch_end(self):
return self._test_validation_epoch_end(step_outputs=self.test_step_outputs, prefix="test",)
return self._test_validation_epoch_end(
step_outputs=self.test_step_outputs,
prefix="test",
)

def loss_func(self, loss_mask, tokens_loss):
"""
Expand All @@ -784,7 +795,7 @@ def loss_func(self, loss_mask, tokens_loss):
return loss

def process_micro_batch(self, micro_batch):
""" Micro batch returned by MegatronT5 dataloader"""
"""Micro batch returned by MegatronT5 dataloader"""

data_b = micro_batch

Expand All @@ -800,8 +811,8 @@ def process_micro_batch(self, micro_batch):
return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask

def _process_global_batch_without_megatron_batch_sampler(self, global_batch, tokenizer=None):
""" Prepares the global batch for megatron-core fwd/bwd functions.
Global batch is a list of micro batches.
"""Prepares the global batch for megatron-core fwd/bwd functions.
Global batch is a list of micro batches.
"""
tokenizer = self.tokenizer if tokenizer is None else tokenizer
text_enc_list = []
Expand Down Expand Up @@ -1076,7 +1087,11 @@ def dummy():
# Setting it to anything else will cause hanging due to tensor shape mismatches.
output_tensor = fwd_bwd_func(
forward_step_func=forward_step_func,
data_iterator=iter([batch_for_pipeline,]),
data_iterator=iter(
[
batch_for_pipeline,
]
),
model=[self.enc_dec_model],
forward_only=True,
num_microbatches=1,
Expand Down Expand Up @@ -1242,7 +1257,11 @@ def dummy():

output_tensor = fwd_bwd_func(
forward_step_func=forward_step_func,
data_iterator=iter([batch_for_pipeline,]),
data_iterator=iter(
[
batch_for_pipeline,
]
),
model=[self.enc_dec_model],
forward_only=True,
num_microbatches=1,
Expand Down Expand Up @@ -1322,21 +1341,21 @@ def dummy():
# choose top-k hypotheses with length penalty applied
len_penalties = compute_beam_search_len_penalty(decoder_seq_lengths, beam_alpha)
scores = scores / len_penalties
scores, indices = sample_token_fn(scores.view(-1, beam_size ** 2), dim=1, log_softmax=False)
scores, indices = sample_token_fn(scores.view(-1, beam_size**2), dim=1, log_softmax=False)
scores = scores.view(-1, 1) * len_penalties

# select predicted sequences which correspond to the chosen hypotheses
predicted_tokens_dec = predicted_tokens_dec.unsqueeze(1).repeat(1, beam_size, 1)
predicted_tokens_dec = torch.cat((predicted_tokens_dec, token_ids.unsqueeze(2)), dim=2)
predicted_tokens_dec = predicted_tokens_dec.view(batch_size, beam_size ** 2, -1)
predicted_tokens_dec = predicted_tokens_dec.view(batch_size, beam_size**2, -1)
p_len = predicted_tokens_dec.size(2)
predicted_tokens_dec_ids = indices.unsqueeze(2).repeat(1, 1, p_len)
predicted_tokens_dec = predicted_tokens_dec.gather(1, predicted_tokens_dec_ids).view(-1, p_len)

# select logits which correspond to the chosen hypotheses
predicted_log_probs = predicted_log_probs.unsqueeze(1).repeat(1, beam_size, 1)
predicted_log_probs = torch.cat((predicted_log_probs, log_probs.unsqueeze(2)), dim=2)
predicted_log_probs = predicted_log_probs.view(batch_size, beam_size ** 2, -1)
predicted_log_probs = predicted_log_probs.view(batch_size, beam_size**2, -1)
predicted_log_probs = predicted_log_probs.gather(1, predicted_tokens_dec_ids[:, :, 1:]).view(
-1, p_len - 1
)
Expand Down Expand Up @@ -1482,16 +1501,16 @@ def complete(self, request: Dict):
return response

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
""" PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""
return batch

def _validate_trainer(self):
""" Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""
if self.trainer.accumulate_grad_batches > 1:
raise ValueError(
Expand All @@ -1502,8 +1521,7 @@ def list_available_models(self):
pass

def build_model_parallel_config(self):
""" Hidden size needs to be set from the cfg.encoder for the pipeline schedule.
"""
"""Hidden size needs to be set from the cfg.encoder for the pipeline schedule."""

model_parallel_config = super().build_model_parallel_config()
try:
Expand Down

0 comments on commit f56141b

Please sign in to comment.