Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to specify the buffer size. #186

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def test_client_open_url_options(open_websocket_mock):
'extra_headers': [(b'X-Test-Header', b'My test header')],
'message_queue_size': 9,
'max_message_size': 333,
'receive_buffer_size': 999,
'connect_timeout': 36,
'disconnect_timeout': 37,
}
Expand Down
126 changes: 88 additions & 38 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,19 @@ def __exit__(self, ty, value, tb):

@asynccontextmanager
async def open_websocket(
host: str,
port: int,
resource: str,
*,
use_ssl: Union[bool, ssl.SSLContext],
subprotocols: Optional[Iterable[str]] = None,
extra_headers: Optional[list[tuple[bytes,bytes]]] = None,
message_queue_size: int = MESSAGE_QUEUE_SIZE,
max_message_size: int = MAX_MESSAGE_SIZE,
connect_timeout: float = CONN_TIMEOUT,
disconnect_timeout: float = CONN_TIMEOUT
):
host: str,
port: int,
resource: str,
*,
use_ssl: Union[bool, ssl.SSLContext],
subprotocols: Optional[Iterable[str]] = None,
extra_headers: Optional[list[tuple[bytes,bytes]]] = None,
message_queue_size: int = MESSAGE_QUEUE_SIZE,
max_message_size: int = MAX_MESSAGE_SIZE,
receive_buffer_size: int = RECEIVE_BYTES,
connect_timeout: float = CONN_TIMEOUT,
disconnect_timeout: float = CONN_TIMEOUT
):
Comment on lines +97 to +109
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to avoid the reformatting in this PR

'''
Open a WebSocket client connection to a host.

Expand All @@ -130,6 +131,9 @@ async def open_websocket(
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:param float connect_timeout: The number of seconds to wait for the
connection before timing out.
:param float disconnect_timeout: The number of seconds to wait when closing
Expand Down Expand Up @@ -168,7 +172,8 @@ async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
resource, use_ssl=use_ssl, subprotocols=subprotocols,
extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size)
max_message_size=max_message_size,
receive_buffer_size=receive_buffer_size)
except trio.TooSlowError:
raise ConnectionTimeout from None
except OSError as e:
Expand Down Expand Up @@ -287,9 +292,12 @@ def _raise(exc: BaseException) -> NoReturn:


async def connect_websocket(nursery, host, port, resource, *, use_ssl,
subprotocols=None, extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE
) -> WebSocketConnection:
subprotocols=None,
extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE,
max_message_size=MAX_MESSAGE_SIZE,
receive_buffer_size=RECEIVE_BYTES,
) -> WebSocketConnection:
'''
Return an open WebSocket client connection to a host.

Expand Down Expand Up @@ -317,6 +325,9 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl,
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:rtype: WebSocketConnection
'''
if use_ssl is True:
Expand Down Expand Up @@ -346,7 +357,8 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl,
path=resource,
client_subprotocols=subprotocols, client_extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size)
max_message_size=max_message_size,
receive_buffer_size=receive_buffer_size)
nursery.start_soon(connection._reader_task)
await connection._open_handshake.wait()
return connection
Expand All @@ -355,6 +367,7 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl,
def open_websocket_url(url, ssl_context=None, *, subprotocols=None,
extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE,
receive_buffer_size=RECEIVE_BYTES,
connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT):
'''
Open a WebSocket client connection to a URL.
Expand All @@ -378,6 +391,9 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None,
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:param float connect_timeout: The number of seconds to wait for the
connection before timing out.
:param float disconnect_timeout: The number of seconds to wait when closing
Expand All @@ -391,12 +407,14 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None,
subprotocols=subprotocols, extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size,
receive_buffer_size=receive_buffer_size,
connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout)


async def connect_websocket_url(nursery, url, ssl_context=None, *,
subprotocols=None, extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE):
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE,
receive_buffer_size=RECEIVE_BYTES):
'''
Return an open WebSocket client connection to a URL.

Expand All @@ -421,13 +439,17 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *,
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:rtype: WebSocketConnection
'''
host, port, resource, ssl_context = _url_to_host(url, ssl_context)
return await connect_websocket(nursery, host, port, resource,
use_ssl=ssl_context, subprotocols=subprotocols,
extra_headers=extra_headers, message_queue_size=message_queue_size,
max_message_size=max_message_size)
max_message_size=max_message_size,
receive_buffer_size=receive_buffer_size)


def _url_to_host(url, ssl_context):
Expand Down Expand Up @@ -468,7 +490,8 @@ def _url_to_host(url, ssl_context):

async def wrap_client_stream(nursery, stream, host, resource, *,
subprotocols=None, extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE):
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE,
receive_buffer_size=RECEIVE_BYTES):
'''
Wrap an arbitrary stream in a WebSocket connection.

Expand All @@ -492,21 +515,26 @@ async def wrap_client_stream(nursery, stream, host, resource, *,
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:rtype: WebSocketConnection
'''
connection = WebSocketConnection(stream,
WSConnection(ConnectionType.CLIENT),
host=host, path=resource,
client_subprotocols=subprotocols, client_extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size)
max_message_size=max_message_size,
receive_buffer_size=receive_buffer_size)
nursery.start_soon(connection._reader_task)
await connection._open_handshake.wait()
return connection


