Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Make Process.run async #272

Open
wants to merge 1 commit into
base: support/0.21.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/source/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"source": [
"class SimpleProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(self.state.name)\n",
" \n",
"process = SimpleProcess()\n",
Expand Down Expand Up @@ -219,7 +219,7 @@
" spec.output('output2.output2a')\n",
" spec.output('output2.output2b')\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" self.out('output1', self.inputs.input1)\n",
" self.out('output2.output2a', self.inputs.input2.input2a)\n",
" self.out('output2.output2b', self.inputs.input2.input2b)\n",
Expand Down Expand Up @@ -277,7 +277,7 @@
"source": [
"class ContinueProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(\"running\")\n",
" return plumpy.Continue(self.continue_fn)\n",
" \n",
Expand Down Expand Up @@ -340,7 +340,7 @@
"\n",
"class WaitProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" return plumpy.Wait(self.resume_fn)\n",
" \n",
" def resume_fn(self):\n",
Expand Down Expand Up @@ -405,7 +405,7 @@
" super().define(spec)\n",
" spec.input('name')\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(self.inputs.name, \"run\")\n",
" return plumpy.Continue(self.continue_fn)\n",
"\n",
Expand Down Expand Up @@ -469,12 +469,12 @@
"source": [
"class SimpleProcess(plumpy.Process):\n",
" \n",
" def run(self):\n",
" async def run(self):\n",
" print(self.get_name())\n",
" \n",
"class PauseProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(f\"{self.get_name()}: pausing\")\n",
" self.pause()\n",
" print(f\"{self.get_name()}: continue step\")\n",
Expand Down Expand Up @@ -727,7 +727,7 @@
" spec.input('name', valid_type=str, default='process')\n",
" spec.output('value')\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(self.inputs.name)\n",
" self.out('value', 'value')\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/process_helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def define(cls, spec):
spec.input('name', default='World', required=True)
spec.output('greeting', valid_type=str)

def run(self):
async def run(self):
self.out('greeting', f'Hello {self.inputs.name}!')
return plumpy.Stop(None, True)

Expand Down
2 changes: 1 addition & 1 deletion examples/process_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def define(cls, spec):
spec.outputs.dynamic = True
spec.output('default', valid_type=int)

def run(self):
async def run(self):
self.out('default', 5)


Expand Down
2 changes: 1 addition & 1 deletion examples/process_wait_and_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class WaitForResumeProc(plumpy.Process):

def run(self):
async def run(self):
print(f'Now I am running: {self.state}')
return plumpy.Wait(self.after_resume_and_exec)

Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Module containing future related methods and classes
"""
import asyncio
from typing import Any, Callable, Coroutine, Optional
from typing import Any, Awaitable, Callable, Optional

import kiwipy

Expand Down Expand Up @@ -54,7 +54,7 @@ def run(self, *args: Any, **kwargs: Any) -> None:
self._action = None # type: ignore


def create_task(coro: Callable[[], Coroutine], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
"""
Schedule a call to a coro in the event loop and wrap the outcome
in a future.
Expand Down
14 changes: 8 additions & 6 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import traceback
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast

import yaml
from yaml.loader import Loader
Expand All @@ -20,7 +20,7 @@
from .base import state_machine
from .lang import NULL
from .persistence import auto_persist
from .utils import SAVED_STATE_TYPE
from .utils import SAVED_STATE_TYPE, ensure_coroutine

__all__ = [
'ProcessState',
Expand Down Expand Up @@ -195,10 +195,12 @@ class Running(State):
_running: bool = False
_run_handle = None

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
def __init__(
self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any
) -> None:
super().__init__(process)
assert run_fn is not None
self.run_fn = run_fn
self.run_fn = ensure_coroutine(run_fn)
self.args = args
self.kwargs = kwargs
self._run_handle = None
Expand All @@ -211,7 +213,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN]))
if self.COMMAND in saved_state:
self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore

Expand All @@ -225,7 +227,7 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over
try:
try:
self._running = True
result = self.run_fn(*self.args, **self.kwargs)
result = await self.run_fn(*self.args, **self.kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finally:
self._running = False
except Interruption:
Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat

# region Execution related methods

def run(self) -> Any:
async def run(self) -> Any:
"""This function will be run when the process is triggered.
It should be overridden by a subclass.
"""
Expand Down
18 changes: 15 additions & 3 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,20 @@
import inspect
import logging
import types
from typing import Set # pylint: disable=unused-import
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, List, MutableMapping, Optional, Tuple, Type
from typing import ( # pylint: disable=unused-import
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Hashable,
Iterator,
List,
MutableMapping,
Optional,
Set,
Tuple,
Type,
)

from . import lang
from .settings import check_override, check_protected
Expand Down Expand Up @@ -221,7 +233,7 @@ def type_check(obj: Any, expected_type: Type) -> None:
raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'")


def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Any]:
def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Awaitable[Any]]:
"""
Ensure that the given function ``fct`` is a coroutine

Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None
awaitable = awaitable.future()
self._awaitables[awaitable] = key

def run(self) -> Any:
async def run(self) -> Any:
return self._do_step()

def _do_step(self) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion test/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class Process(plumpy.Process):

def run(self):
async def run(self):
pass


Expand Down
33 changes: 17 additions & 16 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_logging(self):

class LoggerTester(Process):

def run(self, **kwargs):
async def run(self, **kwargs):
self.logger.info('Test')

# TODO: Test giving a custom logger to see if it gets used
Expand Down Expand Up @@ -442,7 +442,7 @@ def test_kill_in_run(self):
class KillProcess(Process):
after_kill = False

def run(self, **kwargs):
async def run(self, **kwargs):
self.kill('killed')
# The following line should be executed because kill will not
# interrupt execution of a method call in the RUNNING state
Expand All @@ -459,7 +459,7 @@ def test_kill_when_paused_in_run(self):

class PauseProcess(Process):

def run(self, **kwargs):
async def run(self, **kwargs):
self.pause()
self.kill()

Expand Down Expand Up @@ -513,7 +513,7 @@ def test_invalid_output(self):

class InvalidOutput(plumpy.Process):

def run(self):
async def run(self):
self.out('invalid', 5)

proc = InvalidOutput()
Expand Down Expand Up @@ -541,7 +541,7 @@ class Proc(Process):
def define(cls, spec):
super().define(spec)

def run(self):
async def run(self):
return plumpy.UnsuccessfulResult(ERROR_CODE)

proc = Proc()
Expand All @@ -555,7 +555,7 @@ def test_pause_in_process(self):

class TestPausePlay(plumpy.Process):

def run(self):
async def run(self):
fut = self.pause()
test_case.assertIsInstance(fut, plumpy.Future)

Expand All @@ -580,7 +580,7 @@ def test_pause_play_in_process(self):

class TestPausePlay(plumpy.Process):

def run(self):
async def run(self):
fut = self.pause()
test_case.assertIsInstance(fut, plumpy.Future)
result = self.play()
Expand All @@ -597,7 +597,7 @@ def test_process_stack(self):

class StackTest(plumpy.Process):

def run(self):
async def run(self):
test_case.assertIs(self, Process.current())

proc = StackTest()
Expand All @@ -614,7 +614,7 @@ def test_nested(process):

class StackTest(plumpy.Process):

def run(self):
async def run(self):
# TODO: unexpected behaviour here
# if assert error happend here not raise
# it will be handled by try except clause in process
Expand All @@ -624,7 +624,7 @@ def run(self):

class ParentProcess(plumpy.Process):

def run(self):
async def run(self):
expect_true.append(self == Process.current())
StackTest().execute()

Expand All @@ -647,12 +647,12 @@ def test_process_nested(self):

class StackTest(plumpy.Process):

def run(self):
async def run(self):
pass

class ParentProcess(plumpy.Process):

def run(self):
async def run(self):
StackTest().execute()

ParentProcess().execute()
Expand All @@ -661,7 +661,7 @@ def test_call_soon(self):

class CallSoon(plumpy.Process):

def run(self):
async def run(self):
self.call_soon(self.do_except)

def do_except(self):
Expand Down Expand Up @@ -699,7 +699,7 @@ def test_exception_during_run(self):

class RaisingProcess(Process):

def run(self):
async def run(self):
raise RuntimeError('exception during run')

process = RaisingProcess()
Expand All @@ -719,7 +719,7 @@ def init(self):
super().init()
self.steps_ran = []

def run(self):
async def run(self):
self.pause()
self.steps_ran.append(self.run.__name__)
return plumpy.Continue(self.step2)
Expand Down Expand Up @@ -811,6 +811,7 @@ def test_saving_each_step(self):
saver = utils.ProcessSaver(proc)
saver.capture()
self.assertEqual(proc.state, ProcessState.FINISHED)
print(proc)
self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots))

def test_restart(self):
Expand Down Expand Up @@ -980,7 +981,7 @@ def define(cls, spec):
spec.output('required_bool', valid_type=bool)
spec.output_namespace(namespace, valid_type=int, dynamic=True)

def run(self):
async def run(self):
if self.inputs.output_mode == OutputMode.NONE:
pass
elif self.inputs.output_mode == OutputMode.DYNAMIC_PORT_NAMESPACE:
Expand Down
Loading