diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py new file mode 100644 index 00000000..ed54f495 --- /dev/null +++ b/src/plumpy/event_helper.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +import logging +from typing import TYPE_CHECKING, Any, Callable, Set, Type + +from . import persistence + +if TYPE_CHECKING: + from .process_listener import ProcessListener # pylint: disable=cyclic-import + +_LOGGER = logging.getLogger(__name__) + + +@persistence.auto_persist('_listeners') +class EventHelper(persistence.Savable): + + def __init__(self, listener_type: 'Type[ProcessListener]'): + assert listener_type is not None, 'Must provide valid listener type' + + self._listener_type = listener_type + self._listeners: 'Set[ProcessListener]' = set() + + def add_listener(self, listener: 'ProcessListener') -> None: + assert isinstance(listener, self._listener_type), 'Listener is not of right type' + self._listeners.add(listener) + + def remove_listener(self, listener: 'ProcessListener') -> None: + self._listeners.discard(listener) + + def remove_all_listeners(self) -> None: + self._listeners.clear() + + @property + def listeners(self) -> 'Set[ProcessListener]': + return self._listeners + + def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + """Call an event method on all listeners. + + :param event_function: the method of the ProcessListener + :param args: arguments to pass to the method + :param kwargs: keyword arguments to pass to the method + + """ + if event_function is None: + raise ValueError('Must provide valid event method') + + # Make a copy of the list for iteration just in case it changes in a callback + for listener in list(self.listeners): + try: + getattr(listener, event_function.__name__)(*args, **kwargs) + except Exception as exception: # pylint: disable=broad-except + _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index c0a49d9a..110394a2 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, Optional + +from . import persistence +from .utils import SAVED_STATE_TYPE, protected __all__ = ['ProcessListener'] @@ -8,7 +11,26 @@ from .processes import Process # pylint: disable=cyclic-import -class ProcessListener(metaclass=abc.ABCMeta): +@persistence.auto_persist('_params') +class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): + + # region Persistence methods + + def __init__(self) -> None: + super().__init__() + self._params: Dict[str, Any] = {} + + def init(self, **kwargs: Any) -> None: + self._params = kwargs + + @protected + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] + ) -> None: + super().load_instance_state(saved_state, load_context) + self.init(**saved_state['_params']) + + # endregion def on_process_created(self, process: 'Process') -> None: """ diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index d005450d..2e3d84f5 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -42,6 +42,7 @@ from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check +from .event_helper import EventHelper from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return func_wrapper -@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status') +@persistence.auto_persist( + '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' +) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -289,7 +292,7 @@ def __init__( # Runtime variables self._future = persistence.SavableFuture(loop=self._loop) - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = logger self._communicator = communicator @@ -597,6 +600,9 @@ def save_instance_state( if self.outputs: out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) + print(type(out_state)) + print(out_state) + @protected def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: """Load the process from its saved instance state. @@ -612,7 +618,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi # Runtime variables, set initial states self._future = persistence.SavableFuture() - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = None self._communicator = None @@ -621,6 +627,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self._loop = asyncio.get_event_loop() + print('saved_state') + print(saved_state) + self._state: process_states.State = self.recreate_state(saved_state['_state']) if 'communicator' in load_context: @@ -661,11 +670,11 @@ def add_process_listener(self, listener: ProcessListener) -> None: """ assert (listener != self), 'Cannot listen to yourself!' # type: ignore - self.__event_helper.add_listener(listener) + self._event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: """Remove a process listener from the process.""" - self.__event_helper.remove_listener(listener) + self._event_helper.remove_listener(listener) @protected def set_logger(self, logger: logging.Logger) -> None: @@ -778,7 +787,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) + self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -891,7 +900,7 @@ def on_close(self) -> None: self._closed = True def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - self.__event_helper.fire_event(evt, self, *args, **kwargs) + self._event_helper.fire_event(evt, self, *args, **kwargs) # endregion diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 0ba2b910..4eab8efe 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -27,47 +27,6 @@ PID_TYPE = Hashable # pylint: disable=invalid-name -class EventHelper: - - def __init__(self, listener_type: 'Type[ProcessListener]'): - assert listener_type is not None, 'Must provide valid listener type' - - self._listener_type = listener_type - self._listeners: 'Set[ProcessListener]' = set() - - def add_listener(self, listener: 'ProcessListener') -> None: - assert isinstance(listener, self._listener_type), 'Listener is not of right type' - self._listeners.add(listener) - - def remove_listener(self, listener: 'ProcessListener') -> None: - self._listeners.discard(listener) - - def remove_all_listeners(self) -> None: - self._listeners.clear() - - @property - def listeners(self) -> 'Set[ProcessListener]': - return self._listeners - - def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - """Call an event method on all listeners. - - :param event_function: the method of the ProcessListener - :param args: arguments to pass to the method - :param kwargs: keyword arguments to pass to the method - - """ - if event_function is None: - raise ValueError('Must provide valid event method') - - # Make a copy of the list for iteration just in case it changes in a callback - for listener in list(self.listeners): - try: - getattr(listener, event_function.__name__)(*args, **kwargs) - except Exception as exception: # pylint: disable=broad-except - _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) - - class Frozendict(Mapping): """ An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`