From 8143ccbcdb7c3d348fae2b1e6e39cc666d2dc366 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:35:20 -0700 Subject: [PATCH] Fix job resume on non-preemption type failures (#803) * simplest way to fix resume on node failure * lint * add slurm util [y * add slurm.py * fix typo * add print * only modify pickle on master * fix comment * use submit to get pickle path * str paths * typo * typo * timestamp_id was already in cmd --- src/fairchem/core/_cli.py | 1 + src/fairchem/core/common/slurm.py | 26 ++++++++++++++++++++++ src/fairchem/core/tasks/task.py | 18 ++++++++++++--- src/fairchem/core/trainers/base_trainer.py | 5 +++++ 4 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 src/fairchem/core/common/slurm.py diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index 69204a8ced..f1270496a9 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -48,6 +48,7 @@ def checkpoint(self, *args, **kwargs): self.config["timestamp_id"] = self.trainer.timestamp_id if self.trainer.logger is not None: self.trainer.logger.mark_preempting() + logging.info(f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}') return DelayedSubmission(new_runner, self.config) diff --git a/src/fairchem/core/common/slurm.py b/src/fairchem/core/common/slurm.py new file mode 100644 index 0000000000..37023de1bc --- /dev/null +++ b/src/fairchem/core/common/slurm.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import logging +import pickle + +from submitit.core.utils import JobPaths + + +def add_timestamp_id_to_submission_pickle(slurm_folder: str, slurm_job_id: str, timestamp_id: str): + # Try to put the timestamp-id into the original submission pickle's config + # so that if the node crashes, it can be pick up the correct run to resume + # + # we need to do this after the job has started because the timestamp-id is generated at runtime + # instead a-priori before the submission starts (ie: if we had a db to store a global job unique job) + submission_pickle_path = JobPaths(folder=slurm_folder, job_id=slurm_job_id).submitted_pickle + try: + with open(str(submission_pickle_path), "rb") as f: + pkl = pickle.load(f) + # args passed to the runner function (0 is the input_dict): https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/core/_cli.py#L36 + pkl.args[0]["timestamp_id"] = timestamp_id + with open(str(submission_pickle_path), "wb") as f: + pickle.dump(pkl, f) + except Exception as e: + # Since this only affects the ability to resume jobs, if the pickle doesn't exist + # or the format changed, throw a warning instead of failing the job here + logging.warn(f"Couldn't modify the submission pickle with error: {e}") diff --git a/src/fairchem/core/tasks/task.py b/src/fairchem/core/tasks/task.py index b74c4d88c8..c10da91239 100644 --- a/src/fairchem/core/tasks/task.py +++ b/src/fairchem/core/tasks/task.py @@ -20,14 +20,26 @@ def __init__(self, config) -> None: def setup(self, trainer) -> None: self.trainer = trainer - if self.config["checkpoint"] is not None: - self.trainer.load_checkpoint(checkpoint_path=self.config["checkpoint"]) - # save checkpoint path to runner state for slurm resubmissions + # TODO: make checkpoint.pt a constant so we don't pass this string around everywhere self.chkpt_path = os.path.join( self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt" ) + # if the supplied checkpoint exists, then load that, ie: when user specifies the --checkpoint option + # OR if the a job was preempted correctly and the submitit checkpoint function was called + # (https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/core/_cli.py#L44), then we should attempt to + # load that checkpoint + if self.config["checkpoint"] is not None: + logging.info(f"Attemping to load user specified checkpoint at {self.config['checkpoint']}") + self.trainer.load_checkpoint(checkpoint_path=self.config["checkpoint"]) + # if the supplied checkpoint doesn't exist and there exists a previous checkpoint in the checkpoint path, this + # means that the previous job didn't terminate "nicely" (due to node failures, crashes etc), then attempt + # to load the last found checkpoint + elif os.path.exists(self.chkpt_path): + logging.info(f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint") + self.trainer.load_checkpoint(checkpoint_path=self.chkpt_path) + def run(self): raise NotImplementedError diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 40c7e65de6..4c1d4c0d63 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -30,6 +30,9 @@ from fairchem.core.common import distutils, gp_utils from fairchem.core.common.data_parallel import BalancedBatchSampler, OCPCollater from fairchem.core.common.registry import registry +from fairchem.core.common.slurm import ( + add_timestamp_id_to_submission_pickle, +) from fairchem.core.common.typing import assert_is_instance as aii from fairchem.core.common.typing import none_throws from fairchem.core.common.utils import ( @@ -155,6 +158,8 @@ def __init__( self.config["slurm"]["folder"] = self.config["slurm"]["folder"].replace( "%j", self.config["slurm"]["job_id"] ) + if distutils.is_master(): + add_timestamp_id_to_submission_pickle(self.config["slurm"]["folder"], self.config["slurm"]["job_id"], self.timestamp_id) # Define datasets if isinstance(dataset, list):