From a27509e387d734a05e05398b45d921b6931ac391 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 10 Dec 2024 21:07:20 +0000 Subject: [PATCH] Make `Process.run` async (#272) Co-authored-by: Ali (cherry picked from commit 4611154c76ac0991bcf7371b21488f4390648c28) --- src/plumpy/futures.py | 4 ++-- src/plumpy/process_states.py | 18 ++++++++++++------ src/plumpy/processes.py | 4 ++-- src/plumpy/utils.py | 3 ++- src/plumpy/workchains.py | 2 +- 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 161244cd..f52a0d09 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -4,7 +4,7 @@ """ import asyncio -from typing import Any, Callable, Coroutine, Optional +from typing import Any, Awaitable, Callable, Optional import kiwipy @@ -55,7 +55,7 @@ def run(self, *args: Any, **kwargs: Any) -> None: self._action = None # type: ignore -def create_task(coro: Callable[[], Coroutine], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future: +def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future: """ Schedule a call to a coro in the event loop and wrap the outcome in a future. diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 7ae6e9bd..cf29973a 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -3,7 +3,7 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast import yaml from yaml.loader import Loader @@ -19,7 +19,7 @@ from .base import state_machine from .lang import NULL from .persistence import auto_persist -from .utils import SAVED_STATE_TYPE +from .utils import SAVED_STATE_TYPE, ensure_coroutine __all__ = [ 'Continue', @@ -195,10 +195,16 @@ class Running(State): _running: bool = False _run_handle = None - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None - self.run_fn = run_fn + self.run_fn = ensure_coroutine(run_fn) + # We wrap `run_fn` to a coroutine so we can apply await on it, + # even it if it was not a coroutine in the first place. + # This allows the same usage of async and non-async function + # with the await syntax while not changing the program logic. self.args = args self.kwargs = kwargs self._run_handle = None @@ -211,7 +217,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) if self.COMMAND in saved_state: self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore @@ -225,7 +231,7 @@ async def execute(self) -> State: # type: ignore try: try: self._running = True - result = self.run_fn(*self.args, **self.kwargs) + result = await self.run_fn(*self.args, **self.kwargs) finally: self._running = False except Interruption: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ba7967d3..ffddf7b5 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1111,7 +1111,7 @@ def play(self) -> bool: call_with_super_check(self.on_playing) return True - @event(from_states=(process_states.Waiting)) + @event(from_states=process_states.Waiting) def resume(self, *args: Any) -> None: """Start running the process again.""" return self._state.resume(*args) # type: ignore @@ -1184,7 +1184,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat # region Execution related methods - def run(self) -> Any: + async def run(self) -> Any: """This function will be run when the process is triggered. It should be overridden by a subclass. """ diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 36d76bbd..bd1b70a7 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -9,6 +9,7 @@ from collections.abc import Mapping from typing import ( Any, + Awaitable, Callable, Hashable, Iterator, @@ -185,7 +186,7 @@ def type_check(obj: Any, expected_type: Type) -> None: raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'") -def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Any]: +def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Awaitable[Any]]: """ Ensure that the given function ``fct`` is a coroutine diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 748a44d7..b48b1c6b 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -171,7 +171,7 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None self._awaitables[resolved_awaitable] = key - def run(self) -> Any: + async def run(self) -> Any: return self._do_step() def _do_step(self) -> Any: