Skip to content

Commit

Permalink
Make ProcessListener instances persistable (#277)
Browse files Browse the repository at this point in the history
The `ProcessListener` is made persistable by deriving it, as well as the
`EventHelper` class from `persistence.Savable`. The class `EventHelper`
is moved to a new file because of a circular import that would result
between the `utils` and `persistence` modules.

There was a circular reference issue in the test listener that was
storing a reference to the process inside it, making its serialization
impossible. To fix the tests an ugly hack was used: storing the reference
to the process outside the class in a global dict using id as keys. Some
more ugly hacks were needed to correctly check the equality of two
processes. Instances having different listeners should be ignored.

Cherry-pick: 98a375f
  • Loading branch information
rikigigi authored and sphuber committed Nov 13, 2023
1 parent b447fb0 commit 430d63c
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 63 deletions.
2 changes: 2 additions & 0 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
wrapped(self, *args, **kwargs)
self._called -= 1

# Forward wrapped function name to the decorator to show the correct name in the ``call_with_super_check``
wrapper.__name__ = wrapped.__name__
return wrapper


Expand Down
54 changes: 54 additions & 0 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable

from . import persistence

if TYPE_CHECKING:
from typing import Set, Type

from .process_listener import ProcessListener # pylint: disable=cyclic-import

_LOGGER = logging.getLogger(__name__)


@persistence.auto_persist('_listeners', '_listener_type')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)
26 changes: 24 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
# -*- coding: utf-8 -*-
import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional

from . import persistence
from .utils import SAVED_STATE_TYPE, protected

__all__ = ['ProcessListener']

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import


class ProcessListener(metaclass=abc.ABCMeta):
@persistence.auto_persist('_params')
class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta):

# region Persistence methods

def __init__(self) -> None:
super().__init__()
self._params: Dict[str, Any] = {}

def init(self, **kwargs: Any) -> None:
self._params = kwargs

@protected
def load_instance_state(
self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext]
) -> None:
super().load_instance_state(saved_state, load_context)
self.init(**saved_state['_params'])

# endregion

def on_process_created(self, process: 'Process') -> None:
"""
Expand Down
17 changes: 10 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func_wrapper


@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status')
@persistence.auto_persist(
'_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper'
)
class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta):
"""
The Process class is the base for any unit of work in plumpy.
Expand Down Expand Up @@ -289,7 +292,7 @@ def __init__(

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator

Expand Down Expand Up @@ -612,7 +615,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

Expand Down Expand Up @@ -661,11 +664,11 @@ def add_process_listener(self, listener: ProcessListener) -> None:
"""
assert (listener != self), 'Cannot listen to yourself!' # type: ignore
self.__event_helper.add_listener(listener)
self._event_helper.add_listener(listener)

def remove_process_listener(self, listener: ProcessListener) -> None:
"""Remove a process listener from the process."""
self.__event_helper.remove_listener(listener)
self._event_helper.remove_listener(listener)

@protected
def set_logger(self, logger: logging.Logger) -> None:
Expand Down Expand Up @@ -778,7 +781,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None:
"""Output is about to be emitted."""

def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None:
self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)

@super_check
def on_wait(self, awaitables: Sequence[Awaitable]) -> None:
Expand Down Expand Up @@ -891,7 +894,7 @@ def on_close(self) -> None:
self._closed = True

def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
self.__event_helper.fire_event(evt, self, *args, **kwargs)
self._event_helper.fire_event(evt, self, *args, **kwargs)

# endregion

Expand Down
41 changes: 0 additions & 41 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@
PID_TYPE = Hashable # pylint: disable=invalid-name


class EventHelper:

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)


class Frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
Expand Down
5 changes: 3 additions & 2 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ def test_instance_state_with_outputs(self):
# Check that it is a copy
self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Check the contents are the same
self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Remove the ``ProcessSaver`` instance that is only used for testing
utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'})

self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {}))

Expand Down Expand Up @@ -875,7 +876,7 @@ def _check_round_trip(self, proc1):
bundle2 = plumpy.Bundle(proc2)

self.assertEqual(proc1.pid, proc2.pid)
self.assertDictEqual(bundle1, bundle2)
utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'})


class TestProcessNamespace(unittest.TestCase):
Expand Down
40 changes: 40 additions & 0 deletions test/test_workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,46 @@ def test_checkpointing(self):
if step not in ['isA', 's2', 'isB', 's3']:
self.assertTrue(finished, f'Step {step} was not called by workflow')

def test_listener_persistence(self):
persister = plumpy.InMemoryPersister()
process_finished_count = 0

class TestListener(plumpy.ProcessListener):

def on_process_finished(self, process, output):
nonlocal process_finished_count
process_finished_count += 1

class SimpleWorkChain(plumpy.WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(
cls.step1,
cls.step2,
)

def step1(self):
persister.save_checkpoint(self, 'step1')

def step2(self):
persister.save_checkpoint(self, 'step2')

# add SimpleWorkChain and TestListener to this module global namespace, so they can be reloaded from checkpoint
globals()['SimpleWorkChain'] = SimpleWorkChain
globals()['TestListener'] = TestListener

workchain = SimpleWorkChain()
workchain.add_process_listener(TestListener())
output = workchain.execute()

self.assertEqual(process_finished_count, 1)

workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle()
workchain_checkpoint.execute()
self.assertEqual(process_finished_count, 2)

def test_return_in_outline(self):

class WcWithReturn(WorkChain):
Expand Down
Loading

0 comments on commit 430d63c

Please sign in to comment.