Skip to content

Commit

Permalink
plumpy.ProcessListener made persistent
Browse files Browse the repository at this point in the history
solves aiidateam#273

We implement the persistence of ProcessListener by deriving the class
ProcessListener and EventHelper from persistence.Savable.
The class EventHelper is moved to a new file because of a circular
import with utils and persistence
  • Loading branch information
rikigigi committed Aug 29, 2023
1 parent 44d27d1 commit 9b738f8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 50 deletions.
50 changes: 50 additions & 0 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, List, Set, MutableMapping, Optional, Tuple, Type
if TYPE_CHECKING:
from .process_listener import ProcessListener # pylint: disable=cyclic-import


import logging
_LOGGER = logging.getLogger(__name__)

from . import persistence

@persistence.auto_persist('_listeners')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)
29 changes: 27 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
# -*- coding: utf-8 -*-
import abc
from typing import TYPE_CHECKING, Any
from . import persistence
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected

from typing import TYPE_CHECKING, Any, Optional

__all__ = ['ProcessListener']

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import


class ProcessListener(metaclass=abc.ABCMeta):
@persistence.auto_persist('_params')
class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta):

# region Persistence methods

def __init__(self):
super().__init__()
self._params = {}

@abc.abstractmethod
def init(self, **kwargs):
self._params = kwargs

def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext]) -> None:
super().save_instance_state(out_state, save_context)

@protected
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext]) -> None:
super().load_instance_state(saved_state, load_context)
self.init(**saved_state['_params'])


# endregion

def on_process_created(self, process: 'Process') -> None:
"""
Expand Down
21 changes: 14 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
from .event_helper import EventHelper

# pylint: disable=too-many-lines

Expand Down Expand Up @@ -91,7 +92,7 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func_wrapper


@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status')
@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper')
class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta):
"""
The Process class is the base for any unit of work in plumpy.
Expand Down Expand Up @@ -289,7 +290,7 @@ def __init__(

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator

Expand Down Expand Up @@ -597,6 +598,9 @@ def save_instance_state(
if self.outputs:
out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs)

print(type(out_state))
print(out_state)

@protected
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
"""Load the process from its saved instance state.
Expand All @@ -612,7 +616,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

Expand All @@ -621,6 +625,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
else:
self._loop = asyncio.get_event_loop()

print('saved_state')
print(saved_state)

self._state: process_states.State = self.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
Expand Down Expand Up @@ -661,11 +668,11 @@ def add_process_listener(self, listener: ProcessListener) -> None:
"""
assert (listener != self), 'Cannot listen to yourself!' # type: ignore
self.__event_helper.add_listener(listener)
self._event_helper.add_listener(listener)

def remove_process_listener(self, listener: ProcessListener) -> None:
"""Remove a process listener from the process."""
self.__event_helper.remove_listener(listener)
self._event_helper.remove_listener(listener)

@protected
def set_logger(self, logger: logging.Logger) -> None:
Expand Down Expand Up @@ -778,7 +785,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None:
"""Output is about to be emitted."""

def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None:
self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)

@super_check
def on_wait(self, awaitables: Sequence[Awaitable]) -> None:
Expand Down Expand Up @@ -891,7 +898,7 @@ def on_close(self) -> None:
self._closed = True

def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
self.__event_helper.fire_event(evt, self, *args, **kwargs)
self._event_helper.fire_event(evt, self, *args, **kwargs)

# endregion

Expand Down
41 changes: 0 additions & 41 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@
PID_TYPE = Hashable # pylint: disable=invalid-name


class EventHelper:

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)


class Frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
Expand Down

0 comments on commit 9b738f8

Please sign in to comment.