-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
143 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import functools | ||
import sys | ||
from contextlib import asynccontextmanager | ||
|
||
import trio | ||
from async_generator import aclosing | ||
|
||
|
||
def trio_async_generator(wrapped): | ||
"""async generator pattern which supports Trio nurseries and cancel scopes | ||
Decorator which allows async generators to use a Trio nursery or | ||
cancel scope internally. (Normally, it's not allowed to yield from | ||
these Trio constructs in an async generator.) | ||
Though the wrapped function is written as a normal async generator, usage | ||
of the wrapper is different: the wrapper is an async context manager | ||
providing the async generator to be iterated. | ||
Synopsis:: | ||
>>> @trio_async_generator | ||
>>> async def my_generator(): | ||
>>> # yield values, possibly from a nursery or cancel scope | ||
>>> # ... | ||
>>> | ||
>>> | ||
>>> async with my_generator() as agen: | ||
>>> async for value in agen: | ||
>>> print(value) | ||
Implementation: "The idea is that instead of pushing and popping the | ||
generator from the stack of the task that's consuming it, you instead run | ||
the generator code as a second task that feeds the consumer task values." | ||
See https://github.com/python-trio/trio/issues/638#issuecomment-431954073 | ||
ISSUE: pylint is confused by this implementation, and every use will | ||
trigger not-async-context-manager | ||
""" | ||
@asynccontextmanager | ||
@functools.wraps(wrapped) | ||
async def wrapper(*args, **kwargs): | ||
send_channel, receive_channel = trio.open_memory_channel(0) | ||
async with trio.open_nursery() as nursery: | ||
async def adapter(): | ||
async with send_channel, aclosing(wrapped(*args, **kwargs)) as agen: | ||
while True: | ||
try: | ||
# Advance underlying async generator to next yield | ||
value = await agen.__anext__() | ||
except StopAsyncIteration: | ||
break | ||
while True: | ||
try: | ||
# Forward the yielded value into the send channel | ||
try: | ||
await send_channel.send(value) | ||
except trio.BrokenResourceError: | ||
return | ||
break | ||
except BaseException: # pylint: disable=broad-except | ||
# If send_channel.send() raised (e.g. Cancelled), | ||
# throw the raised exception back into the generator, | ||
# and get the next yielded value to forward. | ||
try: | ||
value = await agen.athrow(*sys.exc_info()) | ||
except StopAsyncIteration: | ||
return | ||
|
||
nursery.start_soon(adapter, name=wrapped) | ||
async with receive_channel: | ||
yield receive_channel | ||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from math import inf | ||
|
||
import trio | ||
|
||
from trio_util._trio_async_generator import trio_async_generator | ||
|
||
# pylint: disable=not-async-context-manager | ||
|
||
|
||
@trio_async_generator | ||
async def squares_in_range(start, stop, timeout=inf, max_timeout_count=1): | ||
timeout_count = 0 | ||
for i in range(start, stop): | ||
with trio.move_on_after(timeout) as cancel_scope: | ||
yield i ** 2 | ||
await trio.sleep(0) | ||
if cancel_scope.cancelled_caught: | ||
timeout_count += 1 | ||
if timeout_count == max_timeout_count: | ||
break | ||
|
||
|
||
async def test_trio_agen_full_iteration(): | ||
last = None | ||
async with squares_in_range(0, 50) as squares: | ||
async for square in squares: | ||
last = square | ||
assert last == 49 ** 2 | ||
|
||
|
||
async def test_trio_agen_caller_exits(): | ||
async with squares_in_range(0, 50) as squares: | ||
async for square in squares: | ||
if square >= 400: | ||
return | ||
assert False | ||
|
||
|
||
async def test_trio_agen_caller_cancelled(autojump_clock): | ||
with trio.move_on_after(1): | ||
async with squares_in_range(0, 50) as squares: | ||
async for square in squares: | ||
assert square == 0 | ||
# the sleep will be cancelled by move_on_after above | ||
await trio.sleep(10) | ||
|
||
|
||
async def test_trio_agen_aborts_yield(autojump_clock): | ||
async with squares_in_range(0, 50, timeout=.5, max_timeout_count=1) as squares: | ||
async for square in squares: | ||
assert square == 0 | ||
# timeout in the generator will be triggered and it will abort iteration | ||
await trio.sleep(1) | ||
|
||
|
||
async def test_trio_agen_aborts_yield_and_continues(autojump_clock): | ||
async with squares_in_range(0, 50, timeout=.5, max_timeout_count=99) as squares: | ||
_sum = 0 | ||
async for square in squares: | ||
_sum += square | ||
if square == 5 ** 2: | ||
# this will cause the next iteration (6 ** 2) to time out | ||
await trio.sleep(.6) | ||
assert _sum == sum(i ** 2 for i in range(0, 50)) - 6 ** 2 |