Skip to content

Commit

Permalink
♻️ Make Process.run async
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Jul 7, 2023
1 parent 49a7117 commit ff1b699
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 56 deletions.
16 changes: 8 additions & 8 deletions docs/source/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"source": [
"class SimpleProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(self.state.name)\n",
" \n",
"process = SimpleProcess()\n",
Expand Down Expand Up @@ -219,7 +219,7 @@
" spec.output('output2.output2a')\n",
" spec.output('output2.output2b')\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" self.out('output1', self.inputs.input1)\n",
" self.out('output2.output2a', self.inputs.input2.input2a)\n",
" self.out('output2.output2b', self.inputs.input2.input2b)\n",
Expand Down Expand Up @@ -277,7 +277,7 @@
"source": [
"class ContinueProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(\"running\")\n",
" return plumpy.Continue(self.continue_fn)\n",
" \n",
Expand Down Expand Up @@ -340,7 +340,7 @@
"\n",
"class WaitProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" return plumpy.Wait(self.resume_fn)\n",
" \n",
" def resume_fn(self):\n",
Expand Down Expand Up @@ -405,7 +405,7 @@
" super().define(spec)\n",
" spec.input('name')\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(self.inputs.name, \"run\")\n",
" return plumpy.Continue(self.continue_fn)\n",
"\n",
Expand Down Expand Up @@ -469,12 +469,12 @@
"source": [
"class SimpleProcess(plumpy.Process):\n",
" \n",
" def run(self):\n",
" async def run(self):\n",
" print(self.get_name())\n",
" \n",
"class PauseProcess(plumpy.Process):\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(f\"{self.get_name()}: pausing\")\n",
" self.pause()\n",
" print(f\"{self.get_name()}: continue step\")\n",
Expand Down Expand Up @@ -727,7 +727,7 @@
" spec.input('name', valid_type=str, default='process')\n",
" spec.output('value')\n",
"\n",
" def run(self):\n",
" async def run(self):\n",
" print(self.inputs.name)\n",
" self.out('value', 'value')\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/process_helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def define(cls, spec):
spec.input('name', default='World', required=True)
spec.output('greeting', valid_type=str)

def run(self):
async def run(self):
self.out('greeting', f'Hello {self.inputs.name}!')
return plumpy.Stop(None, True)

Expand Down
2 changes: 1 addition & 1 deletion examples/process_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def define(cls, spec):
spec.outputs.dynamic = True
spec.output('default', valid_type=int)

def run(self):
async def run(self):
self.out('default', 5)


Expand Down
2 changes: 1 addition & 1 deletion examples/process_wait_and_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class WaitForResumeProc(plumpy.Process):

def run(self):
async def run(self):
print(f'Now I am running: {self.state}')
return plumpy.Wait(self.after_resume_and_exec)

Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Module containing future related methods and classes
"""
import asyncio
from typing import Any, Callable, Coroutine, Optional
from typing import Any, Awaitable, Callable, Optional

import kiwipy

Expand Down Expand Up @@ -54,7 +54,7 @@ def run(self, *args: Any, **kwargs: Any) -> None:
self._action = None # type: ignore


def create_task(coro: Callable[[], Coroutine], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
"""
Schedule a call to a coro in the event loop and wrap the outcome
in a future.
Expand Down
14 changes: 8 additions & 6 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import traceback
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast

import yaml
from yaml.loader import Loader
Expand All @@ -20,7 +20,7 @@
from .base import state_machine
from .lang import NULL
from .persistence import auto_persist
from .utils import SAVED_STATE_TYPE
from .utils import SAVED_STATE_TYPE, ensure_coroutine

__all__ = [
'ProcessState',
Expand Down Expand Up @@ -195,10 +195,12 @@ class Running(State):
_running: bool = False
_run_handle = None

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
def __init__(
self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any
) -> None:
super().__init__(process)
assert run_fn is not None
self.run_fn = run_fn
self.run_fn = ensure_coroutine(run_fn)
self.args = args
self.kwargs = kwargs
self._run_handle = None
Expand All @@ -211,7 +213,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN]))
if self.COMMAND in saved_state:
self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore

Expand All @@ -225,7 +227,7 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over
try:
try:
self._running = True
result = self.run_fn(*self.args, **self.kwargs)
result = await self.run_fn(*self.args, **self.kwargs)
finally:
self._running = False
except Interruption:
Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat

# region Execution related methods

def run(self) -> Any:
async def run(self) -> Any:
"""This function will be run when the process is triggered.
It should be overridden by a subclass.
"""
Expand Down
18 changes: 15 additions & 3 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,20 @@
import inspect
import logging
import types
from typing import Set # pylint: disable=unused-import
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, List, MutableMapping, Optional, Tuple, Type
from typing import ( # pylint: disable=unused-import
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Hashable,
Iterator,
List,
MutableMapping,
Optional,
Set,
Tuple,
Type,
)

from . import lang
from .settings import check_override, check_protected
Expand Down Expand Up @@ -221,7 +233,7 @@ def type_check(obj: Any, expected_type: Type) -> None:
raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'")


def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Any]:
def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Awaitable[Any]]:
"""
Ensure that the given function ``fct`` is a coroutine
Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None
awaitable = awaitable.future()
self._awaitables[awaitable] = key

def run(self) -> Any:
async def run(self) -> Any:
return self._do_step()

def _do_step(self) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion test/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class Process(plumpy.Process):

def run(self):
async def run(self):
pass


Expand Down
33 changes: 17 additions & 16 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_logging(self):

class LoggerTester(Process):

def run(self, **kwargs):
async def run(self, **kwargs):
self.logger.info('Test')

# TODO: Test giving a custom logger to see if it gets used
Expand Down Expand Up @@ -442,7 +442,7 @@ def test_kill_in_run(self):
class KillProcess(Process):
after_kill = False

def run(self, **kwargs):
async def run(self, **kwargs):
self.kill('killed')
# The following line should be executed because kill will not
# interrupt execution of a method call in the RUNNING state
Expand All @@ -459,7 +459,7 @@ def test_kill_when_paused_in_run(self):

class PauseProcess(Process):

def run(self, **kwargs):
async def run(self, **kwargs):
self.pause()
self.kill()

Expand Down Expand Up @@ -513,7 +513,7 @@ def test_invalid_output(self):

class InvalidOutput(plumpy.Process):

def run(self):
async def run(self):
self.out('invalid', 5)

proc = InvalidOutput()
Expand Down Expand Up @@ -541,7 +541,7 @@ class Proc(Process):
def define(cls, spec):
super().define(spec)

def run(self):
async def run(self):
return plumpy.UnsuccessfulResult(ERROR_CODE)

proc = Proc()
Expand All @@ -555,7 +555,7 @@ def test_pause_in_process(self):

class TestPausePlay(plumpy.Process):

def run(self):
async def run(self):
fut = self.pause()
test_case.assertIsInstance(fut, plumpy.Future)

Expand All @@ -580,7 +580,7 @@ def test_pause_play_in_process(self):

class TestPausePlay(plumpy.Process):

def run(self):
async def run(self):
fut = self.pause()
test_case.assertIsInstance(fut, plumpy.Future)
result = self.play()
Expand All @@ -597,7 +597,7 @@ def test_process_stack(self):

class StackTest(plumpy.Process):

def run(self):
async def run(self):
test_case.assertIs(self, Process.current())

proc = StackTest()
Expand All @@ -614,7 +614,7 @@ def test_nested(process):

class StackTest(plumpy.Process):

def run(self):
async def run(self):
# TODO: unexpected behaviour here
# if assert error happend here not raise
# it will be handled by try except clause in process
Expand All @@ -624,7 +624,7 @@ def run(self):

class ParentProcess(plumpy.Process):

def run(self):
async def run(self):
expect_true.append(self == Process.current())
StackTest().execute()

Expand All @@ -647,12 +647,12 @@ def test_process_nested(self):

class StackTest(plumpy.Process):

def run(self):
async def run(self):
pass

class ParentProcess(plumpy.Process):

def run(self):
async def run(self):
StackTest().execute()

ParentProcess().execute()
Expand All @@ -661,7 +661,7 @@ def test_call_soon(self):

class CallSoon(plumpy.Process):

def run(self):
async def run(self):
self.call_soon(self.do_except)

def do_except(self):
Expand Down Expand Up @@ -699,7 +699,7 @@ def test_exception_during_run(self):

class RaisingProcess(Process):

def run(self):
async def run(self):
raise RuntimeError('exception during run')

process = RaisingProcess()
Expand All @@ -719,7 +719,7 @@ def init(self):
super().init()
self.steps_ran = []

def run(self):
async def run(self):
self.pause()
self.steps_ran.append(self.run.__name__)
return plumpy.Continue(self.step2)
Expand Down Expand Up @@ -811,6 +811,7 @@ def test_saving_each_step(self):
saver = utils.ProcessSaver(proc)
saver.capture()
self.assertEqual(proc.state, ProcessState.FINISHED)
print(proc)
self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots))

def test_restart(self):
Expand Down Expand Up @@ -980,7 +981,7 @@ def define(cls, spec):
spec.output('required_bool', valid_type=bool)
spec.output_namespace(namespace, valid_type=int, dynamic=True)

def run(self):
async def run(self):
if self.inputs.output_mode == OutputMode.NONE:
pass
elif self.inputs.output_mode == OutputMode.DYNAMIC_PORT_NAMESPACE:
Expand Down
Loading

0 comments on commit ff1b699

Please sign in to comment.