Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: dimapihtar <[email protected]>
  • Loading branch information
dimapihtar committed May 15, 2024
1 parent 57d7672 commit d4eec44
Showing 1 changed file with 97 additions and 71 deletions.
168 changes: 97 additions & 71 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward(self, **kwargs):
the superclass by the square root of the hidden size specified in the configuration.
"""
embeddings = super().forward(**kwargs)
return embeddings * torch.tensor(self.config.hidden_size ** 0.5, dtype=embeddings.dtype)
return embeddings * torch.tensor(self.config.hidden_size**0.5, dtype=embeddings.dtype)


class MegatronGPTExportableModel(torch.nn.Module, Exportable):
Expand All @@ -196,11 +196,14 @@ def __init__(self, model):

def forward(self, tokens, position_ids, attention_mask):
if self.fp8_enabled and HAVE_TE:
with transformer_engine.pytorch.onnx_export(self.fp8_enabled), transformer_engine.pytorch.fp8_autocast(
enabled=self.fp8_enabled, fp8_recipe=self.fp8_recipe
), torch.no_grad(), torch.inference_mode(), torch.autocast(
'cuda', dtype=self.dtype
), warnings.catch_warnings():
with (
transformer_engine.pytorch.onnx_export(self.fp8_enabled),
transformer_engine.pytorch.fp8_autocast(enabled=self.fp8_enabled, fp8_recipe=self.fp8_recipe),
torch.no_grad(),
torch.inference_mode(),
torch.autocast('cuda', dtype=self.dtype),
warnings.catch_warnings(),
):
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*')
assert tokens.shape == position_ids.shape
assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1]
Expand All @@ -211,9 +214,12 @@ def forward(self, tokens, position_ids, attention_mask):
labels=None,
)
else:
with torch.no_grad(), torch.inference_mode(), torch.autocast(
'cuda', dtype=self.dtype
), warnings.catch_warnings():
with (
torch.no_grad(),
torch.inference_mode(),
torch.autocast('cuda', dtype=self.dtype),
warnings.catch_warnings(),
):
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*')
assert tokens.shape == position_ids.shape
assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1]
Expand Down Expand Up @@ -315,7 +321,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
raise ValueError(
'Expert parallelism is currently not supporting Apex distributed optimizer, use Mcore distributed optimizer instead'
)

if self.cfg.get('num_layers', 12) % self.cfg.get('pipeline_model_parallel_size', 1) != 0:
raise ValueError(
f"num_layers ({self.cfg.get('num_layers', 12)}) should be divisible by "
Expand Down Expand Up @@ -515,7 +521,7 @@ def setup_optimizer_param_groups(self):
self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model)

def setup_mcore_distributed_parallel(self):
"""Set up mcore distributed data parallel """
"""Set up mcore distributed data parallel"""
if self.with_distributed_adam and self.use_mcore_dist_optim:
config = get_model_config(self.model[0])
ddp_config = DistributedDataParallelConfig(
Expand Down Expand Up @@ -647,7 +653,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):
if self.validation_param_sync_overlap:
param_sync_func = self.sync_overlap_parameters
elif not self.use_mcore_dist_optim:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,)
no_sync_func = partial(
self._optimizer.no_sync,
greedy_grad_copy=self.megatron_amp_O2,
)
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters
else:
Expand Down Expand Up @@ -750,9 +759,9 @@ def training_step_fwd_bwd_step_call(self, dataloader_iter, forward_only):

def training_step(self, dataloader_iter):
"""
We pass the dataloader iterator function to the micro-batch scheduler.
The input batch to each micro-batch is fetched using the dataloader function
in the micro-batch fwd function.
We pass the dataloader iterator function to the micro-batch scheduler.
The input batch to each micro-batch is fetched using the dataloader function
in the micro-batch fwd function.
"""
# Initialize userbuffer communicators.
if self.initialize_ub:
Expand Down Expand Up @@ -883,7 +892,11 @@ def training_step(self, dataloader_iter):
if self.log_memory_usage:
mem_reserved = torch.cuda.max_memory_reserved()
self.log(
'peak_memory_usage', mem_reserved, prog_bar=True, rank_zero_only=True, batch_size=1,
'peak_memory_usage',
mem_reserved,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)

## logging
Expand All @@ -907,20 +920,29 @@ def training_step(self, dataloader_iter):
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, rank_zero_only=True, batch_size=1)
self.log(
'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1,
'global_step',
self.trainer.global_step,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)

consumed_samples = self._compute_consumed_samples_after_training_step()
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1,
'consumed_samples',
consumed_samples,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)

if self.rampup_batch_size:
self.prev_global_batch_size = current_global_batch_size
self.prev_consumed_samples = consumed_samples
num_microbatch_calculator.update(
consumed_samples=consumed_samples, consistency_check=False,
consumed_samples=consumed_samples,
consistency_check=False,
)
current_global_batch_size = num_microbatch_calculator.current_global_batch_size
self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1)
Expand All @@ -929,20 +951,20 @@ def training_step(self, dataloader_iter):
return loss_mean

def backward(self, *args, **kwargs):
""" LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core.
No need to call it here.
"""LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core.
No need to call it here.
"""
return

