diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index b07093ceafe9..d61eb839b79d 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -83,6 +83,8 @@ from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace + from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject HAVE_MEGATRON_CORE = True @@ -542,6 +544,16 @@ def _load_model_weights(self, nemo_path): sharded_state_dict = None if getattr(self, "sharded_state_dict", None) is not None: sharded_state_dict = self.sharded_state_dict(prefix="model.") + + # WAR: This is a temporary fix to skip loading FP8 parameters for Dot Product Attention + def skip_fp8_load(x): + if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key: + x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt + return x + + if getattr(self.config, 'skip_fp8_attention_checkpoint_load', True): + dict_list_map_inplace(skip_fp8_load, sharded_state_dict) + state_dict, self.is_dist_ckpt = load_nemo_model_weights(nemo_path, sharded_state_dict) return state_dict