Skip to content

Commit

Permalink
Make Process.run async (#272)
Browse files Browse the repository at this point in the history
Co-authored-by: Ali <[email protected]>
(cherry picked from commit 4611154)
  • Loading branch information
chrisjsewell authored Dec 10, 2024
1 parent 2159fb7 commit 3d26349
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/plumpy/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import asyncio
from typing import Any, Callable, Coroutine, Optional
from typing import Any, Awaitable, Callable, Optional

import kiwipy

Expand Down Expand Up @@ -55,7 +55,7 @@ def run(self, *args: Any, **kwargs: Any) -> None:
self._action = None # type: ignore


def create_task(coro: Callable[[], Coroutine], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
"""
Schedule a call to a coro in the event loop and wrap the outcome
in a future.
Expand Down
18 changes: 12 additions & 6 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback
from enum import Enum
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast

import yaml
from yaml.loader import Loader
Expand All @@ -19,7 +19,7 @@
from .base import state_machine
from .lang import NULL
from .persistence import auto_persist
from .utils import SAVED_STATE_TYPE
from .utils import SAVED_STATE_TYPE, ensure_coroutine

__all__ = [
'Continue',
Expand Down Expand Up @@ -195,10 +195,16 @@ class Running(State):
_running: bool = False
_run_handle = None

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
def __init__(
self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any
) -> None:
super().__init__(process)
assert run_fn is not None
self.run_fn = run_fn
self.run_fn = ensure_coroutine(run_fn)
# We wrap `run_fn` to a coroutine so we can apply await on it,
# even it if it was not a coroutine in the first place.
# This allows the same usage of async and non-async function
# with the await syntax while not changing the program logic.
self.args = args
self.kwargs = kwargs
self._run_handle = None
Expand All @@ -211,7 +217,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN]))
if self.COMMAND in saved_state:
self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore

Expand All @@ -225,7 +231,7 @@ async def execute(self) -> State: # type: ignore
try:
try:
self._running = True
result = self.run_fn(*self.args, **self.kwargs)
result = await self.run_fn(*self.args, **self.kwargs)
finally:
self._running = False
except Interruption:
Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ def play(self) -> bool:
call_with_super_check(self.on_playing)
return True

@event(from_states=(process_states.Waiting))
@event(from_states=process_states.Waiting)
def resume(self, *args: Any) -> None:
"""Start running the process again."""
return self._state.resume(*args) # type: ignore
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat

# region Execution related methods

def run(self) -> Any:
async def run(self) -> Any:
"""This function will be run when the process is triggered.
It should be overridden by a subclass.
"""
Expand Down
3 changes: 2 additions & 1 deletion src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import Mapping
from typing import (
Any,
Awaitable,
Callable,
Hashable,
Iterator,
Expand Down Expand Up @@ -185,7 +186,7 @@ def type_check(obj: Any, expected_type: Type) -> None:
raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'")


def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Any]:
def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Awaitable[Any]]:
"""
Ensure that the given function ``fct`` is a coroutine
Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None

self._awaitables[resolved_awaitable] = key

def run(self) -> Any:
async def run(self) -> Any:
return self._do_step()

def _do_step(self) -> Any:
Expand Down

0 comments on commit 3d26349

Please sign in to comment.