diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 6808f930c6..37c5c05feb 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1220,6 +1220,12 @@ inside a single process, and for that you can use .. autofunction:: open_memory_channel(max_buffer_size) +Assigning the send and receive channels to separate variables usually +produces the most readable code. However, in situations where the pair +is preserved-- such as a collection of memory channels-- prefer named tuple +access (``pair.send_channel``, ``pair.receive_channel``) over indexed access +(``pair[0]``, ``pair[1]``). + .. note:: If you've used the :mod:`threading` or :mod:`asyncio` modules, you may be familiar with :class:`queue.Queue` or :class:`asyncio.Queue`. In Trio, :func:`open_memory_channel` is diff --git a/newsfragments/1771.feature.rst b/newsfragments/1771.feature.rst new file mode 100644 index 0000000000..ec4b528aa0 --- /dev/null +++ b/newsfragments/1771.feature.rst @@ -0,0 +1,5 @@ +open_memory_channel() now returns a named tuple with attributes ``send_channel`` +and ``receive_channel``. This can be used to avoid indexed access of the +channel halves in some scenarios such as a collection of channels. (Note: when +dealing with a single memory channel, assigning the send and receive halves +to separate variables via destructuring is still considered more readable.) diff --git a/src/trio/_channel.py b/src/trio/_channel.py index f5ed4004d7..1f328daea0 100644 --- a/src/trio/_channel.py +++ b/src/trio/_channel.py @@ -2,10 +2,12 @@ from collections import OrderedDict, deque from math import inf +from operator import itemgetter from typing import ( TYPE_CHECKING, Generic, Tuple, # only needed for typechecking on <3.9 + TypeVar, ) import attrs @@ -13,19 +15,23 @@ import trio -from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T +from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType from ._core import Abort, RaiseCancelT, Task, enable_ki_protection from ._util import NoPublicConstructor, final, generic_function if TYPE_CHECKING: + from collections.abc import Iterable from types import TracebackType from typing_extensions import Self +T = TypeVar("T") + + def _open_memory_channel( max_buffer_size: int | float, # noqa: PYI041 -) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: +) -> MemoryChannelPair[T]: """Open a channel for passing objects between tasks within a process. Memory channels are lightweight, cheap to allocate, and entirely @@ -53,9 +59,8 @@ def _open_memory_channel( see :ref:`channel-buffering` for more details. If in doubt, use 0. Returns: - A pair ``(send_channel, receive_channel)``. If you have - trouble remembering which order these go in, remember: data - flows from left → right. + A named tuple ``(send_channel, receive_channel)``. The tuple ordering is + intended to match the image of data flowing from left → right. In addition to the standard channel methods, all memory channel objects provide a ``statistics()`` method, which returns an object with the @@ -82,33 +87,12 @@ def _open_memory_channel( if max_buffer_size < 0: raise ValueError("max_buffer_size must be >= 0") state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size) - return ( + return MemoryChannelPair( MemorySendChannel[T]._create(state), MemoryReceiveChannel[T]._create(state), ) -# This workaround requires python3.9+, once older python versions are not supported -# or there's a better way of achieving type-checking on a generic factory function, -# it could replace the normal function header -if TYPE_CHECKING: - # written as a class so you can say open_memory_channel[int](5) - # Need to use Tuple instead of tuple due to CI check running on 3.8 - class open_memory_channel(Tuple["MemorySendChannel[T]", "MemoryReceiveChannel[T]"]): - def __new__( # type: ignore[misc] # "must return a subtype" - cls, max_buffer_size: int | float # noqa: PYI041 - ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: - return _open_memory_channel(max_buffer_size) - - def __init__(self, max_buffer_size: int | float): # noqa: PYI041 - ... - -else: - # apply the generic_function decorator to make open_memory_channel indexable - # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime - open_memory_channel = generic_function(_open_memory_channel) - - @attrs.frozen class MemoryChannelStats: current_buffer_used: int @@ -144,9 +128,12 @@ def statistics(self) -> MemoryChannelStats: @final @attrs.define(eq=False, repr=False, slots=False) -class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor): +class MemorySendChannel( + SendChannel[SendType], + Generic[SendType], + metaclass=NoPublicConstructor, +): _state: MemoryChannelState[SendType] - _closed: bool = False # This is just the tasks waiting on *this* object. As compared to # self._state.send_tasks, which includes tasks from this object and # all clones. @@ -287,7 +274,11 @@ async def aclose(self) -> None: @final @attrs.define(eq=False, repr=False, slots=False) -class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor): +class MemoryReceiveChannel( + ReceiveChannel[ReceiveType], + Generic[ReceiveType], + metaclass=NoPublicConstructor, +): _state: MemoryChannelState[ReceiveType] _closed: bool = False _tasks: set[trio._core._run.Task] = attrs.Factory(set) @@ -431,3 +422,102 @@ def close(self) -> None: async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() + + +# We cannot use generic named tuples before Py 3.11, manually define it. +class MemoryChannelPair( + Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]], + Generic[T], +): + """Named tuple of send/receive memory channels.""" + + __slots__ = () + _fields = ("send_channel", "receive_channel") + + if TYPE_CHECKING: + + @property + def send_channel(self) -> MemorySendChannel[T]: + """Returns the sending channel half.""" + return self[0] + + @property + def receive_channel(self) -> MemoryReceiveChannel[T]: + """Returns the receiving channel half.""" + return self[1] + + else: # More efficient + send_channel = property(itemgetter(0), doc="Returns the sending channel half.") + receive_channel = property( + itemgetter(1), doc="Returns the receiving channel half." + ) + + def __new__( + cls, + send_channel: MemorySendChannel[T], + receive_channel: MemoryReceiveChannel[T], + ) -> Self: + """Create new instance of MemoryChannelPair(send_channel, receive_channel)""" + return tuple.__new__(cls, (send_channel, receive_channel)) # type: ignore[type-var] + + @classmethod + def _make( + cls, + iterable: Iterable[MemorySendChannel[T] | MemoryReceiveChannel[T]], + ) -> Self: + """Make a new MemoryChannelPair object from a sequence or iterable""" + send, rec = iterable + if isinstance(send, MemoryReceiveChannel) or isinstance(rec, MemorySendChannel): + raise TypeError("Channel order passed incorrectly.") + return tuple.__new__(cls, (send, rec)) # type: ignore[type-var] + + def _replace( + self, + *, + send_channel: MemorySendChannel[T] | None = None, + receive_channel: MemoryReceiveChannel[T] | None = None, + ) -> MemoryChannelPair[T]: + """Return a new MemoryChannelPair object replacing specified fields with new values""" + if send_channel is None: + send_channel = self.send_channel + if receive_channel is None: + receive_channel = self.receive_channel + return tuple.__new__( + MemoryChannelPair, + (send_channel, receive_channel), + ) # type: ignore[type-var] + + def __repr__(self) -> str: + """Return a nicely formatted representation string""" + return f"{self.__class__.__name__}(send_channel={self[0]!r}, receive_channel={self[1]!r})" + + def _asdict( + self, + ) -> OrderedDict[str, MemorySendChannel[T] | MemoryReceiveChannel[T]]: + """Return a new OrderedDict which maps field names to their values.""" + return OrderedDict(zip(self._fields, self)) + + def __getnewargs__(self) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: + """Return self as a plain tuple. Used by copy and pickle.""" + return (self[0], self[1]) + + +# This workaround requires python3.9+, once older python versions are not supported +# or there's a better way of achieving type-checking on a generic factory function, +# it could replace the normal function header +if TYPE_CHECKING: + # written as a class so that you can say open_memory_channel[int](5) + # Need to use Tuple instead of tuple due to CI check running on 3.8 + class open_memory_channel(MemoryChannelPair[T]): + def __new__( # type: ignore[misc] # "must return a subtype" + cls, max_buffer_size: int | float # noqa: PYI041 + ) -> MemoryChannelPair[T]: + return _open_memory_channel(max_buffer_size) + + def __init__(self, max_buffer_size: int | float): # noqa: PYI041 + ... + +else: + # apply the generic_function decorator to make open_memory_channel indexable + # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime + open_memory_channel = generic_function(_open_memory_channel) diff --git a/src/trio/_tests/test_channel.py b/src/trio/_tests/test_channel.py index 1271f6b765..c82c6767ed 100644 --- a/src/trio/_tests/test_channel.py +++ b/src/trio/_tests/test_channel.py @@ -409,3 +409,8 @@ async def do_send(s: trio.MemorySendChannel[int], v: int) -> None: assert await r.receive() == 1 with pytest.raises(trio.WouldBlock): r.receive_nowait() + + +def test_named_tuple(): + pair = open_memory_channel(0) + assert pair.send_channel, pair.receive_channel == pair diff --git a/src/trio/_tests/test_highlevel_serve_listeners.py b/src/trio/_tests/test_highlevel_serve_listeners.py index 1ce886eddb..a1457de3d8 100644 --- a/src/trio/_tests/test_highlevel_serve_listeners.py +++ b/src/trio/_tests/test_highlevel_serve_listeners.py @@ -42,7 +42,7 @@ class MemoryListener(trio.abc.Listener[StapledMemoryStream]): async def connect(self) -> StapledMemoryStream: assert not self.closed client, server = memory_stream_pair() - await self.queued_streams[0].send(server) + await self.queued_streams.send_channel.send(server) return client async def accept(self) -> StapledMemoryStream: @@ -50,7 +50,7 @@ async def accept(self) -> StapledMemoryStream: assert not self.closed if self.accept_hook is not None: await self.accept_hook() - stream = await self.queued_streams[1].receive() + stream = await self.queued_streams.receive_channel.receive() self.accepted_streams.append(stream) return stream