def optimizer_zero_grad(self, *args, **kwargs):
""" LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""
return

def _append_sequence_parallel_module_grads(self, module, grads):
""" Helper method for allreduce_sequence_parallel_gradients"""
"""Helper method for allreduce_sequence_parallel_gradients"""

for param in module.parameters():
sequence_parallel_param = getattr(param, 'sequence_parallel', False) or getattr(
Expand All @@ -960,9 +982,9 @@ def _append_sequence_parallel_module_grads(self, module, grads):
grads.append(grad.data)

def allreduce_sequence_parallel_gradients(self):
""" All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used.
Modified from megatron-lm:
https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425
"""All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used.
Modified from megatron-lm:
https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425
"""

grads = []
Expand All @@ -980,8 +1002,7 @@ def allreduce_sequence_parallel_gradients(self):
buf.copy_(synced)

def allreduce_fsdp_sharding_omitted_gradients(self):
""" All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain).
"""
"""All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain)."""
assert isinstance(self.model, torch.nn.Module)
grads = []
for param in self.model.parameters():
Expand Down Expand Up @@ -1028,16 +1049,16 @@ def allreduce_first_last_embeddings(self):
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())

def _make_data_iterator_list(self, data_iterator: Iterator) -> List[Iterator]:
""" Convert data iterator into form expected by Megatron
With interleaved pipeline parallelism, Megatron expects a
list of one data iterator per model chunk. Each model
chunk independently gets data from its data iterator, so
we need to interact with the data iterator multiple times
for each microbatch step. Instead of incorporating this
logic into the data loader, we cache the iterator's output
to the first model chunk and reuse it in the other model
chunks.
"""Convert data iterator into form expected by Megatron
With interleaved pipeline parallelism, Megatron expects a
list of one data iterator per model chunk. Each model
chunk independently gets data from its data iterator, so
we need to interact with the data iterator multiple times
for each microbatch step. Instead of incorporating this
logic into the data loader, we cache the iterator's output
to the first model chunk and reuse it in the other model
chunks.
"""

if not isinstance(self.model, list) or len(self.model) == 1:
Expand Down Expand Up @@ -1329,10 +1350,10 @@ def id_func(output_tensor):

def validation_step(self, dataloader_iter, dataloader_idx=0):
"""
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.
"""
mode = 'test' if self.trainer.testing else 'val'
# Initialize userbuffer communicators.
Expand Down Expand Up @@ -1393,7 +1414,9 @@ def on_validation_epoch_end(self):
if self.loss_broadcast_src_rank is None:
self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank()
torch.distributed.broadcast(
averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
averaged_loss,
self.loss_broadcast_src_rank,
group=parallel_state.get_pipeline_model_parallel_group(),
)

self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1)
Expand Down Expand Up @@ -1498,7 +1521,10 @@ def build_train_valid_test_datasets(self):
dataset_type = MockGPTDataset if mock_dataset else GPTDataset

self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder(
dataset_type, train_valid_test_num_samples, is_dataset_built_on_rank, dataset_config,
dataset_type,
train_valid_test_num_samples,
is_dataset_built_on_rank,
dataset_config,
).build()

if self._train_ds is not None:
Expand Down Expand Up @@ -1708,16 +1734,16 @@ def list_available_models(self):
return None

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
""" PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""
return batch

