Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

demo of what checkpointing plugins might look like #3535

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
3 changes: 1 addition & 2 deletions docs/userguide/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ of the ``slow_double`` app.
# Wait for the results
[i.result() for i in d]

cpt_dir = dfk.checkpoint()
print(cpt_dir) # Prints the checkpoint dir
dfk.checkpoint()


Resuming from a checkpoint
Expand Down
3 changes: 3 additions & 0 deletions parsl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing_extensions import Literal

from parsl.dataflow.dependency_resolvers import DependencyResolver
from parsl.dataflow.memoization import Memoizer
from parsl.dataflow.taskrecord import TaskRecord
from parsl.errors import ConfigurationError
from parsl.executors.base import ParslExecutor
Expand Down Expand Up @@ -98,6 +99,7 @@ class Config(RepresentationMixin, UsageInformation):
def __init__(self,
executors: Optional[Iterable[ParslExecutor]] = None,
app_cache: bool = True,
memoizer: Optional[Memoizer] = None,
checkpoint_files: Optional[Sequence[str]] = None,
checkpoint_mode: Union[None,
Literal['task_exit'],
Expand Down Expand Up @@ -127,6 +129,7 @@ def __init__(self,
self._executors: Sequence[ParslExecutor] = executors
self._validate_executors()

self.memoizer = memoizer
self.app_cache = app_cache
self.checkpoint_files = checkpoint_files
self.checkpoint_mode = checkpoint_mode
Expand Down
190 changes: 39 additions & 151 deletions parsl/dataflow/dflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import os
import pathlib
import pickle
import random
import sys
import threading
Expand All @@ -30,9 +29,9 @@
from parsl.data_provider.data_manager import DataManager
from parsl.data_provider.files import File
from parsl.dataflow.dependency_resolvers import SHALLOW_DEPENDENCY_RESOLVER
from parsl.dataflow.errors import BadCheckpoint, DependencyError, JoinError
from parsl.dataflow.errors import DependencyError, JoinError
from parsl.dataflow.futures import AppFuture
from parsl.dataflow.memoization import Memoizer
from parsl.dataflow.memoization import BasicMemoizer, Memoizer
from parsl.dataflow.rundirs import make_rundir
from parsl.dataflow.states import FINAL_FAILURE_STATES, FINAL_STATES, States
from parsl.dataflow.taskrecord import TaskRecord
Expand Down Expand Up @@ -99,8 +98,6 @@ def __init__(self, config: Config) -> None:

logger.info("Parsl version: {}".format(get_version()))

self.checkpoint_lock = threading.Lock()

self.usage_tracker = UsageTracker(self)
self.usage_tracker.send_start_message()

Expand Down Expand Up @@ -165,17 +162,29 @@ def __init__(self, config: Config) -> None:
self.monitoring.send(MessageType.WORKFLOW_INFO,
workflow_info)

# TODO: this configuration should become part of the particular memoizer code
# - this is a checkpoint-implementation-specific parameter
if config.checkpoint_files is not None:
checkpoints = self.load_checkpoints(config.checkpoint_files)
checkpoint_files = config.checkpoint_files
elif config.checkpoint_files is None and config.checkpoint_mode is not None:
checkpoints = self.load_checkpoints(get_all_checkpoints(self.run_dir))
checkpoint_files = get_all_checkpoints(self.run_dir)
else:
checkpoint_files = []

# self.memoizer: Memoizer = BasicMemoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files)
# the memoize flag might turn into the user choosing different instances
# of the Memoizer interface
self.memoizer: Memoizer
if config.memoizer is not None:
self.memoizer = config.memoizer
else:
checkpoints = {}
self.memoizer = BasicMemoizer()

self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint=checkpoints)
self.checkpointed_tasks = 0
self.memoizer.start(dfk=self, memoize=config.app_cache, checkpoint_files=checkpoint_files, run_dir=self.run_dir)
self._checkpoint_timer = None
self.checkpoint_mode = config.checkpoint_mode

self._modify_checkpointable_tasks_lock = threading.Lock()
self.checkpointable_tasks: List[TaskRecord] = []

# this must be set before executors are added since add_executors calls
Expand All @@ -191,6 +200,10 @@ def __init__(self, config: Config) -> None:
self.add_executors(config.executors)
self.add_executors([parsl_internal_executor])

# TODO: these checkpoint modes should move into the memoizer implementation
# they're (probably?) checkpointer specific: for example the sqlite3-pure-memoizer
# doesn't have a notion of building up an in-memory checkpoint table that needs to be
# flushed on a separate policy
if self.checkpoint_mode == "periodic":
if config.checkpoint_period is None:
raise ConfigurationError("Checkpoint period must be specified with periodic checkpoint mode")
Expand All @@ -200,7 +213,7 @@ def __init__(self, config: Config) -> None:
except Exception:
raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(config.checkpoint_period))
checkpoint_period = (h * 3600) + (m * 60) + s
self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint")
self._checkpoint_timer = Timer(self.invoke_checkpoint, interval=checkpoint_period, name="Checkpoint")

self.task_count = 0
self.tasks: Dict[int, TaskRecord] = {}
Expand Down Expand Up @@ -569,9 +582,9 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
# Do we need to checkpoint now, or queue for later,
# or do nothing?
if self.checkpoint_mode == 'task_exit':
self.checkpoint(tasks=[task_record])
self.memoizer.checkpoint(tasks=[task_record])
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
with self.checkpoint_lock:
with self._modify_checkpointable_tasks_lock:
self.checkpointable_tasks.append(task_record)
elif self.checkpoint_mode is None:
pass
Expand Down Expand Up @@ -1247,15 +1260,23 @@ def cleanup(self) -> None:

self.log_task_states()

# TODO: do this in the basic memoizer
# Checkpointing takes priority over the rest of the tasks
# checkpoint if any valid checkpoint method is specified
if self.checkpoint_mode is not None:
self.checkpoint()

# TODO: accesses to self.checkpointable_tasks should happen
# under a lock?
self.memoizer.checkpoint(self.checkpointable_tasks)

if self._checkpoint_timer:
logger.info("Stopping checkpoint timer")
self._checkpoint_timer.close()

logger.info("Closing memoizer")
self.memoizer.close()
logger.info("Closed memoizer")

# Send final stats
logger.info("Sending end message for usage tracking")
self.usage_tracker.send_end_message()
Expand Down Expand Up @@ -1323,143 +1344,10 @@ def cleanup(self) -> None:

logger.info("DFK cleanup complete")

def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> str:
"""Checkpoint the dfk incrementally to a checkpoint file.

When called, every task that has been completed yet not
checkpointed is checkpointed to a file.

Kwargs:
- tasks (List of task records) : List of task ids to checkpoint. Default=None
if set to None, we iterate over all tasks held by the DFK.

.. note::
Checkpointing only works if memoization is enabled

Returns:
Checkpoint dir if checkpoints were written successfully.
By default the checkpoints are written to the RUNDIR of the current
run under RUNDIR/checkpoints/{tasks.pkl, dfk.pkl}
"""
with self.checkpoint_lock:
if tasks:
checkpoint_queue = tasks
else:
checkpoint_queue = self.checkpointable_tasks
self.checkpointable_tasks = []

checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
checkpoint_dfk = checkpoint_dir + '/dfk.pkl'
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'

if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)

