Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Make Process.run async #272

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/plumpy/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import asyncio
from typing import Any, Callable, Coroutine, Optional
from typing import Any, Awaitable, Callable, Optional

import kiwipy

Expand Down Expand Up @@ -55,7 +55,7 @@ def run(self, *args: Any, **kwargs: Any) -> None:
self._action = None # type: ignore


def create_task(coro: Callable[[], Coroutine], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future:
"""
Schedule a call to a coro in the event loop and wrap the outcome
in a future.
Expand Down
18 changes: 12 additions & 6 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback
from enum import Enum
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast

import yaml
from yaml.loader import Loader
Expand All @@ -19,7 +19,7 @@
from .base import state_machine
from .lang import NULL
from .persistence import auto_persist
from .utils import SAVED_STATE_TYPE
from .utils import SAVED_STATE_TYPE, ensure_coroutine

__all__ = [
'Continue',
Expand Down Expand Up @@ -195,10 +195,16 @@ class Running(State):
_running: bool = False
_run_handle = None

def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
def __init__(
self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any
) -> None:
super().__init__(process)
assert run_fn is not None
self.run_fn = run_fn
self.run_fn = ensure_coroutine(run_fn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am quite worried about this. I think this run_fn will become the continue_fn when it is recover from the Waiting state, which means all such xx_fn should be coroutines along the way. I need to take a close look to see how this change will make things different. Will do it next week.

Copy link

@agoscinski agoscinski Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chrisjsewell suggested to do await_me_maybe https://github.com/aiidateam/plumpy/pull/272/files#r1257558025 that would avoid this but I did not read about any technical reason do it. The blog post he links is only arguing for it cause of cleanness of the code. I assumed wrapping a blocking function is like writing async to the function, it then is just executed like a blocking function when used with await. So

def two():
    # blocking function
    time.sleep(1)
    print("Two")

async def blocking():
    print("One")
    two()
    print("Three")

async def also_blocking():
    # does the same as blocking
    coro = ensure_coroutine(two)
    print("One")
    await coro()
    print("Three")

async def not_blocking():
    coro = ensure_coroutine(two)
    print("One")
    coro() # Runtime warning 
    print("Three")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed wrapping a blocking function is like writing async to the function, it then is just executed like a blocking function when used with await.

That's true, the blog post mentions the "maybe await" pattern is mostly for async framework that can support the downstream app can write block function. If the operation is block function, then it is run in block manner.

I think this run_fn will become the continue_fn when it is recover from the Waiting state, which means all such xx_fn should be coroutines along the way.

From aiida-core point of view, this never happened, since the continue_fn is never set in the aiida-core Waiting class. The def run is used to create the initial Created state and used to transfer the aiida Process
into its own Waiting state(s).

# We wrap `run_fn` to a coroutine so we can apply await on it,
# even it if it was not a coroutine in the first place.
# This allows the same usage of async and non-async function
# with the await syntax while not changing the program logic.
self.args = args
self.kwargs = kwargs
self._run_handle = None
Expand All @@ -211,7 +217,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist

def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
super().load_instance_state(saved_state, load_context)
self.run_fn = getattr(self.process, saved_state[self.RUN_FN])
self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN]))
if self.COMMAND in saved_state:
self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore

Expand All @@ -225,7 +231,7 @@ async def execute(self) -> State: # type: ignore
try:
try:
self._running = True
result = self.run_fn(*self.args, **self.kwargs)
result = await self.run_fn(*self.args, **self.kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is mentioned in the blog post that @chrisjsewell mentioned in https://github.com/aiidateam/plumpy/pull/272/files#r1257558025 with await_me_maybe would be the better design IMO. I think but we already do it the same way in the Process class so this is at least consistent. Maybe we can instead improve the doc a bit on the places where we use ensure_coroutine. Something like

We wrap `run_fn` to a coroutine so we can apply await on it, even it if it was not a coroutine in the first place. This allows the same usage of async and non-async function with the await syntax while not changing the program logic.

At least this is my understanding why we do this. (Also would add something like this to _run_task in class Process)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @agoscinski , I've added your comment.

finally:
self._running = False
except Interruption:
Expand Down
4 changes: 2 additions & 2 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ def play(self) -> bool:
call_with_super_check(self.on_playing)
return True

@event(from_states=(process_states.Waiting))
@event(from_states=process_states.Waiting)
def resume(self, *args: Any) -> None:
"""Start running the process again."""
return self._state.resume(*args) # type: ignore
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat

# region Execution related methods

def run(self) -> Any:
async def run(self) -> Any:
"""This function will be run when the process is triggered.
It should be overridden by a subclass.
"""
Expand Down
3 changes: 2 additions & 1 deletion src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import Mapping
from typing import (
Any,
Awaitable,
Callable,
Hashable,
Iterator,
Expand Down Expand Up @@ -185,7 +186,7 @@ def type_check(obj: Any, expected_type: Type) -> None:
raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'")


def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Any]:
def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Awaitable[Any]]:
"""
Ensure that the given function ``fct`` is a coroutine

Expand Down
2 changes: 1 addition & 1 deletion src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None

self._awaitables[resolved_awaitable] = key

def run(self) -> Any:
async def run(self) -> Any:
return self._do_step()

def _do_step(self) -> Any:
Expand Down
Loading