Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plumpy.ProcessListener made persistent support/0.21.x #277

Merged
merged 4 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
wrapped(self, *args, **kwargs)
self._called -= 1

#the following is to show the correct name later in the call_with_super_check error message
sphuber marked this conversation as resolved.
Show resolved Hide resolved
wrapper.__name__ = wrapped.__name__
return wrapper


Expand Down
52 changes: 52 additions & 0 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable, Set, Type
sphuber marked this conversation as resolved.
Show resolved Hide resolved

from . import persistence

if TYPE_CHECKING:
sphuber marked this conversation as resolved.
Show resolved Hide resolved
from .process_listener import ProcessListener # pylint: disable=cyclic-import

_LOGGER = logging.getLogger(__name__)


@persistence.auto_persist('_listeners', '_listener_type')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.

:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method

"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)
26 changes: 24 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
# -*- coding: utf-8 -*-
import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional

from . import persistence
from .utils import SAVED_STATE_TYPE, protected

__all__ = ['ProcessListener']

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import


class ProcessListener(metaclass=abc.ABCMeta):
@persistence.auto_persist('_params')
class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta):

# region Persistence methods

def __init__(self) -> None:
super().__init__()
self._params: Dict[str, Any] = {}

def init(self, **kwargs: Any) -> None:
self._params = kwargs

@protected
def load_instance_state(
self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext]
) -> None:
super().load_instance_state(saved_state, load_context)
self.init(**saved_state['_params'])

# endregion

def on_process_created(self, process: 'Process') -> None:
"""
Expand Down
17 changes: 10 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func_wrapper


@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status')
@persistence.auto_persist(
'_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper'
)
class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta):
"""
The Process class is the base for any unit of work in plumpy.
Expand Down Expand Up @@ -289,7 +292,7 @@ def __init__(

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator

Expand Down Expand Up @@ -612,7 +615,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

Expand Down Expand Up @@ -661,11 +664,11 @@ def add_process_listener(self, listener: ProcessListener) -> None:

"""
assert (listener != self), 'Cannot listen to yourself!'
self.__event_helper.add_listener(listener)
self._event_helper.add_listener(listener)

def remove_process_listener(self, listener: ProcessListener) -> None:
"""Remove a process listener from the process."""
self.__event_helper.remove_listener(listener)
self._event_helper.remove_listener(listener)

@protected
def set_logger(self, logger: logging.Logger) -> None:
Expand Down Expand Up @@ -778,7 +781,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None:
"""Output is about to be emitted."""

def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None:
self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)

@super_check
def on_wait(self, awaitables: Sequence[Awaitable]) -> None:
Expand Down Expand Up @@ -891,7 +894,7 @@ def on_close(self) -> None:
self._closed = True

def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
self.__event_helper.fire_event(evt, self, *args, **kwargs)
self._event_helper.fire_event(evt, self, *args, **kwargs)

# endregion

Expand Down
41 changes: 0 additions & 41 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@
PID_TYPE = Hashable # pylint: disable=invalid-name


class EventHelper:

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.

:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method

"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)


class Frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
Expand Down
7 changes: 5 additions & 2 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ def test_instance_state_with_outputs(self):
# Check that it is a copy
self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Check the contents are the same
self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
#we remove the ProcessSaver instance that is an object used only for testing
sphuber marked this conversation as resolved.
Show resolved Hide resolved
utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'})
#self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
sphuber marked this conversation as resolved.
Show resolved Hide resolved

self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {}))

Expand Down Expand Up @@ -875,7 +877,8 @@ def _check_round_trip(self, proc1):
bundle2 = plumpy.Bundle(proc2)

self.assertEqual(proc1.pid, proc2.pid)
self.assertDictEqual(bundle1, bundle2)
#self.assertDictEqual(bundle1, bundle2)
sphuber marked this conversation as resolved.
Show resolved Hide resolved
utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'})


class TestProcessNamespace(unittest.TestCase):
Expand Down
43 changes: 43 additions & 0 deletions test/test_workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,49 @@ def test_checkpointing(self):
if step not in ['isA', 's2', 'isB', 's3']:
self.assertTrue(finished, f'Step {step} was not called by workflow')

def test_listener_persistence(self):
persister = plumpy.InMemoryPersister()
process_finished_count = 0

class TestListener(plumpy.ProcessListener):

def on_process_finished(self, process, output):
nonlocal process_finished_count
process_finished_count += 1

class SimpleWorkChain(plumpy.WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(
cls.step1,
cls.step2,
)

def step1(self):
print('step1')
sphuber marked this conversation as resolved.
Show resolved Hide resolved
persister.save_checkpoint(self, 'step1')

def step2(self):
print('step2')
sphuber marked this conversation as resolved.
Show resolved Hide resolved
persister.save_checkpoint(self, 'step2')

# add SimpleWorkChain and TestListener to this module global namespace, so they can be reloaded from checkpoint
globals()['SimpleWorkChain'] = SimpleWorkChain
globals()['TestListener'] = TestListener

workchain = SimpleWorkChain()
workchain.add_process_listener(TestListener())
output = workchain.execute()

self.assertEqual(process_finished_count, 1)

print('reload persister checkpoint:')
sphuber marked this conversation as resolved.
Show resolved Hide resolved
workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle()
workchain_checkpoint.execute()
self.assertEqual(process_finished_count, 2)

def test_return_in_outline(self):

class WcWithReturn(WorkChain):
Expand Down
Loading
Loading