with open(checkpoint_dfk, 'wb') as f:
state = {'rundir': self.run_dir,
'task_count': self.task_count
}
pickle.dump(state, f)

count = 0

with open(checkpoint_tasks, 'ab') as f:
for task_record in checkpoint_queue:
task_id = task_record['id']

app_fu = task_record['app_fu']

if app_fu.done() and app_fu.exception() is None:
hashsum = task_record['hashsum']
if not hashsum:
continue
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}

# We are using pickle here since pickle dumps to a file in 'ab'
# mode behave like a incremental log.
pickle.dump(t, f)
count += 1
logger.debug("Task {} checkpointed".format(task_id))

self.checkpointed_tasks += count

if count == 0:
if self.checkpointed_tasks == 0:
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
else:
logger.debug("No tasks checkpointed in this pass.")
else:
logger.info("Done checkpointing {} tasks".format(count))

return checkpoint_dir

def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
"""Load a checkpoint file into a lookup table.

The data being loaded from the pickle file mostly contains input
attributes of the task: func, args, kwargs, env...
To simplify the check of whether the exact task has been completed
in the checkpoint, we hash these input params and use it as the key
for the memoized lookup table.

Args:
- checkpointDirs (list) : List of filepaths to checkpoints
Eg. ['runinfo/001', 'runinfo/002']

Returns:
- memoized_lookup_table (dict)
"""
memo_lookup_table = {}

for checkpoint_dir in checkpointDirs:
logger.info("Loading checkpoints from {}".format(checkpoint_dir))
checkpoint_file = os.path.join(checkpoint_dir, 'tasks.pkl')
try:
with open(checkpoint_file, 'rb') as f:
while True:
try:
data = pickle.load(f)
# Copy and hash only the input attributes
memo_fu: Future = Future()
assert data['exception'] is None
memo_fu.set_result(data['result'])
memo_lookup_table[data['hash']] = memo_fu

except EOFError:
# Done with the checkpoint file
break
except FileNotFoundError:
reason = "Checkpoint file was not found: {}".format(
checkpoint_file)
logger.error(reason)
raise BadCheckpoint(reason)
except Exception:
reason = "Failed to load checkpoint: {}".format(
checkpoint_file)
logger.error(reason)
raise BadCheckpoint(reason)

logger.info("Completed loading checkpoint: {0} with {1} tasks".format(checkpoint_file,
len(memo_lookup_table.keys())))
return memo_lookup_table

@typeguard.typechecked
def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Future]:
"""Load checkpoints from the checkpoint files into a dictionary.

The results are used to pre-populate the memoizer's lookup_table

Kwargs:
- checkpointDirs (list) : List of run folder to use as checkpoints
Eg. ['runinfo/001', 'runinfo/002']

Returns:
- dict containing, hashed -> future mappings
"""
if checkpointDirs:
return self._load_checkpoints(checkpointDirs)
else:
return {}
def invoke_checkpoint(self) -> None:
with self._modify_checkpointable_tasks_lock:
self.memoizer.checkpoint(self.checkpointable_tasks)
self.checkpointable_tasks = []

@staticmethod
def _log_std_streams(task_record: TaskRecord) -> None:
Expand Down
Loading
Loading