diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index ebd0802180e7..e878a24af0a1 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -307,7 +307,7 @@ def drop_optimizer_states(self, path): from megatron.core import dist_checkpointing dist_checkpointing.remove_sharded_tensors(path, key_prefix="optimizer") - + def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): from megatron.core import dist_checkpointing from megatron.core.dist_checkpointing.dict_utils import extract_matching_values