From f8b88e816a2b4c42e0300687302c9f4e1e192dcd Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:38:57 -0700 Subject: [PATCH] fix bug where trainer state not being loaded (#863) * fix bug where trainer state not being loaded * lint * fix ase --- src/fairchem/core/common/relaxation/ase_utils.py | 4 +++- src/fairchem/core/trainers/base_trainer.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/common/relaxation/ase_utils.py b/src/fairchem/core/common/relaxation/ase_utils.py index 9a7964dc7..2dacce2cb 100644 --- a/src/fairchem/core/common/relaxation/ase_utils.py +++ b/src/fairchem/core/common/relaxation/ase_utils.py @@ -219,7 +219,9 @@ def load_checkpoint( Path to trained model """ try: - self.trainer.load_checkpoint(checkpoint_path, checkpoint) + self.trainer.load_checkpoint( + checkpoint_path, checkpoint, inference_only=True + ) except NotImplementedError: logging.warning("Unable to load checkpoint!") diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 1f450bfcd..9da73f723 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -578,7 +578,7 @@ def load_checkpoint( self, checkpoint_path: str, checkpoint: dict | None = None, - inference_only: bool | None = None, + inference_only: bool = False, ) -> None: map_location = torch.device("cpu") if self.cpu else self.device if checkpoint is None: @@ -590,7 +590,6 @@ def load_checkpoint( checkpoint = torch.load(checkpoint_path, map_location=map_location) # attributes that are necessary for training and validation - inference_only = self.train_dataset is None or inference_only if inference_only is False: self.epoch = checkpoint.get("epoch", 0) self.step = checkpoint.get("step", 0) @@ -601,6 +600,10 @@ def load_checkpoint( self.optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint and checkpoint["scheduler"] is not None: self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) + else: + logging.info( + "Loading checkpoint in inference-only mode, not loading keys associated with trainer state!" + ) if "ema" in checkpoint and checkpoint["ema"] is not None: self.ema.load_state_dict(checkpoint["ema"])