From d4eec446e012eb1ac74e4e2f93c1c9d5162e6735 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Wed, 15 May 2024 20:25:18 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: dimapihtar --- .../language_modeling/megatron_gpt_model.py | 168 ++++++++++-------- 1 file changed, 97 insertions(+), 71 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 040569b1c411..823161e3427b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -174,7 +174,7 @@ def forward(self, **kwargs): the superclass by the square root of the hidden size specified in the configuration. """ embeddings = super().forward(**kwargs) - return embeddings * torch.tensor(self.config.hidden_size ** 0.5, dtype=embeddings.dtype) + return embeddings * torch.tensor(self.config.hidden_size**0.5, dtype=embeddings.dtype) class MegatronGPTExportableModel(torch.nn.Module, Exportable): @@ -196,11 +196,14 @@ def __init__(self, model): def forward(self, tokens, position_ids, attention_mask): if self.fp8_enabled and HAVE_TE: - with transformer_engine.pytorch.onnx_export(self.fp8_enabled), transformer_engine.pytorch.fp8_autocast( - enabled=self.fp8_enabled, fp8_recipe=self.fp8_recipe - ), torch.no_grad(), torch.inference_mode(), torch.autocast( - 'cuda', dtype=self.dtype - ), warnings.catch_warnings(): + with ( + transformer_engine.pytorch.onnx_export(self.fp8_enabled), + transformer_engine.pytorch.fp8_autocast(enabled=self.fp8_enabled, fp8_recipe=self.fp8_recipe), + torch.no_grad(), + torch.inference_mode(), + torch.autocast('cuda', dtype=self.dtype), + warnings.catch_warnings(), + ): warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*') assert tokens.shape == position_ids.shape assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1] @@ -211,9 +214,12 @@ def forward(self, tokens, position_ids, attention_mask): labels=None, ) else: - with torch.no_grad(), torch.inference_mode(), torch.autocast( - 'cuda', dtype=self.dtype - ), warnings.catch_warnings(): + with ( + torch.no_grad(), + torch.inference_mode(), + torch.autocast('cuda', dtype=self.dtype), + warnings.catch_warnings(), + ): warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*') assert tokens.shape == position_ids.shape assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1] @@ -315,7 +321,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): raise ValueError( 'Expert parallelism is currently not supporting Apex distributed optimizer, use Mcore distributed optimizer instead' ) - + if self.cfg.get('num_layers', 12) % self.cfg.get('pipeline_model_parallel_size', 1) != 0: raise ValueError( f"num_layers ({self.cfg.get('num_layers', 12)}) should be divisible by " @@ -515,7 +521,7 @@ def setup_optimizer_param_groups(self): self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model) def setup_mcore_distributed_parallel(self): - """Set up mcore distributed data parallel """ + """Set up mcore distributed data parallel""" if self.with_distributed_adam and self.use_mcore_dist_optim: config = get_model_config(self.model[0]) ddp_config = DistributedDataParallelConfig( @@ -647,7 +653,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): if self.validation_param_sync_overlap: param_sync_func = self.sync_overlap_parameters elif not self.use_mcore_dist_optim: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters else: @@ -750,9 +759,9 @@ def training_step_fwd_bwd_step_call(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - We pass the dataloader iterator function to the micro-batch scheduler. - The input batch to each micro-batch is fetched using the dataloader function - in the micro-batch fwd function. + We pass the dataloader iterator function to the micro-batch scheduler. + The input batch to each micro-batch is fetched using the dataloader function + in the micro-batch fwd function. """ # Initialize userbuffer communicators. if self.initialize_ub: @@ -883,7 +892,11 @@ def training_step(self, dataloader_iter): if self.log_memory_usage: mem_reserved = torch.cuda.max_memory_reserved() self.log( - 'peak_memory_usage', mem_reserved, prog_bar=True, rank_zero_only=True, batch_size=1, + 'peak_memory_usage', + mem_reserved, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) ## logging @@ -907,20 +920,29 @@ def training_step(self, dataloader_iter): lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True, batch_size=1) self.log( - 'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1, + 'global_step', + self.trainer.global_step, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) consumed_samples = self._compute_consumed_samples_after_training_step() # TODO: make sure compute_consumed_samples works for pipeline parallelism self.log( - 'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1, + 'consumed_samples', + consumed_samples, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) if self.rampup_batch_size: self.prev_global_batch_size = current_global_batch_size self.prev_consumed_samples = consumed_samples num_microbatch_calculator.update( - consumed_samples=consumed_samples, consistency_check=False, + consumed_samples=consumed_samples, + consistency_check=False, ) current_global_batch_size = num_microbatch_calculator.current_global_batch_size self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1) @@ -929,20 +951,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. """ return def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ return def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) or getattr( @@ -960,9 +982,9 @@ def _append_sequence_parallel_module_grads(self, module, grads): grads.append(grad.data) def allreduce_sequence_parallel_gradients(self): - """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. - Modified from megatron-lm: - https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 + """All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. + Modified from megatron-lm: + https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 """ grads = [] @@ -980,8 +1002,7 @@ def allreduce_sequence_parallel_gradients(self): buf.copy_(synced) def allreduce_fsdp_sharding_omitted_gradients(self): - """ All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain). - """ + """All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain).""" assert isinstance(self.model, torch.nn.Module) grads = [] for param in self.model.parameters(): @@ -1028,16 +1049,16 @@ def allreduce_first_last_embeddings(self): torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) def _make_data_iterator_list(self, data_iterator: Iterator) -> List[Iterator]: - """ Convert data iterator into form expected by Megatron - - With interleaved pipeline parallelism, Megatron expects a - list of one data iterator per model chunk. Each model - chunk independently gets data from its data iterator, so - we need to interact with the data iterator multiple times - for each microbatch step. Instead of incorporating this - logic into the data loader, we cache the iterator's output - to the first model chunk and reuse it in the other model - chunks. + """Convert data iterator into form expected by Megatron + + With interleaved pipeline parallelism, Megatron expects a + list of one data iterator per model chunk. Each model + chunk independently gets data from its data iterator, so + we need to interact with the data iterator multiple times + for each microbatch step. Instead of incorporating this + logic into the data loader, we cache the iterator's output + to the first model chunk and reuse it in the other model + chunks. """ if not isinstance(self.model, list) or len(self.model) == 1: @@ -1329,10 +1350,10 @@ def id_func(output_tensor): def validation_step(self, dataloader_iter, dataloader_idx=0): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ mode = 'test' if self.trainer.testing else 'val' # Initialize userbuffer communicators. @@ -1393,7 +1414,9 @@ def on_validation_epoch_end(self): if self.loss_broadcast_src_rank is None: self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank() torch.distributed.broadcast( - averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), + averaged_loss, + self.loss_broadcast_src_rank, + group=parallel_state.get_pipeline_model_parallel_group(), ) self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) @@ -1498,7 +1521,10 @@ def build_train_valid_test_datasets(self): dataset_type = MockGPTDataset if mock_dataset else GPTDataset self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder( - dataset_type, train_valid_test_num_samples, is_dataset_built_on_rank, dataset_config, + dataset_type, + train_valid_test_num_samples, + is_dataset_built_on_rank, + dataset_config, ).build() if self._train_ds is not None: @@ -1708,16 +1734,16 @@ def list_available_models(self): return None def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( @@ -1794,9 +1820,9 @@ def on_load_checkpoint(self, checkpoint) -> None: def on_validation_model_zero_grad(self) -> None: """ - Skip gradient zeroing at the beginning of validation routine. - This is needed when overlapping the AllGather of the updated parameters with the following valdation step. - """ + Skip gradient zeroing at the beginning of validation routine. + This is needed when overlapping the AllGather of the updated parameters with the following valdation step. + """ if not self.validation_param_sync_overlap: super().on_validation_model_zero_grad() @@ -1865,9 +1891,9 @@ def initialize_last_rank_embeddings(self): parallel_state.set_virtual_pipeline_model_parallel_rank(0) def _reset_activation_checkpointing_args(self): - """ Disables activation checkpointing completely and saves the values so that - _restore_activation_checkpointing_args can restore them later. This function must always be - called before _restore_activation_checkpointing_args. + """Disables activation checkpointing completely and saves the values so that + _restore_activation_checkpointing_args can restore them later. This function must always be + called before _restore_activation_checkpointing_args. """ # Store values to restore them later. self.last_activations_checkpoint_granularity = self.cfg.activations_checkpoint_granularity @@ -1894,9 +1920,9 @@ def _reset_activation_checkpointing_args(self): module.language_model.encoder.activations_checkpoint_layers_per_pipeline = None def _restore_activation_checkpointing_args(self): - """ Restores the activation checkpointing parameters using the values saved by - _reset_activation_checkpointing_args. This function must never be called before - _reset_activation_checkpointing_args. + """Restores the activation checkpointing parameters using the values saved by + _reset_activation_checkpointing_args. This function must never be called before + _reset_activation_checkpointing_args. """ # Restore config values. self.cfg.activations_checkpoint_granularity = self.last_activations_checkpoint_granularity @@ -1923,9 +1949,9 @@ def _restore_activation_checkpointing_args(self): ) def _reset_sequence_parallelism_args(self): - """ Disables sequence parallelism completely and saves the values so that - _restore_sequence_parallelism_args can restore them later. This function must always be - called before _restore_sequence_parallelism_args. + """Disables sequence parallelism completely and saves the values so that + _restore_sequence_parallelism_args can restore them later. This function must always be + called before _restore_sequence_parallelism_args. """ # Store values to restore them later. self.last_sequence_parallel = self.cfg.sequence_parallel @@ -1942,9 +1968,9 @@ def _reset_sequence_parallelism_args(self): mod.sequence_parallel = False def _restore_sequence_parallelism_args(self): - """ Restores the sequence parallelism parameters using the values saved by - _reset_sequence_parallelism_args. This function must never be called before - _reset_sequence_parallelism_args. + """Restores the sequence parallelism parameters using the values saved by + _reset_sequence_parallelism_args. This function must never be called before + _reset_sequence_parallelism_args. """ # Restore config values. self.cfg.sequence_parallel = self.last_sequence_parallel @@ -1958,10 +1984,10 @@ def _restore_sequence_parallelism_args(self): mod.sequence_parallel = self.last_sequence_parallel def build_transformer_config(self) -> TransformerConfig: - """ Builds the megatron core gpt transformer config for the model. - For attributes in the nemo model config that are the same - as the megatron core TransformerConfig, we will use the value from the nemo model config. - For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. + """Builds the megatron core gpt transformer config for the model. + For attributes in the nemo model config that are the same + as the megatron core TransformerConfig, we will use the value from the nemo model config. + For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. """ normalization = self.cfg.get('normalization', 'layernorm').lower()