diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 939f2067d7eb..73263896af82 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -382,7 +382,7 @@ def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None ) -> None: app_state = AppState() - """ PTL method which we override to accomodate distributed checkpoints and + """ PTL method which we override to accomodate distributed checkpoints and the legacy model parallel checkpoints. When using megatron core, the distributed checkpointing library expects save functions to be @@ -1275,6 +1275,7 @@ def restore_from( return_config: bool = False, trainer: Trainer = None, validate_access_integrity: bool = True, + replace_sharded_tensor_key: Optional[str] = None, ): """ Restores model instance (weights and configuration) into .nemo file @@ -1362,6 +1363,9 @@ def dummy(): checkpoint = {} sharded_state_dict = instance.sharded_state_dict() checkpoint['state_dict'] = sharded_state_dict + if replace_sharded_tensor_key: + for v in checkpoint["state_dict"].values(): + v.key = v.key.replace("model", replace_sharded_tensor_key) checkpoint_io = DistributedCheckpointIO.from_config(conf) checkpoint = checkpoint_io.load_checkpoint(