diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 8e9e3ceab..32ecd2b2d 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -1478,3 +1478,13 @@ def get_weight_table(model: torch.nn.Module) -> tuple[list, list]: row_grad = [None] * len(row_weight) data.append([param_name] + [params.shape] + row_weight + row_grad) # noqa return columns, data + + +def get_checkpoint_format(config: dict) -> str: + # a temporary function to retrieve the checkpoint format from old configs + format = config.get("optim", {}).get("checkpoint_format", "pt") + assert format in ( + "pt", + "dcp", + ), f"checkpoint format can only be pt or dcp, found {format}" + return format diff --git a/src/fairchem/core/tasks/task.py b/src/fairchem/core/tasks/task.py index 265220114..3b69e2108 100644 --- a/src/fairchem/core/tasks/task.py +++ b/src/fairchem/core/tasks/task.py @@ -11,6 +11,7 @@ import os from fairchem.core.common.registry import registry +from fairchem.core.common.utils import get_checkpoint_format from fairchem.core.trainers import OCPTrainer @@ -21,10 +22,13 @@ def __init__(self, config) -> None: def setup(self, trainer) -> None: self.trainer = trainer - # TODO: make checkpoint.pt a constant so we don't pass this string around everywhere - self.chkpt_path = os.path.join( - self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt" - ) + format = get_checkpoint_format(self.config) + if format == "pt": + self.chkpt_path = os.path.join( + self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt" + ) + else: + self.chkpt_path = self.trainer.config["cmd"]["checkpoint_dir"] # if the supplied checkpoint exists, then load that, ie: when user specifies the --checkpoint option # OR if the a job was preempted correctly and the submitit checkpoint function was called @@ -38,7 +42,10 @@ def setup(self, trainer) -> None: # if the supplied checkpoint doesn't exist and there exists a previous checkpoint in the checkpoint path, this # means that the previous job didn't terminate "nicely" (due to node failures, crashes etc), then attempt # to load the last found checkpoint - elif os.path.exists(self.chkpt_path): + elif ( + os.path.isfile(self.chkpt_path) + or (os.path.isdir(self.chkpt_path) and len(os.listdir(self.chkpt_path))) > 0 + ): logging.info( f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint" ) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 5c21a743a..edb5c800d 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -323,11 +323,14 @@ def get_sampler( seed=self.config["cmd"]["seed"], ) - def get_dataloader(self, dataset, sampler) -> DataLoader: + def get_dataloader(self, dataset, sampler, workers=None) -> DataLoader: + num_workers = ( + self.config["optim"]["num_workers"] if workers is None else workers + ) return DataLoader( dataset, collate_fn=self.collater, - num_workers=self.config["optim"]["num_workers"], + num_workers=num_workers, pin_memory=True, batch_sampler=sampler, )