From c28f6c52627732110e4063a3f8d6786d60bde871 Mon Sep 17 00:00:00 2001 From: Rui Catarino <55796280+ruitcatarino@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:02:10 +0000 Subject: [PATCH] Replace RuntimeError with WebSocketDisconnected --- starlette/websockets.py | 12 ++++++++++-- tests/test_websockets.py | 6 +++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/starlette/websockets.py b/starlette/websockets.py index b7acaa3f0..fd4050d5b 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -22,6 +22,14 @@ def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.reason = reason or "" +class WebSocketDisconnected(RuntimeError): + """ + Raised when attempting to use a disconnected WebSocket. + """ + + pass + + class WebSocket(HTTPConnection): def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: super().__init__(scope) @@ -53,7 +61,7 @@ async def receive(self) -> Message: self.client_state = WebSocketState.DISCONNECTED return message else: - raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') + raise WebSocketDisconnected('Cannot call "receive" once a disconnect message has been received.') async def send(self, message: Message) -> None: """ @@ -94,7 +102,7 @@ async def send(self, message: Message) -> None: self.application_state = WebSocketState.DISCONNECTED await self._send(message) else: - raise RuntimeError('Cannot call "send" once a close message has been sent.') + raise WebSocketDisconnected('Cannot call "send" once a close message has been sent.') async def accept( self, diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 8ecf6304c..b8c1ad23b 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -9,7 +9,7 @@ from starlette.responses import Response from starlette.testclient import WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send -from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketDisconnected, WebSocketState from tests.types import TestClientFactory @@ -448,7 +448,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.close() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnected): with client.websocket_connect("/"): pass # pragma: no cover @@ -462,7 +462,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: message = await websocket.receive() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnected): with client.websocket_connect("/") as websocket: websocket.close()