def _validate_trainer(self):
""" Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""
if self.trainer.accumulate_grad_batches > 1:
raise ValueError(
Expand Down Expand Up @@ -1794,9 +1820,9 @@ def on_load_checkpoint(self, checkpoint) -> None:

def on_validation_model_zero_grad(self) -> None:
"""
Skip gradient zeroing at the beginning of validation routine.
This is needed when overlapping the AllGather of the updated parameters with the following valdation step.
"""
Skip gradient zeroing at the beginning of validation routine.
This is needed when overlapping the AllGather of the updated parameters with the following valdation step.
"""
if not self.validation_param_sync_overlap:
super().on_validation_model_zero_grad()

Expand Down Expand Up @@ -1865,9 +1891,9 @@ def initialize_last_rank_embeddings(self):
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

def _reset_activation_checkpointing_args(self):
""" Disables activation checkpointing completely and saves the values so that
_restore_activation_checkpointing_args can restore them later. This function must always be
called before _restore_activation_checkpointing_args.
"""Disables activation checkpointing completely and saves the values so that
_restore_activation_checkpointing_args can restore them later. This function must always be
called before _restore_activation_checkpointing_args.
"""
# Store values to restore them later.
self.last_activations_checkpoint_granularity = self.cfg.activations_checkpoint_granularity
Expand All @@ -1894,9 +1920,9 @@ def _reset_activation_checkpointing_args(self):
module.language_model.encoder.activations_checkpoint_layers_per_pipeline = None

def _restore_activation_checkpointing_args(self):
""" Restores the activation checkpointing parameters using the values saved by
_reset_activation_checkpointing_args. This function must never be called before
_reset_activation_checkpointing_args.
"""Restores the activation checkpointing parameters using the values saved by
_reset_activation_checkpointing_args. This function must never be called before
_reset_activation_checkpointing_args.
"""
# Restore config values.
self.cfg.activations_checkpoint_granularity = self.last_activations_checkpoint_granularity
Expand All @@ -1923,9 +1949,9 @@ def _restore_activation_checkpointing_args(self):
)

def _reset_sequence_parallelism_args(self):
""" Disables sequence parallelism completely and saves the values so that
_restore_sequence_parallelism_args can restore them later. This function must always be
called before _restore_sequence_parallelism_args.
"""Disables sequence parallelism completely and saves the values so that
_restore_sequence_parallelism_args can restore them later. This function must always be
called before _restore_sequence_parallelism_args.
"""
# Store values to restore them later.
self.last_sequence_parallel = self.cfg.sequence_parallel
Expand All @@ -1942,9 +1968,9 @@ def _reset_sequence_parallelism_args(self):
mod.sequence_parallel = False

def _restore_sequence_parallelism_args(self):
""" Restores the sequence parallelism parameters using the values saved by
_reset_sequence_parallelism_args. This function must never be called before
_reset_sequence_parallelism_args.
"""Restores the sequence parallelism parameters using the values saved by
_reset_sequence_parallelism_args. This function must never be called before
_reset_sequence_parallelism_args.
"""
# Restore config values.
self.cfg.sequence_parallel = self.last_sequence_parallel
Expand All @@ -1958,10 +1984,10 @@ def _restore_sequence_parallelism_args(self):
mod.sequence_parallel = self.last_sequence_parallel

def build_transformer_config(self) -> TransformerConfig:
""" Builds the megatron core gpt transformer config for the model.
For attributes in the nemo model config that are the same
as the megatron core TransformerConfig, we will use the value from the nemo model config.
For attributes in TransformerConfig that are not in the nemo model config, we add custom logic.
"""Builds the megatron core gpt transformer config for the model.
For attributes in the nemo model config that are the same
as the megatron core TransformerConfig, we will use the value from the nemo model config.
For attributes in TransformerConfig that are not in the nemo model config, we add custom logic.
"""

normalization = self.cfg.get('normalization', 'layernorm').lower()
Expand Down

0 comments on commit d4eec44

Please sign in to comment.