Skip to content

Commit

Permalink
Add support for restoring from 2.0 checkpoint in 1.0 (#11347)
Browse files Browse the repository at this point in the history
* Add support for restoring from 2.0 checkpoint in 1.0

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
hemildesai and hemildesai authored Nov 20, 2024
1 parent cbb7b17 commit 05b7d4f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
app_state = AppState()
""" PTL method which we override to accomodate distributed checkpoints and
""" PTL method which we override to accomodate distributed checkpoints and
the legacy model parallel checkpoints.
When using megatron core, the distributed checkpointing library expects save functions to be
Expand Down Expand Up @@ -1275,6 +1275,7 @@ def restore_from(
return_config: bool = False,
trainer: Trainer = None,
validate_access_integrity: bool = True,
replace_sharded_tensor_key: Optional[str] = None,
):
"""
Restores model instance (weights and configuration) into .nemo file
Expand Down Expand Up @@ -1362,6 +1363,9 @@ def dummy():
checkpoint = {}
sharded_state_dict = instance.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
if replace_sharded_tensor_key:
for v in checkpoint["state_dict"].values():
v.key = v.key.replace("model", replace_sharded_tensor_key)

checkpoint_io = DistributedCheckpointIO.from_config(conf)
checkpoint = checkpoint_io.load_checkpoint(
Expand Down

0 comments on commit 05b7d4f

Please sign in to comment.