From 7121594f4b6ae42f292a4d6b9f827580dea59ae8 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 6 Jul 2023 18:03:50 +0200 Subject: [PATCH 1/4] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Make=20`Process.run`?= =?UTF-8?q?=20async?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/process_helloworld.py | 2 +- examples/process_launch.py | 2 +- src/plumpy/futures.py | 4 ++-- src/plumpy/process_states.py | 14 ++++++++------ src/plumpy/processes.py | 2 +- src/plumpy/workchains.py | 2 +- tests/test_workchains.py | 10 +++++----- 7 files changed, 19 insertions(+), 17 deletions(-) diff --git a/examples/process_helloworld.py b/examples/process_helloworld.py index db2eff0f..16526510 100644 --- a/examples/process_helloworld.py +++ b/examples/process_helloworld.py @@ -9,7 +9,7 @@ def define(cls, spec): spec.input('name', default='World', required=True) spec.output('greeting', valid_type=str) - def run(self): + async def run(self): self.out('greeting', f'Hello {self.inputs.name}!') return plumpy.Stop(None, True) diff --git a/examples/process_launch.py b/examples/process_launch.py index 645af0fd..b3a212f6 100644 --- a/examples/process_launch.py +++ b/examples/process_launch.py @@ -18,7 +18,7 @@ def define(cls, spec): spec.outputs.dynamic = True spec.output('default', valid_type=int) - def run(self): + async def run(self): self.out('default', 5) diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 161244cd..f52a0d09 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -4,7 +4,7 @@ """ import asyncio -from typing import Any, Callable, Coroutine, Optional +from typing import Any, Awaitable, Callable, Optional import kiwipy @@ -55,7 +55,7 @@ def run(self, *args: Any, **kwargs: Any) -> None: self._action = None # type: ignore -def create_task(coro: Callable[[], Coroutine], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future: +def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future: """ Schedule a call to a coro in the event loop and wrap the outcome in a future. diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 7ae6e9bd..a73152b2 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -3,7 +3,7 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast import yaml from yaml.loader import Loader @@ -19,7 +19,7 @@ from .base import state_machine from .lang import NULL from .persistence import auto_persist -from .utils import SAVED_STATE_TYPE +from .utils import SAVED_STATE_TYPE, ensure_coroutine __all__ = [ 'Continue', @@ -195,10 +195,12 @@ class Running(State): _running: bool = False _run_handle = None - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None - self.run_fn = run_fn + self.run_fn = ensure_coroutine(run_fn) self.args = args self.kwargs = kwargs self._run_handle = None @@ -211,7 +213,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) if self.COMMAND in saved_state: self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore @@ -225,7 +227,7 @@ async def execute(self) -> State: # type: ignore try: try: self._running = True - result = self.run_fn(*self.args, **self.kwargs) + result = await self.run_fn(*self.args, **self.kwargs) finally: self._running = False except Interruption: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ba7967d3..55710a9e 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1184,7 +1184,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat # region Execution related methods - def run(self) -> Any: + async def run(self) -> Any: """This function will be run when the process is triggered. It should be overridden by a subclass. """ diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 748a44d7..b48b1c6b 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -171,7 +171,7 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None self._awaitables[resolved_awaitable] = key - def run(self) -> Any: + async def run(self) -> Any: return self._do_step() def _do_step(self) -> Any: diff --git a/tests/test_workchains.py b/tests/test_workchains.py index 08c7317a..bd202adc 100644 --- a/tests/test_workchains.py +++ b/tests/test_workchains.py @@ -205,7 +205,7 @@ def define(cls, spec): super().define(spec) spec.output('res') - def run(self): + async def run(self): self.out('res', A) class ReturnB(plumpy.Process): @@ -214,7 +214,7 @@ def define(cls, spec): super().define(spec) spec.output('res') - def run(self): + async def run(self): self.out('res', B) class Wf(WorkChain): @@ -394,7 +394,7 @@ def define(cls, spec): spec.outline(cls.run, cls.check) spec.outputs.dynamic = True - def run(self): + async def run(self): return ToContext(subwc=self.launch(SubWorkChain)) def check(self): @@ -406,7 +406,7 @@ def define(cls, spec): super().define(spec) spec.outline(cls.run) - def run(self): + async def run(self): self.out('value', 5) workchain = MainWorkChain() @@ -449,7 +449,7 @@ def define(cls, spec): super().define(spec) spec.output('_return') - def run(self): + async def run(self): self.out('_return', val) class Workchain(WorkChain): From c5336be7de128e281f5e80202a1c74713ada74f1 Mon Sep 17 00:00:00 2001 From: Ali Date: Thu, 5 Dec 2024 14:50:02 +0100 Subject: [PATCH 2/4] fix conflict --- examples/process_helloworld.py | 2 +- examples/process_launch.py | 2 +- tests/test_workchains.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/process_helloworld.py b/examples/process_helloworld.py index 16526510..db2eff0f 100644 --- a/examples/process_helloworld.py +++ b/examples/process_helloworld.py @@ -9,7 +9,7 @@ def define(cls, spec): spec.input('name', default='World', required=True) spec.output('greeting', valid_type=str) - async def run(self): + def run(self): self.out('greeting', f'Hello {self.inputs.name}!') return plumpy.Stop(None, True) diff --git a/examples/process_launch.py b/examples/process_launch.py index b3a212f6..645af0fd 100644 --- a/examples/process_launch.py +++ b/examples/process_launch.py @@ -18,7 +18,7 @@ def define(cls, spec): spec.outputs.dynamic = True spec.output('default', valid_type=int) - async def run(self): + def run(self): self.out('default', 5) diff --git a/tests/test_workchains.py b/tests/test_workchains.py index bd202adc..08c7317a 100644 --- a/tests/test_workchains.py +++ b/tests/test_workchains.py @@ -205,7 +205,7 @@ def define(cls, spec): super().define(spec) spec.output('res') - async def run(self): + def run(self): self.out('res', A) class ReturnB(plumpy.Process): @@ -214,7 +214,7 @@ def define(cls, spec): super().define(spec) spec.output('res') - async def run(self): + def run(self): self.out('res', B) class Wf(WorkChain): @@ -394,7 +394,7 @@ def define(cls, spec): spec.outline(cls.run, cls.check) spec.outputs.dynamic = True - async def run(self): + def run(self): return ToContext(subwc=self.launch(SubWorkChain)) def check(self): @@ -406,7 +406,7 @@ def define(cls, spec): super().define(spec) spec.outline(cls.run) - async def run(self): + def run(self): self.out('value', 5) workchain = MainWorkChain() @@ -449,7 +449,7 @@ def define(cls, spec): super().define(spec) spec.output('_return') - async def run(self): + def run(self): self.out('_return', val) class Workchain(WorkChain): From 72df1edb0c12b349ba1d50f5db41d27b8f1787b6 Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 10 Dec 2024 16:06:33 +0100 Subject: [PATCH 3/4] fix conflicts --- src/plumpy/processes.py | 2 +- src/plumpy/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 55710a9e..ffddf7b5 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1111,7 +1111,7 @@ def play(self) -> bool: call_with_super_check(self.on_playing) return True - @event(from_states=(process_states.Waiting)) + @event(from_states=process_states.Waiting) def resume(self, *args: Any) -> None: """Start running the process again.""" return self._state.resume(*args) # type: ignore diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 36d76bbd..bd1b70a7 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -9,6 +9,7 @@ from collections.abc import Mapping from typing import ( Any, + Awaitable, Callable, Hashable, Iterator, @@ -185,7 +186,7 @@ def type_check(obj: Any, expected_type: Type) -> None: raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'") -def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Any]: +def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Awaitable[Any]]: """ Ensure that the given function ``fct`` is a coroutine From 57322cf3a77b491a418eb4dfa39b0e0db4a51d84 Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 10 Dec 2024 16:20:09 +0100 Subject: [PATCH 4/4] review applied --- src/plumpy/process_states.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index a73152b2..cf29973a 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -201,6 +201,10 @@ def __init__( super().__init__(process) assert run_fn is not None self.run_fn = ensure_coroutine(run_fn) + # We wrap `run_fn` to a coroutine so we can apply await on it, + # even it if it was not a coroutine in the first place. + # This allows the same usage of async and non-async function + # with the await syntax while not changing the program logic. self.args = args self.kwargs = kwargs self._run_handle = None