Skip to content

Commit

Permalink
Fix job resume on non-preemption type failures (#803)
Browse files Browse the repository at this point in the history
* simplest way to fix resume on node failure

* lint

* add slurm util [y

* add slurm.py

* fix typo

* add print

* only modify pickle on master

* fix comment

* use submit to get pickle path

* str paths

* typo

* typo

* timestamp_id was already in cmd
  • Loading branch information
rayg1234 authored Aug 13, 2024
1 parent 6cff5a5 commit 8143ccb
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def checkpoint(self, *args, **kwargs):
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
logging.info(f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}')
return DelayedSubmission(new_runner, self.config)


Expand Down
26 changes: 26 additions & 0 deletions src/fairchem/core/common/slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import logging
import pickle

from submitit.core.utils import JobPaths


def add_timestamp_id_to_submission_pickle(slurm_folder: str, slurm_job_id: str, timestamp_id: str):
# Try to put the timestamp-id into the original submission pickle's config
# so that if the node crashes, it can be pick up the correct run to resume
#
# we need to do this after the job has started because the timestamp-id is generated at runtime
# instead a-priori before the submission starts (ie: if we had a db to store a global job unique job)
submission_pickle_path = JobPaths(folder=slurm_folder, job_id=slurm_job_id).submitted_pickle
try:
with open(str(submission_pickle_path), "rb") as f:
pkl = pickle.load(f)
# args passed to the runner function (0 is the input_dict): https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/core/_cli.py#L36
pkl.args[0]["timestamp_id"] = timestamp_id
with open(str(submission_pickle_path), "wb") as f:
pickle.dump(pkl, f)
except Exception as e:
# Since this only affects the ability to resume jobs, if the pickle doesn't exist
# or the format changed, throw a warning instead of failing the job here
logging.warn(f"Couldn't modify the submission pickle with error: {e}")
18 changes: 15 additions & 3 deletions src/fairchem/core/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,26 @@ def __init__(self, config) -> None:

def setup(self, trainer) -> None:
self.trainer = trainer
if self.config["checkpoint"] is not None:
self.trainer.load_checkpoint(checkpoint_path=self.config["checkpoint"])

# save checkpoint path to runner state for slurm resubmissions
# TODO: make checkpoint.pt a constant so we don't pass this string around everywhere
self.chkpt_path = os.path.join(
self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt"
)

# if the supplied checkpoint exists, then load that, ie: when user specifies the --checkpoint option
# OR if the a job was preempted correctly and the submitit checkpoint function was called
# (https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/core/_cli.py#L44), then we should attempt to
# load that checkpoint
if self.config["checkpoint"] is not None:
logging.info(f"Attemping to load user specified checkpoint at {self.config['checkpoint']}")
self.trainer.load_checkpoint(checkpoint_path=self.config["checkpoint"])
# if the supplied checkpoint doesn't exist and there exists a previous checkpoint in the checkpoint path, this
# means that the previous job didn't terminate "nicely" (due to node failures, crashes etc), then attempt
# to load the last found checkpoint
elif os.path.exists(self.chkpt_path):
logging.info(f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint")
self.trainer.load_checkpoint(checkpoint_path=self.chkpt_path)

def run(self):
raise NotImplementedError

Expand Down
5 changes: 5 additions & 0 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from fairchem.core.common import distutils, gp_utils
from fairchem.core.common.data_parallel import BalancedBatchSampler, OCPCollater
from fairchem.core.common.registry import registry
from fairchem.core.common.slurm import (
add_timestamp_id_to_submission_pickle,
)
from fairchem.core.common.typing import assert_is_instance as aii
from fairchem.core.common.typing import none_throws
from fairchem.core.common.utils import (
Expand Down Expand Up @@ -155,6 +158,8 @@ def __init__(
self.config["slurm"]["folder"] = self.config["slurm"]["folder"].replace(
"%j", self.config["slurm"]["job_id"]
)
if distutils.is_master():
add_timestamp_id_to_submission_pickle(self.config["slurm"]["folder"], self.config["slurm"]["job_id"], self.timestamp_id)

# Define datasets
if isinstance(dataset, list):
Expand Down

0 comments on commit 8143ccb

Please sign in to comment.