From b8dae694f23dcb61e3aea5ec92885c137e31cdd6 Mon Sep 17 00:00:00 2001 From: John Belmonte Date: Fri, 25 Sep 2020 17:47:44 +0900 Subject: [PATCH] add @trio_async_generator --- setup.py | 5 +- src/trio_util/__init__.py | 1 + src/trio_util/_trio_async_generator.py | 73 ++++++++++++++++++++++++++ test-requirements.txt | 4 +- tests/test_trio_async_generator.py | 64 ++++++++++++++++++++++ 5 files changed, 143 insertions(+), 4 deletions(-) create mode 100644 src/trio_util/_trio_async_generator.py create mode 100644 tests/test_trio_async_generator.py diff --git a/setup.py b/setup.py index 981c748..40a4ebe 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,10 @@ license='MIT', packages=[pkg_name], package_dir={'': 'src'}, - install_requires=['trio >= 0.11.0'], + install_requires=[ + 'async_generator', + 'trio >= 0.11.0' + ], python_requires='>=3.7', classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/src/trio_util/__init__.py b/src/trio_util/__init__.py index a663d20..1466ace 100644 --- a/src/trio_util/__init__.py +++ b/src/trio_util/__init__.py @@ -8,6 +8,7 @@ from ._periodic import periodic from ._repeated_event import UnqueuedRepeatedEvent, MailboxRepeatedEvent from ._task_stats import TaskStats +from ._trio_async_generator import trio_async_generator def _metadata_fix(): # don't do this for Sphinx case because it breaks "bysource" member ordering diff --git a/src/trio_util/_trio_async_generator.py b/src/trio_util/_trio_async_generator.py new file mode 100644 index 0000000..dc2de6d --- /dev/null +++ b/src/trio_util/_trio_async_generator.py @@ -0,0 +1,73 @@ +import functools +import sys +from contextlib import asynccontextmanager + +import trio +from async_generator import aclosing + + +def trio_async_generator(wrapped): + """async generator pattern which supports Trio nurseries and cancel scopes + + Decorator which allows async generators to use a Trio nursery or + cancel scope internally. (Normally, it's not allowed to yield from + these Trio constructs in an async generator.) + + Though the wrapped function is written as a normal async generator, usage + of the wrapper is different: the wrapper is an async context manager + providing the async generator to be iterated. + + Synopsis:: + + >>> @trio_async_generator + >>> async def my_generator(): + >>> # yield values, possibly from a nursery or cancel scope + >>> # ... + >>> + >>> + >>> async with my_generator() as agen: + >>> async for value in agen: + >>> print(value) + + Implementation: "The idea is that instead of pushing and popping the + generator from the stack of the task that's consuming it, you instead run + the generator code as a second task that feeds the consumer task values." + See https://github.com/python-trio/trio/issues/638#issuecomment-431954073 + + ISSUE: pylint is confused by this implementation, and every use will + trigger not-async-context-manager + """ + @asynccontextmanager + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + send_channel, receive_channel = trio.open_memory_channel(0) + async with trio.open_nursery() as nursery: + async def adapter(): + async with send_channel, aclosing(wrapped(*args, **kwargs)) as agen: + while True: + try: + # Advance underlying async generator to next yield + value = await agen.__anext__() + except StopAsyncIteration: + break + while True: + try: + # Forward the yielded value into the send channel + try: + await send_channel.send(value) + except trio.BrokenResourceError: + return + break + except BaseException: # pylint: disable=broad-except + # If send_channel.send() raised (e.g. Cancelled), + # throw the raised exception back into the generator, + # and get the next yielded value to forward. + try: + value = await agen.athrow(*sys.exc_info()) + except StopAsyncIteration: + return + + nursery.start_soon(adapter, name=wrapped) + async with receive_channel: + yield receive_channel + return wrapper diff --git a/test-requirements.txt b/test-requirements.txt index 852b9be..402050c 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,11 +5,10 @@ # pip-compile --output-file=test-requirements.txt setup.py test-requirements.in # astroid==2.4.1 # via pylint -async-generator==1.10 # via pytest-trio, trio +async-generator==1.10 # via pytest-trio, trio, trio_util (setup.py) attrs==19.3.0 # via outcome, pytest, trio coverage==5.1 # via pytest-cov idna==2.9 # via trio -importlib-metadata==1.6.0 # via pluggy, pytest isort==4.3.21 # via pylint lazy-object-proxy==1.4.3 # via astroid mccabe==0.6.1 # via pylint @@ -34,4 +33,3 @@ typed-ast==1.4.1 # via astroid, mypy typing-extensions==3.7.4.2 # via mypy wcwidth==0.1.9 # via pytest wrapt==1.12.1 # via astroid -zipp==3.1.0 # via importlib-metadata diff --git a/tests/test_trio_async_generator.py b/tests/test_trio_async_generator.py new file mode 100644 index 0000000..22fab56 --- /dev/null +++ b/tests/test_trio_async_generator.py @@ -0,0 +1,64 @@ +from math import inf + +import trio + +from trio_util._trio_async_generator import trio_async_generator + +# pylint: disable=not-async-context-manager + + +@trio_async_generator +async def squares_in_range(start, stop, timeout=inf, max_timeout_count=1): + timeout_count = 0 + for i in range(start, stop): + with trio.move_on_after(timeout) as cancel_scope: + yield i ** 2 + await trio.sleep(0) + if cancel_scope.cancelled_caught: + timeout_count += 1 + if timeout_count == max_timeout_count: + break + + +async def test_trio_agen_full_iteration(): + last = None + async with squares_in_range(0, 50) as squares: + async for square in squares: + last = square + assert last == 49 ** 2 + + +async def test_trio_agen_caller_exits(): + async with squares_in_range(0, 50) as squares: + async for square in squares: + if square >= 400: + return + assert False + + +async def test_trio_agen_caller_cancelled(autojump_clock): + with trio.move_on_after(1): + async with squares_in_range(0, 50) as squares: + async for square in squares: + assert square == 0 + # the sleep will be cancelled by move_on_after above + await trio.sleep(10) + + +async def test_trio_agen_aborts_yield(autojump_clock): + async with squares_in_range(0, 50, timeout=.5, max_timeout_count=1) as squares: + async for square in squares: + assert square == 0 + # timeout in the generator will be triggered and it will abort iteration + await trio.sleep(1) + + +async def test_trio_agen_aborts_yield_and_continues(autojump_clock): + async with squares_in_range(0, 50, timeout=.5, max_timeout_count=99) as squares: + _sum = 0 + async for square in squares: + _sum += square + if square == 5 ** 2: + # this will cause the next iteration (6 ** 2) to time out + await trio.sleep(.6) + assert _sum == sum(i ** 2 for i in range(0, 50)) - 6 ** 2