diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 2ed6ede6b..a31421a75 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -160,7 +160,11 @@ def __init__( "%j", self.config["slurm"]["job_id"] ) if distutils.is_master(): - add_timestamp_id_to_submission_pickle(self.config["slurm"]["folder"], self.config["slurm"]["job_id"], self.timestamp_id) + add_timestamp_id_to_submission_pickle( + self.config["slurm"]["folder"], + self.config["slurm"]["job_id"], + self.timestamp_id, + ) # Define datasets if isinstance(dataset, list): @@ -426,9 +430,11 @@ def load_references_and_normalizers(self): elementref_config, dataset=self.train_dataset, seed=self.config["cmd"]["seed"], - checkpoint_dir=self.config["cmd"]["checkpoint_dir"] - if not self.is_debug - else None, + checkpoint_dir=( + self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None + ), ) if norms_config is not None: @@ -436,9 +442,11 @@ def load_references_and_normalizers(self): norms_config, dataset=self.train_dataset, seed=self.config["cmd"]["seed"], - checkpoint_dir=self.config["cmd"]["checkpoint_dir"] - if not self.is_debug - else None, + checkpoint_dir=( + self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None + ), element_references=elementrefs, ) @@ -487,15 +495,15 @@ def load_task(self): ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["train_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("train_on_free_atoms", True) + self.config["outputs"][target_name].get( + "train_on_free_atoms", True + ) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["eval_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("eval_on_free_atoms", True) + self.config["outputs"][target_name].get( + "eval_on_free_atoms", True + ) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent @@ -551,13 +559,13 @@ def _unwrapped_model(self): def load_checkpoint( self, checkpoint_path: str, checkpoint: dict | None = None ) -> None: + map_location = torch.device("cpu") if self.cpu else self.device if checkpoint is None: if not os.path.isfile(checkpoint_path): raise FileNotFoundError( errno.ENOENT, "Checkpoint file not found", checkpoint_path ) logging.info(f"Loading checkpoint from: {checkpoint_path}") - map_location = torch.device("cpu") if self.cpu else self.device checkpoint = torch.load(checkpoint_path, map_location=map_location) self.epoch = checkpoint.get("epoch", 0) @@ -600,13 +608,14 @@ def load_checkpoint( mkeys = self.normalizers[target_key].load_state_dict( checkpoint["normalizers"][key] ) + self.normalizers[target_key].to(map_location) assert len(mkeys.missing_keys) == 0 assert len(mkeys.unexpected_keys) == 0 for key, state_dict in checkpoint.get("elementrefs", {}).items(): elementrefs = LinearReferences( max_num_elements=len(state_dict["element_references"]) - 1 - ) + ).to(map_location) mkeys = elementrefs.load_state_dict(state_dict) self.elementrefs[key] = elementrefs assert len(mkeys.missing_keys) == 0