diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index 02858b119bfa..6cce2b42be9c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -149,7 +149,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: torch.Tensor = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None, inference_params: Optional[Any] = None, @@ -169,7 +169,7 @@ def forward( with torch.autocast(device_type="cuda", dtype=self.dtype): return super().forward( hidden_states, - attention_mask, + attention_mask=attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params, @@ -242,25 +242,30 @@ def __init__(self, config, layer_number=1, hidden_dropout=None): def forward( self, hidden_states, - attention_mask, + is_first_microbatch=None, + attention_mask=None, context=None, context_mask=None, rotary_pos_emb=None, inference_params=None, packed_seq_params=None, # TODO: handle this ): + # Use is_first_microbatch argument during CUDA graph capture. Use self.is_first_microbatch otherwise. hidden_states = super().forward( hidden_states, attention_mask=attention_mask, encoder_output=context, enc_dec_attn_mask=context_mask, inference_params=inference_params, - is_first_microbatch=self.is_first_microbatch, + is_first_microbatch=is_first_microbatch if is_first_microbatch is not None else self.is_first_microbatch, # checkpoint_core_attention, ) self.is_first_microbatch = False context = None + # CUDA graph requires returned values to be Tensors + if self.config.enable_cuda_graph and self.training: + return hidden_states return hidden_states, context def _get_layer_offset(self): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index a27f9fd5e5e4..03f494732337 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -206,6 +206,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False), seed=self.cfg.get('seed', 1234), apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), + use_te_rng_tracker=self.cfg.get('use_te_rng_tracker', False), ) # This must be called after initialize model parallel since it needs to know the data parallel size diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 341e534bcd89..55e386bb22e5 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -76,6 +76,7 @@ def initialize_model_parallel_for_nemo( seed=1234, apex_transformer_log_level=30, use_tp_pp_dp_mapping=False, + use_te_rng_tracker=False, ): if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED: @@ -128,6 +129,7 @@ def initialize_model_parallel_for_nemo( set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size) set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank) + tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker) if seed is not None: # @chcui not setting seed is for model conversion. always set seed for training/inference. _set_random_seed(seed) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 32bd7e6c1154..c7ade1c62ae1 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -136,7 +136,8 @@ def hook(*unused): self._grad_copy(param) if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False): self._try_start_bucket_grad_sync( - params=[param], ignore_last_bucket=need_to_initialize, + params=[param], + ignore_last_bucket=need_to_initialize, ) return hook @@ -167,10 +168,14 @@ def init_params( # Initialize FP8 and non-FP8 tensors separately if any(is_float8tensor(param) for param in params): super().init_params( - filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs, + filter(is_float8tensor, params), + param_sync_dtype=torch.uint8, + **kwargs, ) super().init_params( - params, param_sync_dtype=param_sync_dtype, **kwargs, + params, + param_sync_dtype=param_sync_dtype, + **kwargs, ) def init_params_bucket( @@ -200,7 +205,10 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs, + fp32_params, + grad_sync_dtype=torch.float32, + param_sync_dtype=param_sync_dtype, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] @@ -216,7 +224,10 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs, + fp8_params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=torch.uint8, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] @@ -225,12 +236,18 @@ def init_params_bucket( normal_buckets = [] start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs, + params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] - def add_param_to_bucket(param: torch.nn.Parameter, bucket: self.StateBucket,) -> None: + def add_param_to_bucket( + param: torch.nn.Parameter, + bucket: self.StateBucket, + ) -> None: """Add trivial param fragment to bucket""" param_fragments = self.state[param]["fragments"] param_group_id = param_fragments[0].param_group_id @@ -283,7 +300,11 @@ def _init_param_state( # Initialize non-FP8 params as usual if not is_float8tensor(param): super()._init_param_state( - param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, + param, + param_group_id, + param_id, + param_sync_dtype=param_sync_dtype, + **kwargs, ) # Return immediately if already initialized @@ -293,7 +314,11 @@ def _init_param_state( # Initialize with FP32 copy of param fp32_param = param.float() super()._init_param_state( - fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, + fp32_param, + param_group_id, + param_id, + param_sync_dtype=torch.uint8, + **kwargs, ) self.state[param].update(self.state[fp32_param]) del self.state[fp32_param] @@ -360,7 +385,9 @@ def init_param_buffer(self) -> None: # Copy values into param buffer _multi_tensor_copy( - param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf, + param_flat_views, + param_buffer_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Make all params a view into the param buffer @@ -393,7 +420,10 @@ def zero_grad(self, *args, **kwargs) -> None: param.main_grad = self.grad_buffer_view(param) def grad_norm( - self, parameters: Optional[Iterable[torch.nn.Parameter]] = None, norm_type: float = 2.0, force: bool = False, + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None, + norm_type: float = 2.0, + force: bool = False, ) -> torch.Tensor: assert norm_type == 2 @@ -411,7 +441,8 @@ def grad_norm( # Sum over all procs to get grad norm torch.distributed.all_reduce( - grad_norm_sq, op=torch.distributed.ReduceOp.SUM, + grad_norm_sq, + op=torch.distributed.ReduceOp.SUM, ) self._grad_norm = grad_norm_sq.sqrt() @@ -479,7 +510,9 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet # Copy data from parameter buckets to parameters _multi_tensor_copy( - buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Update transpose caches @@ -487,8 +520,6 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet for param in params: if is_float8tensor(param): param._reset_caches() - param.transpose(update_cache=True) - param._lazy_transpose_cache = True @torch.no_grad() def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None: @@ -570,11 +601,15 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, + scales, + packed_scale_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) torch.reciprocal(packed_scales, out=packed_scales) _multi_tensor_copy( - packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, + packed_scale_views, + scale_invs, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Reduce amaxes @@ -582,13 +617,19 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + amaxes, + packed_amax_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) torch.distributed.all_reduce( - packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + packed_amaxes, + op=torch.distributed.ReduceOp.MAX, + group=self.distributed_process_group, ) _multi_tensor_copy( - packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + packed_amax_views, + amaxes, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Reset @@ -602,7 +643,8 @@ def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None optimizer_state_dict = self.state_dict() id_to_sharded_param_map = get_param_id_to_sharded_param_map( - model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(), + model_sharded_state_dict=model_sharded_state_dict, + optim_params_iter=self.parameters(), ) # Convert state step = optimizer_state_dict['state'].pop('step')