Skip to content

Commit

Permalink
Handle Redis pub/sub subscribe errors
Browse files Browse the repository at this point in the history
  • Loading branch information
JoseKilo committed Oct 25, 2024
1 parent 69cf29a commit c4dbea7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
14 changes: 12 additions & 2 deletions broadcaster/_backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ def __init__(self, url: str):
self._conn = redis.Redis.from_url(url)
self._pubsub = self._conn.pubsub()
self._ready = asyncio.Event()
self._queue: asyncio.Queue[Event] = asyncio.Queue()
self._queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue()
self._listener: asyncio.Task[None] | None = None

async def connect(self) -> None:
self._listener = asyncio.create_task(self._pubsub_listener())
self._listener.add_done_callback(self.drop)
await self._pubsub.connect()

async def disconnect(self) -> None:
Expand All @@ -27,6 +28,10 @@ async def disconnect(self) -> None:
if self._listener is not None:
self._listener.cancel()

def drop(self, task: asyncio.Task[None]) -> None:
exc = task.exception()
self._queue.put_nowait(exc)

async def subscribe(self, channel: str) -> None:
self._ready.set()
await self._pubsub.subscribe(channel)
Expand All @@ -38,7 +43,12 @@ async def publish(self, channel: str, message: typing.Any) -> None:
await self._conn.publish(channel, message)

async def next_published(self) -> Event:
return await self._queue.get()
result = await self._queue.get()
if result is None:
raise RuntimeError
if isinstance(result, BaseException):
raise result
return result

async def _pubsub_listener(self) -> None:
# redis-py does not listen to the pubsub connection if there are no channels subscribed
Expand Down
19 changes: 15 additions & 4 deletions broadcaster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Broadcast:
def __init__(self, url: str | None = None, *, backend: BroadcastBackend | None = None) -> None:
assert url or backend, "Either `url` or `backend` must be provided."
self._backend = backend or self._create_backend(cast(str, url))
self._subscribers: dict[str, set[asyncio.Queue[Event | None]]] = {}
self._subscribers: dict[str, set[asyncio.Queue[Event | BaseException | None]]] = {}

def _create_backend(self, url: str) -> BroadcastBackend:
parsed_url = urlparse(url)
Expand Down Expand Up @@ -69,10 +69,19 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
async def connect(self) -> None:
await self._backend.connect()
self._listener_task = asyncio.create_task(self._listener())
self._listener_task.add_done_callback(self.drop)

def drop(self, task: asyncio.Task[None]) -> None:
exc = task.exception()
for queues in self._subscribers.values():
for queue in queues:
queue.put_nowait(exc)

async def disconnect(self) -> None:
if self._listener_task.done():
self._listener_task.result()
exc = self._listener_task.exception()
if exc is None:
self._listener_task.result()
else:
self._listener_task.cancel()
await self._backend.disconnect()
Expand All @@ -88,7 +97,7 @@ async def publish(self, channel: str, message: Any) -> None:

@asynccontextmanager
async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
queue: asyncio.Queue[Event | None] = asyncio.Queue()
queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue()

try:
if not self._subscribers.get(channel):
Expand All @@ -107,7 +116,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:


class Subscriber:
def __init__(self, queue: asyncio.Queue[Event | None]) -> None:
def __init__(self, queue: asyncio.Queue[Event | BaseException | None]) -> None:
self._queue = queue

async def __aiter__(self) -> AsyncGenerator[Event | None, None]:
Expand All @@ -119,6 +128,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]:

async def get(self) -> Event:
item = await self._queue.get()
if isinstance(item, BaseException):
raise item
if item is None:
raise Unsubscribed()
return item
17 changes: 17 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

import pytest
import redis

from broadcaster import Broadcast, BroadcastBackend, Event
from broadcaster._backends.kafka import KafkaBackend
Expand Down Expand Up @@ -56,6 +57,22 @@ async def test_redis():
assert event.message == "hello"


@pytest.mark.asyncio
async def test_redis_disconnect():
with pytest.raises(redis.ConnectionError) as exc:
async with Broadcast("redis://localhost:6379") as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
await broadcast._backend._conn.connection_pool.aclose() # type: ignore[attr-defined]
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"
await subscriber.get()
assert False

assert exc.value.args == ("Connection closed by server.",)


@pytest.mark.asyncio
async def test_redis_stream():
async with Broadcast("redis-stream://localhost:6379") as broadcast:
Expand Down

0 comments on commit c4dbea7

Please sign in to comment.