diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index d7eaaa90e536..a116c8e60299 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -740,7 +740,8 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None): peft_state_dict = torch.load(model_weights_path, map_location)['state_dict'] else: peft_state_dict = {} - base_model_state_dict.update(peft_state_dict) # add the PEFT state_dict into the base model's state_dict + if base_model_state_dict: + base_model_state_dict.update(peft_state_dict) # add the PEFT state_dict into the base model's state_dict return base_model_state_dict def restore_from( @@ -765,13 +766,61 @@ def restore_from( return loaded_params conf, instance, state_dict = loaded_params - if ( - self.peft_model_nemo_path is None and self.peft_model_ckpt_dir is None - ): # we have this check only for training PEFT from scratch - peft_state_dict = instance.get_peft_state_dict() - state_dict.update(peft_state_dict) - state_dict = self.modify_state_dict(conf, state_dict) - self.load_instance_with_state_dict(instance, state_dict, strict) + # if we're using dist checkpointing then state_dict will be None + if state_dict is None: + # dist checkpointing needs torch.distributed to load the checkpoint + if parallel_state.is_unitialized(): + + def dummy(): + return + + if trainer.strategy.launcher is not None: + trainer.strategy.launcher.launch(dummy, trainer=trainer) + trainer.strategy.setup_environment() + + with tempfile.TemporaryDirectory() as tmpdir: + # Check if self.model_extracted_dir is set, and is a valid path + if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir): + # Log that NeMo will use the provided `model_extracted_dir` + logging.info( + f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`." + ) + + # Override `tmpdir` above with the pre-extracted `model_extracted_dir` + tmpdir = self.model_extracted_dir + + else: + # Extract the nemo file into the temporary directory + self._unpack_nemo_file( + path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True + ) + checkpoint = {} + sharded_state_dict = instance.sharded_state_dict() + peft_state_dict = instance.get_peft_state_dict() + for k in peft_state_dict.keys(): + sharded_state_dict.pop(k) + checkpoint['state_dict'] = sharded_state_dict + # remove model weights extension + tmp_model_weights_ckpt = os.path.join(tmpdir, self.model_weights_ckpt) + tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] + assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' + checkpoint = dist_checkpointing.load( + sharded_state_dict=checkpoint, checkpoint_dir=tmp_model_weights_dir + ) + checkpoint['state_dict'].update(peft_state_dict) + instance.on_load_checkpoint(checkpoint) + if hasattr(instance, 'setup_transformer_engine_tp_groups'): + instance.setup_transformer_engine_tp_groups() + + else: + if ( + self.peft_model_nemo_path is None and self.peft_model_ckpt_dir is None + ): # we have this check only for training PEFT from scratch + peft_state_dict = instance.get_peft_state_dict() + state_dict.update(peft_state_dict) + state_dict = self.modify_state_dict(conf, state_dict) + self.load_instance_with_state_dict(instance, state_dict, strict) + logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.') return instance