Skip to content

Commit

Permalink
Amend
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 1, 2024
1 parent d7078fc commit 96c5842
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 280 deletions.
84 changes: 30 additions & 54 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None:
def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -72,27 +72,27 @@ def __init__(
super().__init__(self._format_msg())

def _format_msg(self) -> str:
msg = [f"{self.initial_state} -> {self.final_state}"]
msg = [f'{self.initial_state} -> {self.final_state}']
if self.traceback_str is not None:
msg.append(self.traceback_str)
return "\n".join(msg)
return '\n'.join(msg)


def event(
from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*',
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""A decorator to check for correct transitions, raising ``EventError`` on invalid transitions."""
if from_states != "*":
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
raise TypeError(f"from_states: {from_states}")
if to_states != "*":
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,)
if not all(issubclass(state, State) for state in to_states): # type: ignore
raise TypeError(f"to_states: {to_states}")
raise TypeError(f'to_states: {to_states}')

def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
evt_label = wrapped.__name__
Expand All @@ -101,20 +101,14 @@ def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
def transition(self: Any, *a: Any, **kw: Any) -> Any:
initial = self._state

if from_states != "*" and not any(
isinstance(self._state, state) for state in from_states
): # type: ignore
raise EventError(
evt_label, f"Event {evt_label} invalid in state {initial.LABEL}"
)
if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore
raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}')

result = wrapped(self, *a, **kw)
if not (result is False or isinstance(result, Future)):
if to_states != "*" and not any(
isinstance(self._state, state) for state in to_states
): # type: ignore
if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore
if self._state == initial:
raise EventError(evt_label, "Machine did not transition")
raise EventError(evt_label, 'Machine did not transition')

raise EventError(
evt_label,
Expand Down Expand Up @@ -160,7 +154,7 @@ def label(self) -> LABEL_TYPE:
def enter(self) -> None:
"""Entering the state"""

def execute(self) -> Optional["State"]:
def execute(self) -> Optional['State']:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
Expand All @@ -170,9 +164,9 @@ def execute(self) -> Optional["State"]:
def exit(self) -> None:
"""Exiting the state"""
if self.is_terminal():
raise InvalidStateError(f"Cannot exit a terminal state {self.LABEL}")
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> "State":
def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)

def do_enter(self) -> None:
Expand Down Expand Up @@ -229,7 +223,7 @@ def get_states(cls) -> Sequence[Type[State]]:
if cls.STATES is not None:
return cls.STATES

raise RuntimeError("States not defined")
raise RuntimeError('States not defined')

@classmethod
def initial_state_label(cls) -> LABEL_TYPE:
Expand All @@ -247,7 +241,7 @@ def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
def __ensure_built(cls) -> None:
try:
# Check if it's already been built (and therefore sealed)
if cls.__getattribute__(cls, "sealed"):
if cls.__getattribute__(cls, 'sealed'):
return
except AttributeError:
pass
Expand All @@ -271,9 +265,7 @@ def __init__(self) -> None:
self.__ensure_built()
self._state: Optional[State] = None
self._exception_handler = None # Note this appears to never be used
self.set_debug(
(not sys.flags.ignore_environment and bool(os.environ.get("PYTHONSMDEBUG")))
)
self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG'))))
self._transitioning = False
self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {}

Expand All @@ -282,7 +274,7 @@ def init(self) -> None:
"""Called after entering initial state in `__call__` method of `StateMachineMeta`"""

def __str__(self) -> str:
return f"<{self.__class__.__name__}> ({self.state})"
return f'<{self.__class__.__name__}> ({self.state})'

def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
Expand All @@ -293,9 +285,7 @@ def state(self) -> Optional[LABEL_TYPE]:
return None
return self._state.LABEL

def add_state_event_callback(
self, hook: Hashable, callback: EVENT_CALLBACK_TYPE
) -> None:
def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None:
"""
Add a callback to be called on a particular state event hook.
The callback should have form fn(state_machine, hook, state)
Expand All @@ -305,10 +295,8 @@ def add_state_event_callback(
"""
self._event_callbacks.setdefault(hook, []).append(callback)

def remove_state_event_callback(
self, hook: Hashable, callback: EVENT_CALLBACK_TYPE
) -> None:
if getattr(self, "_closed", False):
def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None:
if getattr(self, '_closed', False):
# if the process is closed, then all callbacks have already been removed
return None
try:
Expand All @@ -324,19 +312,15 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(
self, new_state: Union[State, Type[State]], **kwargs: Any
) -> None:
def transition_to(self, new_state: Union[State, Type[State]], **kwargs: Any) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state is not yet instantiated,
The new target state will be create lazily when the state is not yet instantiated,
which will happened for states not in the expect path such as pause and kill.
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
assert (
not self._transitioning
), "Cannot call transition_to when already transitioning state"
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

initial_state_label = self._state.LABEL if self._state is not None else None
label = None
Expand All @@ -358,9 +342,7 @@ def transition_to(
except StateEntryFailed as exception:
# Make sure we have a state instance
if not isinstance(exception.state, State):
new_state = self._create_state_instance(
exception.state, **exception.kwargs
)
new_state = self._create_state_instance(exception.state, **exception.kwargs)
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)
Expand Down Expand Up @@ -406,7 +388,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
raise ValueError(f"{state_label} is not a valid state")
raise ValueError(f'{state_label} is not a valid state')

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""
Expand All @@ -415,15 +397,11 @@ def _exit_current_state(self, next_state: State) -> None:
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
raise RuntimeError(
f"Cannot enter state '{next_state}' as the initial state"
)
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(
f"Cannot transition from {self._state.LABEL} to {next_state.label}"
)
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.do_exit()

Expand All @@ -435,10 +413,8 @@ def _enter_next_state(self, next_state: State) -> None:
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(
self, state_cls: type[State], **kwargs: Any
) -> State:
def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f"{state_cls.LABEL} is not a valid state")
raise ValueError(f'{state_cls.LABEL} is not a valid state')

Check warning on line 418 in src/plumpy/base/state_machine.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/base/state_machine.py#L418

Added line #L418 was not covered by tests

return state_cls(self, **kwargs)
4 changes: 3 additions & 1 deletion src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MESSAGE_KEY = 'message'
FORCE_KILL_KEY = 'force_kill'


class Intent:
"""Intent constants for a process message"""

Expand All @@ -41,9 +42,10 @@ class Intent:
KILL: str = 'kill'
STATUS: str = 'status'


MessageType = dict[str, Any]

PAUSE_MSG: MessageType= {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None}
PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None}
PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None}
KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False}
STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None}
Expand Down
Loading

0 comments on commit 96c5842

Please sign in to comment.