Skip to content

Commit

Permalink
fix for neva model sharded state dict to skip loading fp8 params
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Jul 8, 2024
1 parent 055aae4 commit d7d8ed4
Showing 1 changed file with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -542,6 +544,16 @@ def _load_model_weights(self, nemo_path):
sharded_state_dict = None
if getattr(self, "sharded_state_dict", None) is not None:
sharded_state_dict = self.sharded_state_dict(prefix="model.")

# WAR: This is a temporary fix to skip loading FP8 parameters for Dot Product Attention
def skip_fp8_load(x):
if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key:
x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt
return x

if getattr(self.config, 'skip_fp8_attention_checkpoint_load', True):
dict_list_map_inplace(skip_fp8_load, sharded_state_dict)

state_dict, self.is_dist_ckpt = load_nemo_model_weights(nemo_path, sharded_state_dict)

return state_dict
Expand Down

0 comments on commit d7d8ed4

Please sign in to comment.