diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index 02858b119bfa..6cce2b42be9c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -149,7 +149,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: torch.Tensor = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None, inference_params: Optional[Any] = None, @@ -169,7 +169,7 @@ def forward( with torch.autocast(device_type="cuda", dtype=self.dtype): return super().forward( hidden_states, - attention_mask, + attention_mask=attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params, @@ -242,25 +242,30 @@ def __init__(self, config, layer_number=1, hidden_dropout=None): def forward( self, hidden_states, - attention_mask, + is_first_microbatch=None, + attention_mask=None, context=None, context_mask=None, rotary_pos_emb=None, inference_params=None, packed_seq_params=None, # TODO: handle this ): + # Use is_first_microbatch argument during CUDA graph capture. Use self.is_first_microbatch otherwise. hidden_states = super().forward( hidden_states, attention_mask=attention_mask, encoder_output=context, enc_dec_attn_mask=context_mask, inference_params=inference_params, - is_first_microbatch=self.is_first_microbatch, + is_first_microbatch=is_first_microbatch if is_first_microbatch is not None else self.is_first_microbatch, # checkpoint_core_attention, ) self.is_first_microbatch = False context = None + # CUDA graph requires returned values to be Tensors + if self.config.enable_cuda_graph and self.training: + return hidden_states return hidden_states, context def _get_layer_offset(self): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index f431d43716b9..31b2809476be 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -204,6 +204,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False), seed=self.cfg.get('seed', 1234), apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), + use_te_rng_tracker=self.cfg.get('use_te_rng_tracker', False), ) # This must be called after initialize model parallel since it needs to know the data parallel size diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 5d5b65b360ee..496229f40da3 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -76,6 +76,7 @@ def initialize_model_parallel_for_nemo( seed=1234, apex_transformer_log_level=30, use_tp_pp_dp_mapping=False, + use_te_rng_tracker=False, ): if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED: @@ -128,6 +129,7 @@ def initialize_model_parallel_for_nemo( set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size) set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) + tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker) if seed is not None: # @chcui not setting seed is for model conversion. always set seed for training/inference. _set_random_seed(seed) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 32bd7e6c1154..750fd0fcd93d 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -487,8 +487,6 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet for param in params: if is_float8tensor(param): param._reset_caches() - param.transpose(update_cache=True) - param._lazy_transpose_cache = True @torch.no_grad() def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None: