diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index c4820f1b..232c5d26 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None: wrapped(self, *args, **kwargs) self._called -= 1 + # Forward wrapped function name to the decorator to show the correct name in the ``call_with_super_check`` + wrapper.__name__ = wrapped.__name__ return wrapper diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py new file mode 100644 index 00000000..3a342321 --- /dev/null +++ b/src/plumpy/event_helper.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +import logging +from typing import TYPE_CHECKING, Any, Callable + +from . import persistence + +if TYPE_CHECKING: + from typing import Set, Type + + from .process_listener import ProcessListener # pylint: disable=cyclic-import + +_LOGGER = logging.getLogger(__name__) + + +@persistence.auto_persist('_listeners', '_listener_type') +class EventHelper(persistence.Savable): + + def __init__(self, listener_type: 'Type[ProcessListener]'): + assert listener_type is not None, 'Must provide valid listener type' + + self._listener_type = listener_type + self._listeners: 'Set[ProcessListener]' = set() + + def add_listener(self, listener: 'ProcessListener') -> None: + assert isinstance(listener, self._listener_type), 'Listener is not of right type' + self._listeners.add(listener) + + def remove_listener(self, listener: 'ProcessListener') -> None: + self._listeners.discard(listener) + + def remove_all_listeners(self) -> None: + self._listeners.clear() + + @property + def listeners(self) -> 'Set[ProcessListener]': + return self._listeners + + def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + """Call an event method on all listeners. + + :param event_function: the method of the ProcessListener + :param args: arguments to pass to the method + :param kwargs: keyword arguments to pass to the method + + """ + if event_function is None: + raise ValueError('Must provide valid event method') + + # Make a copy of the list for iteration just in case it changes in a callback + for listener in list(self.listeners): + try: + getattr(listener, event_function.__name__)(*args, **kwargs) + except Exception as exception: # pylint: disable=broad-except + _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index c0a49d9a..110394a2 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, Optional + +from . import persistence +from .utils import SAVED_STATE_TYPE, protected __all__ = ['ProcessListener'] @@ -8,7 +11,26 @@ from .processes import Process # pylint: disable=cyclic-import -class ProcessListener(metaclass=abc.ABCMeta): +@persistence.auto_persist('_params') +class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): + + # region Persistence methods + + def __init__(self) -> None: + super().__init__() + self._params: Dict[str, Any] = {} + + def init(self, **kwargs: Any) -> None: + self._params = kwargs + + @protected + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] + ) -> None: + super().load_instance_state(saved_state, load_context) + self.init(**saved_state['_params']) + + # endregion def on_process_created(self, process: 'Process') -> None: """ diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index d005450d..5cb64e1b 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -42,6 +42,7 @@ from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check +from .event_helper import EventHelper from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return func_wrapper -@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status') +@persistence.auto_persist( + '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' +) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -289,7 +292,7 @@ def __init__( # Runtime variables self._future = persistence.SavableFuture(loop=self._loop) - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = logger self._communicator = communicator @@ -612,7 +615,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi # Runtime variables, set initial states self._future = persistence.SavableFuture() - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = None self._communicator = None @@ -661,11 +664,11 @@ def add_process_listener(self, listener: ProcessListener) -> None: """ assert (listener != self), 'Cannot listen to yourself!' # type: ignore - self.__event_helper.add_listener(listener) + self._event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: """Remove a process listener from the process.""" - self.__event_helper.remove_listener(listener) + self._event_helper.remove_listener(listener) @protected def set_logger(self, logger: logging.Logger) -> None: @@ -778,7 +781,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) + self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -891,7 +894,7 @@ def on_close(self) -> None: self._closed = True def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - self.__event_helper.fire_event(evt, self, *args, **kwargs) + self._event_helper.fire_event(evt, self, *args, **kwargs) # endregion @@ -1315,7 +1318,7 @@ def out(self, output_port: str, value: Any) -> None: self.on_output_emitted(output_port, value, dynamic) @protected - def encode_input_args(self, inputs: Any) -> Any: + def encode_input_args(self, inputs: Any) -> Any: # pylint: disable=no-self-use """ Encode input arguments such that they may be saved in a :class:`plumpy.persistence.Bundle`. The encoded inputs should contain no reference to the inputs that were passed in. @@ -1327,7 +1330,7 @@ def encode_input_args(self, inputs: Any) -> Any: return copy.deepcopy(inputs) @protected - def decode_input_args(self, encoded: Any) -> Any: + def decode_input_args(self, encoded: Any) -> Any: # pylint: disable=no-self-use """ Decode saved input arguments as they came from the saved instance state :class:`plumpy.persistence.Bundle`. The decoded inputs should contain no reference to the encoded inputs that were passed in. diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 0ba2b910..87ef56f8 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -27,47 +27,6 @@ PID_TYPE = Hashable # pylint: disable=invalid-name -class EventHelper: - - def __init__(self, listener_type: 'Type[ProcessListener]'): - assert listener_type is not None, 'Must provide valid listener type' - - self._listener_type = listener_type - self._listeners: 'Set[ProcessListener]' = set() - - def add_listener(self, listener: 'ProcessListener') -> None: - assert isinstance(listener, self._listener_type), 'Listener is not of right type' - self._listeners.add(listener) - - def remove_listener(self, listener: 'ProcessListener') -> None: - self._listeners.discard(listener) - - def remove_all_listeners(self) -> None: - self._listeners.clear() - - @property - def listeners(self) -> 'Set[ProcessListener]': - return self._listeners - - def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - """Call an event method on all listeners. - - :param event_function: the method of the ProcessListener - :param args: arguments to pass to the method - :param kwargs: keyword arguments to pass to the method - - """ - if event_function is None: - raise ValueError('Must provide valid event method') - - # Make a copy of the list for iteration just in case it changes in a callback - for listener in list(self.listeners): - try: - getattr(listener, event_function.__name__)(*args, **kwargs) - except Exception as exception: # pylint: disable=broad-except - _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) - - class Frozendict(Mapping): """ An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping` @@ -171,7 +130,7 @@ def load_function(name: str, instance: Optional[Any] = None) -> Callable[..., An obj = load_object(name) if inspect.ismethod(obj): if instance is not None: - return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] # pylint: disable=unnecessary-dunder-call + return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] return obj diff --git a/test/test_processes.py b/test/test_processes.py index 737b463d..158cce66 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -800,7 +800,8 @@ def test_instance_state_with_outputs(self): # Check that it is a copy self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {})) # Check the contents are the same - self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {})) + # Remove the ``ProcessSaver`` instance that is only used for testing + utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'}) self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {})) @@ -875,7 +876,7 @@ def _check_round_trip(self, proc1): bundle2 = plumpy.Bundle(proc2) self.assertEqual(proc1.pid, proc2.pid) - self.assertDictEqual(bundle1, bundle2) + utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'}) class TestProcessNamespace(unittest.TestCase): diff --git a/test/test_workchains.py b/test/test_workchains.py index 71cd0f6a..c698aff9 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -283,6 +283,46 @@ def test_checkpointing(self): if step not in ['isA', 's2', 'isB', 's3']: self.assertTrue(finished, f'Step {step} was not called by workflow') + def test_listener_persistence(self): + persister = plumpy.InMemoryPersister() + process_finished_count = 0 + + class TestListener(plumpy.ProcessListener): + + def on_process_finished(self, process, output): + nonlocal process_finished_count + process_finished_count += 1 + + class SimpleWorkChain(plumpy.WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline( + cls.step1, + cls.step2, + ) + + def step1(self): + persister.save_checkpoint(self, 'step1') + + def step2(self): + persister.save_checkpoint(self, 'step2') + + # add SimpleWorkChain and TestListener to this module global namespace, so they can be reloaded from checkpoint + globals()['SimpleWorkChain'] = SimpleWorkChain + globals()['TestListener'] = TestListener + + workchain = SimpleWorkChain() + workchain.add_process_listener(TestListener()) + output = workchain.execute() + + self.assertEqual(process_finished_count, 1) + + workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle() + workchain_checkpoint.execute() + self.assertEqual(process_finished_count, 2) + def test_return_in_outline(self): class WcWithReturn(WorkChain): diff --git a/test/utils.py b/test/utils.py index feb3d1c8..1f7408f6 100644 --- a/test/utils.py +++ b/test/utils.py @@ -185,7 +185,7 @@ def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) - def middle_step(self,): + def middle_step(self): return process_states.Continue(self.last_step) def last_step(self): @@ -260,25 +260,72 @@ def _save(self, p): self.outputs.append(p.outputs.copy()) -class ProcessSaver(plumpy.ProcessListener, Saver): +_ProcessSaverProcReferences = {} +_ProcessSaver_Saver = {} + + +class ProcessSaver(plumpy.ProcessListener): """ - Save the instance state of a process each time it is about to enter a new state + Save the instance state of a process each time it is about to enter a new state. + NB: this is not a general purpose saver, it is only intended to be used for testing + The listener instances inside a process are persisted, so if we store a process + reference in the ProcessSaver instance, we will have a circular reference that cannot be + persisted. So we store the Saver instance in a global dictionary with the key the id of the + ProcessSaver instance. + In the init_not_persistent method we initialize the instances that cannot be persisted, + like the saver instance. The __del__ method is used to clean up the global dictionaries + (note there is no guarantee that __del__ will be called) + """ + def __del__(self): + global _ProcessSaver_Saver + global _ProcessSaverProcReferences + if _ProcessSaverProcReferences is not None and id(self) in _ProcessSaverProcReferences: + del _ProcessSaverProcReferences[id(self)] + if _ProcessSaver_Saver is not None and id(self) in _ProcessSaver_Saver: + del _ProcessSaver_Saver[id(self)] + + def get_process(self): + global _ProcessSaverProcReferences + return _ProcessSaverProcReferences[id(self)] + + def _save(self, p): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)]._save(p) + + def set_process(self, process): + global _ProcessSaverProcReferences + _ProcessSaverProcReferences[id(self)] = process + def __init__(self, proc): - plumpy.ProcessListener.__init__(self) - Saver.__init__(self) - self.process = proc + super().__init__() proc.add_process_listener(self) + self.init_not_persistent(proc) + + def init_not_persistent(self, proc): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)] = Saver() + self.set_process(proc) def capture(self): - self._save(self.process) - if not self.process.has_terminated(): + self._save(self.get_process()) + if not self.get_process().has_terminated(): try: - self.process.execute() + self.get_process().execute() except Exception: pass + @property + def snapshots(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].snapshots + + @property + def outputs(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].outputs + @utils.override def on_process_running(self, process): self._save(process) @@ -335,7 +382,13 @@ def check_process_against_snapshots(loop, proc_class, snapshots): """ for i, bundle in zip(list(range(0, len(snapshots))), snapshots): loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop)) - saver = ProcessSaver(loaded) + # the process listeners are persisted + saver = list(loaded._event_helper._listeners)[0] + assert isinstance(saver, ProcessSaver) + # the process reference inside this particular implementation of process listener + # cannot be persisted because of a circular reference. So we load it there + # also the saver is not persisted for the same reason. We load it manually + saver.init_not_persistent(loaded) saver.capture() # Now check going backwards until running that the saved states match @@ -345,7 +398,11 @@ def check_process_against_snapshots(loop, proc_class, snapshots): break compare_dictionaries( - snapshots[-j], saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], exclude={'exception'} + snapshots[-j], + saver.snapshots[-j], + snapshots[-j], + saver.snapshots[-j], + exclude={'exception', '_listeners'} ) j += 1 @@ -376,6 +433,11 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): elif isinstance(v1, list) and isinstance(v2, list): for vv1, vv2 in zip(v1, v2): compare_value(bundle1, bundle2, vv1, vv2, exclude) + elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1: + # TODO: implement sets with more than one element + compare_value(bundle1, bundle2, list(v1), list(v2), exclude) + elif isinstance(v1, set) and isinstance(v2, set): + raise NotImplementedError('Comparison between sets not implemented') else: if v1 != v2: raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}')