async def wrap_server_stream(nursery, stream,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE):
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE,
receive_buffer_size=RECEIVE_BYTES):
'''
Wrap an arbitrary stream in a server-side WebSocket.

Expand All @@ -520,21 +548,26 @@ async def wrap_server_stream(nursery, stream,
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:type stream: trio.abc.Stream
:rtype: WebSocketRequest
'''
connection = WebSocketConnection(stream,
WSConnection(ConnectionType.SERVER),
message_queue_size=message_queue_size,
max_message_size=max_message_size)
max_message_size=max_message_size,
receive_buffer_size=receive_buffer_size)
nursery.start_soon(connection._reader_task)
request = await connection._get_request()
return request


async def serve_websocket(handler, host, port, ssl_context, *,
handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE,
max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT,
max_message_size=MAX_MESSAGE_SIZE, receive_buffer_size=RECEIVE_BYTES,
connect_timeout=CONN_TIMEOUT,
disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED):
'''
Serve a WebSocket over TCP.
Expand Down Expand Up @@ -564,6 +597,9 @@ async def serve_websocket(handler, host, port, ssl_context, *,
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:param float connect_timeout: The number of seconds to wait for a client
to finish connection handshake before timing out.
:param float disconnect_timeout: The number of seconds to wait for a client
Expand All @@ -579,7 +615,9 @@ async def serve_websocket(handler, host, port, ssl_context, *,
listeners = await open_tcp_listeners()
server = WebSocketServer(handler, listeners,
handler_nursery=handler_nursery, message_queue_size=message_queue_size,
max_message_size=max_message_size, connect_timeout=connect_timeout,
max_message_size=max_message_size,
receive_buffer_size=receive_buffer_size,
connect_timeout=connect_timeout,
disconnect_timeout=disconnect_timeout)
await server.run(task_status=task_status)

Expand Down Expand Up @@ -837,16 +875,18 @@ class WebSocketConnection(trio.abc.AsyncResource):
CONNECTION_ID = itertools.count()

def __init__(
self,
stream: trio.SocketStream | trio.SSLStream[trio.SocketStream],
ws_connection: wsproto.WSConnection,
*,
host=None,
path=None,
client_subprotocols=None, client_extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE,
max_message_size=MAX_MESSAGE_SIZE
):
self,
stream: trio.SocketStream | trio.SSLStream[trio.SocketStream],
ws_connection: wsproto.WSConnection,
*,
host=None,
path=None,
client_subprotocols=None,
client_extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE,
max_message_size=MAX_MESSAGE_SIZE,
receive_buffer_size=RECEIVE_BYTES,
) -> None:
'''
Constructor.

Expand All @@ -872,6 +912,9 @@ def __init__(
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
'''
# NOTE: The implementation uses _close_reason for more than an advisory
# purpose. It's critical internal state, indicating when the
Expand All @@ -884,6 +927,7 @@ def __init__(
self._message_size = 0
self._message_parts: List[Union[bytes, str]] = []
self._max_message_size = max_message_size
self._receive_buffer_size: Optional[int] = receive_buffer_size
self._reader_running = True
if ws_connection.client:
self._initial_request: Optional[Request] = Request(host=host, target=path,
Expand Down Expand Up @@ -1392,7 +1436,7 @@ async def _reader_task(self):

# Get network data.
try:
data = await self._stream.receive_some(RECEIVE_BYTES)
data = await self._stream.receive_some(self._receive_buffer_size)
except (trio.BrokenResourceError, trio.ClosedResourceError):
await self._abort_web_socket()
break
Expand Down Expand Up @@ -1476,7 +1520,8 @@ class WebSocketServer:

def __init__(self, handler, listeners, *, handler_nursery=None,
message_queue_size=MESSAGE_QUEUE_SIZE,
max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT,
max_message_size=MAX_MESSAGE_SIZE, receive_buffer_size=RECEIVE_BYTES,
connect_timeout=CONN_TIMEOUT,
disconnect_timeout=CONN_TIMEOUT):
'''
Constructor.
Expand All @@ -1493,6 +1538,9 @@ def __init__(self, handler, listeners, *, handler_nursery=None,
:param handler_nursery: An optional nursery to spawn connection tasks
inside of. If ``None``, then a new nursery will be created
internally.
:param Optional[int] receive_buffer_size: The buffer size we use to
receive messages internally. None to let trio choose. Defaults
to 4 KiB.
:param float connect_timeout: The number of seconds to wait for a client
to finish connection handshake before timing out.
:param float disconnect_timeout: The number of seconds to wait for a client
Expand All @@ -1505,6 +1553,7 @@ def __init__(self, handler, listeners, *, handler_nursery=None,
self._listeners = listeners
self._message_queue_size = message_queue_size
self._max_message_size = max_message_size
self._receive_buffer_size = receive_buffer_size
self._connect_timeout = connect_timeout
self._disconnect_timeout = disconnect_timeout

Expand Down Expand Up @@ -1587,7 +1636,8 @@ async def _handle_connection(self, stream):
connection = WebSocketConnection(stream,
WSConnection(ConnectionType.SERVER),
message_queue_size=self._message_queue_size,
max_message_size=self._max_message_size)
max_message_size=self._max_message_size,
receive_buffer_size=self._receive_buffer_size)
nursery.start_soon(connection._reader_task)
with trio.move_on_after(self._connect_timeout) as connect_scope:
request = await connection._get_request()
Expand Down
Loading